用pytorch张量中的子集替换特定列的有效方法是什么?



我之前已经问过这个问题,但在不同的上下文中,但是该方法太耗时了。目前,我的代码需要12秒来完成一个epoch。

我需要在张量的特定列上应用掩码。我找不到任何有用的函数,所以我创建了一个子集,然后在子集上应用掩码。现在我需要将这个子集合并到原始张量中,但是我找不到一个有效的方法。

python
import torch
X=torch.rand(10,9) 
tensorsize = X.size()
indices = torch.tensor([0,3,7,5,4])  #sorting them will make the process faster?
candidateCF=torch.index_select(X,1,indices)
mask=torch.FloatTensor(candidateCF.size()).uniform_() >= 0.3
output=candidateCF.mul(mask)
print(output)

以上代码没问题,output是掩码张量

现在我需要把这个output张量替换成X。最有效的方法是什么呢?

有任何方法可以直接屏蔽X的特定列吗?我认为这样会更有效率。

注:澄清一下,output will be a subset of X

我找到了一个解决办法,贴出来,也许对别人有帮助。

X[:, indices] = output # for column replacements
X[indices, :] = output # for row replacements

最新更新