PySpark数据框中每一列乘以其他每一列的有效方法是什么,即不重复的组合



我试图计算两个不同对象之间的二阶相互作用,即"关键词",";所以我需要以有效的方式将PySpark数据框架中的每一列乘以其他每一列。例如,如果我有780个关键字,那么就会有超过30万个潜在的组合(没有重复)。所以我将多个关键字1 *关键字2,和关键字1 *关键字3,以此类推。我的数据帧大约有10,000行。此外,并不是所有的关键字都存在,所以我需要跟踪我的列名。

我的数据帧看起来像这样:

<表类> ID keywords_1 keywords_2 keywords_3 tbody><<tr>574086315505724437

有一种方法可以试试:

cols = df.columns[1:]
import itertools
out = (df.select("ID",*[(F.col(i[0])*F.col(i[1])).alias('*'.join(i)) 
for i in itertools.combinations(cols,2)]))

out.show()
+------+---------------------+---------------------+---------------------+
|    ID|keywords_1*keywords_2|keywords_1*keywords_3|keywords_2*keywords_3|
+------+---------------------+---------------------+---------------------+
|574086|                    3|                   15|                    5|
|505724|                   12|                   28|                   21|
+------+---------------------+---------------------+---------------------+

与@anky的答案使用itertools的方式类似,但这里我使用selectExpr是为了使语法稍微简单一些,利用了表达式与列名相同的事实:

import itertools
cols = df.columns[1:]
df2 = df.selectExpr(
'ID',
*[f'{pair} as `{pair}`'
for pair in [' * '.join(i) for i in itertools.combinations(cols, 2)]
]
)
df2.show()
+------+-----------------------+-----------------------+-----------------------+
|    ID|keywords_1 * keywords_2|keywords_1 * keywords_3|keywords_2 * keywords_3|
+------+-----------------------+-----------------------+-----------------------+
|574086|                    3.0|                   15.0|                    5.0|
|505724|                   12.0|                   28.0|                   21.0|
+------+-----------------------+-----------------------+-----------------------+

但是请注意,通常在列名中保留字符不是一个好主意。这就是为什么我要用反引号括起来。您可以考虑使用下划线(_)代替星号。

按另一列分组,可以执行

import itertools
cols = df.columns[1:]
df2 = df.groupBy('ID').agg(
*[F.expr(f'sum({pair}) as `{pair}`')
for pair in [' * '.join(i) for i in itertools.combinations(cols, 2)]
]
)

在测试了两种不同的PySpark解决方案之后,我还测试了一种使用pandas的方法,结果证明这种解决方案比任何一种PySpark解决方案都快得多。我正在计算相当多的列(~780)之间的相互作用,它产生超过300,000个组合。

df_second_order = df.groupby('ID').apply(
lambda x: pd.concat([x.mul(i[1], axis="index") for i in x.iteritems()],
axis=1,
keys=x.columns
)).reset_index()

该解决方案生成重复的关键字组合,但是过滤掉重复的组合很简单。我还尝试了列表推导(只有列的非重复组合)而不是生成器x.iteritems(),它要慢得多。

一个考虑因素是生成的数据帧可能相当大(24G;我还必须将数据类型更改为int8以最小化所需的内存)。可以简单地遍历一定数量的行,保存生成的数据帧,然后在开始新的迭代之前删除不需要的数据帧。使用这种方法,我能够在3小时30分钟内为将近100万行的数据帧生成超过780列的300,000个组合。

此解决方案基于https://stackoverflow.com/a/38970709/6287730

提供的解决方案