如何将一个索引列表转换成一维布尔张量



我想要一个tf。函数从标签文件返回固定长度的一维布尔张量,每个匹配的索引范围为True。索引范围可以重叠。但是我被困在把索引转换成一个平坦的布尔张量。

我正在读取一个文件,其中每一行都是一个标签和一个固定长度的开始和结束索引。对于本例,我们可以说固定长度为50。

标签文件内容示例:

frog 4.0 10.0
frog 20.0 30.0
goat 2.0 20.0
camel 4.0 15.0

这里是

[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

这是个tf。到目前为止,这个函数运行得很好,但我被如何得到最终张量所困扰。我在执行这个tf。函数使用dataset.map(load_label).

def load_label(file_path, accepted_labels=['goat', 'frog']):
label_datas = tf.io.read_file(file_path)
label_datas = tf.strings.strip(label_datas)
label_datas = tf.strings.split(label_datas, sep='n')
label_datas = tf.strings.split(label_datas, sep=' ')
label_datas = label_datas.to_tensor(default_value='0.0', shape=[None, 3])
list_of_indices = []
for label_data in label_datas:
equal = tf.math.equal(label_data[0], accepted_labels)
if tf.reduce_any(equal):
start = tf.strings.to_number(label_data[1], out_type=tf.float32)
end = tf.strings.to_number(label_data[2], out_type=tf.float32)
start = tf.cast(start, tf.int32)
end = tf.cast(end, tf.int32)
list_of_indices.append(tf.range(start, end, 1))
list_of_indices = tf.concat(list_of_indices, axis=0)
list_of_indices, idx = tf.unique(list_of_indices)

我实际上找到了使用tensorArray和reduce_sum的解决方案。

我相信它可以进一步优化。

def load_label(file_path, accepted_labels=['goat', 'frog']):
label_datas = tf.io.read_file(file_path)
label_datas = tf.strings.strip(label_datas)
label_datas = tf.strings.split(label_datas, sep='n')
label_datas = tf.strings.split(label_datas, sep=' ')
label_datas = label_datas.to_tensor(default_value='0.0', shape=[None, 3])
label_ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)
ta_index_count = 0
for label_data in label_datas:
equal = tf.math.equal(label_data[0], accepted_labels)
if tf.reduce_any(equal):
start = tf.strings.to_number(label_data[1], out_type=tf.float32)
end = tf.strings.to_number(label_data[2], out_type=tf.float32)
start = tf.cast(start, tf.int32)
end = tf.cast(end, tf.int32)
label_ta = label_ta.write(ta_index_count, tf.concat(
[tf.zeros((start), dtype=tf.int32),
tf.ones((end - start), dtype=tf.int32),
tf.zeros((new_length - end), dtype=tf.int32)], 0))
else:
label_ta = label_ta.write(ta_index_count, tf.zeros((new_length), dtype=tf.int32))
ta_index_count = ta_index_count + 1
label_ta = label_ta.stack()
label_ta = tf.reduce_sum(label__ta, 0)
label_ta = tf.where(label_ta > 0, 1, 0)
return label_ta