torch.nn.Softmax中dim参数的用途是什么



我不明白dim参数在torch.nn.Softmax中应用了什么。有一个警告告诉我使用它,我将其设置为1,但我不明白我正在设置什么。它在公式中的使用位置:

Softmax(xi​)=exp(xi)/∑j​exp(xj​)​

这里没有昏暗,那么它适用于什么呢?

torch.nn.Softmax上的Pytorch文档指出:dim(int(–将沿其计算Softmax的维度(因此沿dim的每个切片的总和将为1(。

例如,如果您有一个二维矩阵,您可以选择是将softmax应用于行还是列:

import torch 
import numpy as np
softmax0 = torch.nn.Softmax(dim=0) # Applies along columns
softmax1 = torch.nn.Softmax(dim=1) # Applies along rows 
v = np.array([[1,2,3],
[4,5,6]])
v =  torch.from_numpy(v).float()
softmax0(v)
# Returns
#[[0.0474, 0.0474, 0.0474],
# [0.9526, 0.9526, 0.9526]])

softmax1(v)
# Returns
#[[0.0900, 0.2447, 0.6652],
# [0.0900, 0.2447, 0.6652]]

请注意,对于softmax0,列是如何添加到1的,而对于softmax1,行是如何添加至1的。

最新更新