使用条件转置 Pyspark 数据帧中的多个列



我有一个看起来像这样的火花数据帧。

id cd1 version1   dt1   cd2 version2  dt2      cd3 version3    dt3
1  100    1    20100101 101    1     20100101  102            20100301        
1  101    1    20100102 102          20100201  100    1       20100302
2  201    1    20100103 100    1     20100301  100    1       20100303
2  202    2    20100104 100    1     20100105

我需要将所有代码转置为具有以下条件的单个列

  • 如果对应的版本码为 1,则在第一位数字后添加小数点
  • 每个患者应具有不同的代码

对于上面的示例,输出应如下所示。

id    code     dt
1     1.00   20100101
1     1.01   20100101
1     102    20100301
1     1.01   20100102
1     102    20100201
1     10.0   20100302
2     2.01   20100103
2     1.00   20100301
2     1.00   20100303
2     202    20100104
2     10.0   20100105

我正在使用Pyspark来做到这一点。在上面的示例中,我只显示了 3 个代码及其相应的版本列,但我有 30 个这样的列。此外,此数据大约有 2500 万行。

关于如何实现这一目标的任何想法都将非常有帮助。

您可以explode这些列的列表,以便每行只有一个(cd, version)对首先,让我们创建数据帧:

df = sc.parallelize([[1,100,1,101,1,102,None],[1,101,1,102,None,100,1],[2,201,1,100,1,100,1],
                               [2,202,2,100,1,None,None]]).toDF(["id","cd1","version1","cd2","version2","cd3","version3"])
  1. 使用posexplode

    import pyspark.sql.functions as psf
    from itertools import chain
    nb_versions = 4
    df = df.na.fill(-1).select(
        "id", 
        psf.posexplode(psf.create_map(list(chain(*[(psf.col("cd" + str(i)), psf.col("version"+str(i))) for i in range(1, nb_versions)])))).alias("pos", "cd", "version")
    ).drop("pos").filter("cd != -1")
        +---+---+-------+
        | id| cd|version|
        +---+---+-------+
        |  1|100|      1|
        |  1|101|      1|
        |  1|102|     -1|
        |  1|101|      1|
        |  1|102|     -1|
        |  1|100|      1|
        |  2|201|      1|
        |  2|100|      1|
        |  2|100|      1|
        |  2|202|      2|
        |  2|100|      1|
        +---+---+-------+
    
  2. 使用explode

    nb_versions = 4
    df = df.select(
        "id", 
        psf.explode(psf.array(
            [psf.struct(
                psf.col("cd" + str(i)).alias("cd"), 
                psf.col("version" + str(i)).alias("version")) for i in range(1, nb_versions)])).alias("temp"))
        .select("id", "temp.*")
        +---+----+-------+
        | id|  cd|version|
        +---+----+-------+
        |  1| 100|      1|
        |  1| 101|      1|
        |  1| 102|   null|
        |  1| 101|      1|
        |  1| 102|   null|
        |  1| 100|      1|
        |  2| 201|      1|
        |  2| 100|      1|
        |  2| 100|      1|
        |  2| 202|      2|
        |  2| 100|      1|
        |  2|null|   null|
        +---+----+-------+
    

现在我们可以实施您的条件

  • 除以 100 表示版本==1
  • 非重复值

我们将使用when, otherwise条件和distinct的函数:

df.withColumn("cd", psf.when(df.version == 1, df.cd/100).otherwise(df.cd))
    .distinct().drop("version")
    +---+-----+
    | id|   cd|
    +---+-----+
    |  1|  1.0|
    |  1| 1.01|
    |  1|102.0|
    |  2|  1.0|
    |  2| 2.01|
    |  2|202.0|
    +---+-----+

我就是这样做的。我相信有更好的方法可以做到这一点。

def process_code(raw_data):
    for i in range(1,4):
        cd_col_name = "cd" + str(i)
        version_col_name = "version" + str(i)
        raw_data = raw_data.withColumn("mod_cd" + str(i), when(raw_data[version_col_name] == 1, concat(substring(raw_data[cd_col_name],1,1),lit("."),substring(raw_data[cd_col_name],2,20))).otherwise(raw_data[cd_col_name]))
    mod_cols = [col for col in raw_data.columns if 'mod_cd' in col]
    nb_versions = 3
    new = raw_data.fillna('9999', subset=mod_cols).select("id", psf.posexplode(psf.create_map(list(chain(*[(psf.col("mod_cd" + str(i)), psf.col("dt"+str(i))) for i in range(1, nb_versions)])))).alias("pos", "final_cd", "final_date")).drop("pos")
    return new
test = process_code(df)
test = test.filter(test.final_cd != '9999')
test.show(100, False)

相关内容

  • 没有找到相关文章

最新更新