如果我们传入张量,为什么tf.constant会给出dtype错误



以下代码

a = tf.range(10)
b = tf.constant(a, dtype=tf.float32)

给出以下错误:

TypeError: Expected tensor with type tf.float32 not tf.int32

尽管在文档中,设置dtype意味着tf.constant应该将a强制转换为指定的数据类型。所以我不明白为什么会出现类型错误。

我也知道:

a = np.arange(10)
b = tf.constant(a, dtype=tf.float32)

不会给出错误。

所以实际上,我主要想知道这里的引擎盖下发生了什么。

如果您查看此处的源代码,您将看到EagerTensor得到了特殊处理。基本上,如果EagerTensordtype与新的dtype不匹配,则会引发错误。

这里,tf.range()产生EagerTensor。我不知道为什么对EagerTensors进行特殊治疗。可能是与性能相关的限制。

最新更新