目标
我想修改这个教程来使用这个时间序列数据集,而不是普通的图像数据。
方法
我已经确定了几种将数据输入tf.estimator
API的方法。最方便的(因为使用.from_generator
是猜测…(是如下使用tf.data.Dataset.from_tensor_slices(training_data_ndarray)
:
trnX, trnY, tstX, tstY = load_dataset()
trnXl = trnX.tolist()
tstXl = tstX.tolist()
tstYl = tstY.tolist()
trnYl = trnY.tolist()
trndataset = tf.data.Dataset.from_tensor_slices((trnXl, trnYl))
tstdataset = tf.data.Dataset.from_tensor_slices((tstXl, tstYl))
...
def _input_fn(partition):
if partition == "train":
dst = trndataset
elif partition == "predict":
dst = tstdataset
else:
dst = tstdataset
return dst
错误/问题
TypeError:
input_fn
必须是可调用的,给定:DatasetV1Adapter形状:((128,9(,(6,((,类型:(tf.float32,tf.floot32(>
复制
我使用paperspace实例。如果您有帐户,可以在此处查看。
- 如果没有,请获取以下要点中的完整代码:
- 来自2的数据集
设置:
- Tensorflow 1.15
- Python 3.6.8
- 急切的执行:关闭(adanet库无法处理(
调用时出现错误:
tf.estimator.train_and_evaluate(
estimator,
train_spec=tf.estimator.TrainSpec(
input_fn=_input_fn("train"),
max_steps=TRAIN_STEPS),
eval_spec=tf.estimator.EvalSpec(
input_fn=_input_fn("test"),
steps=None,
start_delay_secs=1,
throttle_secs=1,
))
现在怎么办
使用创建tf.data.Dataset的方法,我看不到前进的道路,因为输入函数("创建返回数据集的输入函数"(似乎不正确。
这可能会在未来对某人有所帮助,但并不能解决我更深层次的问题。
为了使input_fn可调用但仍然接受参数,有两个选项:将函数调用封装在lambda运算符中,或者将实际input_fn封装在另一个返回内部函数的input_fn中
classifier.train(
input_fn=lambda: input_fn(train, train_y, training=True),
steps=5000)
请注意,您将input_fn调用封装在lambda中以捕获参数,同时提供一个不接受参数的输入函数,正如Estimator 所期望的那样
来源:https://www.tensorflow.org/tutorials/estimator/premade