手动获取下一批或将相同的批次与 TensorFlow Data API 一起使用



我正在尝试使用tf。数据 API 可以加速我的代码并防止 GPU 数据匮乏,但有一件事阻止了我对它的适应,它是在多次调用训练操作时使用相同的批次的能力。

假设我将数据集设置为

dataset = tf.data.TextLineDataset("textfile.txt")
dataset = dataset.shuffle(dataset_size)
dataset = dataset.padded_batch(batch_size, ...)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
x_batch = iterator.get_next()
loss1 = someFunctionOf(x_batch)
loss2 = someOtherFunctionOf(x_batch)
train_op1 = someOptimizerOf(loss1)
train_op2 = someOtherOptimizerOf(loss2)

但是现在每当我打电话给train_op1iterator.get_next()就会被召唤,所以当打电话给train_op2时,我正在训练下一批。

从这个问题中,我知道我可以使用flat_maprepeat(n)的组合,其中n是我想重复同一批次的次数,但这个n取决于我调用的train_ops数量,我必须手动计数。另外,我需要这两个train_ops因为它们优化了图形的不同部分。

感谢您的帮助!

试试下面的代码。它创建输入和目标的副本,因此希望当您切换优化器/loss_op时它们不会更改。只要不传递is_new:True标志,它们在sess.run调用之间是持久的。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf

def ds_train(batch_size, num_epochs):  
ds = (tf.data.Dataset.from_tensor_slices(([1.0,2.0,3.0,4.0,5.0], [-1,-2,-3,-4,-5]))
.batch(batch_size)
.repeat(num_epochs)        
)
return ds

batch_size = 1
input_size = 1
num_epochs = 2
with tf.variable_scope("dataset"):       
ds_t = ds_train(batch_size, num_epochs)
with tf.variable_scope("iterator"):
iterator_t = ds_t.make_initializable_iterator()
iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
iterator = tf.data.Iterator.from_string_handle(iterator_handle, 
iterator_t.output_types,
iterator_t.output_shapes)
def next_item():
next_elem = iterator.get_next(name="next_element")
x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]# tf.cast(next_elem[1], tf.int32)
return x, y        

inputs = tf.Variable(tf.zeros(shape=[batch_size,input_size]), dtype=tf.float32, name="inputs", trainable=False, use_resource=True)
target = tf.Variable(tf.zeros(shape=[batch_size], dtype=tf.int32), dtype=tf.int32, name="target", trainable=False,use_resource=True)
is_new = tf.placeholder_with_default(tf.constant(False), shape=[], name="new_item_flag")
def new_data(batch_size, input_size):
# run the data layer to generate a new batch
next_inputs, next_target = next_item()
next_inputs = tf.reshape(next_inputs, shape=[batch_size, input_size])
with tf.control_dependencies([tf.assign(inputs, next_inputs), tf.assign(target, next_target)]):
return tf.identity(inputs), tf.identity(target)
def old_data():
# just forward the existing batch
return inputs, target
next_inputs, next_target = next_item()
inputs, target =  tf.cond(is_new, lambda:new_data(batch_size, input_size), old_data)
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
handle_t = sess.run(iterator_t.string_handle())
sess.run(iterator_t.initializer)
while True:
try:
print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: True}))
except tf.errors.OutOfRangeError:
print("End of training dataset.")
break        

相关内容

  • 没有找到相关文章

最新更新