在Torch中,我如何从整数标签列表中创建一个1-hot张量



我有一个整数类标签的字节张量,例如来自MNIST数据集。

 1
 7
 5
[torch.ByteTensor of size 3]

如何用它来创建一个1热向量张量?

 1  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  1  0  0  0
 0  0  0  0  1  0  0  0  0  0
[torch.DoubleTensor of size 3x10]

我知道我可以用循环做到这一点,但我想知道是否有任何聪明的火炬索引,将得到它为我在单行

indices = torch.LongTensor{1,7,5}:view(-1,1)
one_hot = torch.zeros(3, 10)
one_hot:scatter(2, indices, 1)

您可以在torch/torch7 github自述文件(在主分支中)中找到scatter的文档。

另一种方法是对单位矩阵中的行进行洗牌:

indicies = torch.LongTensor{1,7,5}
one_hot = torch.eye(10):index(1, indicies)

这不是我的主意,我在karpathy/char-rnn找到的。

最新更新