使用TensorFlow matmul时的意外结果,dtype=tf.float32



使用tensorflow 2.3.0, python 3.8.11.

代码如下:

a = tf.constant([2, 2, 3, 3], shape=[2, 2], dtype=tf.float32)
print('-------------------')
print(a)
a2 = tf.matmul(a,a)
print('-------------------')
print(a2)

输出如下(在不同的运行中得到其他错误的结果):

-------------------
tf.Tensor(
[[2. 2.]
[3. 3.]], shape=(2, 2), dtype=float32)
-------------------
tf.Tensor(
[[10. 10.]
[ 0.  0.]], shape=(2, 2), dtype=float32)

但是如果将dtype设置为int32或float64,则得到正确的结果,float64结果如下:

-------------------
tf.Tensor(
[[2. 2.]
[3. 3.]], shape=(2, 2), dtype=float64)
-------------------
tf.Tensor(
[[10. 10.]
[15. 15.]], shape=(2, 2), dtype=float64)

这是一个bug吗?

tf.matmulTensorflow 2.6.0工作良好。

import tensorflow as tf
a = tf.constant([2, 2, 3, 3], shape=[2, 2], dtype=tf.float32)
print('-------------------')
print(a)
a2 = tf.matmul(a,a)
print('-------------------')
print(a2)

-------------------
tf.Tensor(
[[2. 2.]
[3. 3.]], shape=(2, 2), dtype=float32)
-------------------
tf.Tensor(
[[10. 10.]
[15. 15.]], shape=(2, 2), dtype=float32)

相关内容