Torch:根据输出子集计算的损失反向传播



我有一个简单的卷积神经网络,其输出是单通道4x4特征图。在训练期间,只需要根据 16 个输出中的单个值计算(回归)损失。此值的位置将在前向传递后确定。我如何仅从这一个输出中计算损失,同时确保在反向传播期间将所有不相关的梯度归零。

假设我在火炬中有以下简单模型:

require 'nn'
-- the input
local batch_sz = 2
local x = torch.Tensor(batch_sz, 3, 100, 100):uniform(-1,1)
-- the model
local net = nn.Sequential()
net:add(nn.SpatialConvolution(3, 128, 9, 9, 9, 9, 1, 1))
net:add(nn.SpatialConvolution(128, 1, 3, 3, 3, 3, 1, 1))
net:add(nn.Squeeze(1, 3))
print(net)
-- the loss (don't know how to employ it yet)
local loss = nn.SmoothL1Criterion()
-- forward'ing x through the network would result in a 2x4x4 output
y = net:forward(x)
print(y)

我看过nn。SelectTable,似乎如果我将输出转换为表格形式,我将能够实现我想要的?

这是我

当前的解决方案。它的工作原理是将输出拆分为一个表,然后使用 nn。选择表():向后() 获取完整的渐变:

require 'nn'
-- the input
local batch_sz = 2
local x = torch.Tensor(batch_sz, 3, 100, 100):uniform(-1,1)
-- the model
local net = nn.Sequential()
net:add(nn.SpatialConvolution(3, 128, 9, 9, 9, 9, 1, 1))
net:add(nn.SpatialConvolution(128, 1, 3, 3, 3, 3, 1, 1))
net:add(nn.Squeeze(1, 3))
-- convert output into a table format
net:add(nn.View(1, -1))         -- vectorize
net:add(nn.SplitTable(1, 1))    -- split all outputs into table elements
print(net)
-- the loss
local loss = nn.SmoothL1Criterion()
-- forward'ing x through the network would result in a (2)x4x4 output
y = net:forward(x)
print(y)
-- returns the output table's index belonging to specific location
function get_sample_idx(feat_h, feat_w, smpl_idx, feat_r, feat_c)
    local idx = (smpl_idx - 1) * feat_h * feat_w
    return idx + feat_c + ((feat_r - 1) * feat_w)
end
-- I want to back-propagate the loss of this sample at this feature location
local smpl_idx = 2
local feat_r = 3
local feat_c = 4
-- get the actual index location in the output table (for a 4x4 output feature map)
local out_idx = get_sample_idx(4, 4, smpl_idx, feat_r, feat_c)
-- the (fake) ground-truth
local gt = torch.rand(1)
-- compute loss on the selected feature map location for the selected sample
local err = loss:forward(y[out_idx], gt)
-- compute loss gradient, as if there was only this one location
local dE_dy = loss:backward(y[out_idx], gt)
-- now convert into full loss gradient (zero'ing out irrelevant losses)
local full_dE_dy = nn.SelectTable(out_idx):backward(y, dE_dy)
-- do back-prop through who network
net:backward(x, full_dE_dy)
print("The full dE/dy")
print(table.unpack(full_dE_dy))

我真的很感激有人指出一种更简单或更有效的方法。

最新更新