这是使用Petastorm训练mnist数据的代码。
def train_and_test(dataset_url, training_iterations, batch_size, evaluation_interval):
with make_reader(os.path.join(dataset_url, 'train'), num_epochs=None) as train_reader:
with make_reader(os.path.join(dataset_url, 'test'), num_epochs=None) as test_reader:
train_readout = tf_tensors(train_reader)
train_image = tf.cast(tf.reshape(train_readout.image, [784]), tf.float32)
train_label = train_readout.digit
batch_image, batch_label = tf.train.batch(
[train_image, train_label], batch_size=batch_size
)
我不知道如何更换tf.train.batch
。你能帮忙吗?
您可以将dataset.batch
与tf.data.Dataset
一起使用,petastorm
也支持他们网站上提到的tf.data.Dataset
。
关于用petastorm
实现tf.data.Dataset
的代码,您可以在这里获得
有关dataset.batch
的详细信息,您可以在此处找到它。