我有一个整数类标签的字节张量,例如来自MNIST数据集。
1
7
5
[torch.ByteTensor of size 3]
如何用它来创建一个1热向量张量?
1 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 1 0 0 0
0 0 0 0 1 0 0 0 0 0
[torch.DoubleTensor of size 3x10]
我知道我可以用循环做到这一点,但我想知道是否有任何聪明的火炬索引,将得到它为我在单行
indices = torch.LongTensor{1,7,5}:view(-1,1)
one_hot = torch.zeros(3, 10)
one_hot:scatter(2, indices, 1)
您可以在torch/torch7 github自述文件(在主分支中)中找到scatter
的文档。
另一种方法是对单位矩阵中的行进行洗牌:
indicies = torch.LongTensor{1,7,5}
one_hot = torch.eye(10):index(1, indicies)
这不是我的主意,我在karpathy/char-rnn找到的。