我想通过self.put_variable跟踪梯度。有什么办法让这成为可能吗?或者更新提供给跟踪模块的参数的另一种方法?
import jax
from jax import numpy as jnp
from jax import grad,random,jit,vmap
import flax
from flax import linen as nn
class network(nn.Module):
input_size : int
output_size : int
@nn.compact
def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))
self.put_variable("params","b",(x@W+b).reshape(5,))
return jnp.sum(x+b)
if __name__ == "__main__":
key = random.PRNGKey(0)
key_x,key_param,key = random.split(key,3)
x = random.normal(key_x,(1,5))
module = network(5,5)
param = module.init(key_param,x)
print(param)
#x,param = module.apply(param,x,mutable=["params"])
#print(param)
print(grad(module.apply,has_aux=True)(param,x,mutable=["params"]))
我的输出梯度是:
FrozenDict({
params: {
W: DeviceArray([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float32),
b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
},
什么表明它不通过自我跟踪梯度。variable_put方法,因为到W的梯度都是零,而b显然依赖于W。
就像@jakevdp注意到的那样,上面的测试是不正确的,因为b仍然与之前的b相关联。
https://github.com/google/flax/discussions/2215说self。跟踪Put_variable
使用下面的代码测试是否真的是这样:
import jax
from jax import numpy as jnp
from jax import grad,random,jit,vmap
import flax
from flax import linen as nn
class network(nn.Module):
input_size : int
output_size : int
@nn.compact
def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))
b = x@W+b #update the b variable else it is still tied to the previous one.
self.put_variable("params","b",(b).reshape(5,))
return jnp.sum(x+b)
def test_update(param,x):
_, param = module.apply(param,x,mutable=["params"])
return jnp.sum(param["params"]["b"]+x),param
if __name__ == "__main__":
key = random.PRNGKey(0)
key_x,key_param,key = random.split(key,3)
x = random.normal(key_x,(1,5))
module = network(5,5)
param = module.init(key_param,x)
print(param)
print(grad(test_update,has_aux=True)(param,x))
输出:
FrozenDict({
params: {
W: DeviceArray([[ 0.01678762, 0.00234134, 0.00906202, 0.00027337,
0.00599653],
[-0.00729604, -0.00417799, 0.00172333, -0.00566238,
0.0097266 ],
[ 0.00378883, -0.00901531, 0.01898266, -0.01733185,
-0.00616944],
[-0.00806503, 0.00409351, 0.0179838 , -0.00238476,
0.00252594],
[ 0.00398197, 0.00030245, -0.00640218, -0.00145424,
0.00956188]], dtype=float32),
b: DeviceArray([-0.00905032, -0.00574646, 0.01621638, -0.01165553,
-0.0285466 ], dtype=float32),
},
})
(FrozenDict({
params: {
W: DeviceArray([[-1.1489547 , -1.1489547 , -1.1489547 , -1.1489547 ,
-1.1489547 ],
[-2.0069852 , -2.0069852 , -2.0069852 , -2.0069852 ,
-2.0069852 ],
[ 0.98777294, 0.98777294, 0.98777294, 0.98777294,
0.98777294],
[ 0.9311977 , 0.9311977 , 0.9311977 , 0.9311977 ,
0.9311977 ],
[-0.2883922 , -0.2883922 , -0.2883922 , -0.2883922 ,
-0.2883922 ]], dtype=float32),
b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
},
}), FrozenDict({
params: {
W: DeviceArray([[ 0.01678762, 0.00234134, 0.00906202, 0.00027337,
0.00599653],
[-0.00729604, -0.00417799, 0.00172333, -0.00566238,
0.0097266 ],
[ 0.00378883, -0.00901531, 0.01898266, -0.01733185,
-0.00616944],
[-0.00806503, 0.00409351, 0.0179838 , -0.00238476,
0.00252594],
[ 0.00398197, 0.00030245, -0.00640218, -0.00145424,
0.00956188]], dtype=float32),
b: DeviceArray([-0.01861148, -0.00523183, 0.03968921, -0.01952654,
-0.06145691], dtype=float32),
},
}))
第一个FrozenDict是原始参数。
第二个FrozenDict是梯度,显然是通过self.put_variable跟踪的。
最后一个FrozenDict是参数,我们可以看到b是正确更新的。
你的模型的输出是jnp.sum(x + b)
,它不依赖于W
,这反过来意味着相对于W
的梯度应该为零。考虑到这一点,上面显示的输出看起来是正确的。
编辑:听起来你希望你在变量中使用的x@W+b
的结果反映在返回语句中使用的b
的值中;也许你想要这样的东西?
def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))
b = x@W+b
self.put_variable("params","b",b.reshape(5,))
return jnp.sum(x+b)
也就是说,我不清楚你的最终目标是什么,考虑到你问的是这样一个不常见的结构,我怀疑这可能是一个XY问题。也许你可以修改一下你的问题,多说一些你想要完成的任务。