是否有一些有效的方法来重写以下代码以避免安装&导入pandas
并使用torch
/numpy
?我习惯于使用pandas
,所以我这样写,但我正在尝试学习numpy
和torch
,所以我正在寻找不使用pandas
的替代解决方案。
bins = torch.LongTensor(3072).random_(0, 35)
weights = torch.rand((3072))
df = pd.DataFrame({'weights': weights.numpy(), 'bins': bins.numpy()})
bins_sum = df.groupby('bins').weights.sum().values
那么,基本上:在不使用pandas
的情况下,如何获得由bins
分组的weights
的和?
您可以通过torch.unique
计算bins
的唯一元素(分组依据的值(,然后使用索引掩码访问weights
:中的相应元素
unique = torch.unique(bins)
result = torch.zeros(unique.size(), dtype=weights.dtype)
for i, val in enumerate(unique):
result[i] += weights[bins == val].sum()
numpy
具有类似于pandas.isin
的isin
Pandasgroupby
,选择数据(row
(并在group
上应用函数。
def groupby(data, bin_data, grouper, agg):
'''
data: numpy array
bin_data: bin's data
grouper: callable, which give returns a list of array values.
agg: callable, to be applied on group
'''
res = {}
for key,arr in grouper(data, bin_data):
res.update({key, agg(arr)})
return res
# Find the indices where `bins == b` and then use them to select the `arry` values
bin_grouper = lambda arry, bvalue: [(b, arry[np.isin(bvalue, b)]) for b in bvalue]
# Compute result
gdata = groupby(weights.numpy(), bins.numpy(), bin_grouper, np.sum)