如何在apache spark中划分数值列的范围,并为每个范围分配标签



我有以下sparkdataframe:

id weekly_sale
1    40000
2    120000
3    135000
4    211000
5    215000
6    331000
7    337000

我需要查看weekly_sale列中的项目落在以下哪个间隔内:

under 100000
between 100000 and 200000
between 200000 and 300000
more than 300000

所以我想要的输出是:

id weekly_sale  label
1    40000       under 100000    
2    120000      between 100000 and 200000
3    135000      between 100000 and 200000
4    211000      between 200000 and 300000
5    215000      between 200000 and 300000
6    331000      more than 300000
7    337000      more than 300000

任何pyspark, spark。sql和Hive上下文的实现将帮助我。

假设范围和标签定义如下:

splits = [float("-inf"), 100000.0, 200000.0, 300000.0, float("inf")]
labels = [
    "under 100000", "between 100000 and 200000", 
    "between 200000 and 300000", "more than 300000"]
df = sc.parallelize([
    (1, 40000.0), (2, 120000.0), (3, 135000.0),
    (4, 211000.0), (5, 215000.0), (6, 331000.0),
    (7, 337000.0)
]).toDF(["id", "weekly_sale"])

一个可能的方法是使用Bucketizer:

from pyspark.ml.feature import Bucketizer
from pyspark.sql.functions import array, col, lit
bucketizer = Bucketizer(
    splits=splits, inputCol="weekly_sale", outputCol="split"
)
with_split = bucketizer.transform(df)

并稍后附加标签:

label_array = array(*(lit(label) for label in labels))
with_split.withColumn(
    "label", label_array.getItem(col("split").cast("integer"))
).show(10, False)
## +---+-----------+-----+-------------------------+
## |id |weekly_sale|split|label                    |
## +---+-----------+-----+-------------------------+
## |1  |40000.0    |0.0  |under 100000             |
## |2  |120000.0   |1.0  |between 100000 and 200000|
## |3  |135000.0   |1.0  |between 100000 and 200000|
## |4  |211000.0   |2.0  |between 200000 and 300000|
## |5  |215000.0   |2.0  |between 200000 and 300000|
## |6  |331000.0   |3.0  |more than 300000         |
## |7  |337000.0   |3.0  |more than 300000         |
## +---+-----------+-----+-------------------------+

当然有不同的方法可以达到相同的目标。例如,您可以创建一个查找表:

from toolz import sliding_window
from pyspark.sql.functions import broadcast
mapping = [
    (lower, upper, label) for ((lower, upper), label)
    in zip(sliding_window(2, splits), labels)
]
lookup_df =sc.parallelize(mapping).toDF(["lower", "upper", "label"])
df.join(
    broadcast(lookup_df),
    (col("weekly_sale") >= col("lower")) & (col("weekly_sale") < col("upper"))
).drop("lower").drop("upper")

或生成查找表达式:

from functools import reduce
from pyspark.sql.functions import when
def in_range(c):
    def in_range_(acc, x):        
        lower, upper, label = x
        return when(
            (c >= lit(lower)) & (c < lit(upper)), lit(label)
        ).otherwise(acc)
    return in_range_
label = reduce(in_range(col("weekly_sale")), mapping, lit(None))
df.withColumn("label", label)

效率最低的方法是UDF。

相关内容

  • 没有找到相关文章

最新更新