我正在尝试在 Spark 中从头开始编写 ML 算法,并且在设置每个分区的每个功能的直方图时遇到问题。
目标是有一些最终变量N
,并获取每个分区中每列的max
和min
。然后我想映射行以将它们存放在N
箱中,箱长度为(max - min)/N
.我已经尝试mapWithIndex
获取max
,但后来我不确定如何将其与map
函数联系起来并确保正确的max
与正确的分区连接。
尝试以下代码: 假设我们要为每个分区使用N=3
个箱,这是我的数据帧:
from pyspark.sql.window import Window
from pyspark.sql import functions as F
N = 3
values = [
(1, 5),
(2, 13),
(3, 25),
(4, 30),
(5, 38),
(6, 50),
(7, 11),
(8, 73),
(9, 48),
(10, 65),
(11, 55),
(12, 42)
]
columns = ['ID', 'Amount']
df=spark.createDataFrame(values, columns)
df.show()
数据帧如下所示:
+---+------+
| ID|Amount|
+---+------+
| 1| 5|
| 2| 13|
| 3| 25|
| 4| 30|
| 5| 38|
| 6| 50|
| 7| 11|
| 8| 73|
| 9| 48|
| 10| 65|
| 11| 55|
| 12| 42|
+---+------+
让我们将数据帧重新分区为 3 个分区,这样我们就不会有太多的分区:
df = df.repartition(3)
在此之后,我们首先获取每行的分区 ID:
df = df.withColumn('pid', F.spark_partition_id())
计算每个分区中的最大和最小Amount
,并使用它们来计算正确的bin_length
。
df = df.withColumn('max_a', F.max(col('Amount')).over(Window.partitionBy('pid')))
df = df.withColumn('min_a', F.min(col('Amount')).over(Window.partitionBy('pid')))
df = df.withColumn('bin_len', (df['max_a'] - df['min_a'])/N)
现在,我们可以计算每个分区中每行到第一行的距离,并使用它来计算存储桶编号。在这里,我假设存储桶编号从 1 开始。
df = df.withColumn('diff_a', F.col('Amount')-F.first('Amount').over(Window.partitionBy('pid').orderBy('Amount')))
df = df.withColumn('bucket', F.floor(F.col('diff_a')/F.col('bin_len')))
df = df.withColumn('bucket', F.when(col('bucket')==N, col('bucket')).otherwise(col('bucket')+1))
df.show()
最终输出为:
+---+------+---+-----+-----+------------------+------+------+
| ID|Amount|pid|max_a|min_a| bin_len|diff_a|bucket|
+---+------+---+-----+-----+------------------+------+------+
| 1| 5| 1| 73| 5|22.666666666666668| 0| 1|
| 2| 13| 1| 73| 5|22.666666666666668| 8| 1|
| 5| 38| 1| 73| 5|22.666666666666668| 33| 2|
| 8| 73| 1| 73| 5|22.666666666666668| 68| 3|
| 3| 25| 2| 65| 25|13.333333333333334| 0| 1|
| 4| 30| 2| 65| 25|13.333333333333334| 5| 1|
| 12| 42| 2| 65| 25|13.333333333333334| 17| 2|
| 10| 65| 2| 65| 25|13.333333333333334| 40| 3|
| 7| 11| 0| 55| 11|14.666666666666666| 0| 1|
| 9| 48| 0| 55| 11|14.666666666666666| 37| 3|
| 6| 50| 0| 55| 11|14.666666666666666| 39| 3|
| 11| 55| 0| 55| 11|14.666666666666666| 44| 3|
+---+------+---+-----+-----+------------------+------+------+
您可以看到,现在数据帧首先按pid
分组,然后按每个组中的Amount
排序。如果选中pid==1
组,则最小金额 = 5,最大金额 = 73,箱长度 = (73-5)/3 = 22.66666。最小 5 应降至存储桶 #1,最大 73 应降至存储桶 #3,编号 38(介于 27.666 和 50.33333 之间)应降至存储桶 #2。