为什么tf.matmul(a,b, transpose_b=True)有效,而tf.matmul(a, tf.trans



代码:

x = tf.constant([1.,2.,3.], shape = (3,2,4))
y = tf.constant([1.,2.,3.], shape = (3,21,4))
tf.matmul(x,y)                     # Doesn't work. 
tf.matmul(x,y,transpose_b = True)  # This works. Shape is (3,2,21)
tf.matmul(x,tf.transpose(y))       # Doesn't work.

我想知道ytf.matmul(x,y,transpose_b = True)内部变成什么形状,这样我就可以全神贯注地弄清楚 LSTM 内部到底发生了什么。

对于秩为> 2 的张量,可以以不同的方式定义转置,这里的区别在于由tf.transposetf.matmul(..., transpose_b=True)转置的轴。

默认情况下,tf.transpose执行以下操作:

返回的张量的维数i将对应于输入维数perm[i]。如果未给出perm,则将其设置为(n-1...0),其中n是输入张量的秩。因此,默认情况下,此操作对二维输入张量执行常规矩阵转置。

所以在你的情况下,它会y转换为形状为(4, 21, 3)的张量,这与x不兼容(见下文)。

但是如果你设置perm=[0, 2, 1],结果是兼容的:

# Works! (3, 2, 4) * (3, 4, 21) -> (3, 2, 21).
tf.matmul(x, tf.transpose(y, [0, 2, 1]))

关于tf.matmul

您可以计算点积:(a, b, c) * (a, c, d) -> (a, b, d)。但它不是张量点积 - 它是一个批处理操作(见这个问题)。

在这种情况下,a被视为批大小,因此tf.matmul计算矩阵(b, c) * (c, d)a点积。

批处理可以是多个维度,因此这也是有效的:

(a, b, c, d) * (a, b, d, e) -> (a, b, c, e)

相关内容

  • 没有找到相关文章

最新更新