具有值向量的回归模型的pytorch损失函数



我正在训练一个CNN架构,以使用PyTorch解决回归问题,其中我的输出是25个值的张量。输入/目标张量可以是全零或西格玛值为2的高斯分布。一个4样本批次的例子如下:

[[0.13534, 0.32465, 0.60653, 0.8825, 1.0000, 0.88250,0.60653, 0.32465, 0.13534, 0.043937, 0.011109, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,  0., 0.],
[0., 0., 0.,  0., 0., 0., 0.13534, 0.32465, 0.60653, 0.8825, 1.0000, 0.88250,0.60653, 0.32465, 0.13534, 0.043937, 0.011109, 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0.,  0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.13534, 0.32465, 0.60653, 0.8825, 1.0000, 0.88250,0.60653, 0.32465, 0.13534 ],
[0., 0., 0.,  0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]

我的问题是如何为模型设计损失函数,有效地学习25个值的回归输出。

我试过两种类型的损失,torch.nn.MSELoss()torch.nn.MSELoss()-torch.nn.CosineSimilarity()。他们有点工作。然而,有时网络很难收敛,尤其是当存在大量样本且所有样本都为"0"时;零";,这导致网络输出具有全部25个小值的矢量。

我的问题是,还有其他损失我们可以尝试吗?

您的值在规模上似乎没有太大差异,因此MSELoss似乎可以正常工作。你的模型可能会因为你的目标中有很多零而崩溃。

你总是可以尝试torch.nn.L1Loss()(但我不希望它比torch.nn.MSELoss()好多少(

我建议你试着预测高斯平均值/mu,如果你真的需要的话,以后试着为每个样本重新创建高斯

因此,如果你选择尝试这种方法,你有两种选择。

Alt 1

一个很好的选择是将目标编码为分类目标。您的25个元素向量将成为一个值,其中原始目标==1(可能的类将为0、1、2、…、24(。然后我们可以分配一个包含";只有零";作为我们的最后一堂课&;25〃;。所以你的目标:

[[0.13534, 0.32465, 0.60653, 0.8825, 1.0000, 0.88250,0.60653, 0.32465, 0.13534, 0.043937, 0.011109, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,  0., 0.],
[0., 0., 0.,  0., 0., 0., 0.13534, 0.32465, 0.60653, 0.8825, 1.0000, 0.88250,0.60653, 0.32465, 0.13534, 0.043937, 0.011109, 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0.,  0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.13534, 0.32465, 0.60653, 0.8825, 1.0000, 0.88250,0.60653, 0.32465, 0.13534 ],
[0., 0., 0.,  0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]

成为

[4,
10,
20,
25]

如果您这样做,那么您可以尝试常见的torch.nn.CrossEntropyLoss()

我不知道你的数据加载器是什么样子的,但如果你有一个原始格式的样本,你可以用将其转换为我提议的格式

def encode(tensor):
if tensor.sum() == 0:
return len(tensor)
return torch.argmax(tensor)

并返回到高斯:

def decode(value):
n_values = 25
zero = torch.zeros(n_values)
if value == n_values:
return zero
# Create gaussian around value
std = 2
n = torch.arange(n_values) - value
sig = 2*std**2
gauss = torch.exp(-n**2 / sig2)
# Only return 9 values from the gaussian
start_ix = max(value-6, 0)
end_ix = min(value+7,n_values)
zero[start_ix:end_ix] = gauss[start_ix:end_ix]
return zero

(注意,我没有批量试用过,只有样品(

Alt 2

第二种选择是将你的回归目标(仍然只有argmax位置(mu((改变为0-1范围内的更好的回归值,并具有一个单独的神经元;掩码值";(也是0-1(。然后你的批次:

[[0.13534, 0.32465, 0.60653, 0.8825, 1.0000, 0.88250,0.60653, 0.32465, 0.13534, 0.043937, 0.011109, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,  0., 0.],
[0., 0., 0.,  0., 0., 0., 0.13534, 0.32465, 0.60653, 0.8825, 1.0000, 0.88250,0.60653, 0.32465, 0.13534, 0.043937, 0.011109, 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0.,  0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.13534, 0.32465, 0.60653, 0.8825, 1.0000, 0.88250,0.60653, 0.32465, 0.13534 ],
[0., 0., 0.,  0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]

成为

# [Mask, mu]
[
[1, 0.1666], # True, 4/24
[1, 0.4166], # True, 10/24
[1, 0.8333], # True, 20/24
[0, 0]       # False, undefined
]

如果您正在使用此设置,那么您应该能够使用MSELoss并进行修改:

def custom_loss(input, target):
# Assume target and input is of shape [Batch, 2]
mask = target[...,1]
mask_loss = torch.nn.functional.mse_loss(input[...,0], target[...,0])
mu_loss = torch.nn.functional.mse_loss(mask*input[...,1], mask*target[...,1])
return (mask_loss + mu_loss) / 2

如果目标的掩码为1,则这种损失将仅关注第二个值(μ(。否则,它只会尝试为正确的掩码进行优化。

要编码为这种格式,您将使用:

def encode(tensor):
n_values = 25
if tensor.sum() == 0:
return torch.tensor([0,0])
return torch.argmax(tensor) / (n_values-1)

和解码:

def decode(tensor):
n_values = 25
# Parse values
mask, value = tensor
mask = torch.round(mask)
value = torch.round((n_values-1)*value)
zero = torch.zeros(n_values)
if mask == 0:
return zero
# Create gaussian around value
std = 2
n = torch.arange(n_values) - value
sig = 2*std**2
gauss = torch.exp(-n**2 / sig2)
# Only return 9 values from the gaussian
start_ix = max(value-6, 0)
end_ix = min(value+7,n_values)
zero[start_ix:end_ix] = gauss[start_ix:end_ix]
return zero

最新更新