TensorFlow:在使用监控培训课程时验证模型



我正在使用数据集API导入训练和验证数据。我有TF 1.2。所以我只能使用可重新初始化的迭代器,而不能使用可馈送迭代器。

1) 如果我们只想训练网络,我们可以简单地使用监控训练课程。但是,当我们想在训练中进行验证时,我们应该如何做到这一点?我们应该放弃监控培训课程并使用低级别课程吗?

train_dataset = tf.contrib.data.TFRecordDataset([FLAGS.data_dir + "train.tfrecords"])
train_dataset = train_dataset.map(_parse_records)
train_dataset.shuffle(buffer_size=1000)
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.batch(FLAGS.batch_size)
validation_dataset = tf.contrib.data.TFRecordDataset([FLAGS.data_dir + "test.tfrecords"])
validation_dataset = test_dataset.map(_parse_records)
validation_dataset = test_dataset.batch(FLAGS.batch_size)
iterator = tf.contrib.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
train_init_op = iterator.make_initializer(train_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)
training_op = tf.train.AdagradOptimizer(...).minimize(loss)
with tf.train.MonitoredTrainingSession(...) as sess:
sess.run(train_init_op)
while not sess.should_stop():
sess.run(training_op)
# HOW TO VALIDATE?

2) 是否有任何方法可以使用Reinitializable迭代器在epoch中间验证模型,因为当我们在迭代器之间切换时,它需要从数据集的一开始就初始化迭代器。通过Reinitializable迭代器,这可能吗?或者我们必须切换到可馈送迭代器才能做到这一点?

这是TF数据集教程中提供的示例。在这里,如果一个历元中可能有100次迭代,我们可以使用Reinitializable迭代器在第50次迭代时验证模型吗?(我认为使用可馈送迭代器是可能的)

# Run 20 epochs in which the training dataset is traversed, followed by the validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
sess.run(next_element)

3) 在使用Reinitializable迭代器时,在epoch的最后一次迭代中,如果剩余的训练数据样本小于所需的批量大小,会发生什么?剩余的少数样品是在减少批量的情况下使用还是被忽略?

对于您的问题3,我认为TensorFlow表现不佳。对于最后一批,它可能有较少的样本数量。在训练过程中,这经常(总是?)会导致"形状不兼容"的错误。请参阅https://stackoverflow.com/a/48331954/2184122关于如何解决TensorFlow 1.4 的问题

请看一下如何使用tf.MonitoredTrainingSession在训练和验证数据集之间切换?我想你会找到答案的。你可以使用feed_ict来更改要评估的数据集,或者只是重新初始化它

...
training_iterator = training_ds.make_initializable_iterator()
validation_iterator = validation_ds.make_initializable_iterator()
...
sess.run(next_element, feed_dict={handle: training_handle})
...
sess.run(next_element, feed_dict={handle: validation_iterator })

最新更新