如何动态编辑外部.config文件



我正在开发一个带有tensorflow对象检测api的主动机器学习管道。我的目标是动态地更改网络的.config文件中的路径。

标准配置如下:

train_input_reader: {
tf_record_input_reader {
input_path: "/PATH_TO_CONFIGURE/train.record"
}
label_map_path: "/PATH_TO_CONFIGURE/label_map.pbtxt"
}

"PATH_TO_CONFIGURE"应该从我的jupyter笔记本单元格中动态替换。

对象检测API配置文件采用protobuf格式。以下是阅读、编辑和保存它们的大致方法。

import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2
pipeline = pipeline_pb2.TrainEvalPipelineConfig()                                                                                                                                                                                                          
with tf.gfile.GFile('config path', "r") as f:                                                                                                                                                                                                                     
proto_str = f.read()                                                                                                                                                                                                                                          
text_format.Merge(proto_str, pipeline)
pipeline.train_input_reader.tf_record_input_reader.input_path[:] = ['your new entry'] # it's a repeated field 
pipeline.train_input_reader.label_map_path = 'your new entry'
config_text = text_format.MessageToString(pipeline)                                                                                                                                                                                                        
with tf.gfile.Open('config path', "wb") as f:                                                                                                                                                                                                                       
f.write(config_text)

您将不得不调整代码,但总体理念应该是明确的。我建议将其重构为函数并调用Jupyter。

以下是TensorFlow 2对我有效的方法(API从tf.gfile.GFile略微更改为tf.io.gfile.GFile(:

import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2
def read_config():
pipeline = pipeline_pb2.TrainEvalPipelineConfig()                                                                                                                                                                                                          
with tf.io.gfile.GFile('pipeline.config', "r") as f:                                                                                                                                                                                                                     
proto_str = f.read()                                                                                                                                                                                                                                          
text_format.Merge(proto_str, pipeline)
return pipeline
def write_config(pipeline):
config_text = text_format.MessageToString(pipeline)                                                                                                                                                                                                        
with tf.io.gfile.GFile('pipeline.config', "wb") as f:                                                                                                                                                                                                                       
f.write(config_text)
def modify_config(pipeline):
pipeline.model.ssd.num_classes = 1
pipeline.train_config.fine_tune_checkpoint_type = 'detection'
pipeline.train_input_reader.label_map_path = 'label_map.pbtxt'
pipeline.train_input_reader.tf_record_input_reader.input_path[0] = 'train.record'
pipeline.eval_input_reader[0].label_map_path = 'label_map.pbtxt'
pipeline.eval_input_reader[0].tf_record_input_reader.input_path[0] = 'test.record'
return pipeline

def setup_pipeline():
pipeline = read_config()
pipeline = modify_config(pipeline)
write_config(pipeline)
setup_pipeline()

最新更新