如何创建三维Nx.来自四维Nx的张量.张量(长生不老药)?



我需要从两个矩阵a,b(两个二维张量)创建三维张量m,这样m[i][k][j] = a[i][k] * b[k][j](事实上,维度顺序在这里并不重要,所以它也可以:)m[k][i][j] = a[i][k] * b[k][j]m[i][j][k] = a[i][k] * b[k][j].

我找到了如何创建4-D张量n,n[i][l][k][j] = a[i][l]*b[k][j]的方法,现在我需要切片/收集/以某种方式只取l==k的元素。

是否有办法切片n,或创建m在其他更好的方式?

a_mat = Nx.random_uniform({5,5}, names: [:i, :l])
b_mat = Nx.random_uniform({5,5}, names: [:k, :j]) 
n_tensor = Nx.dot(a_mat,[], b_mat,[])
m_tensor = ???

谢谢。

经过一番探索,我认为最好的方法是:

m_tensor = Nx.dot( Nx.transpose(a_mat), [], [0], b_mat, [],[0])

就会创建一个张量

m[k][i][j] = a[i][k] * b[k][j]