假设我想优化向量v
,使其范数等于1。为了做到这一点,我用这个向量定义了一个网络,如下所示:
class myNetwork(nn.Module):
def __init__(self,initial_vector):
super(myNetwork, self).__init__()
#Define vector according to an initial column vector
self.v = nn.Parameter(initial_vector)
def forward(self,x):
#Normalize vector so that its norm is equal to 1
self.v.data = self.v.data / torch.sqrt(self.v.data.transpose(1,0) @ self.v.data)
#Multiply v times a row vector
out = x @ self.v
return out
使用.data
是更新v
的最佳方式吗?它是否考虑了反向传播过程中的归一化?
您可以简单地使用
def forward(self,x):
return x @ self.v / (x**2).sum()
根据您的损失或网络的下游层,甚至可以跳过规范化。
def forward(self,x):
return x @ self.v
只要你的损失对规模是不变的,这就应该起作用,每一步的范数应该只有轻微的变化,但它不是严格稳定的。如果你给出了很多步骤,也许值得在你的损失中添加一个术语thinyvalue * ((myNetwork.v**2).sum()-1)**2
,以确保$v$的范数被吸引到1。