一次获取张量流中多个元素的索引



我是Tensorflow的新手。

我有一个问题。

这里有一维数组。

values = [101,103,105,109,107]
target_values = [105, 103]

我想一次从values那里获得有关target_values的索引。

从上面的例子中提取的索引将如下所示。

indices = [2, 1]

当我使用tf.map_fn功能时。 这个问题很容易解决。

# if you do not change data type from int64 to int32. TypeError will riase
values = tf.cast(tf.constant([100, 101, 102, 103, 104]), tf.int64)
target_values = tf.cast(tf.constant([100, 101]), tf.int64)
indices = tf.map_fn(lambda x: tf.where(tf.equal(values, x)), target_values)

谢谢!

假设target_values中的所有值都在values中,这是一种简单的方法(TF 2.x,但该函数对于1.x应该相同(:

import tensorflow as tf
values = [101, 103, 105, 109, 107]
target_values = [105, 103]
# Assumes all values in target_values are in values
def find_in_array(values, target_values):
values = tf.convert_to_tensor(values)
target_values = tf.convert_to_tensor(target_values)
# stable=True if there may be repeated elements in values
# and you want always first occurrence
idx_s = tf.argsort(values, stable=True)
values_s = tf.gather(values, idx_s)
idx_search = tf.searchsorted(values_s, target_values)
return tf.gather(idx_s, idx_search)
print(find_in_array(values, target_values).numpy())
# [2 1]

最新更新