我想在窗口上进行计数。聚合的计数结果应存储在新列中:
输入数据框:
val df = Seq(("N1", "M1","1"),("N1", "M1","2"),("N1", "M2","3")).toDF("NetworkID", "Station","value")
+---------+-------+-----+
|NetworkID|Station|value|
+---------+-------+-----+
| N1| M1| 1|
| N1| M1| 2|
| N1| M2| 3|
+---------+-------+-----+
val w = Window.partitionBy(df("NetworkID"))
我到目前为止的结果:
df.withColumn("count", count("Station").over(w)).show()
+---------+-------+-----+-----+
|NetworkID|Station|value|count|
+---------+-------+-----+-----+
| N1| M2| 3| 3|
| N1| M1| 1| 3|
| N1| M1| 2| 3|
+---------+-------+-----+-----+
我想拥有的结果:
+---------+-------+-----+-----+
|NetworkID|Station|value|count|
+---------+-------+-----+-----+
| N1| M2| 3| 2|
| N1| M1| 1| 2|
| N1| M1| 2| 2|
+---------+-------+-----+-----+
因为NetworkID N1的站点计数等于2(M1和M2)。
我知道我可以通过创建一个新的数据框,选择2列NetworkID和Station来做到这一点,然后进行组并加入第一个。
,但是我在数据框架上的不同列上有很多汇总计数,我必须避免加入。
预先感谢
您还需要"站"列上的分区,因为您正在计算每个NetworkID的站点。
scala> val df = Seq(("N1", "M1","1"),("N1", "M1","2"),("N1", "M2","3"),("N2", "M1", "4"), ("N2", "M2", "2")).toDF("NetworkID", "Station", "value")
df: org.apache.spark.sql.DataFrame = [NetworkID: string, Station: string ... 1 more field]
scala> val w = Window.partitionBy("NetworkID", "Station")
w: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@5b481d77
scala> df.withColumn("count", count("Station").over(w)).show()
+---------+-------+-----+-----+
|NetworkID|Station|value|count|
+---------+-------+-----+-----+
| N2| M2| 2| 1|
| N1| M2| 3| 1|
| N2| M1| 4| 1|
| N1| M1| 1| 2|
| N1| M1| 2| 2|
+---------+-------+-----+-----+
您想要的是"站点"的独特计数。列,可以表示为countDistinct("Station")
而不是count("Station")
。不幸的是,它尚未支持(只有在我的火花中?)。
org.apache.spark.sql.AnalysisException: Distinct window functions are not supported
作为调整,您可以同时使用dense_rank
前向和向后使用。
df.withColumn("count", (dense_rank() over w.orderBy(asc("Station"))) + (dense_rank() over w.orderBy(desc("Station"))) - 1).show()
+---------+-------+-----+-----+
|NetworkID|Station|value|count|
+---------+-------+-----+-----+
| N1| M1| 2| 2|
| N1| M1| 1| 2|
| N1| M2| 3| 2|
+---------+-------+-----+-----+
我知道它很晚,但请尝试以下操作:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val result = df
.withColumn("dr", dense_rank().over(Window.partitionBy("NetworkID").orderBy("Station")))
.withColumn("count", max("dr").over(Window.partitionBy("NetworkID")))