我看到很多人建议Dataframe.explode
是一种有用的方法,但是它会导致比原始数据帧更多的行,这根本不是我想要的。我只是想做一个非常简单的Dataframe:
rdd.map(lambda row: row + [row.my_str_col.split('-')])
的格式如下:
col1 | my_str_col
-----+-----------
18 | 856-yygrm
201 | 777-psgdg
并将其转换为:
col1 | my_str_col | _col3 | _col4
-----+------------+-------+------
18 | 856-yygrm | 856 | yygrm
201 | 777-psgdg | 777 | psgdg
我知道pyspark.sql.functions.split()
,但它导致嵌套的数组列,而不是像我想要的两个顶级列。
理想情况下,我希望这些新列也被命名。
pyspark.sql.functions.split()
在这里是正确的方法-您只需要将嵌套的ArrayType列扁平化为多个顶级列。在本例中,每个数组只包含2个元素,这非常简单。您只需使用Column.getItem()
将数组的每个部分作为列本身检索:
split_col = pyspark.sql.functions.split(df['my_str_col'], '-')
df = df.withColumn('NAME1', split_col.getItem(0))
df = df.withColumn('NAME2', split_col.getItem(1))
结果将是:
col1 | my_str_col | NAME1 | NAME2
-----+------------+-------+------
18 | 856-yygrm | 856 | yygrm
201 | 777-psgdg | 777 | psgdg
我不确定如何在嵌套数组从行到行大小不相同的一般情况下解决这个问题。
对于一般情况,这里有一个解决方案,它不需要提前知道数组的长度,使用collect
或udf
s。不幸的是,这只适用于spark
版本2.1及以上,因为它需要posexplode
函数。
假设您有以下DataFrame:
df = spark.createDataFrame(
[
[1, 'A, B, C, D'],
[2, 'E, F, G'],
[3, 'H, I'],
[4, 'J']
]
, ["num", "letters"]
)
df.show()
#+---+----------+
#|num| letters|
#+---+----------+
#| 1|A, B, C, D|
#| 2| E, F, G|
#| 3| H, I|
#| 4| J|
#+---+----------+
分割letters
列,然后使用posexplode
将结果数组与数组中的位置一起爆炸。接下来,使用pyspark.sql.functions.expr
抓取索引为pos
的元素。
import pyspark.sql.functions as f
df.select(
"num",
f.split("letters", ", ").alias("letters"),
f.posexplode(f.split("letters", ", ")).alias("pos", "val")
)
.show()
#+---+------------+---+---+
#|num| letters|pos|val|
#+---+------------+---+---+
#| 1|[A, B, C, D]| 0| A|
#| 1|[A, B, C, D]| 1| B|
#| 1|[A, B, C, D]| 2| C|
#| 1|[A, B, C, D]| 3| D|
#| 2| [E, F, G]| 0| E|
#| 2| [E, F, G]| 1| F|
#| 2| [E, F, G]| 2| G|
#| 3| [H, I]| 0| H|
#| 3| [H, I]| 1| I|
#| 4| [J]| 0| J|
#+---+------------+---+---+
现在我们根据这个结果创建两个新列。第一个是新列的名称,它将是letter
和数组中的索引的连接。第二列将是数组中对应索引处的值。我们通过利用pyspark.sql.functions.expr
的功能得到后者,该功能允许我们使用列值作为参数。
df.select(
"num",
f.split("letters", ", ").alias("letters"),
f.posexplode(f.split("letters", ", ")).alias("pos", "val")
)
.drop("val")
.select(
"num",
f.concat(f.lit("letter"),f.col("pos").cast("string")).alias("name"),
f.expr("letters[pos]").alias("val")
)
.show()
#+---+-------+---+
#|num| name|val|
#+---+-------+---+
#| 1|letter0| A|
#| 1|letter1| B|
#| 1|letter2| C|
#| 1|letter3| D|
#| 2|letter0| E|
#| 2|letter1| F|
#| 2|letter2| G|
#| 3|letter0| H|
#| 3|letter1| I|
#| 4|letter0| J|
#+---+-------+---+
现在我们可以只groupBy
num
和pivot
的数据框架。把这些放在一起,我们得到:
df.select(
"num",
f.split("letters", ", ").alias("letters"),
f.posexplode(f.split("letters", ", ")).alias("pos", "val")
)
.drop("val")
.select(
"num",
f.concat(f.lit("letter"),f.col("pos").cast("string")).alias("name"),
f.expr("letters[pos]").alias("val")
)
.groupBy("num").pivot("name").agg(f.first("val"))
.show()
#+---+-------+-------+-------+-------+
#|num|letter0|letter1|letter2|letter3|
#+---+-------+-------+-------+-------+
#| 1| A| B| C| D|
#| 3| H| I| null| null|
#| 2| E| F| G| null|
#| 4| J| null| null| null|
#+---+-------+-------+-------+-------+
如果您想用分隔符分割字符串,还有另一种方法。
import pyspark.sql.functions as f
df = spark.createDataFrame([("1:a:2001",),("2:b:2002",),("3:c:2003",)],["value"])
df.show()
+--------+
| value|
+--------+
|1:a:2001|
|2:b:2002|
|3:c:2003|
+--------+
df_split = df.select(f.split(df.value,":")).rdd.flatMap(
lambda x: x).toDF(schema=["col1","col2","col3"])
df_split.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
| 1| a|2001|
| 2| b|2002|
| 3| c|2003|
+----+----+----+
我不认为这种来回转换到rdd会减慢你的速度…也不要担心最后一个模式规范:它是可选的,您可以避免它将解决方案泛化到列大小未知的数据。
我理解你的痛苦。使用split()可以工作,但也可能导致中断。
让我们把你的df稍微修改一下:
df = spark.createDataFrame([('1:"a:3":2001',),('2:"b":2002',),('3:"c":2003',)],["value"])
df.show()
+------------+
| value|
+------------+
|1:"a:3":2001|
| 2:"b":2002|
| 3:"c":2003|
+------------+
如果您尝试像上面所述的那样对它应用split():
df_split = df.select(split(df.value,":")).rdd.flatMap(
lambda x: x).toDF(schema=["col1","col2","col3"]).show()
你会得到
IllegalStateException:输入行没有模式所需的预期值。需要4个字段,提供3个值。
那么,有没有更优雅的方式来解决这个问题呢?我很高兴有人给我指出来。Pyspark.sql.functions.from_csv()是你的好朋友。
以上面的例子为例df:
from pyspark.sql.functions import from_csv
# Define a column schema to apply with from_csv()
col_schema = ["col1 INTEGER","col2 STRING","col3 INTEGER"]
schema_str = ",".join(col_schema)
# define the separator because it isn't a ','
options = {'sep': ":"}
# create a df from the value column using schema and options
df_csv = df.select(from_csv(df.value, schema_str, options).alias("value_parsed"))
df_csv.show()
+--------------+
| value_parsed|
+--------------+
|[1, a:3, 2001]|
| [2, b, 2002]|
| [3, c, 2003]|
+--------------+
然后我们可以很容易地平坦df,把值放在列中:
df2 = df_csv.select("value_parsed.*").toDF("col1","col2","col3")
df2.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
| 1| a:3|2001|
| 2| b|2002|
| 3| c|2003|
+----+----+----+
没有休息。正确解析的数据。生活是美好的。喝杯啤酒。
我们可以用Column[i]
代替Column.getItem(i)
。
此外,enumerate
在大数据框架中也很有用。
from pyspark.sql import functions as F
保持父列:
for i, c in enumerate(['new_1', 'new_2']): df = df.withColumn(c, F.split('my_str_col', '-')[i])
或
new_cols = ['new_1', 'new_2'] df = df.select('*', *[F.split('my_str_col', '-')[i].alias(c) for i, c in enumerate(new_cols)])
替换父列:
for i, c in enumerate(['new_1', 'new_2']): df = df.withColumn(c, F.split('my_str_col', '-')[i]) df = df.drop('my_str_col')
或
new_cols = ['new_1', 'new_2'] df = df.select( *[c for c in df.columns if c != 'my_str_col'], *[F.split('my_str_col', '-')[i].alias(c) for i, c in enumerate(new_cols)] )