我需要使用 Scala 计算 Spark 中 Double 数据类型分组数据集的准确中位数。
它与类似的查询不同:在 Spark SQL 中查找多个双精度数据类型列的中位数。这个问题是关于分组数据的查找数据,而另一个问题是关于在RDD级别上查找中位数。
这是我的示例数据
scala> sqlContext.sql("select * from test").show()
+---+---+
| id|num|
+---+---+
| A|0.0|
| A|1.0|
| A|1.0|
| A|1.0|
| A|0.0|
| A|1.0|
| B|0.0|
| B|1.0|
| B|1.0|
+---+---+
预期答案:
+--------+
| Median |
+--------+
| 1 |
| 1 |
+--------+
我尝试了以下选项,但没有运气:
1)Hive函数百分位数,它仅适用于BigInt。
2) Hive 函数percentile_approx
,但它没有按预期工作(返回 0.25 vs 1)。
scala> sqlContext.sql("select percentile_approx(num, 0.5) from test group by id").show()
+----+
| _c0|
+----+
|0.25|
|0.25|
+----+
最简单的方法(需要 Spark 2.0.1+ 而不是精确的中位数)
正如评论中提到的第一个问题在 Spark SQL 中查找双精度数据类型列的中位数,我们可以使用percentile_approx
来计算 Spark 2.0.1+ 的中位数。 要将其应用于 Apache Spark 中的分组数据,查询如下所示:
val df = Seq(("A", 0.0), ("A", 0.0), ("A", 1.0), ("A", 1.0), ("A", 1.0), ("A", 1.0), ("B", 0.0), ("B", 1.0), ("B", 1.0)).toDF("id", "num")
df.createOrReplaceTempView("df")
spark.sql("select id, percentile_approx(num, 0.5) as median from df group by id order by id").show()
输出为:
+---+------+
| id|median|
+---+------+
| A| 1.0|
| B| 1.0|
+---+------+
话虽如此,这是一个近似值(而不是每个问题的确切中位数)。
计算分组数据的精确中位数
有多种方法,所以我相信SO中的其他人可以提供更好或更有效的例子。 但这里有一个代码片段,用于计算 Spark 中分组数据的中位数(在 Spark 1.6 和 Spark 2.1 中验证):
import org.apache.spark.SparkContext._
val rdd: RDD[(String, Double)] = sc.parallelize(Seq(("A", 1.0), ("A", 0.0), ("A", 1.0), ("A", 1.0), ("A", 0.0), ("A", 1.0), ("B", 0.0), ("B", 1.0), ("B", 1.0)))
// Scala median function
def median(inputList: List[Double]): Double = {
val count = inputList.size
if (count % 2 == 0) {
val l = count / 2 - 1
val r = l + 1
(inputList(l) + inputList(r)).toDouble / 2
} else
inputList(count / 2).toDouble
}
// Sort the values
val setRDD = rdd.groupByKey()
val sortedListRDD = setRDD.mapValues(_.toList.sorted)
// Output DataFrame of id and median
sortedListRDD.map(m => {
(m._1, median(m._2))
}).toDF("id", "median_of_num").show()
输出为:
+---+-------------+
| id|median_of_num|
+---+-------------+
| A| 1.0|
| B| 1.0|
+---+-------------+
我应该指出一些警告,因为这可能不是最有效的实现:
- 它当前使用的是性能不是很高的
groupByKey
。 您可能希望将其更改为reduceByKey
(有关详细信息,请参阅避免 GroupByKey) - 使用 Scala 函数计算
median
。
此方法应该适用于少量数据,但如果每个键有数百万行,建议使用 Spark 2.0.1+ 并使用percentile_approx
方法。
这是我在SPARK中PERCENTILE_COUNT函数的版本。这可用于查找数据帧中分组数据的中位数值。希望它可以帮助某人。请随时提供您的建议以改进解决方案。
val PERCENTILEFLOOR = udf((maxrank: Integer, percentile: Double) => scala.math.floor(1 + (percentile * (maxrank - 1))))
val PERCENTILECEIL = udf((maxrank: Integer, percentile: Double) => scala.math.ceil(1 + (percentile * (maxrank - 1))))
val PERCENTILECALC = udf((maxrank: Integer, percentile: Double, floorVal: Double, ceilVal: Double, floorNum: Double, ceilNum: Double)
=> {
if (ceilNum == floorNum) {
floorVal
} else {
val RN = (1 + (percentile * (maxrank - 1)))
((ceilNum - RN) * floorVal) + ((RN - floorNum) * ceilVal)
} })
/** * The result of PERCENTILE_CONT is computed by linear interpolation between values after ordering them. * Using the percentile value (P) and the number of rows (N) in the aggregation group, * we compute the row number we are interested in after ordering the rows with respect to the sort specification. * This row number (RN) is computed according to the formula RN = (1+ (P*(N-1)). * The final result of the aggregate function is computed by linear interpolation between the values from rows at row numbers
* CRN = CEILING(RN) and FRN = FLOOR(RN). * * The final result will be: * * If (CRN = FRN = RN) then the result is * (value of expression from row at RN) * Otherwise the result is * (CRN - RN) * (value of expression for row at FRN) + * (RN - FRN) * (value of expression for row at CRN) * * Parameter details * * @inputDF - Dataframe for computation * @medianCol - Column for which percentile to be calculated * @grouplist - Group list for dataframe before sorting * @percentile - numeric value between 0 and 1 to express the percentile to be calculated * */
def percentile_count(inputDF: DataFrame, medianCol: String, groupList: List[String], percentile: Double): DataFrame = {
val orderList = List(medianCol)
val wSpec3 = Window.partitionBy(groupList.head, groupList.tail: _*).orderBy(orderList.head, orderList.tail: _*)
// Group, sort and rank the DF
val rankedDF = inputDF.withColumn("rank", row_number().over(wSpec3))
// Find the maximum for each group
val groupedMaxDF = rankedDF.groupBy(groupList.head, groupList.tail: _*).agg(max("rank").as("maxval"))
// CRN calculation
val ceilNumDF = groupedMaxDF.withColumn("rank", PERCENTILECEIL(groupedMaxDF("maxval"), lit(percentile))).drop("maxval")
// FRN calculation
val floorNumDF = groupedMaxDF.withColumn("rank", PERCENTILEFLOOR(groupedMaxDF("maxval"), lit(percentile)))
val ntileGroup = "rank" :: groupList
//Get the values for the CRN and FRN
val floorDF = floorNumDF.join(rankedDF, ntileGroup).withColumnRenamed("rank", "floorNum").withColumnRenamed(medianCol, "floorVal")
val ceilDF = ceilNumDF.join(rankedDF, ntileGroup).withColumnRenamed("rank", "ceilNum").withColumnRenamed(medianCol, "ceilVal")
//Get both the values for CRN and FRN in same row
val resultDF = floorDF.join(ceilDF, groupList)
val finalList = "median_" + medianCol :: groupList
// Calculate the median using the UDF PERCENTILECALC and returns the DF
resultDF.withColumn("median_" + medianCol, PERCENTILECALC(resultDF("maxval"), lit(percentile), resultDF("floorVal"), resultDF("ceilVal"), resultDF("floorNum"), resultDF("ceilNum"))).select(finalList.head, finalList.tail: _*)
}
您可以尝试此解决方案以获得精确的中位数。我在这里描述了Spark sql解决方案gist.github。 为了计算精确的中位数,我将 row_number() 和 count() 函数与窗口函数结合使用。
val data1 = Array( ("a", 0), ("a", 1), ("a", 1), ("a", 1), ("a", 0), ("a", 1))
val data2 = Array( ("b", 0), ("b", 1), ("b", 1))
val union = data1.union(data2)
val df = sc.parallelize(union).toDF("key", "val")
df.cache.createOrReplaceTempView("kvTable")
spark.sql("SET spark.sql.shuffle.partitions=2")
var ds = spark.sql("""
SELECT key, avg(val) as median
FROM ( SELECT key, val, rN, (CASE WHEN cN % 2 = 0 then (cN DIV 2) ELSE (cN DIV 2) + 1 end) as m1, (cN DIV 2) + 1 as m2
FROM (
SELECT key, val, row_number() OVER (PARTITION BY key ORDER BY val ) as rN, count(val) OVER (PARTITION BY key ) as cN
FROM kvTable
) s
) r
WHERE rN BETWEEN m1 and m2
GROUP BY key
""")
Spark 可以有效地执行和优化此查询,因为它重用了数据分区。
scala> ds.show
+---+------+
|key|median|
+---+------+
| a| 1.0|
| b| 1.0|
+---+------+
element_at
在Spark 2.4中添加了高阶函数。我们可以与 Window 函数一起使用,或者 groupBy 然后重新连接。
示例数据
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
case class Salary(depName: String, empNo: Long, salary: Long)
val empsalary = Seq(
Salary("sales", 1, 5000),
Salary("personnel", 2, 3900),
Salary("sales", 3, 4800),
Salary("sales", 4, 4800),
Salary("personnel", 5, 3500),
Salary("develop", 7, 4200),
Salary("develop", 8, 6000),
Salary("develop", 9, 4500),
Salary("develop", 10, 5200),
Salary("develop", 11, 5200)).toDS
带窗口功能
val byDepName = Window.partitionBy('depName).orderBy('salary)
val df = empsalary.withColumn(
"salaries", collect_list('salary) over byDepName).withColumn(
"median_salary", element_at('salaries, (size('salaries)/2 + 1).cast("int")))
df.show(false)
与组由然后重新加入
val dfMedian = empsalary.groupBy("depName").agg(
sort_array(collect_list('salary)).as("salaries")).select(
'depName,
element_at('salaries, (size('salaries)/2 + 1).cast("int")).as("median_salary"))
empsalary.join(dfMedian, "depName").show(false)
如果你不想使用 spark-sql(就像我一样),你可以使用cume_dist
函数。
请参阅以下示例:
import org.apache.spark.sql.{functions => F}
import org.apache.spark.sql.expressions.Window
val df = (1 to 10).toSeq.toDF
val win = Window.
partitionBy(F.col("value")).
orderBy(F.col("value")).
rangeBetween(Window.unboundedPreceding, Window.currentRow)
df.withColumn("c", F.cume_dist().over(win)).show
结果:
+-----+---+
|value| c|
+-----+---+
| 1|0.1|
| 2|0.2|
| 3|0.3|
| 4|0.4|
| 5|0.5|
| 6|0.6|
| 7|0.7|
| 8|0.8|
| 9|0.9|
| 10|1.0|
+-----+---+
中位数是df("c")
等于 0.5 的值。 我希望它有所帮助,埃利奥。
只是为了补充 Elior 的答案并响应 Erkan,每列的输出为 1.0 的原因是 partitionBy(F.col("value")) 将数据分区为每个分区的单行,这样当窗口计算cume_dist
时,它会对单个值执行此操作并得到 1.0。
从窗口操作中删除分区By(F.col("值"))会产生预期的分位数。
埃利奥的回答开始
如果你不想使用spark-sql(就像我一样),你可以使用cume_dist
函数。 请参阅以下示例:
import org.apache.spark.sql.{functions => F}
import org.apache.spark.sql.expressions.Window
val df = (1 to 10).toSeq.toDF
val win = Window.
partitionBy(F.col("value")). //Remove this line
orderBy(F.col("value")).
rangeBetween(Window.unboundedPreceding, Window.currentRow)
df.withColumn("c", F.cume_dist().over(win)).show
结果:
+-----+---+
|value| c|
+-----+---+
| 1|0.1|
| 2|0.2|
| 3|0.3|
| 4|0.4|
| 5|0.5|
| 6|0.6|
| 7|0.7|
| 8|0.8|
| 9|0.9|
| 10|1.0|
+-----+---+
中位数是df("c")
等于 0.5 的值。我希望它有所帮助,埃利奥。
埃利奥的回答结束
不带分区定义的窗口:
val win = Window.
orderBy(F.col("value")).
rangeBetween(Window.unboundedPreceding, Window.currentRow)
df.withColumn("c", F.cume_dist().over(win)).show