我需要使用 Spark Scala 在数据帧上实现以下场景:
Scenarios-1: If the "KEY" exist one time, take the "TYPE_VAL" as is .
Eg: KEY=66 exist once so take the TYPE_VAL=100
Scenarios-2: If the "KEY" exist more than one time, Check for the same TYPE_VAL, if it is same, then take TYPE_VAL once .
Eg: for KEY=68,so TYPE_VAL=23
Scenarios-3: If the "KEY" exist more than one time, Check for the same TYPE_VAL and subtract the other TYPE_VAL.
Eg: for KEY=67 , TYPE_VAL=10 exists twice,so subtract 2 & 4 from 10, finally TYPE_VAL=4
我尝试对同一键使用分组依据,但无法导出所有场景
//Sample Input Values
val values = List(List("66","100") ,
List("67","10") , List("67","10"),List("67","2"),List("67","4")
List("68","23"),List("68","23")).map(x =>(x(0), x(1)))
import spark.implicits._
//created a dataframe
val df1 = values.toDF("KEY","TYPE_VAL")
df1.show(false)
------------------------
KEY |TYPE_VAL |
------------------------
66 |100 |
67 |10 |
67 |10 |
67 |2 |
67 |4 |
68 |23 |
68 |23 |
-------------------------
预期输出 :
df2.show(false)
------------------------
KEY |TYPE_VAL |
------------------------
66 |100 | -------> [single row ,so 100]
67 |4 | -------> [four rows,out of which two are same & rest are diffrent, so (10 - 2 - 4) = 4 ]
68 |23 | -------> [two rows with same values, so 23]
-------------------------
如果可以假设每个键的记录数不能太大(即最多 ~数千?(,则可以在分组后使用 collect_list
将所有匹配项放入数组中,然后使用 UDF 根据该数组计算结果:
import org.apache.spark.sql.functions._
import spark.implicits._
// create the sample data:
val df1 = List(
(66, 100),
(67, 10),
(67, 10),
(67, 2),
(67, 4),
(68, 23),
(68, 23)
).toDF("KEY", "TYPE_VAL")
// define a UDF that computes the result per scenario for a given Seq[Int].
// This is just one possible implementation, simpler ones probably exist...
val computeTypeVal = udf { (vals: Seq[Int]) =>
vals.groupBy(identity).values.toList.sortBy(-_.size).flatten match {
case a :: Nil => a
case a :: b :: tail if a == b => a - tail.filterNot(_ == a).sum
case _ => 0 // or whatever else should be done for other cases
}
}
// group by key, use functions.collect_list to collect all value per key and apply UDF
df1.groupBy($"KEY")
.agg(collect_list($"TYPE_VAL") as "VALS")
.select($"KEY", computeTypeVal($"VALS") as "TYPE_VAL")
.sort($"KEY")
.show()
增强用户 Tzach Zohar 共享的解决方案,以处理输入列是否具有不同的数据类型,如 Int、Double、null
val df1 = List(
(66, Some("100")),
(67, Some("10.4")),
(67, Some("10.4")),
(67, Some("2")),
(67, Some("4")),
(68, Some("23")),
(68, Some("23")),
(99, None),
(999,Some(""))
).toDF("KEY", "TYPE_VAL")
df1.show()
+---+--------+
|KEY|TYPE_VAL|
+---+--------+
| 66| 100|
| 67| 10.4|
| 67| 10.4|
| 67| 2|
| 67| 4|
| 68| 23|
| 68| 23|
| 99| null|
|999| |
+---+--------+
所以增强的udf如下:
val computeTypeVal = udf { (vals: Seq[String]) =>
vals.groupBy(identity).values.toList.sortBy(-_.size).flatten match {
case a :: Nil => if (a == "") None else Some(a.toDouble)
case a :: b :: tail if a == b => Some(a.toDouble - tail.map(_.toDouble).filterNot(_ == a.toDouble).sum)
case _ => Some(0.00) // or whatever else should be done for other cases
}
}
df1.groupBy($"KEY").agg(collect_list($"TYPE_VAL") as "VALS").select($"KEY", computeTypeVal($"VALS") as "TYPE_VAL").show()
+---+--------+
|KEY|TYPE_VAL|
+---+--------+
| 68| 23.0|
|999| null|
| 99| 0.0|
| 66| 100.0|
| 67| 4.4|
+---+--------+