我之前已经问过这个问题,但在不同的上下文中,但是该方法太耗时了。目前,我的代码需要12秒来完成一个epoch。
我需要在张量的特定列上应用掩码。我找不到任何有用的函数,所以我创建了一个子集,然后在子集上应用掩码。现在我需要将这个子集合并到原始张量中,但是我找不到一个有效的方法。
python
import torch
X=torch.rand(10,9)
tensorsize = X.size()
indices = torch.tensor([0,3,7,5,4]) #sorting them will make the process faster?
candidateCF=torch.index_select(X,1,indices)
mask=torch.FloatTensor(candidateCF.size()).uniform_() >= 0.3
output=candidateCF.mul(mask)
print(output)
以上代码没问题,output
是掩码张量
现在我需要把这个output
张量替换成X
。最有效的方法是什么呢?
或
有任何方法可以直接屏蔽X的特定列吗?我认为这样会更有效率。
注:澄清一下,output will be a subset of X
我找到了一个解决办法,贴出来,也许对别人有帮助。
X[:, indices] = output # for column replacements
X[indices, :] = output # for row replacements