Pytorch:沿多个轴使用张量的索引,或一次分散到多个索引



我正在尝试更新 Pytorch 中多维张量的非常具体的索引,但我不确定如何访问正确的索引。我可以在 Numpy 中以非常简单的方式做到这一点:

import numpy as np
#set up the array containing the data
data = 100*np.ones((10,10,2))
data[5:,:,:] = 0
#select the data points that I want to update
idxs = np.nonzero(data.sum(2))
#generate the updates that I am going to do
updates = np.random.randint(5,size=(idxs[0].shape[0],2))
#update the data
data[idxs[0],idxs[1],:] = updates

我需要在 Pytorch 中实现这一点,但我不确定如何做到这一点。似乎我需要scatter函数,但这仅适用于单个维度而不是我需要的多个维度。我该怎么做?

这些操作在它们的 PyTorch 对应项中的工作方式完全相同,除了torch.nonzero,默认情况下返回大小为 [z, n] 的张量(其中 z 是非零元素的数量,n 是维度的数量(,而不是大小为[z]n个张量的元组(如 NumPy(,但可以通过设置as_tuple=True来更改该行为。

除此之外,您可以直接将其转换为 PyTorch,但您需要确保类型匹配,因为您不能将类型为torch.long(默认值为torch.randint(的张量分配给torch.float类型(默认值为torch.ones(的张量。在这种情况下,data可能意味着具有类型torch.long

#set up the array containing the data
data = 100*torch.ones((10,10,2), dtype=torch.long)
data[5:,:,:] = 0
#select the data points that I want to update
idxs = torch.nonzero(data.sum(2), as_tuple=True)
#generate the updates that I am going to do
updates = torch.randint(5,size=(idxs[0].shape[0],2))
#update the data
data[idxs[0],idxs[1],:] = updates

最新更新