我正在尝试使用tf。数据 API 可以加速我的代码并防止 GPU 数据匮乏,但有一件事阻止了我对它的适应,它是在多次调用训练操作时使用相同的批次的能力。
假设我将数据集设置为
dataset = tf.data.TextLineDataset("textfile.txt")
dataset = dataset.shuffle(dataset_size)
dataset = dataset.padded_batch(batch_size, ...)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
x_batch = iterator.get_next()
loss1 = someFunctionOf(x_batch)
loss2 = someOtherFunctionOf(x_batch)
train_op1 = someOptimizerOf(loss1)
train_op2 = someOtherOptimizerOf(loss2)
但是现在每当我打电话给train_op1
,iterator.get_next()
就会被召唤,所以当打电话给train_op2
时,我正在训练下一批。
从这个问题中,我知道我可以使用flat_map
和repeat(n)
的组合,其中n
是我想重复同一批次的次数,但这个n
取决于我调用的train_ops
数量,我必须手动计数。另外,我需要这两个train_ops
因为它们优化了图形的不同部分。
感谢您的帮助!
试试下面的代码。它创建输入和目标的副本,因此希望当您切换优化器/loss_op时它们不会更改。只要不传递is_new:True
标志,它们在sess.run
调用之间是持久的。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def ds_train(batch_size, num_epochs):
ds = (tf.data.Dataset.from_tensor_slices(([1.0,2.0,3.0,4.0,5.0], [-1,-2,-3,-4,-5]))
.batch(batch_size)
.repeat(num_epochs)
)
return ds
batch_size = 1
input_size = 1
num_epochs = 2
with tf.variable_scope("dataset"):
ds_t = ds_train(batch_size, num_epochs)
with tf.variable_scope("iterator"):
iterator_t = ds_t.make_initializable_iterator()
iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
iterator = tf.data.Iterator.from_string_handle(iterator_handle,
iterator_t.output_types,
iterator_t.output_shapes)
def next_item():
next_elem = iterator.get_next(name="next_element")
x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]# tf.cast(next_elem[1], tf.int32)
return x, y
inputs = tf.Variable(tf.zeros(shape=[batch_size,input_size]), dtype=tf.float32, name="inputs", trainable=False, use_resource=True)
target = tf.Variable(tf.zeros(shape=[batch_size], dtype=tf.int32), dtype=tf.int32, name="target", trainable=False,use_resource=True)
is_new = tf.placeholder_with_default(tf.constant(False), shape=[], name="new_item_flag")
def new_data(batch_size, input_size):
# run the data layer to generate a new batch
next_inputs, next_target = next_item()
next_inputs = tf.reshape(next_inputs, shape=[batch_size, input_size])
with tf.control_dependencies([tf.assign(inputs, next_inputs), tf.assign(target, next_target)]):
return tf.identity(inputs), tf.identity(target)
def old_data():
# just forward the existing batch
return inputs, target
next_inputs, next_target = next_item()
inputs, target = tf.cond(is_new, lambda:new_data(batch_size, input_size), old_data)
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
handle_t = sess.run(iterator_t.string_handle())
sess.run(iterator_t.initializer)
while True:
try:
print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: True}))
except tf.errors.OutOfRangeError:
print("End of training dataset.")
break