我有一个看起来像这样的火花数据帧。
id cd1 version1 dt1 cd2 version2 dt2 cd3 version3 dt3
1 100 1 20100101 101 1 20100101 102 20100301
1 101 1 20100102 102 20100201 100 1 20100302
2 201 1 20100103 100 1 20100301 100 1 20100303
2 202 2 20100104 100 1 20100105
我需要将所有代码转置为具有以下条件的单个列
- 如果对应的版本码为 1,则在第一位数字后添加小数点
- 每个患者应具有不同的代码
对于上面的示例,输出应如下所示。
id code dt
1 1.00 20100101
1 1.01 20100101
1 102 20100301
1 1.01 20100102
1 102 20100201
1 10.0 20100302
2 2.01 20100103
2 1.00 20100301
2 1.00 20100303
2 202 20100104
2 10.0 20100105
我正在使用Pyspark来做到这一点。在上面的示例中,我只显示了 3 个代码及其相应的版本列,但我有 30 个这样的列。此外,此数据大约有 2500 万行。
关于如何实现这一目标的任何想法都将非常有帮助。
您可以explode
这些列的列表,以便每行只有一个(cd, version)
对首先,让我们创建数据帧:
df = sc.parallelize([[1,100,1,101,1,102,None],[1,101,1,102,None,100,1],[2,201,1,100,1,100,1],
[2,202,2,100,1,None,None]]).toDF(["id","cd1","version1","cd2","version2","cd3","version3"])
使用
posexplode
:import pyspark.sql.functions as psf from itertools import chain nb_versions = 4 df = df.na.fill(-1).select( "id", psf.posexplode(psf.create_map(list(chain(*[(psf.col("cd" + str(i)), psf.col("version"+str(i))) for i in range(1, nb_versions)])))).alias("pos", "cd", "version") ).drop("pos").filter("cd != -1") +---+---+-------+ | id| cd|version| +---+---+-------+ | 1|100| 1| | 1|101| 1| | 1|102| -1| | 1|101| 1| | 1|102| -1| | 1|100| 1| | 2|201| 1| | 2|100| 1| | 2|100| 1| | 2|202| 2| | 2|100| 1| +---+---+-------+
使用
explode
:nb_versions = 4 df = df.select( "id", psf.explode(psf.array( [psf.struct( psf.col("cd" + str(i)).alias("cd"), psf.col("version" + str(i)).alias("version")) for i in range(1, nb_versions)])).alias("temp")) .select("id", "temp.*") +---+----+-------+ | id| cd|version| +---+----+-------+ | 1| 100| 1| | 1| 101| 1| | 1| 102| null| | 1| 101| 1| | 1| 102| null| | 1| 100| 1| | 2| 201| 1| | 2| 100| 1| | 2| 100| 1| | 2| 202| 2| | 2| 100| 1| | 2|null| null| +---+----+-------+
现在我们可以实施您的条件
- 除以 100 表示版本==1
- 非重复值
我们将使用when, otherwise
条件和distinct
的函数:
df.withColumn("cd", psf.when(df.version == 1, df.cd/100).otherwise(df.cd))
.distinct().drop("version")
+---+-----+
| id| cd|
+---+-----+
| 1| 1.0|
| 1| 1.01|
| 1|102.0|
| 2| 1.0|
| 2| 2.01|
| 2|202.0|
+---+-----+
我就是这样做的。我相信有更好的方法可以做到这一点。
def process_code(raw_data):
for i in range(1,4):
cd_col_name = "cd" + str(i)
version_col_name = "version" + str(i)
raw_data = raw_data.withColumn("mod_cd" + str(i), when(raw_data[version_col_name] == 1, concat(substring(raw_data[cd_col_name],1,1),lit("."),substring(raw_data[cd_col_name],2,20))).otherwise(raw_data[cd_col_name]))
mod_cols = [col for col in raw_data.columns if 'mod_cd' in col]
nb_versions = 3
new = raw_data.fillna('9999', subset=mod_cols).select("id", psf.posexplode(psf.create_map(list(chain(*[(psf.col("mod_cd" + str(i)), psf.col("dt"+str(i))) for i in range(1, nb_versions)])))).alias("pos", "final_cd", "final_date")).drop("pos")
return new
test = process_code(df)
test = test.filter(test.final_cd != '9999')
test.show(100, False)