我有一个PySpark DataFrame,它具有数组列类型和整数列类型。我想要算出整数列的值最接近哪个数组位置。见下文:
df = spark.createDataFrame(
[
(1, [5, 20, 100, 250], 2),
(2, [16, 53, 120, 180], 168),
(3, [100, 200, 1000, 2500], 3500),
],
["id", "array_col", "int_col"]
)
我想创建一个新的列,用来查看array_col中int_col的值最接近的数组索引,生成一个新的df,如下所示:
| ID | array_col | int_col | closest_index |
| 1 | [5, 20, 100, 250] | 2 | 0 |
| 2 | [16, 53, 120, 180] | 168 | 3 |
| 3 | [100, 200, 1000, 2500] | 3501 | 3 |
我试过这样做:
def find_nearest(value):
res = bin_array[np.newaxis, :] - value.values[:, np.newaxis]
ret_vals = [bin_array[np.argmin(np.abs(i))] for i in res]
return pd.Series(ret_vals)
然后从那里,使用array_position函数来定位索引位置,但在DataFrame上没有运气。任何帮助将非常感激!
您可以定义一个查找最近索引的UDF,然后对每一行使用它。
下面是一个例子:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
def find_nearest_index(array, value):
return min(range(len(array)), key=lambda i: abs(array[i] - value))
if __name__ == "__main__":
spark = SparkSession.builder.master("local").appName("Test").getOrCreate()
df = spark.createDataFrame(
[
(1, [5, 20, 100, 250], 2),
(2, [16, 53, 120, 180], 168),
(3, [100, 200, 1000, 2500], 3500),
],
["id", "array_col", "int_col"],
)
nearest_index_udf = F.udf(lambda x, y: find_nearest_index(x, y))
df = df.withColumn(
"Nearest Index", nearest_index_udf(F.col("array_col"), F.col("int_col"))
)
给了结果:
+---+--------------------+-------+-------------+
| id| array_col|int_col|Nearest Index|
+---+--------------------+-------+-------------+
| 1| [5, 20, 100, 250]| 2| 0|
| 2| [16, 53, 120, 180]| 168| 3|
| 3|[100, 200, 1000, ...| 3500| 3|
+---+--------------------+-------+-------------+
您可以使用posexplode()函数来扩展数组,该数组还添加了位置列,然后识别最接近的值并检索相关行,如下面的代码所示:
import pyspark.sql.functions as F
# explode original df and add "delta" column with distance of position from "int_col"
df1 = (df
.select("*", F.posexplode("array_col"))
.withColumn("delta", F.abs(F.col("col")-F.col("int_col"))))
# groupby to find minimum delta rows for each id
df2 = (df1
.groupBy("id")
.agg(F.min("delta").alias("delta")))
# join df1 and df2 to retrieve rows with minimal "delta" from df1
df_output = (df1
.join(df2, ["id", "delta"])
.select("id", "array_col", "int_col", F.col("pos").alias("closest_index")))