TF 对象检测 API 默认获取所有 GPU 内存,因此很难判断我可以进一步增加多少批处理大小。通常我只是继续增加它,直到我收到 CUDA OOM 错误。
另一方面,默认情况下,PyTorch 不会占用所有 GPU 内存,因此很容易看出我还剩下多少百分比可以使用,而无需所有的反复试验。
有没有更好的方法可以使用我缺少的 TF 对象检测 API 确定批量大小?类似于model_main.py
allow-growth
旗的东西?
我一直在查找源代码,但没有发现与此相关的FLAG。
但是,在文件model_main.py
中 https://github.com/tensorflow/models/blob/master/research/object_detection/model_main.py 您可以找到以下主函数定义:
def main(unused_argv):
flags.mark_flag_as_required('model_dir')
flags.mark_flag_as_required('pipeline_config_path')
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs(
run_config=config,
...
这个想法是以类似的方式修改它,例如以下方式:
config_proto = tf.ConfigProto()
config_proto.gpu_options.allow_growth = True
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir, session_config=config_proto)
因此,添加config_proto
和更改config
但保持所有其他条件相同。
此外,allow_growth
使程序根据需要使用尽可能多的 GPU 内存。因此,根据您的GPU,您最终可能会吃掉所有内存。在这种情况下,您可能需要使用
config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
它定义了要使用的内存比例。
希望这有所帮助。
如果您不想修改文件,似乎应该打开一个问题,因为我没有看到任何 FLAG。 除非标志
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
'file.')
意味着与此相关的东西。但我不这么认为,因为从model_lib.py
看来,它与训练、评估和推断配置有关,而不是 GPU 使用配置。