给定一个固定大小的TensorArray和具有统一形状的条目,我想去一个包含相同值的张量,只是通过将TensorArray的索引维度作为规则轴。
TensorArrays有一个方法叫做"gather"据说它应该做到这一点。事实上,下面的例子是可行的:
array = tf.TensorArray(tf.int32, size=3)
array.write(0, 10)
array.write(1, 20)
array.write(2, 30)
gathered = array.gather([0, 1, 2])
"gathered"然后生成所需的张量:
tf.Tensor([10 20 30], shape=(3,), dtype=int32)
不幸的是,当将其包装在tf中时,这将停止工作。函数,如下所示:
@tf.function
def func():
array = tf.TensorArray(tf.int32, size=3)
array.write(0, 10)
array.write(1, 20)
array.write(2, 30)
gathered = array.gather([0, 1, 2])
return gathered
tensor = func()
"tensor"则错误地生成以下张量:
tf.Tensor([0 0 0], shape=(3,), dtype=int32)
你对此有一个解释,或者你能建议一种替代方法从TensorArray到张量在tf.function中?
按https://github.com/tensorflow/tensorflow/issues/30409#issuecomment-508962873你必须:
取代加勒比海盗。写(j, t), arr = arr。写(j, t)
问题是。函数以图的形式执行。在急于模式下,数组将被更新(为了方便),但您实际上应该使用返回值进行链式操作:https://www.tensorflow.org/api_docs/python/tf/TensorArray#returns_6
代替array.gather()
,尝试使用array.stack()
,它将从TensorArray
返回Tensor