为什么数据类型在Scala中调用UDF时会更改



我有一个df:

joined.printSchema
root
 |-- cc_num: long (nullable = true)
 |-- lat: double (nullable = true)
 |-- long: double (nullable = true)
 |-- merch_lat: double (nullable = true)
 |-- merch_long: double (nullable = true)

我有一个UDF:

def getDistance (lat1:Double, lon1:Double, lat2:Double, lon2:Double) = {
    val r : Int = 6371 //Earth radius
    val latDistance : Double = Math.toRadians(lat2 - lat1)
    val lonDistance : Double = Math.toRadians(lon2 - lon1)
    val a : Double = Math.sin(latDistance / 2) * Math.sin(latDistance / 2) + Math.cos(Math.toRadians(lat1)) * Math.cos(Math.toRadians(lat2)) * Math.sin(lonDistance / 2) * Math.sin(lonDistance / 2)
    val c : Double = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a))
    val distance : Double = r * c
    distance
  }

我需要使用:

为DF生成新列
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))

我在下面收到错误:

Name: Unknown Error
Message: <console>:35: error: type mismatch;
 found   : String("lat")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                          ^
<console>:35: error: type mismatch;
 found   : String("long")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                                 ^
<console>:35: error: type mismatch;
 found   : String("merch_lat")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                                         ^
<console>:35: error: type mismatch;
 found   : String("merch_long")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                                                      ^

您可以从架构中看到,所有涉及的字段均为double的类型,它符合UDF的参数类型定义,为什么我会看到数据类型不匹配错误?

任何人都可以在这里启发什么问题以及如何修复它?

非常感谢。

您的 getDistance方法不是UDF,它是一个期望4 Double参数的scala方法,而您正在通过4个字符串。

要解决此问题,您需要:

  • "用UDF包装"您的方法,
  • Pass 参数,而不是使用UDF时的字符串,您可以通过将列名与$
  • 前缀前缀来执行此操作。
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import spark.implicits._ // assuming "spark" is your SparkSession
val distanceUdf: UserDefinedFunction = udf(getDistance _)
joined.withColumn("distance", distanceUdf($"lat", $"long", $"merch_lat", $"merch_long"))

相关内容

  • 没有找到相关文章

最新更新