使用筛选和UDF优化Spark代码



我正在使用Spark处理一个包含2000万个XML文档的数据集。我本来是在处理所有的,但实际上我只需要其中的三分之一。在另一个spark工作流中,我创建了一个数据帧keyfilter,其中一列是每个XML的键,第二列是布尔值,如果应该处理与键对应的XML,则为True,否则为False

XML本身是使用Pandas UDF处理的,我不能共享它。

我在DataBricks上的笔记本基本上是这样工作的:

import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>
keyfilter = spark.read.parquet('/path/to/keyfilter/os/s3.parquet')
keyfilter.cache()
def process_part(part, fraction=1, filter=True, return_df=False):
try:
df = spark.read.parquet('/path/to/parquets/on/s3/%s/part-%05d*' % (DATE, part))
# Sometimes, the file part-xxxxx doesn't exist
except AnalysisException:
return None
if fraction < 1:
df = df.sample(fraction=fraction, withReplacement=False)
if filter:
df_with_filter = df.join(keyfilter, on='key', how='left').fillna(False)
filtered_df = df_with_filter.filter(col('filter')).drop('filter')
mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
else:
mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
mod_df.write.parquet('/output/path/on/s3/part-%05d_%s_%d' % (part, DATE, time.time()))
if return_df:
return mod_df

n_cores = 6
i=0
while n_cores*i < 1024:
with ThreadPool(n_cores) as p:
p.map(process_part, range(n_cores*i, min(1024, n_cores*i+n_cores)))
i += 1

我发布这个问题的原因是,尽管Pandas UDF应该是发生的最昂贵的操作,但添加过滤实际上会使我的代码运行速度比根本不过滤时慢得多。我是Spark的新手,我想知道我在这里是否做了一些愚蠢的事情,导致keyfilter的连接非常慢,如果是的话,是否有办法让它们变快(例如,有没有办法让keyfilter像一个从键到布尔的哈希表,比如SQL中的CREATE INDEX?(。我想keyfilter的大尺寸在这里起着某种作用;它有2000万行,而process_part中的df只有这些行的一小部分(然而,df的大小要大得多,因为它包含XML文档(。我应该把所有的部分组合成一个巨大的数据帧,而不是一次处理一个吗?

或者有没有一种方法可以通知Spark密钥在两个数据帧中都是唯一的?

使联接在合理的时间范围内发生的关键是在keyfilter上使用broadcast来执行广播哈希联接,而不是标准联接。我还合并了一些部分并降低了并行度(出于某种原因,过多的线程有时似乎会导致引擎崩溃(。我的新性能代码如下:

import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col, braodcast
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>
keyfilter = spark.read.parquet('/path/to/keyfilter/on/s3.parquet')
keyfilter.cache()
def process_parts(part_pair, fraction=1, return_df=False, filter=True):
dfs = []
parts_start, parts_end = part_pair
parts = range(parts_start, parts_end)
for part in parts:
try:
df = spark.read.parquet('/input/path/on/s3/%s/part-%05d*' % (DATE, part))
dfs.append(df)
except AnalysisException:
print("There is no part %05d!" % part)
continue
if len(dfs) >= 2:
df = reduce(lambda x, y: x.union(y), dfs)
elif len(dfs) == 1:
df = dfs[0]
else:
return None
if fraction < 1:
df = df.sample(fraction=fraction, withReplacement=False)
if filter:
df_with_filter = df.join(broadcast(keyfilter), on='key', how='left').fillna(False)
filtered_df = df_with_filter.filter(col('filter')).drop('filter')
mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
else:
mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
mod_df.write.parquet('/output/path/on/s3/parts-%05d-%05d_%s_%d' % (parts_start, parts_end-1, DATE, time.time()))
if return_df:
return mod_df

start_time = time.time()
pairs = [(i*4, i*4+4) for i in range(256)]
with ThreadPool(3) as p:
batch_start_time = time.time()
for i, _ in enumerate(p.imap_unordered(process_parts, pairs, chunksize=1)):
batch_end_time = time.time()
batch_len = batch_end_time - batch_start_time
cum_len = batch_end_time - start_time
print('Processed group %d/256 %d minutes and %d seconds after previous group.' % (i+1, batch_len // 60, batch_len % 60))
print('%d hours, %d minutes, %d seconds since start.' % (cum_len // 3600, (cum_len % 3600) // 60, cum_len % 60))
batch_start_time = time.time()

最新更新