如何组合tf.map_fn和tf.split



所以我想要的东西的pseucode是:

splitted_outputs = [tf.split(output, rate, axis=0) for output in outputs]

在哪里输出是形状的张量(512,?,128),而splitted_outputs是张量或张量的列表,具有3个维度。因此,我可以迭代这种张量张量。

我尝试使用tf.map_fn

splitted_outputs = tf.map_fn(
    lambda output: tf.split(output, rate, axis=0),
    outputs,
    dtype=list
)

不可能,因为list不是合法的TF dtype

您可以在 outputs上使用 tf.unstack获取"子观察器"的列表,然后在每个上使用 tf.split

splitted_outputs = [tf.split(output, rate, axis=0) for output in tf.unstack(outputs, axis=0)]

请注意,只有知道给定axis的大小,或者您需要提供num参数。

才能使用tf.unstack

最新更新