spark dataframe-将非Column类型的变量传递给udf



我正在尝试修改"特征";矢量列通过擦除一些特征(存储在feature_idx_to_wipe中(。伪代码如下,问题是udf不带Set。我想知道如何解决这个问题,或者是否有更好的方法。

//data
val feature_idx_to_wipe = Set(1, 2)
val dfA = spark.createDataFrame(Seq(
(0, Vectors.sparse(6, Seq((0, 1.0), (1, 1.0), (2, 1.0)))),
(1, Vectors.sparse(6, Seq((2, 1.0), (3, 1.0), (4, 1.0)))),
(2, Vectors.sparse(6, Seq((0, 1.0), (2, 1.0), (4, 1.0))))
)).toDF("id", "features")
dfA.show(false)
+---+-------------------------+
|id |features                 |
+---+-------------------------+
|0  |(6,[0,1,2],[1.0,1.0,1.0])|
|1  |(6,[2,3,4],[1.0,1.0,1.0])|
|2  |(6,[0,2,4],[1.0,1.0,1.0])|
+---+-------------------------+
//udf 
def wipe(v: NewSparseVector, idx2clean:Set[Int]) : NewSparseVector = {
val lb:ListBuffer[(Int, Double)]=ListBuffer()
v.foreachActive {
case (i, v) =>
if(!idx2clean.contains(i)){
lb += ((i, v))
}
}
NewVectors.sparse(v.size, lb.toSeq).toSparse 
}
val udf_wipe = udf((x: NewSparseVector, idx2clean:Set[Int]) => wipe(x, idx2clean))
//apply udf
dfA.withColumn("features_wiped", udf_wipe(col("features"), feature_idx_to_wipe))
// error: 
// scala> dfA.withColumn("nf", udf_wipe(col("features"), tc))
// <console>:98: error: type mismatch;
//  found   : scala.collection.immutable.Set[Int]
//  required: org.apache.spark.sql.Column
//        dfA.withColumn("nf", udf_wipe(col("features"), tc))
//target (a new column of vector added, with features at index 1,2 are removed)
dfA.select("id","features_wiped").show(false)
+---+-------------------------+
|id |features_wiped           |
+---+-------------------------+
|0  |(6,[0],[1.0])            |
|1  |(6,[3,4],[1.0,1.0])      |
|2  |(6,[0,4],[1.0,1.0])      |
+---+-------------------------+

函数擦除可以通过如下方式转换为curried函数:

def wipe(v: NewSparseVector)(idx2clean:Set[Int]) : NewSparseVector

为相应的函数创建udf:

val udf_wipe = udf((x: NewSparseVector) => wipe(x)(feature_idx_to_wipe))

最后将udf应用于数据帧:

dfA.withColumn("features_wiped", udf_wipe(col("features")))

另一种选择-

测试数据

//data
val dfA = spark.createDataFrame(Seq(
(0, Vectors.sparse(6, Seq((0, 1.0), (1, 1.0), (2, 1.0)))),
(1, Vectors.sparse(6, Seq((2, 1.0), (3, 1.0), (4, 1.0)))),
(2, Vectors.sparse(6, Seq((0, 1.0), (2, 1.0), (4, 1.0))))
)).toDF("id", "features")
dfA.show(false)
/**
* +---+-------------------------+
* |id |features                 |
* +---+-------------------------+
* |0  |(6,[0,1,2],[1.0,1.0,1.0])|
* |1  |(6,[2,3,4],[1.0,1.0,1.0])|
* |2  |(6,[0,2,4],[1.0,1.0,1.0])|
* +---+-------------------------+
*/

备选方案-1使用lit如下-

// Alternative-1
//udf
val feature_idx_to_wipe = Array(1, 2)
import org.apache.spark.ml.linalg.{SparseVector => NewSparseVector}
def wipe(v: NewSparseVector, idx2clean:Seq[Int]) : NewSparseVector = {
val lb:ListBuffer[(Int, Double)]=ListBuffer()
v.foreachActive {
case (i, v) =>
if(!idx2clean.contains(i)){
lb += ((i, v))
}
}
Vectors.sparse(v.size, lb.toSeq).toSparse
}
val udf_wipe = udf((x: NewSparseVector, idx2clean:Seq[Int]) => wipe(x, idx2clean))
//apply udf
val newDF = dfA.withColumn("features_wiped", udf_wipe(col("features"), lit(feature_idx_to_wipe)))
//target (a new column of vector added, with features at index 1,2 are removed)
newDF.select("id","features_wiped").show(false)
/**
* +---+-------------------+
* |id |features_wiped     |
* +---+-------------------+
* |0  |(6,[0],[1.0])      |
* |1  |(6,[3,4],[1.0,1.0])|
* |2  |(6,[0,4],[1.0,1.0])|
* +---+-------------------+
*/

备选方案-2使用广播变量sparkcontext.broadcast如下-

//    Alternative2
//data
val feature_idx_to_wipe1 = Set(1, 2)
val broabcastSet = spark.sparkContext.broadcast(feature_idx_to_wipe1)
//udf
import org.apache.spark.ml.linalg.{SparseVector => NewSparseVector}
def wipe1(v: NewSparseVector) : NewSparseVector = {
val idx2clean = broabcastSet.value
val lb:ListBuffer[(Int, Double)]=ListBuffer()
v.foreachActive {
case (i, v) =>
if(!idx2clean.contains(i)){
lb += ((i, v))
}
}
Vectors.sparse(v.size, lb.toSeq).toSparse
}
val udf_wipe1 = udf((x: NewSparseVector) => wipe1(x))
//apply udf
val newDF1 = dfA.withColumn("features_wiped", udf_wipe1(col("features")))
//target (a new column of vector added, with features at index 1,2 are removed)
newDF1.select("id","features_wiped").show(false)
/**
* +---+-------------------+
* |id |features_wiped     |
* +---+-------------------+
* |0  |(6,[0],[1.0])      |
* |1  |(6,[3,4],[1.0,1.0])|
* |2  |(6,[0,4],[1.0,1.0])|
* +---+-------------------+
*/

最新更新