我遇到了一个使用torch.einsum
计算张量乘法的代码。我能理解低阶张量的工作原理,但不能理解4D张量,如下图:
import torch
a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))
c = torch.einsum('nxhd,nyhd->nhxy', [a,b])
print(c.size())
# output: torch.Size([3, 2, 5, 4])
我需要帮助关于:
- 这里执行的操作是什么(解释矩阵如何相乘/转置等)?
- 在这种情况下
torch.einsum
实际上是有益的吗?
(如果您只想了解einsum中涉及的步骤分解,请跳到tl;dr部分)
我将尝试解释einsum
如何一步一步地为这个例子工作,但不是使用torch.einsum
,我将使用numpy.einsum
(文档),它做了完全相同的,但我只是,一般来说,更舒服。然而,同样的步骤也发生在火炬上。
让我们用NumPy -
重写上面的代码import numpy as np
a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape
#(3, 2, 5, 4)
一步一步np.einsum
Einsum由3个步骤组成:multiply
,sum
和transpose
让我们看看我们的维度。我们有一个(3, 5, 2, 10)
和一个(3, 4, 2, 10)
,我们需要基于'nxhd,nyhd->nhxy'
把它们带到(3, 2, 5, 4)
1。用
让我们不要担心n,x,y,h,d
轴的顺序,只担心你是想保留它们还是删除(减少)它们。把它们以表格的形式写下来,看看我们如何安排我们的维度-
## Multiply ##
n x y h d
--------------------
a -> 3 5 2 10
b -> 3 4 2 10
c1 -> 3 5 4 2 10
为了让x
和y
轴之间的广播乘法得到(x, y)
,我们必须在正确的位置添加一个新轴,然后相乘。
a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)
c1 = a1*b1
c1.shape
#(3, 5, 4, 2, 10) #<-- (n, x, y, h, d)
2。Sum/Reduce
接下来,我们要减少最后一个轴10。这将得到尺寸(n,x,y,h)
。
## Reduce ##
n x y h d
--------------------
c1 -> 3 5 4 2 10
c2 -> 3 5 4 2
这很简单。让我们把np.sum
放到axis=-1
c2 = np.sum(c1, axis=-1)
c2.shape
#(3,5,4,2) #<-- (n, x, y, h)
3。'
最后一步是使用转置重新排列轴。我们可以使用np.transpose
。np.transpose(0,3,1,2)
基本上是在第0轴之后引入第3轴,并推动第1轴和第2轴。因此,(n,x,y,h)
变成了(n,h,x,y)
c3 = c2.transpose(0,3,1,2)
c3.shape
#(3,2,5,4) #<-- (n, h, x, y)
4。最后检查
让我们做最后的检查,看看c3是否与由np.einsum
-
np.allclose(c,c3)
#True
TL;博士
因此,我们将'nxhd , nyhd -> nhxy'
实现为-
input -> nxhd, nyhd
multiply -> nxyhd #broadcasting
sum -> nxyh #reduce
transpose -> nhxy
优势与多个步骤相比,np.einsum
的优点是您可以选择"路径"。用同一个函数来做计算和执行多个操作。这可以通过optimize
参数来实现,该参数将优化insum表达式的收缩顺序。
这些操作的非详尽列表,可以由einsum
计算,如下所示并附有示例:
- 数组跟踪,
numpy.trace
. - 返回对角线,
numpy.diag
。 - 数组轴求和,
numpy.sum
. - 换位和排列,
numpy.transpose
. - 矩阵乘法和点积,
numpy.matmul
numpy.dot
. - 矢量内外产品,
numpy.inner
numpy.outer
. - 广播,元素和标量乘法,
numpy.multiply
. - 张量收缩,
numpy.tensordot
. - 链式数组操作,低效的计算顺序,
numpy.einsum_path
.
基准
%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
显示np.einsum
的操作速度比单个步骤快。