我所指的代码:
predicted_index = torch.argmax(predictions[0, -1, :]).item()
这是tensor
而不是list
,主要区别在于:
tensor
有一个指定的dtype
(通常是PyTorch中的float32
(- 运行操作更快
你的预测是3D
张量,你正在使用它:
0
第8行- 最后一列(
-1
索引( - 来自第三维度的所有元素(
:
(
基本上,切片后会留下一个向量。
torch.argmax
返回最大元素所在的索引,例如:
torch.argmax(torch.tensor([-1, 0, 1.5, 1, 0])) # would return 2'
argmax
的代码在C++
中实现,并保持迄今为止找到的最大值的索引,并返回最后找到的索引(O(n)
复杂度(。
.item()
将tensor
更改为它的Python对应项(通常float
来自任何浮点,int
来自整数族类型等(