PyTorch中两个张量的所有可能的连接



假设我有两个张量ST定义为:

S = torch.rand((3,2,1))
T = torch.ones((3,2,1))

我们可以把它们看作包含形状为(2, 1)的张量的批。在本例中,批大小为3

我想连接批次之间所有可能的配对。单个批的串联产生形状为(4, 1)的张量。而且还有3*3的组合,所以最终得到的张量C必须具有(3, 3, 4, 1)的形状。

一个解决方案是这样做:

for i in range(S.shape[0]):
for j in range(T.shape[0]):
C[i,j,:,:] = torch.cat((S[i,:,:],T[j,:,:]))

但是for循环不能很好地扩展到大批量。是否有PyTorch命令来做到这一点?

我不知道有任何现成的命令可以做这样的操作。但是,您可以使用单个矩阵乘法以一种简单的方式实现它。


技巧是从已经堆叠的S,T张量开始构建一个包含所有批元素对的张量。然后将其与适当选择的mask张量相乘…在这种方法中,跟踪形状和尺寸大小是必不可少的。

  1. 堆栈由(注意重塑,我们实际上是将ST的批元素平铺成ST上的单个批轴):

    >>> ST = torch.stack((S, T)).reshape(6, 2)
    >>> ST
    tensor([[0.7792, 0.0095],
    [0.1893, 0.8159],
    [0.0680, 0.7194],
    [1.0000, 1.0000],
    [1.0000, 1.0000],
    [1.0000, 1.0000]]
    # ST.shape = (6, 2)
    
  2. 您可以使用rangeitertools.product检索所有(S[i], T[j])对:

    >>> indices = torch.tensor(list(product(range(0, 3), range(3, 6))))
    tensor([[0, 3],
    [0, 4],
    [0, 5],
    [1, 3],
    [1, 4],
    [1, 5],
    [2, 3],
    [2, 4],
    [2, 5]])
    # indices.shape = (9, 2)
    
  3. 从那里,我们使用torch.nn.functional.one_hot构建索引的单热编码:

    >>> mask = one_hot(indices).float()
    tensor([[[1., 0., 0., 0., 0., 0.],
    [0., 0., 0., 1., 0., 0.]],
    [[1., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 1., 0.]],
    [[1., 0., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 1.]],
    [[0., 1., 0., 0., 0., 0.],
    [0., 0., 0., 1., 0., 0.]],
    [[0., 1., 0., 0., 0., 0.],
    [0., 0., 0., 0., 1., 0.]],
    [[0., 1., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 1.]],
    [[0., 0., 1., 0., 0., 0.],
    [0., 0., 0., 1., 0., 0.]],
    [[0., 0., 1., 0., 0., 0.],
    [0., 0., 0., 0., 1., 0.]],
    [[0., 0., 1., 0., 0., 0.],
    [0., 0., 0., 0., 0., 1.]]])
    # mask.shape = (9, 2, 6)
    
  4. 最后,我们计算矩阵乘法并将其重塑为最终形式:

    >>> (mask@ST).reshape(3, 3, 4, 1)
    tensor([[[[0.7792],
    [0.0095],
    [1.0000],
    [1.0000]],
    [[0.7792],
    [0.0095],
    [1.0000],
    [1.0000]],
    [[0.7792],
    [0.0095],
    [1.0000],
    [1.0000]]],
    
    [[[0.1893],
    [0.8159],
    [1.0000],
    [1.0000]],
    [[0.1893],
    [0.8159],
    [1.0000],
    [1.0000]],
    [[0.1893],
    [0.8159],
    [1.0000],
    [1.0000]]],
    
    [[[0.0680],
    [0.7194],
    [1.0000],
    [1.0000]],
    [[0.0680],
    [0.7194],
    [1.0000],
    [1.0000]],
    [[0.0680],
    [0.7194],
    [1.0000],
    [1.0000]]]])
    

我最初使用torch.einsum:torch.einsum('bf,pib->pif', ST, mask)。但是,后来意识到,如果我们交换两个操作数,bf,pib->pif可以很好地简化为一个简单的torch.Tensor.matmul操作:即。

pib,bf->pif(下标b中间缩小)。

在numpy中叫做np。使用网格

https://stackoverflow.com/a/35608701/3259896

在pytorch中,应该是

torch.stack(
torch.meshgrid(x, y)
).T.reshape(-1,2)

其中x和y是你的两个列表。你可以用任何数字。X, y, z等

然后将其重塑为您使用的列表数量。

如果你使用三个列表,使用.reshape(-1,3),四个使用.reshape(-1,4),等等

对于5张量,使用

torch.stack(
torch.meshgrid(a, b, c, d, e)
).T.reshape(-1,5)

我的解决方案是使用torch.repeat_interleaveTensor.repeat来复制for循环。

例如我有

>>> tensor_1 # shape(3, 4)
tensor([[0.1164, 0.6336, 0.7037, 0.1360],
[0.9316, 0.9569, 0.4108, 0.5415],
[0.6325, 0.3159, 0.3307, 0.0700]])
>>> tensor_2 # shape(2, 4)
tensor([[0.1687, 0.3315, 0.1523, 0.1123],
[0.1792, 0.8289, 0.7350, 0.2479]])

得到

的结果
for i in range(tensor_1.shape[0]):
for j in range(tensor_2.shape[0]):
torch.cat([tensor_1[i, ...], tensor_2[j, ...]], dim=0) # shape (8, )

我们可以做

b, h = tensor_1.shape
e, h = tensor_2.shape
result = torch.cat(
[torch.repeat_interleave(tensor_1, repeats=e, dim=0), tensor_2.repeat(b, 1), ]
, dim=-1,
).reshape(b, e, 2 * h)

(torch.repeat_interleave用于外部for循环,它以元素方式重复tensor_1e次。Tensor.repeat是内部for循环,重复tensor_2(b作为一个整体),它给出

>>> result
tensor([[[0.1164, 0.6336, 0.7037, 0.1360, 0.1687, 0.3315, 0.1523, 0.1123],
[0.1164, 0.6336, 0.7037, 0.1360, 0.1792, 0.8289, 0.7350, 0.2479]],
[[0.9316, 0.9569, 0.4108, 0.5415, 0.1687, 0.3315, 0.1523, 0.1123],
[0.9316, 0.9569, 0.4108, 0.5415, 0.1792, 0.8289, 0.7350, 0.2479]],
[[0.6325, 0.3159, 0.3307, 0.0700, 0.1687, 0.3315, 0.1523, 0.1123],
[0.6325, 0.3159, 0.3307, 0.0700, 0.1792, 0.8289, 0.7350, 0.2479]]])

最新更新