如何将先前转换为标准张量的具有嵌套不规则维度的不规则张量转换回


RaggedTensor有方法to_tensor()from_tensor()。然而,如果ragged_tensor具有嵌套维度,则应用tf.RaggedTensor.from_tensor(ragged_tensor.to_tensor(), padding=0)似乎失败

示例:

data = tf.ragged.constant([
[[4,35,6,33], [7,2], [89,56,12]],
[[2,11], [9]]
])
tf.RaggedTensor.from_tensor(data.to_tensor(), padding=0)

返回错误

Traceback (most recent call last):
File "./src/tppmodel.py", line 34, in <module>
tf.RaggedTensor.from_tensor(data.to_tensor(), padding=0)
File "/opt/conda/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in    error_handler
raise e.with_traceback(filtered_tb) from None
File "/opt/conda/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py", line 1307, in assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes () and (4,) are incompatible

预期

tf.RaggedTensor.from_tensor(data.to_tensor(),..., padding=0) = data

我自己可能已经找到了答案。张贴以备不时之需:不使用填充参数,而是传递原始张量的嵌套行长度

tf.RaggedTensor.from_tensor(data.to_tensor(), lengths=data.nested_row_lengths())
data = tf.ragged.constant([
[[4,35,6,33], [7,2], [89,56,12]],
[[2,11], [9]]
])
data
tf.RaggedTensor.from_tensor(data.to_tensor(), lengths=data.nested_row_lengths())
> <tf.RaggedTensor [[[4, 35, 6, 33], [7, 2], [89, 56, 12]], [[2, 11], [9]]]>

相关内容

  • 没有找到相关文章

最新更新