通过在行中查找NumPy数组匹配项,有效地筛选DataFrame



给定

df = pd.DataFrame({'x': [np.array(['1', '2.3']), np.array(['30', '99'])]},
index=[pd.date_range('2020-01-01', '2020-01-02', freq='D')])

我想筛选np.array(['1', '2.3'])。我可以做

df[df['x'].apply(lambda x: np.array_equal(x, np.array(['1', '2.3'])))]

但是这是最快的方法吗?

EDIT:让我们假设numpy数组中的所有元素都是字符串,尽管这不是一个好的做法!

DataFrame的长度可以达到500k行,每个numpy数组中的值的数量可以达到10。

您可以依靠列表理解来获得性能:

df[np.array([np.array_equal(x,np.array([1, 2.3])) for x in df['x'].values])]

通过timeit的性能(在我目前使用4gb ram的系统上):

%timeit -n 2000 df[np.array([np.array_equal(x,np.array([1, 2.3])) for x in df['x'].values])]
#output:
425 µs ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
%timeit -n 2000 df[df['x'].apply(lambda x: np.array_equal(x, np.array([1, 2.3])))]
#output:
875 µs ± 28.6 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)

我的建议是执行以下操作:

import numpy as np
mat = np.stack([np.array(["a","b","c"]),np.array(["d","e","f"])])

事实上,这将是来自数据帧cols的实际数据。请确保这些是单个numpy数组。

然后执行:

matching_rows = (np.array(["a","b","c"]) == mat).all(axis=1)

它向您输出一个bool数组,指示匹配项的位置。所以你可以这样过滤你的行:

df[matching_rows]

最新更新