如何有效地均值规范化列



我有一个大数据集(~4M行~3K列),我目前使用Python/PySpark中的以下代码对每个列进行均值规范化:

import pyspark.sql.functions as f
means_pd = df.select(*[f.mean(c).alias(c) for c in df.columns]).toPandas()
diffs = df
for c in df.columns:
mean = means_pd.loc[0,c]
diffs = diffs.withColumn(c, f.col(c) - f.lit(mean))

这是相当慢的,特别是循环遍历列的步骤。必须有更好的方法来做到这一点,因为有像MinMaxScalar这样的函数包含这样的步骤,但不会永远使用。我怎样才能加快速度?

你可以计算一个窗口的平均值,然后减去结果。

df.select([(F.col(c) - F.mean(c).over(W.orderBy())).alias(c) for c in df.columns])

这样可以避免循环(3kwithColumn),并且完全在Spark中完成,而不需要Pandas。

测试:

from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
[(1, 1, 0),
(2, 3, 3),
(3, 5, 6)],
['c1', 'c2', 'c3'])
df_diffs = df.select([(F.col(c) - F.mean(c).over(W.orderBy())).alias(c) for c in df.columns])
df_diffs.show()
# +----+----+----+
# |  c1|  c2|  c3|
# +----+----+----+
# |-1.0|-2.0|-3.0|
# | 0.0| 0.0| 0.0|
# | 1.0| 2.0| 3.0|
# +----+----+----+

相关内容

  • 没有找到相关文章

最新更新