- 我有一个包含"年龄"列的表格。我想根据年龄将人们分组,例如:[0,5),[5,10),[10,15),....
- 然后我将对每个组进行相同的计算并比较结果。
- 这样做的目的是查看年龄是否与其他变量相关。
- 请帮忙。
您可以通过
此公式将年龄分类range_start = age - (age % interval)
这是interval
5
为演示创建测试数据帧
val df = (100 to 400 by 7).map(id => (s"user$id", id % 60))
.toDF("name", "age")
df.show(false)
+-------+---+
|name |age|
+-------+---+
|user100|40 |
|user107|47 |
|user114|54 |
|user121|1 |
|user128|8 |
|user135|15 |
|user142|22 |
|user149|29 |
|user156|36 |
|user163|43 |
|user170|50 |
|user177|57 |
|user184|4 |
|user191|11 |
|user198|18 |
|user205|25 |
|user212|32 |
|user219|39 |
|user226|46 |
|user233|53 |
+-------+---+
only showing top 20 rows
按interval
分
import org.apache.spark.sql.functions._
val interval = 5
df.withColumn("range", $"age" - ($"age" % interval))
.withColumn("range", concat($"range", lit(" - "), $"range" + interval)) //optional one
.groupBy($"range")
.agg(collect_list($"name").as("names")) //change it to needed agg function or anything
.show(false)
+--------+------------------------------------+
|range |names |
+--------+------------------------------------+
|10 to 15|[user191, user254, user310, user373]|
|50 to 55|[user114, user170, user233, user352]|
|5 to 10 |[user128, user247, user366] |
|55 to 60|[user177, user296, user359] |
|45 to 50|[user107, user226, user289, user345]|
|35 to 40|[user156, user219, user275, user338]|
|25 to 30|[user149, user205, user268, user387]|
|30 to 35|[user212, user331, user394] |
|0 to 5 |[user121, user184, user240, user303]|
|20 to 25|[user142, user261, user324, user380]|
|15 to 20|[user135, user198, user317] |
|40 to 45|[user100, user163, user282] |
+--------+------------------------------------+
我们甚至可以使用相同的公式使用 UDF,但这可能会稍微慢一些。
演示:
示例 DF:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.ml.feature.Bucketizer
scala> val df = spark.range(20).withColumn("age", round(rand()*90).cast(IntegerType))
df: org.apache.spark.sql.DataFrame = [id: bigint, age: int]
scala> df.show
+---+---+
| id|age|
+---+---+
| 0| 58|
| 1| 57|
| 2| 43|
| 3| 62|
| 4| 18|
| 5| 70|
| 6| 26|
| 7| 54|
| 8| 70|
| 9| 42|
| 10| 38|
| 11| 79|
| 12| 77|
| 13| 14|
| 14| 87|
| 15| 28|
| 16| 15|
| 17| 59|
| 18| 81|
| 19| 25|
+---+---+
溶液:
scala> :paste
// Entering paste mode (ctrl-D to finish)
val splits = Range.Double(0,120,5).toArray
val bucketizer = new Bucketizer()
.setInputCol("age")
.setOutputCol("age_range_id")
.setSplits(splits)
val df2 = bucketizer.transform(df)
// Exiting paste mode, now interpreting.
splits: Array[Double] = Array(0.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0, 65.0, 70.0, 75.0, 80.0, 85.0, 90.0, 95.0, 100.0, 105.0, 110.0, 115.0)
bucketizer: org.apache.spark.ml.feature.Bucketizer = bucketizer_3c2040bf50c7
df2: org.apache.spark.sql.DataFrame = [id: bigint, age: int ... 1 more field]
scala> df2.groupBy("age_range_id").count().show
+------------+-----+
|age_range_id|count|
+------------+-----+
| 8.0| 2|
| 7.0| 1|
| 11.0| 3|
| 14.0| 2|
| 3.0| 2|
| 2.0| 1|
| 17.0| 1|
| 10.0| 1|
| 5.0| 3|
| 15.0| 2|
| 16.0| 1|
| 12.0| 1|
+------------+-----+
或者,您可以使用Spark SQL API:
df.createOrReplaceTempView("tab")
val query = """
with t as (select int(age/5) as age_id from tab)
select age_id, count(*) as count
from t
group by age_id
"""
spark.sql(query).show
结果:
scala> spark.sql(query).show
+------+-----+
|age_id|count|
+------+-----+
| 12| 1|
| 16| 1|
| 3| 2|
| 5| 3|
| 15| 2|
| 17| 1|
| 8| 2|
| 7| 1|
| 10| 1|
| 11| 3|
| 14| 2|
| 2| 1|
+------+-----+