是否可以使用数据集/迭代器在单个张量流操作中遍历所有小批量?



我正在使用tf.data.dataset/iterator机制并尝试提高数据加载性能。 我突然想到,从 Python 卸载整个小批量循环可能会有所帮助。 我的数据足够小,存储在 CPU 或 GPU 上没有问题。

那么,是否可以在对session.run的调用中将优化器节点循环到完整的小批量纪元上?

iterator.get_next()返回的张量每session.run只增加一次,这似乎使得无法遍历小批量数据集......但如果可以做到,我的 CPU 每个纪元只需要接触一次 Python 线程。

更新:@muskrat建议使用tf.slice可用于此目的。 请参阅我随后的非答案,并使用tf.while_loop对此进行示意图实现。 但是,问题是这是否可以使用数据集/迭代器来完成......我仍然想知道。

从描述来看,您似乎已经将数据集预加载为 CPU/GPU 上的常量,就像本例所示。 这当然是第一步。

其次,我建议使用tf.slice()来复制小批量操作的效果。 换句话说,只需手动从预加载常量(数据集)中切片小批量,即可获得所需的行为。 例如,请参阅切片文档或此相关文章。

如果这还不够详细,请编辑您的问题以包含一个代码示例(带有 mnist 或其他东西),我可以提供更多细节。

这个"答案"是Muskrattf.slice建议的实现,其中包含tf.while_loop制定的细节(在如何在tensorflow和 https://www.tensorflow.org/api_docs/python/tf/while_loop 中使用tf.while_loop()的帮助下)。

除非你的数据和模型足够小,以至于你受到Python I/O的瓶颈(像我一样!),否则这个解决方案可能是学术性的。

优势:

  • 在不返回到 Python 线程的情况下通过小批量进行训练。
  • 使用具有 GPU 实现的操作,这意味着整个图形可以放置在 GPU 中。
  • 在我的小数据集上,这可能是Python I/O的瓶颈,这个解决方案的速度是我的数据集/迭代器(每个小批量接触Python一次)的两倍,是传递小批量feed_dict速度的四倍。

弊:

  • tf.while_loop奸诈的。 了解何时评估循环体内的操作以及何时评估它们所依赖的操作是具有挑战性的,尤其是(瘦)官方文档和有限的堆栈溢出覆盖范围。
  • tf.while_loop缺失的文档是,循环主体外部的张量只被计算一次,即使内部操作依赖于它们。这意味着必须在循环中定义优化、模型和损失。 如果您希望能够在训练时期之间调用验证损失操作,这会限制灵活性。 据推测,这可以通过tf.cond语句和通过feed_dict传入的适当标志来实现。 但不像tf.data中的数据集/迭代器机制那样灵活或优雅。
  • 在每个纪元添加洗牌操作似乎在 GPU 上不可用。

这是我的原理图代码(为了简洁起见,我省略了变量和模型定义):

def buildModel(info, training_data, training_targets):
graph = tf.Graph()
with graph.as_default():
# numBatches is passed in from Python once per Epoch.
batch_size = tf.placeholder(tf.float32, name = 'batch_size')
# Initializers for loop variables for tf.while_loop
batchCounter = tf.Variable(0, dtype=tf.float32, trainable=False)
lossList =  tf.Variable(tf.zeros([0,1]), trainable=False)
# In a full example, I'd normalize my data here.  And possibly shuffle 
tf_training_data     =  tf.constant(training_data,    dtype=tf.float32)
tf_training_targets  =  tf.constant(training_targets, dtype=tf.float32)  
# For brevity, I'll spare the definitions of my variables.  Because tf.Variables
# are essentially treated as globals in the model and are manipulated directly (like with tf.apply)
# they can reside outside runMinibatch, the body of tf.while_loop.
# weights_1 =
# biases_1  = 
# etc.
def moreMinibatches(batchCount, lossList):
return (batchCount + 1) * batch_size <= len(training_data)
def runMinibatch(batchCount, lossList):
# These tensors and ops have to be defined inside runMinibatch, otherwise they're not updated as tf.wile_loop loops.  This means
# slices, model definition, loss tensor, and training op.
dat_batch  = tf.slice(tf_training_data,    [tf.cast(batchCounter * batch_size, tf.int32) , 0], [tf.cast(batch_size, tf.int32), -1])
targ_batch = tf.slice(tf_training_targets, [tf.cast(batchCounter * batch_size, tf.int32) , 0], [tf.cast(batch_size, tf.int32), -1])
# Here's where you'd define the model as a function of weights and biases above and dat_batch
# model = <insert here>
loss         = tf.reduce_mean(tf.squared_difference(model, targ_batch))
optimizer    = tf.train.AdagradOptimizer() # for example
train_op = optimizer.minimize(while_loss, name='optimizer')
# control_dependences ensures that train_op is run before return
# even though the return values don't explicitly depend on it.  
with tf.control_dependencies([train_op]):
return batchCount + 1,  tf.concat([lossList, [[while_loss]]],0)
# So, the idea is that this trains a full epoch without returning to Python.
trainMinibatches = tf.while_loop(moreMinibatches, runMinibatch, [minibatchCounter, lossList]
shape_invariants=[batchCounter.get_shape(), tf.TensorShape(None)])
return (graph, 
{'trainMinibatches'     : trainAllMinibatches,
'minibatchCounter'      : minibatchCounter,
'norm_loss'             : norm_loss,
} )
numEpochs     = 100 # e.g.
minibatchSize = 32  # 
# training_dataset = <data here>
# training_targets = <targets here>
graph, ops = buildModel(info, training_dataset, training_targets, 
minibatch_size)
with tf.Session(graph=graph, config=config) as session:
tf.global_variables_initializer().run()
for i in range(numEpochs):
# This op will train on as all minibatches that fit in the full dataset. finalBatchCount with be the number of 
# complete minibatches in the dataset.  lossList is a list of each step's minibatches.
finalBatchCount, lossList = session.run(ops['trainAllMinibatches'], 
feed_dict={'batch_size:0':minibatchSize})
print('minibatch losses at Epoch', i, ': ', lossList)

我实现了tf.slice()和上面建议tf.while_loop矢量化小批量的方法。

在我的情况下,性能比使用feed_dict的小批量快约 1.86 倍,但我发现存在一个问题,即每个 epoch 的损失值不稳定。

然后,我改为每个纪元tf.random_shuffle输入,问题得到了很大的缓解。(性能提升降低至1.68倍)

最新更新