获取给定优化器在 Tensorflow 中最小化的损失



我正在为我的 tensorflow 工作区在单元测试系统中工作,我想知道是否有任何方法或属性,给定一个带有优化器操作的图形(在调用 .minimize() 之后),以获得它正在优化的最终损失张量和它控制的变量。

例如,如果我调用train_op = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)我想检索cross_entropy只能访问train_op。

我可以访问train_op对象,我只想知道它引用了哪些损失以及哪些变量控制。

相当微不足道:

def build_graph():
cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(...)
train_op = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
return cross_entropy, train_op    # both tensorflow OPs
cross_entropy, train_op = build_graph()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Evaluate and return cross_entropy, and execute the train_op (update the weights)
result = sess.run([cross_entropy, train_op], feed_dict={...})
print(result[0])   # The value of the cross entropy loss function

这里有许多优秀的教程:https://github.com/aymericdamien/TensorFlow-Examples

您会发现他们在完整的工作模型中正是这样做的。

如果您无法访问张量,则可以按名称在图中查找它:

tf.get_default_graph().get_tensor_by_name("example:0")

看到这个问题: 张量流:如何按名称获取张量?

请注意,如果你没有很好地命名你的张量,这将是后方的皇家痛苦,所以,这是命名你的张量的众多好理由之一。张量的默认名称将使用对操作的引用、冒号、索引号,例如"add:2"表示第三个添加操作。

您可以使用以下内容获取图形中所有张量的列表:

[n.name for n in tf.get_default_graph().as_graph_def().node]

该代码是从这个问题中复制的:在Tensorflow中,获取图中所有张量的名称


在评论中回答此后续问题:

我想知道哪一个是优化train_op没有 必须使用特定名称命名它们。所以给定一个train_op对象, 有没有办法检索张量(或张量的名称) 哪个表示train_op最小化的最后一个值?我需要它 因为我正在自动化一组单元测试,以便如果我插入 我的系统会自动找到的张量流模型,给定 优化器,表示损失的张量(这样我可以 自动执行梯度检查)。

作为我研究的一部分,我已经编写了一个梯度下降优化器。以下是您可以考虑的一些想法:

1)这是我在做同样的事情时遵循的优化器的链接: https://github.com/openai/iaf/blob/master/tf_utils/adamax.py 这是在python中实现AdaMax。你会对_apply_dense()感兴趣,它采用渐变及其变量并执行更新。为每个可训练变量调用它。请注意,tensorflow 中的大多数优化器都是用 C 编码的,而不是使用 python 接口。所以我不确定这是否有帮助,但更好地理解这个过程不会是一件坏事。

2)您可以获得任何变量相对于任何其他变量的梯度。因此,您可以使用tf.trainable_variables()获取可训练变量的集合,然后调用tf.gradients以获取可训练变量相对于损失函数的梯度。不过,您需要为此使用损失函数,而不是训练 OP。我希望您可以从优化器自动找到损失。

如果您只是想从训练 OP 中找到损失函数,您可以通过遵循图依赖关系找到您需要的内容,如以下问题中所述:如何列出节点依赖的所有 Tensorflow 变量?

这是我之前用来获取每个变量及其输入和输出的列表的一种方法。我怀疑你可以弄清楚如何遍历这个数据结构来找到你需要的东西。

tf.get_default_graph().as_graph_def()
Out[6]: 
node {
name: "x"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 42.0
}
}
}
}

最新更新