TensorFlow:如何定义dataset.train.next_batch



我正在尝试学习TensorFlow并在以下位置研究示例:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb

然后我在下面的代码中有一些问题:

for epoch in range(training_epochs):
    # Loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        # Run optimization op (backprop) and cost op (to get loss value)
        _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
    # Display logs per epoch step
    if epoch % display_step == 0:
        print("Epoch:", '%04d' % (epoch+1),
              "cost=", "{:.9f}".format(c))

既然 mnist 只是一个数据集,mnist.train.next_batch到底是什么意思?dataset.train.next_batch是如何定义的?

谢谢!

mnist 对象从 tf.contrib.learn 模块中定义的 read_data_sets() 函数返回。这里实现了 mnist.train.next_batch(batch_size) 方法,它返回一个由两个数组组成的元组,其中第一个表示一批 batch_size MNIST 图像,第二个表示与这些图像对应的一批batch-size标签。

图像作为大小为 [batch_size, 784] 的 2-D NumPy 数组返回(

因为 MNIST 图像中有 784 像素),标签作为大小为 [batch_size] 的一维 NumPy 数组返回(如果使用 one_hot=False 调用read_data_sets())或大小为 [batch_size, 10] 的 2-D NumPy 数组(如果使用 one_hot=True 调用read_data_sets())。

相关内容

  • 没有找到相关文章

最新更新