TF map_fn或while_loop,用于不同形状的张量列表



我想处理一个不同形状的张量序列(列表)并输出另一个张量列表。想想每个时间戳上具有不同隐藏状态大小的 RNN。类似的东西

输入: [tf.ones((1, 2, 2)), tf.ones((2, 2,3)), tf.ones((3, 2, 1))]

输出: [tf.zeros((1, 2, 4)), tf.zeros((4, 2, 6)), tf.zeros((6, 2, 1))]

我无法将输入(或输出)堆叠到单个张量中,因为它们都有不同的形状,因此我不能将tf.map_fn用于任务。目前,我使用 python 进行循环,但它似乎是次优的。

我能做些什么更好的事情吗?

您可以使用tf.while_loop重复执行任意 TensorFlow 操作,直到出现某种停止条件。停止条件本身被指定为 op。

请注意,应谨慎使用tf.while_loop,因为默认情况下,其迭代将并行运行。例如,如果循环体递增 tf。变量,则必须使用控件依赖项来确保迭代按顺序运行。

但是,您提到您有一个带有 Python 循环的工作实现。如果可能的话,使用 Python 进行循环通常是最有效的解决方案。在 Python 中构建循环时,会为循环中的每个迭代创建单独的操作。这让TensorFlow在图构建中决定如何将计算资源分配给每个操作。例如,如果提前知道迭代次数,则更容易预测内存要求和并行化可能性。

因此,当在图形构建时停止条件未知时,最常使用tf.while_looptf.map_fn

如果迭代次数固定但非常多,您可能仍希望使用 tf.while_loop 而不是 Python 循环,因为每个操作的内存成本非常重要。

相关内容

  • 没有找到相关文章

最新更新