如何替换已弃用的tf.train.batch



这是使用Petastorm训练mnist数据的代码。

def train_and_test(dataset_url, training_iterations, batch_size, evaluation_interval):
with make_reader(os.path.join(dataset_url, 'train'), num_epochs=None) as train_reader:
with make_reader(os.path.join(dataset_url, 'test'), num_epochs=None) as test_reader:
train_readout = tf_tensors(train_reader)
train_image = tf.cast(tf.reshape(train_readout.image, [784]), tf.float32)
train_label = train_readout.digit
batch_image, batch_label = tf.train.batch(
[train_image, train_label], batch_size=batch_size
)

我不知道如何更换tf.train.batch。你能帮忙吗?

您可以将dataset.batchtf.data.Dataset一起使用,petastorm也支持他们网站上提到的tf.data.Dataset

关于用petastorm实现tf.data.Dataset的代码,您可以在这里获得
有关dataset.batch的详细信息,您可以在此处找到它。

最新更新