如何按索引而不是名称获取列



我有以下初始 PySpark DataFrame:

+----------+--------------------------------+
|product_PK|                        products|
+----------+--------------------------------+
|      686 |          [[686,520.70],[645,2]]|
|      685 |[[685,45.556],[678,23],[655,21]]|
|      693 |                              []|
df = sqlCtx.createDataFrame(
    [(686, [[686,520.70], [645,2]]), (685, [[685,45.556], [678,23],[655,21]]), (693, [])],
    ["product_PK", "products"]
)

products包含嵌套数据。我需要提取每对值中的第二个值。我正在运行此代码:

temp_dataframe = dataframe.withColumn("exploded" , explode(col("products"))).withColumn("score", col("exploded").getItem("_2"))

它适用于特定的数据帧。但是,我想将此代码放入一个函数中,并在不同的数据帧上运行它。我的所有数据帧都具有相同的结构。唯一的区别是子列"_2"在某些数据帧中可能有不同的命名,例如 "col1""col2".

例如:

DataFrame content
root
 |-- product_PK: long (nullable = true)
 |-- products: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- _1: long (nullable = true)
 |    |    |-- _2: double (nullable = true)
 |-- exploded: struct (nullable = true)
 |    |-- _1: long (nullable = true)
 |    |-- _2: double (nullable = true)
DataFrame content
root
 |-- product_PK: long (nullable = true)
 |-- products: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- product_PK: long (nullable = true)
 |    |    |-- col2: integer (nullable = true)
 |-- exploded: struct (nullable = true)
 |    |-- product_PK: long (nullable = true)
 |    |-- col2: integer (nullable = true)

我尝试使用像 getItem(1) 这样的索引,但它说必须提供列的名称。

有没有办法避免指定列名或以某种方式概括代码的这一部分?

我的目标是exploded包含嵌套数据中每对的第二个值,即 _2col1col2 .

听起来你走在正确的轨道上。 我认为实现此目的的方法是读取架构以确定要爆炸的字段的名称。 但是,您需要使用 schema.fields 来查找结构字段,然后使用它的属性来找出结构中的字段,而不是 schema.names。 下面是一个示例:

from pyspark.sql.functions import *
from pyspark.sql.types import *
# Setup the test dataframe
data = [
    (686, [(686, 520.70), (645, 2.)]), 
    (685, [(685, 45.556), (678, 23.), (655, 21.)]), 
    (693, [])
]
schema = StructType([
    StructField("product_PK", StringType()),
    StructField("products", 
        ArrayType(StructType([
            StructField("_1", IntegerType()),
            StructField("col2", FloatType())
        ]))
    )
])
df = sqlCtx.createDataFrame(data, schema) 
# Find the products field in the schema, then find the name of the 2nd field
productsField = next(f for f in df.schema.fields if f.name == 'products')
target_field = productsField.dataType.elementType.names[1]
# Do your explode using the field name
temp_dataframe = df.withColumn("exploded" , explode(col("products"))).withColumn("score", col("exploded").getItem(target_field))

现在,如果你检查结果,你会得到这个:

>>> temp_dataframe.printSchema()
root
 |-- product_PK: string (nullable = true)
 |-- products: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- _1: integer (nullable = true)
 |    |    |-- col2: float (nullable = true)
 |-- exploded: struct (nullable = true)
 |    |-- _1: integer (nullable = true)
 |    |-- col2: float (nullable = true)
 |-- score: float (nullable = true)

这就是你想要的吗?

>>> df.show(10, False)
+----------+-----------------------------------------------------------------------+
|product_PK|products                                                               |
+----------+-----------------------------------------------------------------------+
|686       |[WrappedArray(686, null), WrappedArray(645, 2)]                        |
|685       |[WrappedArray(685, null), WrappedArray(678, 23), WrappedArray(655, 21)]|
|693       |[]                                                                     |
+----------+-----------------------------------------------------------------------+
>>> import pyspark.sql.functions as F
>>> df.withColumn("exploded", F.explode("products")) 
...   .withColumn("exploded", F.col("exploded").getItem(1)) 
...   .show(10,False)
+----------+-----------------------------------------------------------------------+--------+
|product_PK|products                                                               |exploded|
+----------+-----------------------------------------------------------------------+--------+
|686       |[WrappedArray(686, null), WrappedArray(645, 2)]                        |null    |
|686       |[WrappedArray(686, null), WrappedArray(645, 2)]                        |2       |
|685       |[WrappedArray(685, null), WrappedArray(678, 23), WrappedArray(655, 21)]|null    |
|685       |[WrappedArray(685, null), WrappedArray(678, 23), WrappedArray(655, 21)]|23      |
|685       |[WrappedArray(685, null), WrappedArray(678, 23), WrappedArray(655, 21)]|21      |
+----------+-----------------------------------------------------------------------+--------+

鉴于您的exploded列是一个struct

 |-- exploded: struct (nullable = true)
 |    |-- _1: integer (nullable = true)
 |    |-- col2: float (nullable = true)

您可以使用以下逻辑在不知道名称的情况下获取第二个元素

from pyspark.sql import functions as F
temp_dataframe = df.withColumn("exploded" , F.explode(F.col("products")))
temp_dataframe.withColumn("score", F.col("exploded."+temp_dataframe.select(F.col("exploded.*")).columns[1]))

您应该将输出为

+----------+--------------------------------------+------------+------+
|product_PK|products                              |exploded    |score |
+----------+--------------------------------------+------------+------+
|686       |[[686,520.7], [645,2.0]]              |[686,520.7] |520.7 |
|686       |[[686,520.7], [645,2.0]]              |[645,2.0]   |2.0   |
|685       |[[685,45.556], [678,23.0], [655,21.0]]|[685,45.556]|45.556|
|685       |[[685,45.556], [678,23.0], [655,21.0]]|[678,23.0]  |23.0  |
|685       |[[685,45.556], [678,23.0], [655,21.0]]|[655,21.0]  |21.0  |
+----------+--------------------------------------+------------+------+

相关内容

  • 没有找到相关文章

最新更新