Spark SQL - 查找乘客访问过的国家/地区数量最多



我有一个数据框如下...

+-----------+--------+----+---+----------+
|passengerId|flightId|from| to|      date|
+-----------+--------+----+---+----------+
|       3173|      41|  fr| cn|2017-01-11|
|       3173|      48|  cn| at|2017-01-13|
|       3173|      57|  at| pk|2017-01-17|
|       3173|      71|  pk| il|2017-01-21|
|       3173|     118|  il| se|2017-02-12|
|       3173|     137|  se| iq|2017-02-18|
|       3173|     154|  iq| at|2017-02-24|
|       3173|     231|  at| ar|2017-03-22|
|       3173|     245|  ar| cl|2017-03-28|
|       3173|     270|  cl| sg|2017-04-08|
|       3173|     287|  sg| iq|2017-04-14|
|       3173|     308|  iq| nl|2017-04-21|
|       3173|     317|  nl| dk|2017-04-24|
|       3173|     336|  dk| se|2017-04-29|
|       3173|     463|  se| th|2017-06-14|
|       3173|     480|  th| th|2017-06-20|
|       3173|     650|  th| th|2017-08-21|
|       3173|     660|  th| nl|2017-08-26|
|       3173|     670|  nl| sg|2017-09-01|
|       3173|     695|  sg| ca|2017-09-10|
+-----------+--------+----+---+----------+

我想找到乘客去过的国家/地区数量最多,但不包括起始国家/地区。例如,如果乘客所在的国家是:在 -> pk -> il -> se -> iq -> at,正确答案将是 4 个国家。 输出应采用以下格式:

Passenger ID    Longest Run
3173                4
1234                n
…                   …

Spark 2.4以来: 你可以使用collect_listconcatarray_removearray_distinctsize火花函数的组合来做到这一点。

import org.apache.spark.sql.functions._
import spark.implicits._
val data = Seq(
(3173, 41, "fr", "cn", "2017-01-11"),
(3173, 48, "cn", "at", "2017-01-13"),
(3173, 57, "at", "pk", "2017-01-17"),
(3173, 71, "pk", "il", "2017-01-21"),
(3173, 118, "il", "se", "2017-02-12"),
(3173, 137, "se", "iq", "2017-02-18"),
(3173, 154, "iq", "at", "2017-02-24"),
(3173, 231, "at", "ar", "2017-03-22"),
(3173, 245, "ar", "cl", "2017-03-28"),
(3173, 270, "cl", "sg", "2017-04-08"),
(3173, 287, "sg", "iq", "2017-04-14"),
(3173, 308, "iq", "nl", "2017-04-21"),
(3173, 317, "nl", "dk", "2017-04-24"),
(3173, 336, "dk", "se", "2017-04-29"),
(3173, 463, "se", "th", "2017-06-14"),
(3173, 480, "th", "th", "2017-06-20"),
(3173, 650, "th", "th", "2017-08-21"),
(3173, 660, "th", "nl", "2017-08-26"),
(3173, 670, "nl", "sg", "2017-09-01"),
(3173, 695, "sg", "fr", "2017-09-10")
).toDF("passengerId", "flightId", "from", "to", "date")
// first we need to group by passenger to collect all his "from" and "to" countries
val dataWithCountries = data.groupBy("passengerId")
.agg(
// concat is for concatenate two lists of strings from columns "from" and "to"
concat(
// collect list gathers all values from the given column into array
collect_list(col("from")),
collect_list(col("to"))
).name("countries")
)

汇总后,我们将列出所有国家/地区的重复乘客列表。接下来,我们必须首先从不同的国家/地区列表中删除(请参阅array_remove函数)他的国家/地区(乘客from列的第一个值)(请参阅array_distinct),并使用size函数计算国家/地区:

val passengerLongestRuns = dataWithCountries.withColumn(
"longest_run",
size(array_remove(array_distinct(col("countries")), col("countries").getItem(0)))
)
passengerLongestRuns.show(false)

输出:

+-----------+-----------+
|passengerId|longest_run|
+-----------+-----------+
|3173       |12         |
+-----------+-----------+

对于Spark <2.4,您可以将removedistinct定义为用户定义的函数

def removeAllFirstOccurrences(list: Seq[String]): Seq[String] = list.tail.filter(_ != list.head)
val removeFirstCountry = spark.udf.register[Seq[String], Seq[String]]("remove_first_country", removeAllFirstOccurrences)
def distinct(list: Seq[String]): Seq[String] = list.distinct
val distinctArray = spark.udf.register[Seq[String], Seq[String]]("array_distinct", distinct)
val passengerLongestRuns = dataWithCountries.withColumn(
"longest_run",
size(
distinctArray(
removeFirstCountry(
col("countries")
)
)
)
)
passengerLongestRuns.show(false)

输出:

+-----------+---------+
|passengerId|countries|
+-----------+---------+
|3173       |12       |
+-----------+---------+

最新更新