我看过几篇关于向TensorFlow权重变量添加简单约束(即非负性(的文章,但没有一篇关于如何防止权重改变符号的文章。例如,如果我有W = tf.get_variable('W1', [512, 256], initializer=initializer, dtype=tf.float32)
如何添加约束,以便在初始化后W[i,j]
无法更改符号?我没有看到在 tf.get_variable(( 中使用"约束"选项的明确方法。
我解决这个问题的方法如下。
对于每个重量,您存储初始符号。这可以使用以下代码完成
w1 = tf.get_variable('W1', [512, 256], initializer=initializer, dtype=tf.float32)
w1_sign = tf.zeros_like(w1)
store_sign = tf.assign(w1_sign, tf.sign(w1))
可以使用以下代码在权重中断符号约束时使权重为 0。
constraint_op = tf.assign(w1, tf.where(w1_sign * w1 >= 0, w1, 0))
现在您可以按如下方式运行上面的代码
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(store_sign)
for _ in range(train_itr):
sess.run(some_train_op)
sess.run(constraint_op)
请注意,在上面的代码中,您只运行一次操作store_sign
,并且在每次运行train_op
后运行操作constraint_op
。
同样的想法可以应用于tf.get_variable
的论证constraints
。