Tensorflow:如何使用 tf.train.batch()



我使用Tensorflow(版本1.7.0和Python 3.5(作为神经网络,并且在使用tf.train.batch()函数时遇到问题。看这里。

我的数据集的维度为:

测试图像(100000、900(测试标签 (100000, 10(

所以我有 100000 张大小为 30 x 30 像素的测试图像。标签是大小为 100000 x 10 的单热矩阵。

import numpy as np
train_images = np.load("train_images.npy")
train_labels = np.load("train_labels.npy")

现在我想随机获取一个大小为 100 的批次,并想使用函数 tf.train.batch() .

我在代码中按如下方式使用该函数:

# Launch the graph
with tf.Session() as sess:
    sess.run(init)
    num_examples=100000
    batch_size = 100
    # Training cycle
    for epoch in range(training_epochs):
        total_batch = int(num_examples/batch_size)
        # Loop over all batches
        for i in range(total_batch):
            batch_x, batch_y = tf.train.batch(
                [train_images, train_labels],
                batch_size=batch_size,
                allow_smaller_final_batch=True,
                )
            _, c = sess.run([optimizer, cost], feed_dict={x: batch_x, y:batch_y})

这样做时,我收到以下错误:

    Traceback (most recent call last):
    File "my_network.py", line 124, in <module>
    _, c = sess.run([optimizer, cost], feed_dict={x: batch_x, y:batch_y})
    File "/home/samuel/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 905, in run
    run_metadata_ptr)
    File "/home/samuel/tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1091, in _run
    'feed with key ' + str(feed) + '.')
    TypeError: The value of a feed cannot be a tf.Tensor object. 
    Acceptable feed values include Python scalars, strings, lists, 
    numpy ndarrays, or TensorHandles.For reference, the tensor object was
    Tensor("batch:0", shape=(?, 900), dtype=uint8) which was passed to the
feed with key Tensor("Placeholder:0", shape=(?, 900), dtype=float32).

我该怎么做才能使用tf.train.batch()以使我的网络正常工作?是否需要使用其他方法创建小批量?

tf.train.batch已被弃用,您应该改用tf.data进行批处理。

可是。。。

为了清楚起见,您应该将图形构建和图形运行步骤分开。

然后,当使用tf.train.batch时,在构建图形之后但在运行任何操作之前,您需要运行tf.queue_runner.start_queue_runners((来启动线程,该线程将数据预取为批处理。

相关内容

  • 没有找到相关文章

最新更新