我有一个df:
joined.printSchema
root
|-- cc_num: long (nullable = true)
|-- lat: double (nullable = true)
|-- long: double (nullable = true)
|-- merch_lat: double (nullable = true)
|-- merch_long: double (nullable = true)
我有一个UDF:
def getDistance (lat1:Double, lon1:Double, lat2:Double, lon2:Double) = {
val r : Int = 6371 //Earth radius
val latDistance : Double = Math.toRadians(lat2 - lat1)
val lonDistance : Double = Math.toRadians(lon2 - lon1)
val a : Double = Math.sin(latDistance / 2) * Math.sin(latDistance / 2) + Math.cos(Math.toRadians(lat1)) * Math.cos(Math.toRadians(lat2)) * Math.sin(lonDistance / 2) * Math.sin(lonDistance / 2)
val c : Double = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a))
val distance : Double = r * c
distance
}
我需要使用:
为DF生成新列joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
我在下面收到错误:
Name: Unknown Error
Message: <console>:35: error: type mismatch;
found : String("lat")
required: Double
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
^
<console>:35: error: type mismatch;
found : String("long")
required: Double
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
^
<console>:35: error: type mismatch;
found : String("merch_lat")
required: Double
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
^
<console>:35: error: type mismatch;
found : String("merch_long")
required: Double
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
^
您可以从架构中看到,所有涉及的字段均为double
的类型,它符合UDF的参数类型定义,为什么我会看到数据类型不匹配错误?
任何人都可以在这里启发什么问题以及如何修复它?
非常感谢。
您的 getDistance
方法不是UDF,它是一个期望4 Double
参数的scala方法,而您正在通过4个字符串。
要解决此问题,您需要:
- "用UDF包装"您的方法,
- Pass 列参数,而不是使用UDF时的字符串,您可以通过将列名与
$
前缀前缀来执行此操作。
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import spark.implicits._ // assuming "spark" is your SparkSession
val distanceUdf: UserDefinedFunction = udf(getDistance _)
joined.withColumn("distance", distanceUdf($"lat", $"long", $"merch_lat", $"merch_long"))