我有以下初始 PySpark DataFrame:
+----------+--------------------------------+
|product_PK| products|
+----------+--------------------------------+
| 686 | [[686,520.70],[645,2]]|
| 685 |[[685,45.556],[678,23],[655,21]]|
| 693 | []|
df = sqlCtx.createDataFrame(
[(686, [[686,520.70], [645,2]]), (685, [[685,45.556], [678,23],[655,21]]), (693, [])],
["product_PK", "products"]
)
列products
包含嵌套数据。我需要提取每对值中的第二个值。我正在运行此代码:
temp_dataframe = dataframe.withColumn("exploded" , explode(col("products"))).withColumn("score", col("exploded").getItem("_2"))
它适用于特定的数据帧。但是,我想将此代码放入一个函数中,并在不同的数据帧上运行它。我的所有数据帧都具有相同的结构。唯一的区别是子列"_2"
在某些数据帧中可能有不同的命名,例如 "col1"
或"col2"
.
例如:
DataFrame content
root
|-- product_PK: long (nullable = true)
|-- products: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- _1: long (nullable = true)
| | |-- _2: double (nullable = true)
|-- exploded: struct (nullable = true)
| |-- _1: long (nullable = true)
| |-- _2: double (nullable = true)
DataFrame content
root
|-- product_PK: long (nullable = true)
|-- products: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- product_PK: long (nullable = true)
| | |-- col2: integer (nullable = true)
|-- exploded: struct (nullable = true)
| |-- product_PK: long (nullable = true)
| |-- col2: integer (nullable = true)
我尝试使用像 getItem(1)
这样的索引,但它说必须提供列的名称。
有没有办法避免指定列名或以某种方式概括代码的这一部分?
我的目标是exploded
包含嵌套数据中每对的第二个值,即 _2
或col1
或col2
.
听起来你走在正确的轨道上。 我认为实现此目的的方法是读取架构以确定要爆炸的字段的名称。 但是,您需要使用 schema.fields 来查找结构字段,然后使用它的属性来找出结构中的字段,而不是 schema.names。 下面是一个示例:
from pyspark.sql.functions import *
from pyspark.sql.types import *
# Setup the test dataframe
data = [
(686, [(686, 520.70), (645, 2.)]),
(685, [(685, 45.556), (678, 23.), (655, 21.)]),
(693, [])
]
schema = StructType([
StructField("product_PK", StringType()),
StructField("products",
ArrayType(StructType([
StructField("_1", IntegerType()),
StructField("col2", FloatType())
]))
)
])
df = sqlCtx.createDataFrame(data, schema)
# Find the products field in the schema, then find the name of the 2nd field
productsField = next(f for f in df.schema.fields if f.name == 'products')
target_field = productsField.dataType.elementType.names[1]
# Do your explode using the field name
temp_dataframe = df.withColumn("exploded" , explode(col("products"))).withColumn("score", col("exploded").getItem(target_field))
现在,如果你检查结果,你会得到这个:
>>> temp_dataframe.printSchema()
root
|-- product_PK: string (nullable = true)
|-- products: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- _1: integer (nullable = true)
| | |-- col2: float (nullable = true)
|-- exploded: struct (nullable = true)
| |-- _1: integer (nullable = true)
| |-- col2: float (nullable = true)
|-- score: float (nullable = true)
这就是你想要的吗?
>>> df.show(10, False)
+----------+-----------------------------------------------------------------------+
|product_PK|products |
+----------+-----------------------------------------------------------------------+
|686 |[WrappedArray(686, null), WrappedArray(645, 2)] |
|685 |[WrappedArray(685, null), WrappedArray(678, 23), WrappedArray(655, 21)]|
|693 |[] |
+----------+-----------------------------------------------------------------------+
>>> import pyspark.sql.functions as F
>>> df.withColumn("exploded", F.explode("products"))
... .withColumn("exploded", F.col("exploded").getItem(1))
... .show(10,False)
+----------+-----------------------------------------------------------------------+--------+
|product_PK|products |exploded|
+----------+-----------------------------------------------------------------------+--------+
|686 |[WrappedArray(686, null), WrappedArray(645, 2)] |null |
|686 |[WrappedArray(686, null), WrappedArray(645, 2)] |2 |
|685 |[WrappedArray(685, null), WrappedArray(678, 23), WrappedArray(655, 21)]|null |
|685 |[WrappedArray(685, null), WrappedArray(678, 23), WrappedArray(655, 21)]|23 |
|685 |[WrappedArray(685, null), WrappedArray(678, 23), WrappedArray(655, 21)]|21 |
+----------+-----------------------------------------------------------------------+--------+
鉴于您的exploded
列是一个struct
|-- exploded: struct (nullable = true)
| |-- _1: integer (nullable = true)
| |-- col2: float (nullable = true)
您可以使用以下逻辑在不知道名称的情况下获取第二个元素
from pyspark.sql import functions as F
temp_dataframe = df.withColumn("exploded" , F.explode(F.col("products")))
temp_dataframe.withColumn("score", F.col("exploded."+temp_dataframe.select(F.col("exploded.*")).columns[1]))
您应该将输出为
+----------+--------------------------------------+------------+------+
|product_PK|products |exploded |score |
+----------+--------------------------------------+------------+------+
|686 |[[686,520.7], [645,2.0]] |[686,520.7] |520.7 |
|686 |[[686,520.7], [645,2.0]] |[645,2.0] |2.0 |
|685 |[[685,45.556], [678,23.0], [655,21.0]]|[685,45.556]|45.556|
|685 |[[685,45.556], [678,23.0], [655,21.0]]|[678,23.0] |23.0 |
|685 |[[685,45.556], [678,23.0], [655,21.0]]|[655,21.0] |21.0 |
+----------+--------------------------------------+------------+------+