我有一个tfrecords文件,不幸的是其中没有标签值。它有两个值:图像和Id
因此,为了获得标签,我需要查看pandas DataFrame中的Id来驱动其值,然后根据其值创建标签,例如:
if df[df['id'] == Id]['value'] > threshold_value:
label = 1
else:
label = 0
但是,我不知道如何将张量("ParseSingleExample/ParseExample/PaseExampleV2:1",shape=((,dtype=string(转换为python字符串。
我在这里复制了解析tfrecords的代码:
def parse_tf_records(example_input):
feature_description_dict = {
IMAGE_FIELD: tf.io.FixedLenFeature(IMAGE_SIZE, tf.float32),
ID_FIELD: tf.io.FixedLenFeature([], tf.string)
}
parsed_example = tf.io.parse_single_example(example_input, feature_description_dict)
return parsed_example
和
def read_tfrecord(example_input):
parsed_example = parse_tf_records(example_input)
image_data = parsed_example[IMAGE_FIELD]
id_data = parsed_example[ID_FIELD]
# label = Look for the value of id_data in a Pandas Dataframe and compare the value to threshold_value
label_data = tf.cast(label, tf.int32)
return image_data , label_data
我使用tensorflow 2.4.1。如果有人能帮我,我真的很感激。谢谢
好的,tf.py_function就是答案。这是我的代码,它运行得很好:
def get_label(tf_id):
_id = tf_id.numpy().decode('utf-8')
if df[df['id'] == _id]['value'] > threshold_value:
label = 1
else:
label = 0
return tf.cast(label, tf.int32)
def read_tfrecord(example_input):
parsed_example = parse_tf_records(example_input)
image_data = parsed_example[IMAGE_FIELD]
id_data = parsed_example[ID_FIELD]
label = tf.py_function(get_label, [id_data], tf.int32)
return image_data , label