我有一个形状为 Nx7 的张量,它看起来像这样:
[0.97863993 0.64479575 -0.202357 0.94678476 0.0080051 0.44507797 0.47864
0.05914348 -0.72649432 0.193803 0.47295245 0.8381458 0.30449861 0.46783]
我还有另一个相同形状的张量,它是一个布尔掩码:
[True False True True False True False
False True False False True False False]
我想获取第一个张量中每一行的 argmax,但只获取掩码为 True 的那些元素,所以基本上是以下数组的 argmax:
[0.97863993 X -0.202357 0.94678476 X 0.44507797 X
X -0.72649432 X X 0.8381458 X X]
因此,它应该变成:
[0
4]
这在TensorFlow中可能吗?我试图用tf.boolean_mask
弄清楚,但我看不出如何处理掩码中具有不同数量的True
值的不同行。
TF 中的输入代码:
mask = tf.placeholder(shape=[None, 7], dtype=tf.bool)
val = tf.placeholder(shape=[None, 7], dtype=tf.float32)
arg_max = ???
请注意,我也希望正确处理负值(否则Ishant Mrinal提出的方法将起作用(。
将布尔数组转换为浮点数组
# mask = tf.placeholder(shape=[None, 7], dtype=tf.bool)
# mask = tf.cast(mask, dtype=tf.float32)
mask = tf.placeholder(shape=[None, 7], dtype=tf.float32)
val = tf.placeholder(shape=[None, 7], dtype=tf.float32)
argmax = tf.argmax(tf.multiply(val, mask), axis=1)
sess.run(argmax, {val: your_val_array, mask: 2*mask_bool_array.astype(float)-1 })
要模拟屏蔽的 argmax,您可以将掩码外的值设置为-inf
,例如:
masked_val = tf.minimum(val, (2* tf.to_float(mask) - 1) * np.inf)
masked_arg_max = tf.argmax(masked_val, axis=1)
或者,要计算masked_val
,您可以使用
masked_val = tf.where(mask, val, -tf.ones_like(val) * np.inf)
这可以说更清楚,但可能会浪费内存。
对于蒙面的argmin,您将执行相反的操作:
masked_val = tf.maximum(val, (1 - 2* tf.to_float(mask)) * np.inf)
masked_arg_min = tf.argmin(masked_val, axis=1)