如何正确使用CTC损失Seq2Seq



我正在尝试自己创建ASR模型,并学习如何使用CTC损失。

我测试并看到:

ctc_loss = nn.CTCLoss(blank=95)
output: tensor([[63,  8,  1, 38, 29, 14, 41, 71, 14, 29, 45, 41, 3]]): torch.Size([1, 13]); output_size: tensor([13])
input1: torch.Size([167, 1, 96]); input1_size: tensor([167])

在将argmax应用于该输入(=音素的预测(之后,

torch.argmax(input1, dim=2)

我得到一系列符号:

tensor([[63, 63, 63, 63, 63, 63, 95, 95, 63, 63, 95, 95,  8,  8,  8, 95,  8, 95,
8,  8, 95, 95, 95,  1,  1, 95,  1, 95,  1,  1, 95, 95, 38, 95, 95, 38,
38, 38, 38, 38, 29, 29, 29, 29, 29, 29, 29, 95, 29, 29, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 14, 95, 14, 95, 95, 95, 95, 14, 95, 14, 41, 41,
41, 95, 41, 41, 41, 41, 41, 41, 71, 71, 71, 95, 71, 71, 71, 71, 71, 95,
95, 14, 14, 95, 14, 14, 95, 14, 14, 95, 29, 29, 95, 29, 29, 29, 29, 29,
29, 29, 45, 95, 95, 45, 45, 95, 45, 45, 45, 45, 41, 95, 41, 41, 95, 95,
95, 41, 41, 41,  3,  3,  3,  3,  3, 95,  3,  3,  3, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95]])

以及以下损失值。

ctc_loss(input1, output, input_size, output_size)
# Returns 222.8446

使用不同的输入:

input2: torch.Size([167, 1, 96]) input2_size: tensor([167])
torch.argmax(input2, dim=2)

预测只是一系列空白符号。

tensor([[95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95]]) 

然而,具有相同期望输出的损耗值要低得多。

ctc_loss(input2, output, input_size, output_size)
# Returns 3.7955

我不知道为什么input1input2好,但input1的损失比input2高?有人能解释一下吗?

CTC损失不影响argmax预测,而是影响整个输出分布。CTC损失是产生期望输出的所有可能输出序列的负对数似然的总和。输出符号可能与空白符号交错,这留下了指数级的许多可能性。理论上,正确输出的负对数似然的和可能较低,并且最可能的序列仍然是全空的。

在实践中,这种情况非常罕见,所以我想其他地方可能有问题。PyTorch中实现的CTCLoss需要对数概率作为输入,例如通过应用log_softmax函数。不同种类的输入可能会导致奇怪的结果,比如你观察到的结果。

相关内容

  • 没有找到相关文章

最新更新