什么是Pytorch等同于Pandas的groupby.apply(list)?



我有以下pytorch张量long_format:

tensor([[ 1.,  1.],
[ 1.,  2.],
[ 1.,  3.],
[ 1.,  4.],
[ 0.,  5.],
[ 0.,  6.],
[ 0.,  7.],
[ 1.,  8.],
[ 0.,  9.],
[ 0., 10.]])

我想将第一列分组,并将第二列存储为张量。不能保证每个分组的结果大小相同。

[tensor([ 1., 2., 3., 4., 8.]),
tensor([ 5.,  6., 7., 9., 10.])]

是否有很好的方法来做这个使用纯Pytorch操作符?出于可追溯性的考虑,我希望避免使用for循环。

我尝试使用for循环和空张量的空列表,但这会导致不正确的跟踪(不同的输入值给出相同的结果)

n_groups = 2
inverted = [torch.empty([0]) for _ in range(n_groups)]
for index, value in long_format:
value = value.unsqueeze(dim=0)
index = index.int()
if type(inverted[index]) != torch.Tensor:
inverted[index] = value
else:
inverted[index] = torch.cat((inverted[index], value))

您可以使用以下代码:

import torch
x = torch.tensor([[ 1.,  1.],
[ 1.,  2.],
[ 1.,  3.],
[ 1.,  4.],
[ 0.,  5.],
[ 0.,  6.],
[ 0.,  7.],
[ 1.,  8.],
[ 0.,  9.],
[ 0., 10.]])
result =  [x[x[:,0]==i][:,1] for i in x[:,0].unique()]

[tensor([ 5.,  6.,  7.,  9., 10.]), tensor([1., 2., 3., 4., 8.])]

最新更新