我正在寻找Python/numpy/numba/C-extension实现的数据结构和算法,以提高我目前解决以下约简问题的方法的性能:
输入我有一个非常大的结构化(Numpy)数组,格式为' '
iarr = numpy.array(
[([entityId, subentityId], subentityValue),
...,
...
], dtype=[('e', '<2u4'), ('r', '<f4')])
- 有
m
实体(百万数量级)和n
子实体(<20)。 - 没有重复的实体/子实体组合
m
和n
是什么,目前还不知道。- 子实体的数量因实体而异,但主要是每个实体8或6个。
- 数组无序
预期输出
我需要找到每个entityId
的最大或最小subentityValue
。
我不需要保留subentityId
值的来源信息。
结果应该是这样的记录数组:
oarr = numpy.array(
[(entityId, subentityValue),
...,
...
], dtype=[('e', '<u4'), ('r', '<f4')])
- 结果数组不需要排序
- 数组可以为最大值或最小值创建,因此数组中的
entityId
s是唯一的。
同样,输出也可以是一个字典,其中entityId
s作为键,最大或最小subentityValues
作为值。
当前实现(慢!)
我最初使用Python, Numpy和Numba的方法是(在这里描述用于查找每个entitId
的最大subentityValue
):
初始化一个字典(
numba.typed.Dict
),其中键是唯一的entityId
s,并且保证初始值小于数组中的任何subentityValue
(例如-99999.9)。odict = numba.typed.Dict.empty(key_type=nb.int64, value_type=nb.float64) # types for compatability to Python's dict smallest_r = nb.float64(-99999.9) for entity_id in np.unique(iarr['e'].astype(np.int64)): odict[entity_id] = smallest_r
循环遍历输入数组中的记录,并将
dictionary[entityId]
的值与记录的entityValue
和进行比较a)如果
dictionary[entityId]
大于entityValue
,则不做任何操作,b)如果
dictionary[entityId]
小于entityValue
,则用entityValue
覆盖它。for i in numba.prange(iarr['e'].shape[0]): if odict[iarr['e'][i]] < iarr['r'][i]: odict[iarr['e'][i]] = iarr['r'][i]
返回
odict
字典作为结果。
这工作得很好,但这是目前系统中最大的瓶颈。
为了提高性能,我试图并行化这个(@numba.jit(..., parallel=True)
),却发现numba
的typed.Dict
不是线程安全的,在这种情况下给我不正确的结果。
我非常高兴完全放弃我的解决方案,以支持更好(更快)的东西。有什么建议吗?
按e的第一个元素分组源行,然后计算min和max,对于每一组来说,使用起来都比较方便Pandas代替Numpy。
从必要的导入开始:
import numpy as np
import pandas as pd
出于测试目的,我将源数组创建为:
iarr = np.array([
([10, 1], 10.5), ([10, 1], 9.5), ([10, 1], 10.0),
([10, 2], 9.1), ([10, 2], 9.2), ([10, 2], 9.4),
([10, 3], 7.5), ([10, 3], 9.7), ([10, 3], 8.0),
([20, 2], 7.3), ([20, 2], 7.1), ([20, 2], 8.0),
([20, 3], 7.3), ([20, 3], 9.7), ([20, 3], 8.0)],
dtype=[('e', '<u4', (2,)), ('r', '<f4')])
第一步是用:
创建一个pandasonic系列- r列的值,
- 索引(实际上是MultiIndex))创建自e列
执行该操作的代码为:
s = pd.Series(iarr['r'], index=pd.MultiIndex.from_arrays(iarr['e'].T))
然后,要获得结果,同时使用min和max作为DataFrame,运行:
result = s.groupby(level=0).agg(['min', 'max'])
索引(最左边的未命名列)包含entityId和"actual"列同时包含min和max。
结果是:
min max
10 7.5 10.5
20 7.1 9.7
如果需要,可以将其转换为Numpy数组:
oarr = np.core.records.fromarrays(
result.reset_index().values.T,
names='entityId, min, max', formats='u4, f4, f4')
我的代码应该比普通的python代码运行得快得多解决方案。