Tensorflow:可以阻止tf.where的一个分支执行吗?



我正在研究编码器解码器设置。我希望能够运行编码器一次,然后执行多个解码器运行。我想出的解决方案是向解码器提供一个 TF 条件节点(使用 tf.where),其中包含编码器的最终隐藏状态(在这种情况下,当我请求解码器输出时,TF 将运行编码器),或者一个占位符包含编码器的存储结果(在这种情况下,理论上 TF 不需要运行编码器)。

以下是代码的相关部分:

encoder_state = tf.where(gen_math_ops.greater_equal(branching_points, 0), encoder_state,
rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])

由于我没有从这种方法中获得加速,所以我很确定它不起作用,并且 tf.where 的两个分支每次都由 TF 运行,即使它只需要从占位符读取。

有没有办法使用 tf.where 这样它就不会运行编码器?我查看了该方法的描述,但我不确定是否总是计算两个分支,我看到了关于这个问题的矛盾信息。

谢谢!

当您想要延迟执行其中一个分支直到计算谓词时,可以使用tf.cond()函数。

encoder_state = tf.cond(
tf.greater_equal(branching_points, 0),
lambda: encoder_state,
lambda: tf.nn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])

我尝试使用 tf.cond创建一个模型并输入字典,但 tf.cond 只会接受一个输入,因此如果您有多个branching_points,这将不起作用。
我已经创建了解决方法,但它非常复杂,我希望看到更好的解决方案,特别是如果true_fn和false_fn计算成本高昂,这只会提高性能。 如果未选取true_fn或false_fn的分支,则不应执行其分支(例如,如果您在这些函数中使用 tf.assign 时),此解决方案也很有用

首先,我创建布尔张量:

branch_1 = tf.greater_equal(branching_points, 0)
branch_2 = tf.logical_not(branch_1)

然后我使用布尔掩码只从分支执行 True 条件

result_1 = tf.boolean_mask(branch_1)
result_2 = tf.boolean_mask(branch_2)

最后,如果需要,您可以形成一个张量。 如果顺序很重要,您可以使用tf.where(tf.equal(branch_1,True))tf.where(tf.equal(branch_2,True))分别获取result_1和result_2的索引。然后你应用tf.scatter_nd。 如果顺序无关紧要,您可以简单地使用 tf.concat

最新更新