张量流:从 2 GB 的 numpy 数组创建小批量>



我正在尝试将小批量的numpy数组馈送到我的模型中,但我坚持批处理。使用"tf.train.shuffle_batch"引发错误,因为"images"数组大于 2 GB。我试图绕过它并创建占位符,但是当我尝试馈送数组时,它们仍然由 tf 表示。张量对象。我主要关心的是我在模型类下定义了操作,并且在运行会话之前不会调用对象。有没有人知道如何处理这个问题?

def main(mode, steps):
  config = Configuration(mode, steps)

  if config.TRAIN_MODE:
      images, labels = read_data(config.simID)
      assert images.shape[0] == labels.shape[0]
      images_placeholder = tf.placeholder(images.dtype,
                                                images.shape)
      labels_placeholder = tf.placeholder(labels.dtype,
                                                labels.shape)
      dataset = tf.data.Dataset.from_tensor_slices(
                (images_placeholder, labels_placeholder))
      # shuffle
      dataset = dataset.shuffle(buffer_size=1000)
      # batch
      dataset = dataset.batch(batch_size=config.batch_size)
      iterator = dataset.make_initializable_iterator()
      image, label = iterator.get_next()
      model = Model(config, image, label)
      with tf.Session() as sess:
          sess.run(tf.global_variables_initializer())
          sess.run(iterator.initializer, 
                   feed_dict={images_placeholder: images,
                          labels_placeholder: labels})
          # ...
          for step in xrange(steps):
              sess.run(model.optimize)

您正在使用 tf.Data 的可初始化迭代器将数据馈送到您的模型。这意味着可以根据占位符参数化数据集,然后为迭代器调用初始值设定项 op 以准备使用它。

如果使用可初始化迭代器或 tf.Data 中的任何其他迭代器将输入馈送到模型,则不应使用 sess.runfeed_dict参数来尝试执行数据馈送。相反,请根据iterator.get_next()的输出定义模型,并省略sess.run中的feed_dict

大致如下:

iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()
# use get_next outputs to define model
model = Model(config, image_batch, label_batch) 
# placeholders fed in while initializing the iterator
sess.run(iterator.initializer, 
            feed_dict={images_placeholder: images,
                       labels_placeholder: labels})
for step in xrange(steps):
     # iterator will feed image and label in the background
     sess.run(model.optimize) 

迭代器将在后台将数据馈送到您的模型,不需要通过feed_dict进行额外的馈送。

相关内容

  • 没有找到相关文章

最新更新