计算`Spark Dataframe的连续行上的编辑距离



我有一个数据帧,如下所示:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions._
import spark.implicits._
// some data...
val df = Seq(
(1, "AA", "BB", ("AA", "BB")),
(2, "AA", "BB", ("AA", "BB")),
(3, "AB", "BB", ("AB", "BB"))
).toDF("id","name", "surname", "array")
df.show()

我想计算连续行中"array"列之间的编辑距离。作为一个例子,我想计算第1列中的"array"实体("AA","BB"(和第2列中的‘array’实体("AA","BB"(之间的编辑距离。这是我正在使用的编辑距离功能:

def editDist2[A](a: Iterable[A], b: Iterable[A]): Int = {
val startRow = (0 to b.size).toList
a.foldLeft(startRow) { (prevRow, aElem) =>
(prevRow.zip(prevRow.tail).zip(b)).scanLeft(prevRow.head + 1) {
case (left, ((diag, up), bElem)) => {
val aGapScore = up + 1
val bGapScore = left + 1
val matchScore = diag + (if (aElem == bElem) 0 else 1)
List(aGapScore, bGapScore, matchScore).min
}
}
}.last
}

我知道我需要为这个函数创建一个UDF,但似乎做不到。如果我按原样使用这个函数,并使用Spark Windowing来获取前一行:

// creating window - ordered by ID
val window = Window.orderBy("id")
// using the window with lag function to compare to previous value in each column
df.withColumn("edit-d", editDist2(($"array"), lag("array", 1).over(window))).show()

我得到以下错误:

<console>:245: error: type mismatch;
found   : org.apache.spark.sql.ColumnName
required: Iterable[?]
df.withColumn("edit-d", editDist2(($"array"), lag("array", 1).over(window))).show()

我发现可以使用Spark自己的levenstein函数。此函数接收两个字符串进行比较,因此不能与数组一起使用。

// creating window - ordered by ID
val window = Window.orderBy("id")
// using the window with lag function to compare to previous value in each column
df.withColumn("edit-d", levenshtein(($"name"), lag("name", 1).over(window)) + levenshtein(($"surname"), lag("surname", 1).over(window))).show()

给出所需输出:

+---+----+-------+--------+------+
| id|name|surname|   array|edit-d|
+---+----+-------+--------+------+
|  1|  AA|     BB|[AA, BB]|  null|
|  2|  AA|     BB|[AA, BB]|     0|
|  3|  AB|     BB|[AB, BB]|     1|
+---+----+-------+--------+------+

最新更新