我想得到在1D张量中多次出现的元素。确切地说,我想创建一个与tf.unique
相反的函数。例如,如果x = [1, 1, 2, 3, 4, 5, 6, 7, 4, 5, 4]
,我需要输出为[1,1,4,4,4,5,5]
,同时还检索源张量中那些元素的索引。我的最终目标是在一批中获得标签出现多次的示例。
您可以使用现有的Tensorflow操作,通过对唯一项进行计数来创建唯一项的密集索引集,然后使用tf.unsorted_segment_sum
对其进行计数。计数后,使用tf.greater
选择具有> N
的项目,并将它们收集回密集列表:
import tensorflow as tf
a = tf.constant([8, 7, 8, 1, 3, 4, 5, 9, 5, 0, 5])
init = tf.initialize_all_variables()
unique_a_vals, unique_idx = tf.unique(a)
count_a_unique = tf.unsorted_segment_sum(tf.ones_like(a),
unique_idx,
tf.shape(a)[0])
more_than_one = tf.greater(count_a_unique, 1)
more_than_one_idx = tf.squeeze(tf.where(more_than_one))
more_than_one_vals = tf.squeeze(tf.gather(unique_a_vals, more_than_one_idx))
# If you want the original indexes:
not_duplicated, _ = tf.listdiff(a, more_than_one_vals)
dups_in_a, indexes_in_a = tf.listdiff(a, not_duplicated)
with tf.Session() as s:
s.run(init)
a, dupvals, dupidxes, dia = s.run([a, more_than_one_vals,
indexes_in_a, dups_in_a])
print "Input: ", a
print "Duplicate values: ", dupvals
print "Indexes of duplicates in a: ", dupidxes
print "Dup vals with dups: ", dia
输入:[8 7 8 1 3 4 5 9 5 0 5]
重复值:[8 5]
a:[0 2 6 8 10]中重复项的索引
带有重复的重复值:[8 8 5 5 5]