计算两个 Seq 列与第三列之间的相关性的正确方法



我有一个数据帧,其中每行有 3 列:

ID:Long, ratings1:Seq[Double], ratings2:Seq[Double]

对于每一行,我需要计算这些向量之间的相关性。

我想出了以下解决方案,它似乎效率低下(不像Jarrod Roberson提到的那样工作),因为我必须为每个Seq创建RDD:

val similarities = ratingPairs.map(row => {
      val ratings1 = sc.parallelize(row.getAs[Seq[Double]]("ratings1"))
      val ratings2 = sc.parallelize(row.getAs[Seq[Double]]("ratings2"))
      val corr:Double = Statistics.corr(ratings1, ratings2)
      Similarity(row.getAs[Long]("ID"), corr)
    })

有没有办法正确计算这种相关性?

假设你有一个数组的相关函数:

def correlation(arr1: Array[Double], arr2: Array[Double]): Double

(对于该函数的潜在实现,它完全独立于Spark,你可以问一个单独的问题或在线搜索,有一些足够接近的资源,例如这个实现)。

现在,剩下要做的就是用 UDF 包装这个函数并使用它:

import org.apache.spark.sql.functions._
import spark.implicits._
val corrUdf = udf {
  (arr1: Seq[Double], arr2: Seq[Double]) => correlation(arr1.toArray, arr2.toArray)
}
val result = df.select($"ID", corrUdf($"ratings1", $"ratings2") as "correlation")

相关内容

最新更新