粗糙张量的张量流tf.map_fn失败,对象类型为 'RaggedTensor' 没有透镜



这个Tensorflow文档给出了在粗糙张量上使用tf.map_fn的例子,它适用于Tensorflow 2.4.1及以上版本:

digits = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
print(tf.map_fn(tf.math.square, digits))

然而,下面的例子导致错误"'RaggedTensor'类型的对象没有len"在Tensorflow 2.4.1或Tensorflow 2.5中运行时:

import tensorflow as tf
X=tf.ragged.constant([[1.,2.],[3.,4.,5.]], dtype=tf.float32)
@tf.function
def powerX(i):
global X
return X**i
Y = tf.map_fn(powerX, tf.range(3, dtype=tf.float32))

有办法让这个工作吗?我不明白抛出的错误。一般来说,我试图通过映射一个用户定义的函数来获得完全并行性,该函数在粗糙张量上只有Tensorflow操作,结果是粗糙张量。

tf.map_fn需要一个输出签名。我不确定为什么它不能从输入中推断出来,但这是一个给张量流研究人员的问题。下面的代码将为您工作。

import tensorflow as tf
X=tf.ragged.constant([[1.,2.],[3.,4.,5.]], dtype=tf.float32)
@tf.function
def powerX(i):
global X
return X**i
signature = tf.type_spec_from_value(powerX(X))
Y = tf.map_fn(powerX, tf.range(3, dtype=tf.float32),fn_output_signature=signature)

最新更新