如何火炬.如何执行这个四维张量乘法?



我遇到了一个使用torch.einsum计算张量乘法的代码。我能理解低阶张量的工作原理,但不能理解4D张量,如下图:

import torch
a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))
c = torch.einsum('nxhd,nyhd->nhxy', [a,b])
print(c.size())
# output: torch.Size([3, 2, 5, 4])

我需要帮助关于:

  1. 这里执行的操作是什么(解释矩阵如何相乘/转置等)?
  2. 在这种情况下torch.einsum实际上是有益的吗?

(如果您只想了解einsum中涉及的步骤分解,请跳到tl;dr部分)

我将尝试解释einsum如何一步一步地为这个例子工作,但不是使用torch.einsum,我将使用numpy.einsum(文档),它做了完全相同的,但我只是,一般来说,更舒服。然而,同样的步骤也发生在火炬上。

让我们用NumPy -

重写上面的代码
import numpy as np
a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape
#(3, 2, 5, 4)

一步一步np.einsum

Einsum由3个步骤组成:multiply,sumtranspose

让我们看看我们的维度。我们有一个(3, 5, 2, 10)和一个(3, 4, 2, 10),我们需要基于'nxhd,nyhd->nhxy'把它们带到(3, 2, 5, 4)

1。用

让我们不要担心n,x,y,h,d轴的顺序,只担心你是想保留它们还是删除(减少)它们。把它们以表格的形式写下来,看看我们如何安排我们的维度-

## Multiply ##
n   x   y   h   d
--------------------
a  ->  3   5       2   10
b  ->  3       4   2   10
c1 ->  3   5   4   2   10

为了让xy轴之间的广播乘法得到(x, y),我们必须在正确的位置添加一个新轴,然后相乘。

a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)
c1 = a1*b1
c1.shape
#(3, 5, 4, 2, 10)  #<-- (n, x, y, h, d)

2。Sum/Reduce

接下来,我们要减少最后一个轴10。这将得到尺寸(n,x,y,h)

## Reduce ##
n   x   y   h   d
--------------------
c1  ->  3   5   4   2   10
c2  ->  3   5   4   2

这很简单。让我们把np.sum放到axis=-1

c2 = np.sum(c1, axis=-1)
c2.shape
#(3,5,4,2)  #<-- (n, x, y, h)

3。'

最后一步是使用转置重新排列轴。我们可以使用np.transposenp.transpose(0,3,1,2)基本上是在第0轴之后引入第3轴,并推动第1轴和第2轴。因此,(n,x,y,h)变成了(n,h,x,y)

c3 = c2.transpose(0,3,1,2)
c3.shape
#(3,2,5,4)  #<-- (n, h, x, y)

4。最后检查

让我们做最后的检查,看看c3是否与由np.einsum-

生成的c相同
np.allclose(c,c3)
#True

TL;博士

因此,我们将'nxhd , nyhd -> nhxy'实现为-
input     -> nxhd, nyhd
multiply  -> nxyhd      #broadcasting
sum       -> nxyh       #reduce
transpose -> nhxy

优势与多个步骤相比,np.einsum的优点是您可以选择"路径"。用同一个函数来做计算和执行多个操作。这可以通过optimize参数来实现,该参数将优化insum表达式的收缩顺序。

这些操作的非详尽列表,可以由einsum计算,如下所示并附有示例:

  • 数组跟踪,numpy.trace.
  • 返回对角线,numpy.diag
  • 数组轴求和,numpy.sum.
  • 换位和排列,numpy.transpose.
  • 矩阵乘法和点积,numpy.matmulnumpy.dot.
  • 矢量内外产品,numpy.innernumpy.outer.
  • 广播,元素和标量乘法,numpy.multiply.
  • 张量收缩,numpy.tensordot.
  • 链式数组操作,低效的计算顺序,numpy.einsum_path.

基准
%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)

显示np.einsum的操作速度比单个步骤快。

最新更新