从前面的行中累积数组(PySpark数据帧)



一个(Python)示例将使我的问题变得清晰。假设我有一个Spark数据帧,其中包含在特定日期观看某些电影的人,如下所示:

movierecord = spark.createDataFrame([("Alice", 1, ["Avatar"]),("Bob", 2, ["Fargo", "Tron"]),("Alice", 4, ["Babe"]), ("Alice", 6, ["Avatar", "Airplane"]), ("Alice", 7, ["Pulp Fiction"]), ("Bob", 9, ["Star Wars"])],["name","unixdate","movies"])

上面定义的模式和数据帧如下所示:

root
|-- name: string (nullable = true)
|-- unixdate: long (nullable = true)
|-- movies: array (nullable = true)
|    |-- element: string (containsNull = true)
+-----+--------+------------------+
|name |unixdate|movies            |
+-----+--------+------------------+
|Alice|1       |[Avatar]          |
|Bob  |2       |[Fargo, Tron]     |
|Alice|4       |[Babe]            |
|Alice|6       |[Avatar, Airplane]|
|Alice|7       |[Pulp Fiction]    |
|Bob  |9       |[Star Wars]       |
+-----+--------+------------------+

我想从上面开始生成一个新的数据帧列,该列包含每个用户看到的所有以前的电影,没有重复(每个unixdate字段的"以前")。所以它应该是这样的:

+-----+--------+------------------+------------------------+
|name |unixdate|movies            |previous_movies         |
+-----+--------+------------------+------------------------+
|Alice|1       |[Avatar]          |[]                      |
|Bob  |2       |[Fargo, Tron]     |[]                      |
|Alice|4       |[Babe]            |[Avatar]                |
|Alice|6       |[Avatar, Airplane]|[Avatar, Babe]          |
|Alice|7       |[Pulp Fiction]    |[Avatar, Babe, Airplane]|
|Bob  |9       |[Star Wars]       |[Fargo, Tron]           |
+-----+--------+------------------+------------------------+

我如何以一种非常有效的方式实现这一点?

SQL仅 而不保留对象的顺序

  • 所需进口:

    import pyspark.sql.functions as f
    from pyspark.sql.window import Window
    
  • 窗口定义:

    w = Window.partitionBy("name").orderBy("unixdate")
    
  • 完整解决方案:

    (movierecord
    # Flatten movies
    .withColumn("previous_movie", f.explode("movies"))
    # Collect unique
    .withColumn("previous_movies", f.collect_set("previous_movie").over(w))
    # Drop duplicates for a single unixdate
    .groupBy("name", "unixdate")
    .agg(f.max(f.struct(
    f.size("previous_movies"),
    f.col("movies").alias("movies"),
    f.col("previous_movies").alias("previous_movies")
    )).alias("tmp"))
    # Shift by one and extract
    .select(
    "name", "unixdate", "tmp.movies", 
    f.lag("tmp.previous_movies", 1).over(w).alias("previous_movies")))
    
  • 结果:

    +-----+--------+------------------+------------------------+
    |name |unixdate|movies            |previous_movies         |
    +-----+--------+------------------+------------------------+
    |Bob  |2       |[Fargo, Tron]     |null                    |
    |Bob  |9       |[Star Wars]       |[Fargo, Tron]           |
    |Alice|1       |[Avatar]          |null                    |
    |Alice|4       |[Babe]            |[Avatar]                |
    |Alice|6       |[Avatar, Airplane]|[Babe, Avatar]          |
    |Alice|7       |[Pulp Fiction]    |[Babe, Airplane, Avatar]|
    +-----+--------+------------------+------------------------+
    

SQL一个Python UDF保留顺序:

  • 进口:

    import pyspark.sql.functions as f
    from pyspark.sql.window import Window
    from pyspark.sql import Column
    from pyspark.sql.types import ArrayType, StringType
    from typing import List, Union
    # https://github.com/pytoolz/toolz
    from toolz import unique, concat, compose
    
  • UDF:

    def flatten_distinct(col: Union[Column, str]) -> Column:
    def flatten_distinct_(xss: Union[List[List[str]], None]) -> List[str]:
    return compose(list, unique, concat)(xss or [])
    return f.udf(flatten_distinct_, ArrayType(StringType()))(col)
    
  • 窗口定义与以前一样。

  • 完整解决方案:

    (movierecord
    # Collect lists
    .withColumn("previous_movies", f.collect_list("movies").over(w))
    # Flatten and drop duplicates
    .withColumn("previous_movies", flatten_distinct("previous_movies"))
    # Shift by one
    .withColumn("previous_movies", f.lag("previous_movies", 1).over(w))
    # For presentation only
    .orderBy("unixdate")) 
    
  • 结果:

    +-----+--------+------------------+------------------------+
    |name |unixdate|movies            |previous_movies         |
    +-----+--------+------------------+------------------------+
    |Alice|1       |[Avatar]          |null                    |
    |Bob  |2       |[Fargo, Tron]     |null                    |
    |Alice|4       |[Babe]            |[Avatar]                |
    |Alice|6       |[Avatar, Airplane]|[Avatar, Babe]          |
    |Alice|7       |[Pulp Fiction]    |[Avatar, Babe, Airplane]|
    |Bob  |9       |[Star Wars]       |[Fargo, Tron]           |
    +-----+--------+------------------+------------------------+
    

性能

我认为,在限制条件下,没有有效的方法来解决这个问题。请求的输出不仅需要大量的数据复制(数据是二进制编码的,以适应Tungsten格式,因此您可以获得可能的压缩,但对象标识松散),而且在Spark计算模型下,还需要大量昂贵的操作,包括昂贵的分组和排序。

previous_movies的期望大小是有界的并且很小,但在一般情况下是不可行的,那个么这应该是可以的。

通过为用户保留单一的惰性历史记录,可以很容易地解决数据重复问题。这在SQL中无法完成,但在低级别的RDD操作中非常容易。

爆炸和collect_模式是昂贵的。如果你的要求很严格,但你想提高性能,你可以用Scala UDF代替Python UDF。

相关内容

  • 没有找到相关文章

最新更新