通过在另一个数据框中的所有行上迭代一个数据框中的一行来检查最小值



假设我有以下两个数据帧:

DF1:
+----------+----------+----------+
|     Place|Population|    IndexA|     
+----------+----------+----------+
|         A|       Int|       X_A|
|         B|       Int|       X_B|
|         C|       Int|       X_C|
+----------+----------+----------+
DF2:
+----------+----------+
|      City|    IndexB|     
+----------+----------+
|         D|       X_D|      
|         E|       X_E|   
|         F|       X_F|  
|      ....|      ....|
|        ZZ|      X_ZZ|
+----------+----------+

以上数据帧的大小通常要大得多。

我想确定哪个City(DF2)距离DF1的每个Place最近。可以根据该指数计算出距离。因此,对于DF1中的每一行,我必须遍历DF2中的每一行,并根据索引的计算寻找最短距离。对于距离计算,定义了一个函数:

val distance = udf(
(indexA: Long, indexB: Long) => {
h3.instance.h3Distance(indexA, indexB)
})

我试了如下:

val output =  DF1.agg(functions.min(distance(col("IndexA"), DF2.col("IndexB"))))

但是这个,代码编译但是我得到以下错误:

线程"main"异常org.apache.spark.sql.AnalysisException: Resolved attribute(s)
H3Index#220L从Places#316,Population#330, indexax# 338L中丢失。Aggregate
[min(if (isnull(indexax# 338L) OR isnull(indexb# 220L))))) null elseUDF(knownnotnull(IndexA#338L), knownnotnull(IndexB#220L))) AS min(UDF(IndexA, IndexB))#346].

所以我想我在DF2中的每一行迭代时做了一些错误,当从DF1中取一行时,但我找不到解决方案。

我做错了什么?我的方向对吗?

您得到这个错误是因为您正在使用的索引列只存在于DF2中,而不存在于您试图执行聚合的DF1中。

为了使该字段可访问并确定到所有点的距离,您需要

  1. DF1Df2交叉连接,使Df1的每一个指标都匹配DF2的每一个指标
  2. 使用您的udf
  3. 确定距离
  4. 找到这个新交叉连接udf的最小值与距离

这可能像:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, min, udf}
val distance = udf(
(indexA: Long, indexB: Long) => {
h3.instance.h3Distance(indexA, indexB)
})
val resultDF = DF1.crossJoin(DF2)
.withColumn("distance", distance(col("IndexA"), col("IndexB")))
//instead of using a groupby then matching the min distance of the aggregation with the initial df. I've chosen to use a window function min to determine the min_distance of each group (determined by Place) and filter by the city with the min distance to each place
.withColumn("min_distance", min("distance").over(Window.partitionBy("Place")))
.where(col("distance") === col("min_distance"))
.drop("min_distance")

这将产生一个数据框,其中包含来自数据框和distance的列。

NB。你目前的方法是比较一个df中的每个项目和另一个df中的每个项目,这是一个昂贵的操作。如果您有机会提前过滤(例如加入启发式列,即可能表明某个地方可能更接近某个城市的其他列),建议这样做。

让我知道这是否适合你。

如果您只有几个城市(少于或大约1000),您可以通过在数组中收集城市,然后使用这个收集的数组为每个地方执行距离计算来避免crossJoinWindow洗牌:

import org.apache.spark.sql.functions.{array_min, col, struct, transform, typedLit, udf}
val citiesIndexes = df2.select("City", "IndexB")
.collect()
.map(row => (row.getString(0), row.getLong(1)))
val result = df1.withColumn(
"City",
array_min(
transform(
typedLit(citiesIndexes),
x => struct(distance(col("IndexA"), x.getItem("_2")), x.getItem("_1"))
)
).getItem("col2")
)

这段代码适用于Spark 3及更高版本。如果您使用的Spark版本小于3.0,请使用自定义函数替换array_min(...).getItem("col2")部分。

最新更新