在组上展开().mean()的性能调整



具有以下用户事件DF:

id                timestamp
0    1  2021-11-23 11:01:00.000
1    1  2021-11-23 11:02:00.000
2    1  2021-11-23 11:10:00.000
3    1  2021-11-23 11:11:00.000
4    1  2021-11-23 11:22:00.000
5    1  2021-11-23 11:40:00.000
6    1  2021-11-23 11:41:00.000
7    1  2021-11-23 11:42:00.000
8    1  2021-11-23 11:43:00.000
9    1  2021-11-23 11:44:00.000
10   2  2021-11-23 11:01:00.000
11   2  2021-11-23 11:02:00.000
12   2  2021-11-23 11:10:00.000
13   2  2021-11-23 11:11:00.000
14   2  2021-11-23 11:22:00.000
15   2  2021-11-23 11:40:00.000
16   2  2021-11-23 11:41:00.000
17   2  2021-11-23 11:42:00.000
18   2  2021-11-23 11:43:00.000
19   2  2021-11-23 11:44:00.000

我计算每行的平均会话时间如下:

  1. 每个会话是一系列间隔不到5分钟的事件
  2. 我计算会话中第一个事件与当前事件之间的秒数
  3. 然后,我为每个用户计算expanding((.mean((

这是我的代码:

def average_session_time(**kwargs):
df = kwargs['df'].copy()
df['timestamp'] = pd.to_datetime(df.timestamp)
df['session_grp'] = df.groupby('id').apply(
lambda x: (x.groupby([pd.Grouper(key="timestamp", freq='5min', origin='start')])).ngroup()).reset_index(
drop=True).values.reshape(-1)
# Getting relevant 5min groups
ng = df.groupby(['id', 'session_grp'])
df['fts'] = ng['timestamp'].transform('first')
df['delta'] = df['timestamp'].sub(df['fts']).dt.total_seconds()
return df.groupby('id')['delta'].expanding().mean().reset_index(drop=True)

输出为:

0      0.000000
1     30.000000
2     20.000000
3     15.000000
4     12.000000
5     10.000000
6      8.571429
7     15.000000
8     26.666667
9     42.000000
10     0.000000
11    30.000000
12    20.000000
13    15.000000
14    12.000000
15    10.000000
16     8.571429
17    15.000000
18    26.666667
19    42.000000
Name: delta, dtype: float64

代码运行良好,但当它在大型数据集上运行时,性能会受到影响,并且需要很长时间进行计算。我试着调整代码,但无法获得更多的性能。如何以不同的方式编写此函数以提高性能?

这是一个带有运行代码的Colab。

一个非常快速的解决方案是将NumpyNumba组合在一起,以您的方式对连续行进行分组。

首先,列需要转换为本机Numpy类型,因为CPython对象的计算速度非常慢(而且占用更多内存(。你可以用

ids = df['Id'].values.astype('S32')
timestamps = df['timestamp'].values.astype('datetime64[ns]')

这假设ID最多由32个ASCII字符组成。如果ID可以包含unicode字符,则可以使用'U32'(稍微慢一点(。您也可以使用np.unicode_让Numpy为您查找绑定。然而,这要慢得多(因为Numpy需要解析两次所有字符串(。

一旦转换为datetime64[ns],时间戳就可以转换为64位整数,以便Numba进行非常快速的计算。

然后,我们的想法是将字符串ID转换为基本整数,因为处理字符串非常缓慢。您可以通过搜索不同的相邻字符串来定位具有相同ID的块:

ids_int = np.insert(np.cumsum(ids[1:] != ids[:-1], dtype=np.int64), 0, 0)

请注意,在所提供的数据集中,没有一组行与具有不同ID的另一行共享相同ID。如果这个假设并不总是正确的,您可以使用np.argsort(ids, kind='stable')对输入字符串(ids(进行排序,应用此解决方案,然后根据np.argsort的输出对结果进行重新排序。请注意,对字符串进行排序有点慢,但仍然比问题中提供的解决方案的计算时间快得多(在我的机器上大约为100-200ms(。

最后,您可以使用基本循环使用Numba计算结果。


完整解决方案

这是生成的代码:

import pandas as pd
import numpy as np
import numba as nb
@nb.njit('float64[:](int64[::1], int64[::1])')
def compute_result(ids, timestamps):
n = len(ids)
result = np.empty(n, dtype=np.float64)
if n == 0:
return result
id_group_first_timestamp = timestamps[0]
session_group_first_timestamp = timestamps[0]
id_group_count = 1
id_group_delta_sum = 0.0
last_session_group = 0
result[0] = 0
delay = np.int64(300e9) # 5 min (in ns)
for i in range(1, n):
# If there is a new group of IDs
if ids[i-1] != ids[i]:
id_group_first_timestamp = timestamps[i]
id_group_delta_sum = 0.0
id_group_count = 1
last_session_group = 0
session_group_first_timestamp = timestamps[i]
else:
id_group_count += 1
session_group = (timestamps[i] - id_group_first_timestamp) // delay
# If there is a new session group
if session_group != last_session_group:
session_group_first_timestamp = timestamps[i]
delta = (timestamps[i] - session_group_first_timestamp) * 1e-9
id_group_delta_sum += delta
result[i] = id_group_delta_sum / id_group_count
last_session_group = session_group
return result
def fast_average_session_time(df):
ids = df['Id'].values.astype('S32')
timestamps = df['timestamp'].values.astype('datetime64[ns]').astype(np.int64)
ids_int = np.insert(np.cumsum(ids[1:] != ids[:-1], dtype=np.int64), 0, 0)
return compute_result(ids_int, timestamps)

请注意,输出是Numpy数组,而不是数据帧,但您可以使用pd.DataFrame(result)从Numpy数组轻松构建数据帧。


基准

以下是我的机器上的最终性能:

Initial solution:            12_537 ms   (   x1.0)
Scott Boston's solution:      7_431 ms   (   x1.7)
This solution:                   64 ms   ( x195.9)
Time taken by 
compute_result only:              2.5 ms

因此,此解决方案的速度几乎是初始解决方案的200倍。

请注意,大约85%的时间用于unicode/datetime字符串解析,这几乎无法优化。事实上,这种处理很慢,因为在现代处理器上处理短unicode字符串本身就很昂贵,而且CPython对象引入了大量开销(例如引用计数(。此外,由于CPython GIL和进程间通信缓慢,该处理无法有效地并行化。因此,这个代码几乎是最优的(只要您使用CPython(。

我想我可以通过在pd.to_datetime中使用format并在groupby中使用as_index参数而不是调用rest_index:来将您的时间减半

def average_session_time(**kwargs):
df = kwargs['df'].copy()
df['timestamp'] = pd.to_datetime(df.timestamp, format='%Y-%m-%d %H:%M:%S.%f')
grp_id = df.groupby('Id', as_index=False)
df['session_grp'] = grp_id.apply(
lambda x: (x.groupby([pd.Grouper(key="timestamp", freq='5min', origin='start')])).ngroup()).values.reshape(-1)
# Getting relevant 5min groups
ng = df.groupby(['Id', 'session_grp'])
df['fts'] = ng['timestamp'].transform('first')
df['delta'] = df['timestamp'].sub(df['fts']).dt.total_seconds()
return grp_id['delta'].expanding().mean().reset_index(level=0, drop=True)

原始时间:

40.228641986846924

新定时:

16.08320665359497

最新更新