NumPy Tensordot axes=2



我知道有很多关于tensordot的问题,我已经浏览了一些15页的迷你书的答案,我肯定人们花了很多时间,但我还没有找到axes=2的解释。

这让我认为np.tensordot(b,c,axes=2) == np.sum(b * c),但作为一个数组:

b = np.array([[1,10],[100,1000]])
c = np.array([[2,3],[5,7]])
np.tensordot(b,c,axes=2)
Out: array(7532)

但是这个失败了:

a = np.arange(30).reshape((2,3,5))
np.tensordot(a,a,axes=2)

如果有人能提供一个简短,简明的np.tensordot(x,y,axes=2)解释,只有axes=2,那么我很乐意接受。

In [70]: a = np.arange(24).reshape(2,3,4)
In [71]: np.tensordot(a,a,axes=2)
Traceback (most recent call last):
File "<ipython-input-71-dbe04e46db70>", line 1, in <module>
np.tensordot(a,a,axes=2)
File "<__array_function__ internals>", line 5, in tensordot
File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum

在我之前的文章中,我推断axis=2翻译成axes=([-2,-1],[0,1])

numpy。Tensordot函数一步一步地工作?

In [72]: np.tensordot(a,a,axes=([-2,-1],[0,1]))
Traceback (most recent call last):
File "<ipython-input-72-efdbfe6ff0d3>", line 1, in <module>
np.tensordot(a,a,axes=([-2,-1],[0,1]))
File "<__array_function__ internals>", line 5, in tensordot
File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum

这是在尝试对第一个a的最后两个维度和第二个a的前两个维度进行双轴约简。这个a的尺寸不匹配。显然,这个axes是为2d数组设计的,没有太多考虑3d数组。它不是3轴复位。

这些个位数轴值是一些开发人员认为方便的东西,但这并不意味着它们是经过严格考虑或测试的。

元组轴给你更多的控制:

In [74]: np.tensordot(a,a,axes=[(0,1,2),(0,1,2)])
Out[74]: array(4324)
In [75]: np.tensordot(a,a,axes=[(0,1),(0,1)])
Out[75]: 
array([[ 880,  940, 1000, 1060],
[ 940, 1006, 1072, 1138],
[1000, 1072, 1144, 1216],
[1060, 1138, 1216, 1294]])

最新更新