PySpark 数据帧:查找最接近整数列值的数组列索引



我有一个PySpark DataFrame,它具有数组列类型和整数列类型。我想要算出整数列的值最接近哪个数组位置。见下文:

df = spark.createDataFrame(
[
(1, [5, 20, 100, 250], 2),  
(2, [16, 53, 120, 180], 168),
(3, [100, 200, 1000, 2500], 3500),
],
["id", "array_col", "int_col"]  
)

我想创建一个新的列,用来查看array_col中int_col的值最接近的数组索引,生成一个新的df,如下所示:

| ID      | array_col              | int_col | closest_index |
| 1       | [5, 20, 100, 250]      | 2       | 0             |
| 2       | [16, 53, 120, 180]     | 168     | 3             |
| 3       | [100, 200, 1000, 2500] | 3501    | 3             |

我试过这样做:

def find_nearest(value):
res = bin_array[np.newaxis, :] - value.values[:, np.newaxis]
ret_vals = [bin_array[np.argmin(np.abs(i))] for i in res]
return pd.Series(ret_vals)

然后从那里,使用array_position函数来定位索引位置,但在DataFrame上没有运气。任何帮助将非常感激!

您可以定义一个查找最近索引的UDF,然后对每一行使用它。

下面是一个例子:

from pyspark.sql import SparkSession
import pyspark.sql.functions as F

def find_nearest_index(array, value):
return min(range(len(array)), key=lambda i: abs(array[i] - value))

if __name__ == "__main__":
spark = SparkSession.builder.master("local").appName("Test").getOrCreate()
df = spark.createDataFrame(
[
(1, [5, 20, 100, 250], 2),
(2, [16, 53, 120, 180], 168),
(3, [100, 200, 1000, 2500], 3500),
],
["id", "array_col", "int_col"],
)
nearest_index_udf = F.udf(lambda x, y: find_nearest_index(x, y))
df = df.withColumn(
"Nearest Index", nearest_index_udf(F.col("array_col"), F.col("int_col"))
)

给了结果:

+---+--------------------+-------+-------------+
| id|           array_col|int_col|Nearest Index|
+---+--------------------+-------+-------------+
|  1|   [5, 20, 100, 250]|      2|            0|
|  2|  [16, 53, 120, 180]|    168|            3|
|  3|[100, 200, 1000, ...|   3500|            3|
+---+--------------------+-------+-------------+

您可以使用posexplode()函数来扩展数组,该数组还添加了位置列,然后识别最接近的值并检索相关行,如下面的代码所示:

import pyspark.sql.functions as F
# explode original df and add "delta" column with distance of position from "int_col" 
df1 = (df
.select("*", F.posexplode("array_col"))
.withColumn("delta", F.abs(F.col("col")-F.col("int_col"))))
# groupby to find minimum delta rows for each id
df2 = (df1
.groupBy("id")
.agg(F.min("delta").alias("delta")))

# join df1 and df2 to retrieve rows with minimal "delta" from df1        
df_output = (df1
.join(df2, ["id", "delta"])
.select("id", "array_col", "int_col", F.col("pos").alias("closest_index")))

最新更新