我有一个数据框如下...
+-----------+--------+----+---+----------+
|passengerId|flightId|from| to| date|
+-----------+--------+----+---+----------+
| 3173| 41| fr| cn|2017-01-11|
| 3173| 48| cn| at|2017-01-13|
| 3173| 57| at| pk|2017-01-17|
| 3173| 71| pk| il|2017-01-21|
| 3173| 118| il| se|2017-02-12|
| 3173| 137| se| iq|2017-02-18|
| 3173| 154| iq| at|2017-02-24|
| 3173| 231| at| ar|2017-03-22|
| 3173| 245| ar| cl|2017-03-28|
| 3173| 270| cl| sg|2017-04-08|
| 3173| 287| sg| iq|2017-04-14|
| 3173| 308| iq| nl|2017-04-21|
| 3173| 317| nl| dk|2017-04-24|
| 3173| 336| dk| se|2017-04-29|
| 3173| 463| se| th|2017-06-14|
| 3173| 480| th| th|2017-06-20|
| 3173| 650| th| th|2017-08-21|
| 3173| 660| th| nl|2017-08-26|
| 3173| 670| nl| sg|2017-09-01|
| 3173| 695| sg| ca|2017-09-10|
+-----------+--------+----+---+----------+
我想找到乘客去过的国家/地区数量最多,但不包括起始国家/地区。例如,如果乘客所在的国家是:在 -> pk -> il -> se -> iq -> at,正确答案将是 4 个国家。 输出应采用以下格式:
Passenger ID Longest Run
3173 4
1234 n
… …
自Spark 2.4以来: 你可以使用collect_list
、concat
、array_remove
、array_distinct
和size
火花函数的组合来做到这一点。
import org.apache.spark.sql.functions._
import spark.implicits._
val data = Seq(
(3173, 41, "fr", "cn", "2017-01-11"),
(3173, 48, "cn", "at", "2017-01-13"),
(3173, 57, "at", "pk", "2017-01-17"),
(3173, 71, "pk", "il", "2017-01-21"),
(3173, 118, "il", "se", "2017-02-12"),
(3173, 137, "se", "iq", "2017-02-18"),
(3173, 154, "iq", "at", "2017-02-24"),
(3173, 231, "at", "ar", "2017-03-22"),
(3173, 245, "ar", "cl", "2017-03-28"),
(3173, 270, "cl", "sg", "2017-04-08"),
(3173, 287, "sg", "iq", "2017-04-14"),
(3173, 308, "iq", "nl", "2017-04-21"),
(3173, 317, "nl", "dk", "2017-04-24"),
(3173, 336, "dk", "se", "2017-04-29"),
(3173, 463, "se", "th", "2017-06-14"),
(3173, 480, "th", "th", "2017-06-20"),
(3173, 650, "th", "th", "2017-08-21"),
(3173, 660, "th", "nl", "2017-08-26"),
(3173, 670, "nl", "sg", "2017-09-01"),
(3173, 695, "sg", "fr", "2017-09-10")
).toDF("passengerId", "flightId", "from", "to", "date")
// first we need to group by passenger to collect all his "from" and "to" countries
val dataWithCountries = data.groupBy("passengerId")
.agg(
// concat is for concatenate two lists of strings from columns "from" and "to"
concat(
// collect list gathers all values from the given column into array
collect_list(col("from")),
collect_list(col("to"))
).name("countries")
)
汇总后,我们将列出所有国家/地区的重复乘客列表。接下来,我们必须首先从不同的国家/地区列表中删除(请参阅array_remove
函数)他的国家/地区(乘客from
列的第一个值)(请参阅array_distinct
),并使用size
函数计算国家/地区:
val passengerLongestRuns = dataWithCountries.withColumn(
"longest_run",
size(array_remove(array_distinct(col("countries")), col("countries").getItem(0)))
)
passengerLongestRuns.show(false)
输出:
+-----------+-----------+
|passengerId|longest_run|
+-----------+-----------+
|3173 |12 |
+-----------+-----------+
对于Spark <2.4,您可以将remove
和distinct
定义为用户定义的函数:
def removeAllFirstOccurrences(list: Seq[String]): Seq[String] = list.tail.filter(_ != list.head)
val removeFirstCountry = spark.udf.register[Seq[String], Seq[String]]("remove_first_country", removeAllFirstOccurrences)
def distinct(list: Seq[String]): Seq[String] = list.distinct
val distinctArray = spark.udf.register[Seq[String], Seq[String]]("array_distinct", distinct)
val passengerLongestRuns = dataWithCountries.withColumn(
"longest_run",
size(
distinctArray(
removeFirstCountry(
col("countries")
)
)
)
)
passengerLongestRuns.show(false)
输出:
+-----------+---------+
|passengerId|countries|
+-----------+---------+
|3173 |12 |
+-----------+---------+