使用 TensorFlow Object Detection API 确定最大批处理大小



TF 对象检测 API 默认获取所有 GPU 内存,因此很难判断我可以进一步增加多少批处理大小。通常我只是继续增加它,直到我收到 CUDA OOM 错误。

另一方面,默认情况下,PyTorch 不会占用所有 GPU 内存,因此很容易看出我还剩下多少百分比可以使用,而无需所有的反复试验。

有没有更好的方法可以使用我缺少的 TF 对象检测 API 确定批量大小?类似于model_main.pyallow-growth旗的东西?

我一直在查找源代码,但没有发现与此相关的FLAG。

但是,在文件model_main.py中 https://github.com/tensorflow/models/blob/master/research/object_detection/model_main.py 您可以找到以下主函数定义:

def main(unused_argv):
flags.mark_flag_as_required('model_dir')
flags.mark_flag_as_required('pipeline_config_path')
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs(
run_config=config,
...

这个想法是以类似的方式修改它,例如以下方式:

config_proto = tf.ConfigProto()
config_proto.gpu_options.allow_growth = True
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir, session_config=config_proto)

因此,添加config_proto和更改config但保持所有其他条件相同。

此外,allow_growth使程序根据需要使用尽可能多的 GPU 内存。因此,根据您的GPU,您最终可能会吃掉所有内存。在这种情况下,您可能需要使用

config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9

它定义了要使用的内存比例。

希望这有所帮助。

如果您不想修改文件,似乎应该打开一个问题,因为我没有看到任何 FLAG。 除非标志

flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
'file.')

意味着与此相关的东西。但我不这么认为,因为从model_lib.py看来,它与训练、评估和推断配置有关,而不是 GPU 使用配置。

相关内容

最新更新