tensorflow,从粗糙张量中移除空元素



我有一个像这样的粗糙张量:

<tf.RaggedTensor [[[1, 2]], [[12, 13]], [[16, 17], []], [[18, 19], [20]]]>

我的问题是,如何才能从中删除空元素?所以结果是

<tf.RaggedTensor [[[1, 2]], [[12, 13]], [[16, 17]], [[18, 19], [20]]]>

谢谢!

Boi有这么难:

import tensorflow as tf
@tf.function
def remove_empty_lists(rt):
nrl = rt.nested_row_lengths()
empties = tf.squeeze(tf.where(nrl[1] == 0), axis=1)
diff = tf.expand_dims(rt.nested_row_splits[0][1:], axis=0) - tf.expand_dims(empties, axis=1)
diff_absolute = tf.where(diff<=0, diff.dtype.limits[1], diff)
diff_min = tf.argmin(diff_absolute, axis=1)
counts = tf.unique_with_counts(diff_min)
to_subtract = tf.scatter_nd(tf.expand_dims(counts.y, 1),counts.count,nrl[0].shape)
non_empties = tf.squeeze(tf.where(nrl[1] != 0), axis=1)
nrl_updated = tf.gather(nrl[1], non_empties)
result = tf.RaggedTensor.from_nested_row_lengths(rt.flat_values, (nrl[0] - tf.cast(to_subtract, tf.int64), nrl_updated))
return result

if __name__ == "__main__":
tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[12, 13]],          [[16, 17],[]],          [[18, 19], [20]]])))
tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[12, 13],[]],       [[16, 17]],             [[18, 19], [20]]])))
tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[],[12, 13]],       [[16, 17]],             [[18, 19], [20]]])))
tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[12, 13]],          [[16, 17]],             [[18, 19], [20], []]])))
tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[12, 13]],          [[16, 17]],             [[18, 19], [], [20]]])))
tf.print(remove_empty_lists(tf.ragged.constant([[[1, 2]], [[],[12, 13],[],[]], [[],[],[],[],[16, 17]], [[18, 19], [20]]])))
# all yields [[[1, 2]], [[12, 13]], [[16, 17]], [[18, 19], [20]]]
# NOTE: I only tested this for
# - len(rt.nested_row_lengths()) == 2
# - only 0 values in rt.nested_row_lengths()[1]

这是Frederik Bode开发的解决方案的副本,我没有足够的声誉来添加评论。这个版本允许函数在渴望和图形执行中使用。


@tf.function
def remove_empty_lists(rt):
nrl = rt.nested_row_lengths()
empties = tf.squeeze(tf.where(nrl[1] == 0), axis=1)
diff = tf.expand_dims(rt.nested_row_splits[0][1:], axis=0) - tf.expand_dims(empties, axis=1)
diff_absolute = tf.where(diff <= 0, diff.dtype.limits[1], diff)
diff_min = tf.argmin(diff_absolute, axis=1)
counts = tf.unique_with_counts(diff_min)
to_subtract = tf.scatter_nd(tf.expand_dims(counts.y, 1), counts.count, tf.cast(tf.shape(nrl[0]), tf.dtypes.int64))
non_empties = tf.squeeze(tf.where(nrl[1] != 0), axis=1)
nrl_updated = tf.gather(nrl[1], non_empties)
result = tf.RaggedTensor.from_nested_row_lengths(rt.flat_values,
(nrl[0] - tf.cast(to_subtract, tf.int64), nrl_updated))
return result

最新更新