如何获得PySpark DataFrame的引用列?



给定PySpark DataFrame是否有可能获得由DataFrame引用的源列列表?

也许一个更具体的例子可以帮助解释我所追求的。假设我有一个定义为:

的DataFrame
import pyspark.sql.functions as func
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
source_df = spark.createDataFrame(
[("pru", 23, "finance"), ("paul", 26, "HR"), ("noel", 20, "HR")],
["name", "age", "department"],
)
source_df.createOrReplaceTempView("people")
sqlDF = spark.sql("SELECT name, age, department FROM people")
df = sqlDF.groupBy("department").agg(func.max("age").alias("max_age"))
df.show()

返回:

+----------+--------+                                                           
|department|max_age |
+----------+--------+
|   finance|      23|
|        HR|      26|
+----------+--------+

df引用的列是[department, age]。是否有可能以编程方式获得引用列的列表?

由于在pyspark中捕获explain()的结果,我知道我可以将计划提取为字符串:

df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), "formatted")

返回:

== Physical Plan ==
AdaptiveSparkPlan (6)
+- HashAggregate (5)
+- Exchange (4)
+- HashAggregate (3)
+- Project (2)
+- Scan ExistingRDD (1)

(1) Scan ExistingRDD
Output [3]: [name#0, age#1L, department#2]
Arguments: [name#0, age#1L, department#2], MapPartitionsRDD[4] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)
(2) Project
Output [2]: [age#1L, department#2]
Input [3]: [name#0, age#1L, department#2]
(3) HashAggregate
Input [2]: [age#1L, department#2]
Keys [1]: [department#2]
Functions [1]: [partial_max(age#1L)]
Aggregate Attributes [1]: [max#22L]
Results [2]: [department#2, max#23L]
(4) Exchange
Input [2]: [department#2, max#23L]
Arguments: hashpartitioning(department#2, 200), ENSURE_REQUIREMENTS, [plan_id=60]
(5) HashAggregate
Input [2]: [department#2, max#23L]
Keys [1]: [department#2]
Functions [1]: [max(age#1L)]
Aggregate Attributes [1]: [max(age#1L)#12L]
Results [2]: [department#2, max(age#1L)#12L AS max_age#13L]
(6) AdaptiveSparkPlan
Output [2]: [department#2, max_age#13L]
Arguments: isFinalPlan=false

这是有用的,但它不是我需要的。我需要一个引用列的列表。这可能吗?

也许问这个问题的另一种方法是……是否有办法获得解释计划作为一个对象,我可以迭代/探索?

更新。感谢@matt-andruff的回复,我得到了这个:

df._jdf.queryExecution().executedPlan().treeString().split("+-")[-2]

返回:

' Project [age#1L, department#2]n            '

,我想我可以从中解析我想要的信息,但这是一种远非优雅的方式,而且特别容易出错。

我真正想要的是一种安全的、可靠的、api支持的方式来获取这些信息。我开始觉得这是不可能的。

有一个对象,不幸的是它是一个java对象,并没有翻译成pyspark。

你仍然可以使用Spark构造访问它:

>>> df._jdf.queryExecution().executedPlan().apply(0).output().apply(0).toString()
u'department#1621'
>>> df._jdf.queryExecution().executedPlan().apply(0).output().apply(1).toString()
u'max_age#1632L'

你可以循环以上两个apply来获得你正在寻找的信息,如:

plan = df._jdf.queryExecution().executedPlan()
steps = [ plan.apply(i) for i in range(1,100) if not isinstance(plan.apply(i), type(None)) ]
iterator = steps[0].inputSet().iterator()
>>> iterator.next().toString()
u'department#1621'
>>> iterator.next().toString()
u'max#1642L'
steps = [ plan.apply(i) for i in range(1,100) if not isinstance(plan.apply(i), type(None)) ]
projections = [ (steps[0].p(i).toJSON().encode('ascii','ignore')) for i in range(1,100) if not( isinstance(steps[0].p(i), type(None) )) and steps[0].p(i).nodeName().encode('ascii','ignore') == 'Project' ]
dd = spark.sparkContext.parallelize(projections)
df2 = spark.read.json(rdd)
>>> df2.show(1,False)
+-----+------------------------------------------+----+------------+------+--------------+------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----+
|child|class                                     |name|num-children|output|outputOrdering|outputPartitioning|projectList                                                                                                                                                                                                                                                                                                                                                                                              |rdd |
+-----+------------------------------------------+----+------------+------+--------------+------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----+
|0    |org.apache.spark.sql.execution.ProjectExec|null|1           |null  |null          |null              |[[[org.apache.spark.sql.catalyst.expressions.AttributeReference, long, [1620, 4ad48da6-03cf-45d4-9b35-76ac246fadac, org.apache.spark.sql.catalyst.expressions.ExprId], age, true, 0, [people]]], [[org.apache.spark.sql.catalyst.expressions.AttributeReference, string, [1621, 4ad48da6-03cf-45d4-9b35-76ac246fadac, org.apache.spark.sql.catalyst.expressions.ExprId], department, true, 0, [people]]]]|null|
+-----+------------------------------------------+----+------------+------+--------------+------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----+
df2.select(func.explode(func.col('projectList'))).select( func.col('col')[0]["name"] ) .show(100,False)
+-----------+
|col[0].name|
+-----------+
|age        |
|department |
+-----------+

range—>有点黑客,但显然size不工作。我敢肯定,如果有更多的时间,我可以改进范围hack。

您可以使用json以编程方式提取信息。

我有一些东西,虽然不是我最初问题的答案(参见Matt Andruff的答案),但在这里仍然有用。这是一种获取pyspark.sql.column.Column引用的所有源列的方法。

简单再生产:

from pyspark.sql import functions as f, SparkSession
SparkSession.builder.getOrCreate()
col = f.concat(f.col("A"), f.col("B"))
type(col)
col._jc.expr().references().toList().toString()

的回报:

& lt;类pyspark.sql.column.Column的祝辞
"列表(',' B)">

绝对不是完美的,它仍然需要您从返回的字符串中解析出列名,但至少我所追求的信息是可用的。从references()返回的对象上可能有更多的方法,这使得解析返回的字符串更容易,但如果有,我还没有找到它!

下面是我写的解析函数

def parse_references(references: str):
return sorted(
"".join(
references.replace("'", "")
.replace("List(", "")
.replace(")", "")
.replace(")", "")
.split()
).split(",")
)
assert parse_references("List('A, 'B)") == ["A", "B"]

PySpark并不是真正为这种低级技巧而设计的(这更需要Scala,因为Spark是用Scala开发的,所以它提供了里面的所有功能)。

访问QueryExecution的这一步是进入Spark SQL查询执行引擎的主要入口。

问题是py4j(用作JVM和Python环境之间的桥梁)使它在PySpark方面没有用处。

如果需要访问最终查询计划(就在它转换为rdd之前),可以使用以下命令:

df._jdf.queryExecution().executedPlan().prettyJson()

查看QueryExecutionAPI

QueryExecutionListener

你真的应该考虑Scala拦截任何你想要的查询,QueryExecutionListener似乎是一个相当可行的起点。

还有更多,但都是Scala:)

我真正追求的是一种故障安全,可靠,api支持的方式来获取此信息。我开始觉得这是不可能的。

我并不感到惊讶,因为你抛弃了最好的答案:Scala。我建议将它用于PoC,看看你能得到什么,然后(如果你必须的话)寻找Python解决方案(我认为这是可行的,但非常容易出错)。

您可以尝试下面的代码,这将为您提供数据框架中的列列表及其数据类型。

for field in df.schema.fields:
print(field.name +" , "+str(field.dataType))

最新更新