Tensorflow Graph - 检查节点是否依赖于占位符



Tensorflow图中,有没有办法找出node是否依赖于placeholder,就像node.depends(placeholder) -> Bool

import tensorflow as tf
x = tf.placeholder(name='X', dtype=tf.int64, shape=[])
y = tf.placeholder(name='Y', dtype=tf.int64, shape=[])
p = tf.add(x, y)
q = tf.add(x, x)
sess = tf.Session()
result = sess.run([p,q], feed_dict={x: 1, y: 2})
print(result)
result = sess.run([p,q], feed_dict={x: 1, y: 3})
print(result)

在上面的代码示例中,q不依赖于y。在session.run的第二次调用中,我们只修改y。因此,q不需要再次评估。在这些情况下,会话是否会自动重用现有值?如果是这样,有没有办法找出哪些是.run期间评估的节点?

否则,如果我能快速找出哪些节点依赖于我修改的占位符,我只能将这些节点作为输入发送以运行,并重用现有值(在字典中作为缓存(。

这个想法是为了避免昂贵的评估,更重要的是,在我的应用程序中,最大限度地减少在输出节点更改时需要触发的昂贵操作(外部tensorflow- 这是我应用程序中的必要条件。

检查图中张量之间的依赖关系可以使用如下函数来完成:

import tensorflow as tf
# Checks if tensor a depends on tensor b
def tensor_depends(a, b):
if a.graph is not b.graph:
return False
gd = a.graph.as_graph_def()
gd_sub = tf.graph_util.extract_sub_graph(gd, [a.op.name])
return b.op.name in {n.name for n in gd_sub.node}

例如:

import tensorflow as tf
x = tf.placeholder(name='X', dtype=tf.int64, shape=[])
y = tf.placeholder(name='Y', dtype=tf.int64, shape=[])
p = tf.add(x, y)
q = tf.add(x, x)
print(tensor_depends(q, x))
# True
print(tensor_depends(q, y))
# False

关于您的问题,通常 TensorFlow 会在每次运行时重新计算所有内容,即使输入没有变化。如果你想缓存结果,你可以在更高的级别自己做 - 对于 TensorFlow,不清楚它应该保留什么结果(只有最新的,一些最近的,......在任何情况下,即使输入没有变化,输出也可能发生变化,就像循环模型一样,或者更一般地说,由于变量或数据集等有状态对象的任何变化。有一些优化机会丢失了,但期望 TensorFlow 解决这些问题可能是不合理的(分析模型以确定它是否可以缓存结果、可以缓存哪些结果、多少、如何配置它等(。如果您知道图中的某些中间值不会更改,您还可以计算该部分,然后将其作为输入(您实际上可以为feed_dict中的任何张量输入值(。否则,正如您所建议的,您可以只检查什么取决于什么,并仅在必要时重新计算。

最新更新