这个函数是如何通过切片遍历令牌的,你能解释一下它是如何选择列表中的元素的吗



我所指的代码:

predicted_index = torch.argmax(predictions[0, -1, :]).item()

这是tensor而不是list,主要区别在于:

  • tensor有一个指定的dtype(通常是PyTorch中的float32(
  • 运行操作更快

你的预测是3D张量,你正在使用它:

  • 0第8行
  • 最后一列(-1索引(
  • 来自第三维度的所有元素(:(

基本上,切片后会留下一个向量。

torch.argmax返回最大元素所在的索引,例如:

torch.argmax(torch.tensor([-1, 0, 1.5, 1, 0])) # would return 2'

argmax的代码在C++中实现,并保持迄今为止找到的最大值的索引,并返回最后找到的索引(O(n)复杂度(。

.item()tensor更改为它的Python对应项(通常float来自任何浮点,int来自整数族类型等(

相关内容

最新更新