PyTorch中可学习的标量权值,并保证标量之和为1



我有这样的代码:

class MyModule(nn.Module):

def __init__(self, channel, reduction=16, n_segment=8):
super(MyModule, self).__init__()
self.channel = channel
self.reduction = reduction
self.n_segment = n_segment

self.conv1 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel//self.reduction, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel//self.reduction, kernel_size=1, bias=False)
self.conv3 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel//self.reduction, kernel_size=1, bias=False)
#whatever
# learnable weight
self.W_1 = nn.Parameter(torch.randn(1), requires_grad=True)
self.W_2 = nn.Parameter(torch.randn(1), requires_grad=True)
self.W_3 = nn.Parameter(torch.randn(1), requires_grad=True)
def forward(self, x):

# whatever

## branch1                
bottleneck_1 = self.conv1(x)

## branch2
bottleneck_2 = self.conv2(x)

## branch3                
bottleneck_3 = self.conv3(x)

## summation
output = self.avg_pool(self.W_1*bottleneck_1 + 
self.W_2*bottleneck_2 + 
self.W_3*bottleneck_3) 

return output

如你所见,3个可学习的标量(W_1,W_2W_3)被用于加权目的。但是,这种方法不能保证这些标量的和等于1。如何在Pytorch中使可学习的标量的和等于1 ?由于

保持简洁:

## summation
WSum = self.W_1 + self.W_2 + self.W_3
output = self.avg_pool( self.W_1/WSum *bottleneck_1 + 
self.W_2/WSum *bottleneck_2 + 
self.W_3/WSum *bottleneck_3)

也可以用分配律:

output = self.avg_pool(self.W_1*bottleneck_1 + 
self.W_2*bottleneck_2 + 
self.W_3*bottleneck_3) /WSum

最新更新