在使用tensorflow创建数据集时,需要返回函数对象



我是机器学习的新手,我正在尝试使用Tensorflow API创建一个机器学习模型,该模型来自这里的Tensorflow文档中的教程但我很难理解代码的这一部分

def make_input_fn(data_df, label_df, num_epochs=10, shuffle=True, batch_size=32):
def input_function():          # inner function, this will be returned
ds = tf.data.Dataset.from_tensor_slices((dict(data_df), label_df))     # create tf.data.Dataset object with data and its label
if shuffle:
ds = ds.shuffle(1000)   # randomize order of data
ds = ds.batch(batch_size).repeat(num_epochs)      # split dataset into batches of 32 and repeat process for number of epochs
return ds  # return a batch of the dataset
return input_function  # return a function object for use

然后将函数的输出存储在变量中

train_input_fn = make_input_fn(dftrain, y_train)

最后用数据集对模型进行训练

linear_est.train(train_input_fn)

当我只是在make_input_function中返回内部函数的函数名,而不是仅仅返回我们的数据集并将其传递给训练模型时,我没有意识到我们试图做什么。

我是Python的初学者,刚刚开始学习机器学习,我找不到合适的答案来回答我的问题,所以如果有人能以初学者友好的方式解释它,我将不胜感激。

当我只是在make_input_function中返回内部函数的函数名,而不是仅仅返回我们的数据集并将其传递给训练模型时,我没有意识到我们试图做什么。

在python编程中,这被称为Currying,它用于通过评估函数参数的增量嵌套,将多参数函数转换为单参数函数。Currying还将一个论点修正为另一个论点,从而在执行时形成相对模式。

在tensorflow中,基于文档(https://www.tensorflow.org/api_docs/python/tf/estimator/LinearClassifier#train)。

train(
input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None
)

估计器的方法序列期望参数CCD_ 1。原因是每次调用Estimator.train()时,它都会通过调用input_fnmodel_fn并将它们连接在一起来创建一个新的图。如果您提供张量或数据集,则会导致不同的错误。

最新更新