Tensorflow中对此的等效操作是什么?例如,我有一个x = np.array([-12,4,6,8,100])
。我想做这么简单的事情:x = x[x>5]
,但我找不到任何TF操作。谢谢
在TF
中,您可以执行类似的操作以获得类似的结果。
import numpy as np
import tensorflow as tf
x = np.array([-12,4,6,8,100])
y = tf.gather(x, tf.where(x > 5))
y.numpy().reshape(-1)
array([ 6, 8, 100])
详细信息
tf.where
将返回作为True
的condition
的索引。如
x = np.array([-12,4,6,8,100])
tf.where(x > 5)
<tf.Tensor: shape=(3, 1), dtype=int64, numpy=
array([[2],
[3],
[4]])>
然后,使用tf.gather
,根据索引(来自tf.where
(从params(x
(轴进行切片。如
tf.gather(x, tf.where(x > 5))