我正在尝试自己创建ASR模型,并学习如何使用CTC损失。
我测试并看到:
ctc_loss = nn.CTCLoss(blank=95)
output: tensor([[63, 8, 1, 38, 29, 14, 41, 71, 14, 29, 45, 41, 3]]): torch.Size([1, 13]); output_size: tensor([13])
input1: torch.Size([167, 1, 96]); input1_size: tensor([167])
在将argmax
应用于该输入(=音素的预测(之后,
torch.argmax(input1, dim=2)
我得到一系列符号:
tensor([[63, 63, 63, 63, 63, 63, 95, 95, 63, 63, 95, 95, 8, 8, 8, 95, 8, 95,
8, 8, 95, 95, 95, 1, 1, 95, 1, 95, 1, 1, 95, 95, 38, 95, 95, 38,
38, 38, 38, 38, 29, 29, 29, 29, 29, 29, 29, 95, 29, 29, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 14, 95, 14, 95, 95, 95, 95, 14, 95, 14, 41, 41,
41, 95, 41, 41, 41, 41, 41, 41, 71, 71, 71, 95, 71, 71, 71, 71, 71, 95,
95, 14, 14, 95, 14, 14, 95, 14, 14, 95, 29, 29, 95, 29, 29, 29, 29, 29,
29, 29, 45, 95, 95, 45, 45, 95, 45, 45, 45, 45, 41, 95, 41, 41, 95, 95,
95, 41, 41, 41, 3, 3, 3, 3, 3, 95, 3, 3, 3, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95]])
以及以下损失值。
ctc_loss(input1, output, input_size, output_size)
# Returns 222.8446
使用不同的输入:
input2: torch.Size([167, 1, 96]) input2_size: tensor([167])
torch.argmax(input2, dim=2)
预测只是一系列空白符号。
tensor([[95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95]])
然而,具有相同期望输出的损耗值要低得多。
ctc_loss(input2, output, input_size, output_size)
# Returns 3.7955
我不知道为什么input1
比input2
好,但input1
的损失比input2
高?有人能解释一下吗?
CTC损失不影响argmax预测,而是影响整个输出分布。CTC损失是产生期望输出的所有可能输出序列的负对数似然的总和。输出符号可能与空白符号交错,这留下了指数级的许多可能性。理论上,正确输出的负对数似然的和可能较低,并且最可能的序列仍然是全空的。
在实践中,这种情况非常罕见,所以我想其他地方可能有问题。PyTorch中实现的CTCLoss
需要对数概率作为输入,例如通过应用log_softmax
函数。不同种类的输入可能会导致奇怪的结果,比如你观察到的结果。