假设我有以下两个数据帧:
DF1:
+----------+----------+----------+
| Place|Population| IndexA|
+----------+----------+----------+
| A| Int| X_A|
| B| Int| X_B|
| C| Int| X_C|
+----------+----------+----------+
DF2:
+----------+----------+
| City| IndexB|
+----------+----------+
| D| X_D|
| E| X_E|
| F| X_F|
| ....| ....|
| ZZ| X_ZZ|
+----------+----------+
以上数据帧的大小通常要大得多。
我想确定哪个City
(DF2
)距离DF1
的每个Place
最近。可以根据该指数计算出距离。因此,对于DF1
中的每一行,我必须遍历DF2
中的每一行,并根据索引的计算寻找最短距离。对于距离计算,定义了一个函数:
val distance = udf(
(indexA: Long, indexB: Long) => {
h3.instance.h3Distance(indexA, indexB)
})
我试了如下:
val output = DF1.agg(functions.min(distance(col("IndexA"), DF2.col("IndexB"))))
但是这个,代码编译但是我得到以下错误:
线程"main"异常org.apache.spark.sql.AnalysisException: Resolved attribute(s)
H3Index#220L从Places#316,Population#330, indexax# 338L中丢失。Aggregate
[min(if (isnull(indexax# 338L) OR isnull(indexb# 220L))))) null elseUDF(knownnotnull(IndexA#338L), knownnotnull(IndexB#220L))) AS min(UDF(IndexA, IndexB))#346].
所以我想我在DF2
中的每一行迭代时做了一些错误,当从DF1
中取一行时,但我找不到解决方案。
我做错了什么?我的方向对吗?
您得到这个错误是因为您正在使用的索引列只存在于DF2
中,而不存在于您试图执行聚合的DF1
中。
为了使该字段可访问并确定到所有点的距离,您需要
- 将
DF1
和Df2
交叉连接,使Df1
的每一个指标都匹配DF2
的每一个指标 - 使用您的udf 确定距离
- 找到这个新交叉连接udf的最小值与距离
这可能像:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, min, udf}
val distance = udf(
(indexA: Long, indexB: Long) => {
h3.instance.h3Distance(indexA, indexB)
})
val resultDF = DF1.crossJoin(DF2)
.withColumn("distance", distance(col("IndexA"), col("IndexB")))
//instead of using a groupby then matching the min distance of the aggregation with the initial df. I've chosen to use a window function min to determine the min_distance of each group (determined by Place) and filter by the city with the min distance to each place
.withColumn("min_distance", min("distance").over(Window.partitionBy("Place")))
.where(col("distance") === col("min_distance"))
.drop("min_distance")
这将产生一个数据框,其中包含来自数据框和distance
的列。
NB。你目前的方法是比较一个df中的每个项目和另一个df中的每个项目,这是一个昂贵的操作。如果您有机会提前过滤(例如加入启发式列,即可能表明某个地方可能更接近某个城市的其他列),建议这样做。
让我知道这是否适合你。
如果您只有几个城市(少于或大约1000),您可以通过在数组中收集城市,然后使用这个收集的数组为每个地方执行距离计算来避免crossJoin
和Window
洗牌:
import org.apache.spark.sql.functions.{array_min, col, struct, transform, typedLit, udf}
val citiesIndexes = df2.select("City", "IndexB")
.collect()
.map(row => (row.getString(0), row.getLong(1)))
val result = df1.withColumn(
"City",
array_min(
transform(
typedLit(citiesIndexes),
x => struct(distance(col("IndexA"), x.getItem("_2")), x.getItem("_1"))
)
).getItem("col2")
)
这段代码适用于Spark 3及更高版本。如果您使用的Spark版本小于3.0,请使用自定义函数替换array_min(...).getItem("col2")
部分。