tensorflow map_fn出错.无法指定输出签名



我正试图使用tensorflow的tf.map_fn来映射一个粗糙的张量,但我遇到了一个无法修复的错误。以下是一些演示错误的最小代码:

import tensorflow as tf
X = tf.ragged.constant([[0,1,2], [0,1]])
def outer_product(x):
return x[...,None]*x[None,...]
tf.map_fn(outer_product, X)

我想要的输出是:

tf.ragged.constant([
[[0, 0, 0],
[0, 1, 2],
[0, 2, 4]],
[[0, 0],
[0, 1]]
])

我得到的错误是:

"InvalidArgumentError:所有平面值都必须具有兼容的形状。索引0处的形状:[3]。索引1处的形状:[2]。如果您正在使用tf.map_fn,则可能需要指定一个显式fn_output_signature和适当的ragged_bank,和/或转换将张量输出到RaggedTensors。[操作:RaggedTensorFromVariant]";

我意识到我需要指定fn_output_signature,但尽管进行了实验,我还是无法确定它应该是什么。

编辑:我稍微整理了一下AloneTogether的优秀答案,并创建了一个映射粗糙张量的函数。他的答案使用tf.ragged.stack函数将张量转换为不规则张量,tf.map_fn出于某种原因需要

def ragged_map_fn(func, t): 
def new_func(t):
return tf.ragged.stack(func(t),0)
signature = tf.type_spec_from_value(new_func(t[0]))
ans = tf.map_fn(new_func, t, fn_output_signature=signature)
ans = tf.squeeze(ans, 1)
return ans

粗糙张量有时真的很棘手。这里有一个你可以尝试的选项:

import tensorflow as tf
X = tf.ragged.constant([
[0,1,2], 
[0,1]
])
def outer_product(x):
t = x[...,None] * x[None,...]
return tf.ragged.stack(t)

y = tf.map_fn(outer_product, X, fn_output_signature=tf.RaggedTensorSpec(shape=[1, None, None],
dtype=tf.type_spec_from_value(X).dtype,
ragged_rank=2,
row_splits_dtype=tf.type_spec_from_value(X).row_splits_dtype))
tf.print(y)
#y = tf.concat([y[0, :], y[1, :]], axis=0) # Remove additional dimension from Ragged Tensor
y = y.merge_dims(0, 1)
tf.print(y)
[
[
[
[0, 0, 0], 
[0, 1, 2], 
[0, 2, 4]
]
], 
[
[
[0, 0], 
[0, 1]
]
]
]

y.merge_dims(0, 1)tf.concat:去除附加维度后

[
[
[0, 0, 0], 
[0, 1, 2], 
[0, 2, 4]
], 
[
[0, 0], 
[0, 1]
]
]

相关内容

  • 没有找到相关文章

最新更新