我在Tensorflow中很难找到关于Dataset
的map
方法的细节。示例
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.map(lambda x: x + 2)
list(dataset.as_numpy_iterator())
工作正常,但通过将map
应用为来更改元素类型
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.map(lambda x: x / 10.0)
list(dataset.as_numpy_iterator())
产生错误消息
TypeError: `x` and `y` must have the same dtype, got tf.int32 != tf.float32.
因为应用的映射函数的返回类型与其输入类型不同。为什么会这样?不可能改变类型吗?如果是这样,我如何才能实现将数据集中的元素类型更改为tf.float32
的预期结果?
请注意,实际的数据集更为复杂,但这是说明该问题的最小示例。
我终于自己发现了。这个问题根本与map
无关,但对于分区,需要显式强制转换。
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.map(lambda x: tf.cast(x, tf.float32) / 10.0)
list(dataset.as_numpy_iterator())