根据spark中的移动和向DataFrame添加批号



我有一个数据集,我需要批量处理(由于API限制)。

一个批处理的列text_length之和不能超过1000。批处理中的最大行数不能大于5.

为此,我想在单个批次中添加批号,以便稍后根据batch_number处理数据。

我如何在pyspark(在Databricks)中实现这一点。我对这一切都很陌生,我甚至不知道在网上找什么。

非常感谢你的帮助。

下面的表格说明了我正在努力实现的目标:

原始表

<表类>idtext_lengthtbody><<tr>15002400320043005100610071008100910010300

如果您不是在寻找最佳解决方案,而是在Spark中寻找一种不太复杂的解决问题的方法,我们可以将问题分为两步:

  1. 将数据分成5行块,忽略文本长度
  2. 如果一个块中的文本长度总和太大,则将该块拆分为其他块

这个解决方案不是最优的,因为它生产的批次太多了。

步骤1可以使用zipWithIndex实现。在创建批id时,我们要留出足够的"空间"。用于稍后分割批次。在此步骤结束时,将一个块中的所有行分组到一个列表中,作为步骤2的输入:

df = ...
r = df.rdd.zipWithIndex().toDF() 
.select("_1.id", "_1.text_length", "_2") 
.withColumn("batch", F.expr("cast(_2 / 5 as long)*5")) 
.withColumn("data", F.struct("id", "text_length", "batch")) 
.groupBy("batch") 
.agg(F.collect_list("data").alias("data"))

第2部分主要由一个udf组成,它检查在一个批处理中是否超过了最大文本长度。如果是,则以下元素的批处理id加1。因为我们在第1部分中跳过了足够多的批处理id,所以我们没有得到任何冲突。

def splitBatchIfNecessary(data):
text_length = 0
batch = -1
for d in data:
text_length = text_length + d.text_length
if text_length > 1000:
if batch == -1:
text_length = 0
batch = d.batch + 1
yield (d.id, d.text_length, d.batch)
else:
text_length = d.text_length
batch = batch + 1
yield (d.id, d.text_length, batch)          
else:
if batch == -1:
batch = d.batch
yield (d.id, d.text_length, batch)
schema=r.schema["data"].dataType
split_udf = F.udf(splitBatchIfNecessary, schema)
r = r.withColumn("data",split_udf(F.col("data")) ) 
.selectExpr("explode(data)") 
.select("col.*") 

输出:

+---+-----------+-----+                                                         
| id|text_length|batch|
+---+-----------+-----+
|  1|        500|    0|
|  2|        400|    0|
|  3|        200|    1|
|  4|        300|    1|
|  5|        100|    1|
|  6|        100|    5|
|  7|        100|    5|
|  8|        100|    5|
|  9|        100|    5|
| 10|        300|    5|
+---+-----------+-----+

可能的优化是将zipWithIndex替换为zipWithUniqueIds(但会变得更加"不完整")。批量)或使用矢量化udf。

与mck状态不同,这不是"分区问题"。

问题是,1)Spark与分区一起工作-不仅仅是一个这样的分区才能有效;2)没有分组属性来确保"批"可以在一个分区中自然形成或自然提取,只有。此外,我们可以有负数或分数吗?-这没有说明。

  1. 这意味着处理只需要基于一个分区,但它可能不够大,也就是OOM。

  2. 尝试处理每个分区是没有意义的,因为所有的工作都需要在每个分区N、N+1等完成,以抵消分区N-1中的影响。我已经在SO上制定了一个考虑分区边界的解决方案,但这违背了Spark的原则,而且用例更简单。

  3. 实际上不是一个Spark用例。它是一个顺序算法,而不是并行算法,使用PL/SQL, Scala, JAVA, c++。

  4. 唯一的方法是:

    • 在全局应用zipWithIndex的固定大小分区上循环(为了安全)
      • 使用Scala分批处理-临时结果
      • 从上次创建的批处理中获取所有项并与下一个分区合并
      • 从temp结果中删除最后一批
      • 重复周期
  5. NB:绕过数据分区边界方面的近似似乎不起作用——>另一个答案证明了这一点。你会得到一个折衷的结果,而不是真正的答案。而且要纠正它实际上并不容易,因为批次有间隙,并且可能由于分组而在其他分区中。

最新更新