我在Tensorflow中实现了一個特殊的損失函數。这是一个特殊函数的 numpy 样式代码,它选择顶部 q 元素并屏蔽每行和每列中的其他元素。请注意,A
是n*n
矩阵,q
是小于n
的整数。
def thresh(A, q):
A_ = A.copy()
n = A_.shape[1]
for i in range(n):
A_[i, :][A_[i, :].argsort()[0:n - q]] = 0
A_[:, i][A_[:, i].argsort()[0:n - q]] = 0
return A_
现在的问题是我有一个形状为(n,n)
的 Tensorflow 张量A
,我想实现与 numpy 相同的逻辑。但是,我不能使用索引直接为张量A
赋值。安永对此有一些解决方案吗?
TLDR;
我们可以创建一个函数,按行屏蔽除前k
个元素之外的所有元素,如下所示:
def mask_all_but_top_k(X, k):
n = X.shape[1]
top_k_indices = tf.math.top_k(X, k).indices
mask = tf.reduce_sum(tf.one_hot(top_k_indices, n), axis=1)
return mask * X
不幸的是,tf.map.top_k
不允许我们指定维度,但我们当然可以通过先转置X
然后用tf.transpose()
转置结果来复制此列
解释
我们可以通过创建一个 1 和 0 的掩码,然后逐个元素相乘来达到目标。
因此,例如,考虑n=4, k=2
的情况,我们有以下矩阵:
array([[0.67757607, 0.74070597, 0.89508283, 0.11858773],
[0.7661159 , 0.8737055 , 0.73599136, 0.1552105 ],
[0.7093129 , 0.44203556, 0.48861897, 0.83231044],
[0.24682868, 0.36648738, 0.92984104, 0.9881872 ]], dtype=float32)
然后我们可以使用tf.math.top_k
函数获取矩阵每行中前 2 个值的索引:
top_k_indices = tf.math.top_k(X, 2).indices
现在,我们使用一个小技巧来首先one_hot
编码这些:
tf.one_hot(top_k_indices, 4)
array([[[0., 0., 1., 0.],
[0., 1., 0., 0.]],
[[0., 1., 0., 0.],
[1., 0., 0., 0.]],
[[0., 0., 0., 1.],
[1., 0., 0., 0.]],
[[0., 0., 0., 1.],
[0., 0., 1., 0.]]], dtype=float32)>
然后将它们reduce_sum
到倒数第二个维度以创建我们的蒙版:
tf.reduce_sum(tf.one_hot(top_k_indices, 4), axis=1)
array([[0., 1., 1., 0.],
[1., 1., 0., 0.],
[1., 0., 0., 1.],
[0., 0., 1., 1.]], dtype=float32)>
现在我们可以做一个 Hadamard(元素)乘法来获得所需的结果:
array([[0. , 0.74070597, 0.89508283, 0. ],
[0.7661159 , 0.8737055 , 0. , 0. ],
[0.7093129 , 0. , 0. , 0.83231044],
[0. , 0. , 0.92984104, 0.9881872 ]], dtype=float32)>
将所有这些放在一起,我们可以创建一个函数,该函数按行屏蔽除顶部k
元素之外的所有元素,如下所示:
def mask_all_but_top_k(X, k):
n = X.shape[1]
top_k_indices = tf.math.top_k(X, k).indices
mask = tf.reduce_sum(tf.one_hot(top_k_indices, n), axis=1)
return mask * X