Tensorflow中的过滤张量元素



Tensorflow中对此的等效操作是什么?例如,我有一个x = np.array([-12,4,6,8,100])。我想做这么简单的事情:x = x[x>5],但我找不到任何TF操作。谢谢

TF中,您可以执行类似的操作以获得类似的结果。

import numpy as np 
import tensorflow as tf 
x = np.array([-12,4,6,8,100])
y = tf.gather(x, tf.where(x > 5))
y.numpy().reshape(-1)
array([  6,   8, 100])

详细信息

tf.where将返回作为Truecondition索引。如

x = np.array([-12,4,6,8,100])
tf.where(x > 5)
<tf.Tensor: shape=(3, 1), dtype=int64, numpy=
array([[2],
[3],
[4]])>

然后,使用tf.gather,根据索引(来自tf.where(从params(x(轴进行切片。如

tf.gather(x, tf.where(x > 5))

最新更新