PyTorch:将向量中除顶部k以外的所有元素归零



我正在尝试创建一个新的激活层,让我们称之为topk,它的工作原理如下。它将采用大小为n的矢量x作为输入(将前一层输出乘以权重矩阵并添加偏差的结果(和正整数k,并将输出大小为n(x(的矢量topk(x(,其元素为:

x_i (if x_i is one of the top k elements of x) 
topk(x)_i = 
0 (otherwise)

在计算topk(x(的梯度时,x的顶部k个元素应该具有梯度1,其他元素都为0。

我应该如何实现这一点?任何帮助都将不胜感激。

您可以使用torch.topk进行以下操作:

k = 2
output = torch.randn(5)
vals, idx = output.topk(k)
topk = torch.zeros_like(output)
topk[idx] = vals
>>> topk
tensor([1.0557, 0.0000, 0.0000, 1.4562, 0.0000])

注意,虽然topk()'values'是可微的,但'indices'不是(类似于argmax不是可微函数(。

# Find the top-k values and their indices along the last dimension of the tensor.
topk_values, topk_indices = torch.topk(x, k, dim=-1)
# Create a mask tensor with the same shape as 'x', initialized with zeros.
mask = torch.zeros_like(x)
# Set the top-k values to their original values in 'x'.
mask.scatter_(-1, topk_indices, topk_values)