是否有通过self跟踪梯度的方法.Put_variable法在亚麻?



我想通过self.put_variable跟踪梯度。有什么办法让这成为可能吗?或者更新提供给跟踪模块的参数的另一种方法?

import jax 
from jax import numpy as jnp 
from jax import grad,random,jit,vmap 
import flax 
from flax import linen as nn 

class network(nn.Module):
input_size : int 
output_size : int 
@nn.compact
def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))

self.put_variable("params","b",(x@W+b).reshape(5,))  

return jnp.sum(x+b)

if __name__ == "__main__":
key = random.PRNGKey(0)
key_x,key_param,key = random.split(key,3)
x = random.normal(key_x,(1,5))
module = network(5,5)
param = module.init(key_param,x)
print(param)
#x,param = module.apply(param,x,mutable=["params"])
#print(param)
print(grad(module.apply,has_aux=True)(param,x,mutable=["params"]))

我的输出梯度是:

FrozenDict({
params: {
W: DeviceArray([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float32),
b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
},

什么表明它不通过自我跟踪梯度。variable_put方法,因为到W的梯度都是零,而b显然依赖于W。

就像@jakevdp注意到的那样,上面的测试是不正确的,因为b仍然与之前的b相关联。
https://github.com/google/flax/discussions/2215说self。跟踪Put_variable

使用下面的代码测试是否真的是这样:

import jax 
from jax import numpy as jnp 
from jax import grad,random,jit,vmap 
import flax 
from flax import linen as nn 
class network(nn.Module):
input_size : int 
output_size : int 
@nn.compact
def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))
b = x@W+b #update the b variable else it is still tied to the previous one.
self.put_variable("params","b",(b).reshape(5,))  

return jnp.sum(x+b)
def test_update(param,x):
_, param = module.apply(param,x,mutable=["params"])
return jnp.sum(param["params"]["b"]+x),param 
if __name__ == "__main__":
key = random.PRNGKey(0)
key_x,key_param,key = random.split(key,3)
x = random.normal(key_x,(1,5))
module = network(5,5)
param = module.init(key_param,x)
print(param)
print(grad(test_update,has_aux=True)(param,x))

输出:

FrozenDict({
params: {
W: DeviceArray([[ 0.01678762,  0.00234134,  0.00906202,  0.00027337,
0.00599653],
[-0.00729604, -0.00417799,  0.00172333, -0.00566238,
0.0097266 ],
[ 0.00378883, -0.00901531,  0.01898266, -0.01733185,
-0.00616944],
[-0.00806503,  0.00409351,  0.0179838 , -0.00238476,
0.00252594],
[ 0.00398197,  0.00030245, -0.00640218, -0.00145424,
0.00956188]], dtype=float32),
b: DeviceArray([-0.00905032, -0.00574646,  0.01621638, -0.01165553,
-0.0285466 ], dtype=float32),
},
})
(FrozenDict({
params: {
W: DeviceArray([[-1.1489547 , -1.1489547 , -1.1489547 , -1.1489547 ,
-1.1489547 ],
[-2.0069852 , -2.0069852 , -2.0069852 , -2.0069852 ,
-2.0069852 ],
[ 0.98777294,  0.98777294,  0.98777294,  0.98777294,
0.98777294],
[ 0.9311977 ,  0.9311977 ,  0.9311977 ,  0.9311977 ,
0.9311977 ],
[-0.2883922 , -0.2883922 , -0.2883922 , -0.2883922 ,
-0.2883922 ]], dtype=float32),
b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
},
}), FrozenDict({
params: {
W: DeviceArray([[ 0.01678762,  0.00234134,  0.00906202,  0.00027337,
0.00599653],
[-0.00729604, -0.00417799,  0.00172333, -0.00566238,
0.0097266 ],
[ 0.00378883, -0.00901531,  0.01898266, -0.01733185,
-0.00616944],
[-0.00806503,  0.00409351,  0.0179838 , -0.00238476,
0.00252594],
[ 0.00398197,  0.00030245, -0.00640218, -0.00145424,
0.00956188]], dtype=float32),
b: DeviceArray([-0.01861148, -0.00523183,  0.03968921, -0.01952654,
-0.06145691], dtype=float32),
},
}))

第一个FrozenDict是原始参数。
第二个FrozenDict是梯度,显然是通过self.put_variable跟踪的。
最后一个FrozenDict是参数,我们可以看到b是正确更新的。

你的模型的输出是jnp.sum(x + b),它不依赖于W,这反过来意味着相对于W的梯度应该为零。考虑到这一点,上面显示的输出看起来是正确的。

编辑:听起来你希望你在变量中使用的x@W+b的结果反映在返回语句中使用的b的值中;也许你想要这样的东西?

def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))
b = x@W+b
self.put_variable("params","b",b.reshape(5,)) 

return jnp.sum(x+b)

也就是说,我不清楚你的最终目标是什么,考虑到你问的是这样一个不常见的结构,我怀疑这可能是一个XY问题。也许你可以修改一下你的问题,多说一些你想要完成的任务。

最新更新