滤波器然后计数许多不同的阈值



我想计算满足可以实现的超大数据帧条件的行数

df.filter(col("value") >= thresh).count()

我想知道范围[1, 10]中每个阈值的结果。枚举每个阈值,然后执行此操作将扫描数据帧10次。它很慢。

如果我只扫描df一次就可以实现它?

为每个阈值创建一个指标列,然后求和:

import random
import pyspark.sql.functions as F
from pyspark.sql import Row
df = spark.createDataFrame([Row(value=random.randint(0,10)) for _ in range(1_000_000)])
df.select([
(F.col("value") >= thresh)
.cast("int")
.alias(f"ind_{thresh}") 
for thresh in range(1,11)
]).groupBy().sum().show()
# +----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+
# |sum(ind_1)|sum(ind_2)|sum(ind_3)|sum(ind_4)|sum(ind_5)|sum(ind_6)|sum(ind_7)|sum(ind_8)|sum(ind_9)|sum(ind_10)|
# +----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+
# |    908971|    818171|    727240|    636334|    545463|    454279|    363143|    272460|    181729|      90965|
# +----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+

将条件聚合与when表达式一起使用就可以了。

这里有一个例子:

from pyspark.sql import functions as F
df = spark.createDataFrame([(1,), (2,), (3,), (4,), (4,), (6,), (7,)], ["value"])
count_expr = [
F.count(F.when(F.col("value") >= th, 1)).alias(f"gte_{th}")
for th in range(1, 11)
]
df.select(*count_expr).show()
#+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
#|gte_1|gte_2|gte_3|gte_4|gte_5|gte_6|gte_7|gte_8|gte_9|gte_10|
#+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
#|    7|    6|    5|    4|    2|    2|    1|    0|    0|     0|
#+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+

使用pyspark.sql.functions:中的用户定义函数udf

import pandas as pd
import numpy as np
df = pd.DataFrame(np.random.randint(0,100, size=(20)), columns=['val'])
thres =  [90, 80, 30]     # these are the thresholds
thres.sort(reverse=True)  # list needs to be sorted
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
spark = SparkSession.builder 
.master("local[2]") 
.appName("myApp") 
.getOrCreate()
sparkDF = spark.createDataFrame(df)
myUdf = udf(lambda x: 0 if x>thres[0] else 1 if x>thres[1] else 2 if  x>thres[2] else 3)
sparkDF = sparkDF.withColumn("rank", myUdf(sparkDF.val))
sparkDF.show()
# +---+----+                                                                      
# |val|rank|
# +---+----+
# | 28|   3|
# | 54|   2|
# | 19|   3|
# |  4|   3|
# | 74|   2|
# | 62|   2|
# | 95|   0|
# | 19|   3|
# | 55|   2|
# | 62|   2|
# | 33|   2|
# | 93|   0|
# | 81|   1|
# | 41|   2|
# | 80|   2|
# | 53|   2|
# | 14|   3|
# | 16|   3|
# | 30|   3|
# | 77|   2|
# +---+----+
sparkDF.groupby(['rank']).count().show()
# Out: 
# +----+-----+
# |rank|count|
# +----+-----+
# |   3|    7|
# |   0|    2|
# |   1|    1|
# |   2|   10|
# +----+-----+

如果一个值严格大于thres[i]但小于或等于thres[i-1],则该值的级别为i。这样可以最大限度地减少比较次数。

对于CCD_ 8,我们具有秩0->[max, 90[,1->[90, 80[,2->[80, 30[,3->[30, min]

最新更新