计算 Tensorflow 中的权重更新比率



我正在寻找一种方法来计算 Tensorflow 中优化器步骤的权重更新比率。权重-更新-比率定义为每个步骤中的更新尺度除以可变尺度,可用于检查网络训练。

理想情况下,我想要一种非侵入式的方式来在单个会话运行中计算它,但无法完成我想要的。由于更新尺度和参数尺度与训练步骤无关,因此需要向图添加显式依赖关系,以便在更新步骤之前和之后绘制可变尺度。不幸的是,在 TF 中似乎只能为新节点定义依赖项,这使问题进一步复杂化。

到目前为止,我想出的最好的是用于定义必要操作的上下文管理器。其用法如下

opt = tf.train.AdamOptimizer(1e0)
grads = tf.gradients(loss, tf.trainable_variables())
grads = list(zip(grads, tf.trainable_variables()))
with compute_weight_update_ratio('wur') as wur:
train = opt.apply_gradients(grads_and_vars=grads)
# ...
with tf.Session() as sess:
sess.run(wur.ratio)

compute_weight_update_ratio的完整代码可以在下面找到。让我烦恼的是,在当前状态下,权重更新比率(至少norm_before(是用每个训练步骤计算的,但出于性能原因,我宁愿有选择地这样做(例如,仅在计算摘要时(。

关于如何改进的任何想法?

@contextlib.contextmanager
def compute_weight_update_ratio(name, var_scope=None):
'''Injects training to compute weight-update-ratio.
The weight-update-ratio is computed as the update scale divided
by the variable scale before the update and should be somewhere in the 
range 1e-2 or 1e-3.
Params
------
name : str
Operation name
Kwargs
------
var_scope : str, optional
Name selection of variables to compute weight-update-ration for. Defaults to all. Regex supported.
'''
class WeightUpdateRatio:
def __init__(self):
self.num_train = len(tf.get_collection(tf.GraphKeys.TRAIN_OP))
self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=var_scope)
self.norm_before = tf.norm(self.variables, name='norm_before')
def compute_ratio(self,):
train_ops = tf.get_collection(tf.GraphKeys.TRAIN_OP)
assert len(train_ops) > self.num_train, 'Missing training op'
with tf.control_dependencies(train_ops[self.num_train:]):
self.norm_after = tf.norm(self.variables, name='norm_after')
absdiff = tf.abs(tf.subtract(self.norm_after, self.norm_before), name='absdiff')
self.ratio = tf.divide(absdiff, self.norm_before, name=name)
with tf.name_scope(name) as scope:
try:
wur = WeightUpdateRatio()
with tf.control_dependencies([wur.norm_before]):
yield wur
finally:
wur.compute_ratio()

您无需过多担心性能。Tensorflow只执行产生输出所需的子图。

因此,在训练循环中,如果在迭代期间未调用wur.ratio,则不会执行为计算它而创建的任何额外节点。

最新更新