使用tensorflow 2.3.0, python 3.8.11.
代码如下:
a = tf.constant([2, 2, 3, 3], shape=[2, 2], dtype=tf.float32)
print('-------------------')
print(a)
a2 = tf.matmul(a,a)
print('-------------------')
print(a2)
输出如下(在不同的运行中得到其他错误的结果):
-------------------
tf.Tensor(
[[2. 2.]
[3. 3.]], shape=(2, 2), dtype=float32)
-------------------
tf.Tensor(
[[10. 10.]
[ 0. 0.]], shape=(2, 2), dtype=float32)
但是如果将dtype设置为int32或float64,则得到正确的结果,float64结果如下:
-------------------
tf.Tensor(
[[2. 2.]
[3. 3.]], shape=(2, 2), dtype=float64)
-------------------
tf.Tensor(
[[10. 10.]
[15. 15.]], shape=(2, 2), dtype=float64)
这是一个bug吗?
tf.matmul
与Tensorflow 2.6.0
工作良好。
import tensorflow as tf
a = tf.constant([2, 2, 3, 3], shape=[2, 2], dtype=tf.float32)
print('-------------------')
print(a)
a2 = tf.matmul(a,a)
print('-------------------')
print(a2)
-------------------
tf.Tensor(
[[2. 2.]
[3. 3.]], shape=(2, 2), dtype=float32)
-------------------
tf.Tensor(
[[10. 10.]
[15. 15.]], shape=(2, 2), dtype=float32)