Spark sql如何在不丢失null值的情况下进行分解



我有一个数据帧,我正在尝试将其压平。作为这个过程的一部分,我想分解它,所以如果我有一列数组,数组的每个值都将用于创建一个单独的行。例如,

id | name | likes
_______________________________
1  | Luke | [baseball, soccer]

应该成为

id | name | likes
_______________________________
1  | Luke | baseball
1  | Luke | soccer

这是我的代码

private DataFrame explodeDataFrame(DataFrame df) {
    DataFrame resultDf = df;
    for (StructField field : df.schema().fields()) {
        if (field.dataType() instanceof ArrayType) {
            resultDf = resultDf.withColumn(field.name(), org.apache.spark.sql.functions.explode(resultDf.col(field.name())));
            resultDf.show();
        }
    }
    return resultDf;
}

问题是,在我的数据中,一些数组列有null。在这种情况下,将删除整行。所以这个数据帧:

id | name | likes
_______________________________
1  | Luke | [baseball, soccer]
2  | Lucy | null

成为

id | name | likes
_______________________________
1  | Luke | baseball
1  | Luke | soccer

而不是

id | name | likes
_______________________________
1  | Luke | baseball
1  | Luke | soccer
2  | Lucy | null

如何分解数组以避免丢失空行?

我使用的是Spark 1.5.2和Java 8

Spark 2.2+

您可以使用explode_outer功能:

import org.apache.spark.sql.functions.explode_outer
df.withColumn("likes", explode_outer($"likes")).show
// +---+----+--------+
// | id|name|   likes|
// +---+----+--------+
// |  1|Luke|baseball|
// |  1|Luke|  soccer|
// |  2|Lucy|    null|
// +---+----+--------+

火花<=2.1

在Scala中,Java等效函数应该几乎相同(要导入单个函数,请使用import static(。

import org.apache.spark.sql.functions.{array, col, explode, lit, when}
val df = Seq(
  (1, "Luke", Some(Array("baseball", "soccer"))),
  (2, "Lucy", None)
).toDF("id", "name", "likes")
df.withColumn("likes", explode(
  when(col("likes").isNotNull, col("likes"))
    // If null explode an array<string> with a single null
    .otherwise(array(lit(null).cast("string")))))

这里的想法基本上是用所需类型的CCD_ 4来代替CCD_。对于复杂类型(也称为structs(,您必须提供完整的模式:

val dfStruct = Seq((1L, Some(Array((1, "a")))), (2L, None)).toDF("x", "y")
val st =  StructType(Seq(
  StructField("_1", IntegerType, false), StructField("_2", StringType, true)
))
dfStruct.withColumn("y", explode(
  when(col("y").isNotNull, col("y"))
    .otherwise(array(lit(null).cast(st)))))

dfStruct.withColumn("y", explode(
  when(col("y").isNotNull, col("y"))
    .otherwise(array(lit(null).cast("struct<_1:int,_2:string>")))))

注意

如果阵列Column是在containsNull设置为false的情况下创建的,则应首先更改此设置(使用Spark 2.1测试(:

df.withColumn("array_column", $"array_column".cast(ArrayType(SomeType, true)))

您可以使用explode_outer()函数。

根据已接受的答案,当数组元素是复杂类型时,很难手动定义(例如,使用大型结构(。

为了自动做到这一点,我编写了以下辅助方法:

  def explodeOuter(df: Dataset[Row], columnsToExplode: List[String]) = {
      val arrayFields = df.schema.fields
          .map(field => field.name -> field.dataType)
          .collect { case (name: String, type: ArrayType) => (name, type.asInstanceOf[ArrayType])}
          .toMap
      columnsToExplode.foldLeft(df) { (dataFrame, arrayCol) =>
      dataFrame.withColumn(arrayCol, explode(when(size(col(arrayCol)) =!= 0, col(arrayCol))
        .otherwise(array(lit(null).cast(arrayFields(arrayCol).elementType)))))    
 }

编辑:看来spark 2.2和更新版本已经内置了这个。

要处理空映射类型列:对于Spark<=2.1

 List((1, Array(2, 3, 4), Map(1 -> "a")),
(2, Array(5, 6, 7), Map(2 -> "b")),
(3, Array[Int](), Map[Int, String]())).toDF("col1", "col2", "col3").show()

 df.select('col1, explode(when(size(map_keys('col3)) === 0, map(lit("null"), lit("null"))).
otherwise('col3))).show()
from pyspark.sql.functions import *
def flatten_df(nested_df):
    flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
    nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']
    flat_df = nested_df.select(flat_cols +
                               [col(nc + '.' + c).alias(nc + '_' + c)
                                for nc in nested_cols
                                for c in nested_df.select(nc + '.*').columns])
    print("flatten_df_count :", flat_df.count())
    return flat_df
def explode_df(nested_df):
    flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct' and c[1][:5] != 'array']
    array_cols = [c[0] for c in nested_df.dtypes if c[1][:5] == 'array']
    for array_col in array_cols:
        schema = new_df.select(array_col).dtypes[0][1]
        nested_df = nested_df.withColumn(array_col, when(col(array_col).isNotNull(), col(array_col)).otherwise(array(lit(None)).cast(schema))) 
    nested_df = nested_df.withColumn("tmp", arrays_zip(*array_cols)).withColumn("tmp", explode("tmp")).select([col("tmp."+c).alias(c) for c in array_cols] + flat_cols)
    print("explode_dfs_count :", nested_df.count())
    return nested_df

new_df = flatten_df(myDf)
while True:
    array_cols = [c[0] for c in new_df.dtypes if c[1][:5] == 'array']
    if len(array_cols):
        new_df = flatten_df(explode_df(new_df))
    else:
        break
    
new_df.printSchema()

使用arrays_zipexplode可以更快地完成任务并解决null问题。

相关内容

  • 没有找到相关文章

最新更新