在Spark SQL中自动且优雅地展平DataFrame



全部,

是否有一种优雅且可接受的方法来用嵌套StructType的列展平Spark SQL表(Parquet)

例如

如果我的模式是:

foo
|_bar
|_baz
x
y
z

如何在不需要手动运行的情况下将其选择为扁平的表格形式

df.select("foo.bar","foo.baz","x","y","z")

换句话说,在给定StructTypeDataFrame的情况下,我如何通过编程获得上述代码的结果

简单的答案是,没有"已接受";方法,但您可以使用递归函数非常优雅地完成此操作,该函数通过遍历DataFrame.schema生成select(...)语句。

递归函数应该返回一个Array[Column]。每当函数命中StructType时,它都会调用自己,并将返回的Array[Column]附加到自己的Array[Column]中。

类似于:

import org.apache.spark.sql.Column
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.functions.col
def flattenSchema(schema: StructType): Seq[Column] = schema.fields.flatMap {
case StructField(name, inner: StructType, _, _) => allColumns(inner).map(sub => col(s"$name.$sub"))
case StructField(name, _, _, _)                 => Seq(col(name))
}

然后你会这样使用它:

df.select(flattenSchema(df.schema):_*)

我只是想分享我的Pyspark解决方案——它或多或少是@David Griffin解决方案的翻译,因此它支持任何级别的嵌套对象。

from pyspark.sql.types import StructType, ArrayType  
def flatten(schema, prefix=None):
fields = []
for field in schema.fields:
name = prefix + '.' + field.name if prefix else field.name
dtype = field.dataType
if isinstance(dtype, ArrayType):
dtype = dtype.elementType
if isinstance(dtype, StructType):
fields += flatten(dtype, prefix=name)
else:
fields.append(name)
return fields

df.select(flatten(df.schema)).show()

我正在改进我以前的答案,并为我自己的问题提供一个解决方案,该问题在接受的答案的评论中有所陈述。

这个公认的解决方案创建一个列对象数组,并使用它来选择这些列。在Spark中,如果您有一个嵌套的DataFrame,您可以选择这样的子列:df.select("Parent.Child"),这将返回一个具有子列值的DataFrame并命名为child。但是,如果不同父结构的属性具有相同的名称,则会丢失有关父结构的信息,最终可能会使用相同的列名,并且无法再按名称访问它们,因为它们是明确的。

这是我的问题。

我找到了解决问题的办法,也许它也能帮助其他人。我单独给flattenSchema打了电话:

val flattenedSchema = flattenSchema(df.schema)

这返回了一个列对象数组。我没有在select()中使用它,它会返回一个DataFrame,其中的列由上一级的子级命名,而是将原始列名映射为字符串,然后在选择Parent.Child列后,它将其重命名为Parent.Child而不是Child(为了方便起见,我还用下划线替换了点):

val renamedCols = flattenedSchema.map(name => col(name.toString()).as(name.toString().replace(".","_")))

然后你可以使用原始答案中所示的选择功能:

var newDf = df.select(renamedCols:_*)

我在开源的spark daria项目中添加了一个DataFrame#flattenSchema方法。

以下是如何在代码中使用该函数。

import com.github.mrpowers.spark.daria.sql.DataFrameExt._
df.flattenSchema().show()
+-------+-------+---------+----+---+
|foo.bar|foo.baz|        x|   y|  z|
+-------+-------+---------+----+---+
|   this|     is|something|cool| ;)|
+-------+-------+---------+----+---+

也可以使用flattenSchema()方法指定不同的列名分隔符。

df.flattenSchema(delimiter = "_").show()
+-------+-------+---------+----+---+
|foo_bar|foo_baz|        x|   y|  z|
+-------+-------+---------+----+---+
|   this|     is|something|cool| ;)|
+-------+-------+---------+----+---+

这个分隔符参数非常重要。如果您要在Redshift中展开架构以加载表,则不能使用句点作为分隔符。

以下是生成此输出的完整代码片段。

val data = Seq(
Row(Row("this", "is"), "something", "cool", ";)")
)
val schema = StructType(
Seq(
StructField(
"foo",
StructType(
Seq(
StructField("bar", StringType, true),
StructField("baz", StringType, true)
)
),
true
),
StructField("x", StringType, true),
StructField("y", StringType, true),
StructField("z", StringType, true)
)
)
val df = spark.createDataFrame(
spark.sparkContext.parallelize(data),
StructType(schema)
)
df.flattenSchema().show()

底层代码类似于David Griffin的代码(以防您不想在项目中添加spark-daria依赖项)。

object StructTypeHelpers {
def flattenSchema(schema: StructType, delimiter: String = ".", prefix: String = null): Array[Column] = {
schema.fields.flatMap(structField => {
val codeColName = if (prefix == null) structField.name else prefix + "." + structField.name
val colName = if (prefix == null) structField.name else prefix + delimiter + structField.name
structField.dataType match {
case st: StructType => flattenSchema(schema = st, delimiter = delimiter, prefix = colName)
case _ => Array(col(codeColName).alias(colName))
}
})
}
}
object DataFrameExt {
implicit class DataFrameMethods(df: DataFrame) {
def flattenSchema(delimiter: String = ".", prefix: String = null): DataFrame = {
df.select(
StructTypeHelpers.flattenSchema(df.schema, delimiter, prefix): _*
)
}
}
}

=====编辑====

这里还有一些针对更复杂模式的附加处理:https://medium.com/@lvhuyen/使用-spark-dataframe-having-acomplex-schema-a3bce8c3f44

===============

PySpark,添加到@Evan V的答案中,当你的域名有特殊字符时,比如点".",连字符"-",…:

from pyspark.sql.types import StructType, ArrayType  
def normalise_field(raw):
return raw.strip().lower() 
.replace('`', '') 
.replace('-', '_') 
.replace(' ', '_') 
.strip('_')
def flatten(schema, prefix=None):
fields = []
for field in schema.fields:
name = "%s.`%s`" % (prefix, field.name) if prefix else "`%s`" % field.name
dtype = field.dataType
if isinstance(dtype, ArrayType):
dtype = dtype.elementType
if isinstance(dtype, StructType):
fields += flatten(dtype, prefix=name)
else:
fields.append(col(name).alias(normalise_field(name)))
return fields
df.select(flatten(df.schema)).show()

您也可以使用SQL将列选择为平面列。

  1. 获取原始数据帧架构
  2. 通过浏览架构生成SQL字符串
  3. 查询原始数据帧

我用Java实现了一个:https://gist.github.com/ebuildy/3de0e2855498e5358e4eed1a4f72ea48

(也使用递归方法,我更喜欢SQL方式,所以您可以通过Spark shell轻松测试它)。

这里有一个函数,它可以执行您想要的操作,并且可以处理包含同名列的多个嵌套列,前缀为:

from pyspark.sql import functions as F
def flatten_df(nested_df):
flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']
flat_df = nested_df.select(flat_cols +
[F.col(nc+'.'+c).alias(nc+'_'+c)
for nc in nested_cols
for c in nested_df.select(nc+'.*').columns])
return flat_df

之前:

root
|-- x: string (nullable = true)
|-- y: string (nullable = true)
|-- foo: struct (nullable = true)
|    |-- a: float (nullable = true)
|    |-- b: float (nullable = true)
|    |-- c: integer (nullable = true)
|-- bar: struct (nullable = true)
|    |-- a: float (nullable = true)
|    |-- b: float (nullable = true)
|    |-- c: integer (nullable = true)

之后:

root
|-- x: string (nullable = true)
|-- y: string (nullable = true)
|-- foo_a: float (nullable = true)
|-- foo_b: float (nullable = true)
|-- foo_c: integer (nullable = true)
|-- bar_a: float (nullable = true)
|-- bar_b: float (nullable = true)
|-- bar_c: integer (nullable = true)

要将David Griffen和V.Samma的答案结合起来,您可以这样做来压平,同时避免重复的列名:

import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
def flattenSchema(schema: StructType, prefix: String = null) : Array[Column] = {
schema.fields.flatMap(f => {
val colName = if (prefix == null) f.name else (prefix + "." + f.name)
f.dataType match {
case st: StructType => flattenSchema(st, colName)
case _ => Array(col(colName).as(colName.replace(".","_")))
}
})
}
def flattenDataFrame(df:DataFrame): DataFrame = {
df.select(flattenSchema(df.schema):_*)
}
var my_flattened_json_table = flattenDataFrame(my_json_table)

这是对解决方案的修改,但它使用了tailrec表示法


@tailrec
def flattenSchema(
splitter: String,
fields: List[(StructField, String)],
acc: Seq[Column]): Seq[Column] = {
fields match {
case (field, prefix) :: tail if field.dataType.isInstanceOf[StructType] =>
val newPrefix = s"$prefix${field.name}."
val newFields = field.dataType.asInstanceOf[StructType].fields.map((_, newPrefix)).toList
flattenSchema(splitter, tail ++ newFields, acc)
case (field, prefix) :: tail =>
val colName = s"$prefix${field.name}"
val newCol  = col(colName).as(colName.replace(".", splitter))
flattenSchema(splitter, tail, acc :+ newCol)
case _ => acc
}
}
def flattenDataFrame(df: DataFrame): DataFrame = {
val fields = df.schema.fields.map((_, ""))
df.select(flattenSchema("__", fields.toList, Seq.empty): _*)
}

如果您使用的是嵌套结构和数组,这是对上面代码的一点补充。

def flattenSchema(schema: StructType, prefix: String = null) : Array[Column] = {
schema.fields.flatMap(f => {
val colName = if (prefix == null) f.name else (prefix + "." + f.name)
f match {
case StructField(_, struct:StructType, _, _) => flattenSchema(struct, colName)
case StructField(_, ArrayType(x :StructType, _), _, _) => flattenSchema(x, colName)
case StructField(_, ArrayType(_, _), _, _) => Array(col(colName))
case _ => Array(col(colName))
}
})
}

我一直在使用一行代码,它导致了一个具有5列bar、baz、x、y、z:的扁平模式

df.select("foo.*", "x", "y", "z")

至于CCD_ 18:我通常会保留CCD_。例如,如果您有一个列idList,它是字符串列表,您可以执行以下操作:

df.withColumn("flattenedId", functions.explode(col("idList")))
.drop("idList")

这将导致一个新的数据帧,其列名为flattenedId(不再是列表)

这是基于@Evan V的解决方案来处理嵌套更重的Json文件。对我来说,原始解决方案的问题是,当一个ArrayType正好嵌套在另一个ArraeType中时,我得到了一个错误。

例如,如果Json看起来像:

{"e":[{"f":[{"g":"h"}]}]}

我会得到一个错误:

"cannot resolve '`e`.`f`['g']' due to data type mismatch: argument 2 requires integral type

为了解决这个问题,我修改了一些代码,我同意这看起来超级愚蠢,只是把它发布在这里,这样有人可能会想出一个更好的解决方案。

def flatten(schema, prefix=None):
fields = []
for field in schema.fields:
name = prefix + '.' + field.name if prefix else field.name
dtype = field.dataType
if isinstance(dtype, T.StructType):
fields += flatten(dtype, prefix=name)
else:
fields.append(name)
return fields

def explodeDF(df):
for (name, dtype) in df.dtypes:
if "array" in dtype:
df = df.withColumn(name, F.explode(name))
return df
def df_is_flat(df):
for (_, dtype) in df.dtypes:
if ("array" in dtype) or ("struct" in dtype):
return False
return True
def flatJson(jdf):
keepGoing = True
while(keepGoing):
fields = flatten(jdf.schema)
new_fields = [item.replace(".", "_") for item in fields]
jdf = jdf.select(fields).toDF(*new_fields)
jdf = explodeDF(jdf)
if df_is_flat(jdf):
keepGoing = False
return jdf

用法:

df = spark.read.json(path_to_json)
flat_df = flatJson(df)
flat_df.show()
+---+---+-----+
|  a|e_c|e_f_g|
+---+---+-----+
|  b|  d|    h|
+---+---+-----+
import org.apache.spark.sql.SparkSession
import org.apache.spark.SparkConf
import org.apache.spark.sql.types.StructType
import scala.collection.mutable.ListBuffer 
val columns=new ListBuffer[String]()
def flattenSchema(schema:StructType,prefix:String=null){
for(i<-schema.fields){
if(i.dataType.isInstanceOf[StructType]) {
val columnPrefix = i.name + "."
flattenSchema(i.dataType.asInstanceOf[StructType], columnPrefix)
}
else {
if(prefix == null)
columns.+=(i.name)
else
columns.+=(prefix+i.name)
}
}
}

结合Evan V、Avrell和Steco的思想。在PySpark中使用"处理具有特殊字符的查询字段时,我还提供了完整的SQL语法。

下面的解决方案给出了以下内容,

  1. 处理嵌套的JSON模式
  2. 在嵌套列中处理相同的列名(我们将给出用下划线分隔的整个层次结构的别名)
  3. 处理特殊字符。(我们用">"处理特殊字符,我没有处理">"的连续出现,但我们也可以用适当的"sub"替换来处理)
  4. 为我们提供SQL语法
  5. 查询字段包含在"中

下面是代码片段,

df=spark.read.json('<JSON FOLDER / FILE PATH>')
df.printSchema()
from pyspark.sql.types import StructType, ArrayType
def flatten(schema, prefix=None):
fields = []
for field in schema.fields:
name = prefix + '.' + field.name if prefix else field.name
dtype = field.dataType
if isinstance(dtype, ArrayType):
dtype = dtype.elementType

if isinstance(dtype, StructType):
fields += flatten(dtype, prefix=name)
else:
alias_name=name.replace('.','_').replace(' ','_').replace('(','').replace(')','').replace('-','_').replace('&','_').replace(r'(_){2,}',r'1')
name=name.replace('.','`.`')
field_name = "`" + name + "`" + " AS " + alias_name
fields.append(field_name)
return fields
df.createOrReplaceTempView("to_flatten_df")
query_fields=flatten(df.schema)
def listToString(s):  

# initialize an empty string 
str1 = ""
# traverse in the string   
for ele in s:  
str1 = str1 + ele + ','
# return string   
return str1  
spark.sql("SELECT " + listToString(query_fields)[:-1] + " FROM to_flatten_df" ).show()

相关内容

  • 没有找到相关文章

最新更新