隐式更改nn值的最佳方式.Pytorch中的Parameter()



假设我想优化向量v,使其范数等于1。为了做到这一点,我用这个向量定义了一个网络,如下所示:

class myNetwork(nn.Module):
def __init__(self,initial_vector):
super(myNetwork, self).__init__()
#Define vector according to an initial column vector
self.v = nn.Parameter(initial_vector)
def forward(self,x):
#Normalize vector so that its norm is equal to 1
self.v.data = self.v.data / torch.sqrt(self.v.data.transpose(1,0) @ self.v.data) 
#Multiply v times a row vector 
out = x @ self.v
return out 

使用.data是更新v的最佳方式吗?它是否考虑了反向传播过程中的归一化?

您可以简单地使用

def forward(self,x):
return x @ self.v / (x**2).sum()

根据您的损失或网络的下游层,甚至可以跳过规范化。

def forward(self,x):
return x @ self.v

只要你的损失对规模是不变的,这就应该起作用,每一步的范数应该只有轻微的变化,但它不是严格稳定的。如果你给出了很多步骤,也许值得在你的损失中添加一个术语thinyvalue * ((myNetwork.v**2).sum()-1)**2,以确保$v$的范数被吸引到1。

最新更新