[py]Spark SQL:多列会话化



给定一个正长i和一个数据帧

+-----+--+--+                                                          
|group|n1|n2|                                                              
+-----+--+--+                                                              
|    1| 0| 0|                                                              
|    1| 1| 1|                                                              
|    1| 1| 5|                                                              
|    1| 2| 2|                                                              
|    1| 2| 6|                                                              
|    1| 3| 3|                                                              
|    1| 3| 7|                                                              
|    1| 4| 4|                                                              
|    1| 5| 1|                                                              
|    1| 5| 5|                                                              
+-----+--+--+

您将如何会话同一group中的行,以便对于会话中的每对连续行r1r2r2.n1>r1.n1r2.n2>r1.n2和 max(r2.n1-r1.n1r2.n2-r1.n2)

例如,给定数据帧且i=3 的结果为

+-----+--+--+-------+
|group|n1|n2|session|
+-----+--+--+-------+
|    1| 0| 0|      1|
|    1| 1| 1|      1|
|    1| 1| 5|      2|
|    1| 2| 2|      1|
|    1| 2| 6|      2|
|    1| 3| 3|      1|
|    1| 3| 7|      2|
|    1| 4| 4|      1|
|    1| 5| 1|      3|
|    1| 5| 5|      1|
+-----+--+--+-------+

任何帮助或提示将不胜感激。谢谢!

这看起来像您正在尝试用相同的数字标记图形的所有连接部分。一个好的解决方案是使用graphframes: https://graphframes.github.io/quick-start.html

从数据帧:

df = sc.parallelize([[1, 0, 0],[1, 1, 1],[1, 1, 5],[1, 2, 2],[1, 2, 6],
[1, 3, 3],[1, 3, 7],[1, 4, 4],[1, 5, 1],[1, 5, 5]]).toDF(["group","n1","n2"])

我们将创建一个顶点数据帧,其中包含唯一id的列表:

import pyspark.sql.functions as psf
v = df.select(psf.struct("n1", "n2").alias("id"), "group")
+-----+-----+
|   id|group|
+-----+-----+
|[0,0]|    1|
|[1,1]|    1|
|[1,5]|    1|
|[2,2]|    1|
|[2,6]|    1|
|[3,3]|    1|
|[3,7]|    1|
|[4,4]|    1|
|[5,1]|    1|
|[5,5]|    1|
+-----+-----+

以及根据您声明的布尔条件定义的边缘数据帧:

i = 3
e = df.alias("r1").join(
df.alias("r2"), 
(psf.col("r1.group") == psf.col("r2.group"))
& (psf.col("r1.n1") < psf.col("r2.n1"))
& (psf.col("r1.n2") < psf.col("r2.n2"))
& (psf.greatest(
psf.col("r2.n1") - psf.col("r1.n1"),
psf.col("r2.n2") - psf.col("r1.n2")) < i)
).select(psf.struct("r1.n1", "r1.n2").alias("src"), psf.struct("r2.n1", "r2.n2").alias("dst"))
+-----+-----+
|  src|  dst|
+-----+-----+
|[0,0]|[1,1]|
|[0,0]|[2,2]|
|[1,1]|[2,2]|
|[1,1]|[3,3]|
|[1,5]|[2,6]|
|[1,5]|[3,7]|
|[2,2]|[3,3]|
|[2,2]|[4,4]|
|[2,6]|[3,7]|
|[3,3]|[4,4]|
|[3,3]|[5,5]|
|[4,4]|[5,5]|
+-----+-----+

现在要查找所有连接的组件:

from graphframes import *
g = GraphFrame(v, e)
res = g.connectedComponents()
+-----+-----+------------+
|   id|group|   component|
+-----+-----+------------+
|[0,0]|    1|309237645312|
|[1,1]|    1|309237645312|
|[1,5]|    1| 85899345920|
|[2,2]|    1|309237645312|
|[2,6]|    1| 85899345920|
|[3,3]|    1|309237645312|
|[3,7]|    1| 85899345920|
|[4,4]|    1|309237645312|
|[5,1]|    1|292057776128|
|[5,5]|    1|309237645312|
+-----+-----+------------+

相关内容

  • 没有找到相关文章

最新更新