我想计算满足可以实现的超大数据帧条件的行数
df.filter(col("value") >= thresh).count()
我想知道范围[1, 10]
中每个阈值的结果。枚举每个阈值,然后执行此操作将扫描数据帧10次。它很慢。
如果我只扫描df一次就可以实现它?
为每个阈值创建一个指标列,然后求和:
import random
import pyspark.sql.functions as F
from pyspark.sql import Row
df = spark.createDataFrame([Row(value=random.randint(0,10)) for _ in range(1_000_000)])
df.select([
(F.col("value") >= thresh)
.cast("int")
.alias(f"ind_{thresh}")
for thresh in range(1,11)
]).groupBy().sum().show()
# +----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+
# |sum(ind_1)|sum(ind_2)|sum(ind_3)|sum(ind_4)|sum(ind_5)|sum(ind_6)|sum(ind_7)|sum(ind_8)|sum(ind_9)|sum(ind_10)|
# +----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+
# | 908971| 818171| 727240| 636334| 545463| 454279| 363143| 272460| 181729| 90965|
# +----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+
将条件聚合与when
表达式一起使用就可以了。
这里有一个例子:
from pyspark.sql import functions as F
df = spark.createDataFrame([(1,), (2,), (3,), (4,), (4,), (6,), (7,)], ["value"])
count_expr = [
F.count(F.when(F.col("value") >= th, 1)).alias(f"gte_{th}")
for th in range(1, 11)
]
df.select(*count_expr).show()
#+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
#|gte_1|gte_2|gte_3|gte_4|gte_5|gte_6|gte_7|gte_8|gte_9|gte_10|
#+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
#| 7| 6| 5| 4| 2| 2| 1| 0| 0| 0|
#+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
使用pyspark.sql.functions
:中的用户定义函数udf
import pandas as pd
import numpy as np
df = pd.DataFrame(np.random.randint(0,100, size=(20)), columns=['val'])
thres = [90, 80, 30] # these are the thresholds
thres.sort(reverse=True) # list needs to be sorted
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
spark = SparkSession.builder
.master("local[2]")
.appName("myApp")
.getOrCreate()
sparkDF = spark.createDataFrame(df)
myUdf = udf(lambda x: 0 if x>thres[0] else 1 if x>thres[1] else 2 if x>thres[2] else 3)
sparkDF = sparkDF.withColumn("rank", myUdf(sparkDF.val))
sparkDF.show()
# +---+----+
# |val|rank|
# +---+----+
# | 28| 3|
# | 54| 2|
# | 19| 3|
# | 4| 3|
# | 74| 2|
# | 62| 2|
# | 95| 0|
# | 19| 3|
# | 55| 2|
# | 62| 2|
# | 33| 2|
# | 93| 0|
# | 81| 1|
# | 41| 2|
# | 80| 2|
# | 53| 2|
# | 14| 3|
# | 16| 3|
# | 30| 3|
# | 77| 2|
# +---+----+
sparkDF.groupby(['rank']).count().show()
# Out:
# +----+-----+
# |rank|count|
# +----+-----+
# | 3| 7|
# | 0| 2|
# | 1| 1|
# | 2| 10|
# +----+-----+
如果一个值严格大于thres[i]
但小于或等于thres[i-1]
,则该值的级别为i
。这样可以最大限度地减少比较次数。
对于CCD_ 8,我们具有秩0->[max, 90[
,1->[90, 80[
,2->[80, 30[
,3->[30, min]