当udf函数不能接受足够大的输入变量时,使用Spark dataframe



我正在准备一个带有id和我的特征向量的DataFrame,以便稍后用于进行预测。我在我的数据框架上做了一个groupBy,在我的groupBy中,我将几个列合并为一个新列:

def mergeFunction(...) // with 14 input variables
val myudffunction( mergeFunction ) // Spark doesn't support this
df.groupBy("id").agg(
   collect_list(df(...)) as ...
   ... // too many of these (something like 14 of them)
).withColumn("features_labels",
  myudffunction(
     col(...)
     , col(...) )
.select("id", "feature_labels")

这就是我如何创建我的特征向量和它们的标签。到目前为止,它一直在为我工作,但这是第一次使用这种方法的特征向量大于10,这是Spark中udf函数接受的最大值。

我不知道我还能怎么解决这个问题?udf输入的大小是火花会变大,是我理解错了吗,还是有更好的方法吗?

用户自定义函数最多可定义22个参数。只有udf helper被定义为最多10个参数。要处理具有大量参数的函数,可以使用org.apache.spark.sql.UDFRegistration

例如

val dummy = ((
  x0: Int, x1: Int, x2: Int, x3: Int, x4: Int, x5: Int, x6: Int, x7: Int, 
  x8: Int, x9: Int, x10: Int, x11: Int, x12: Int, x13: Int, x14: Int, 
  x15: Int, x16: Int, x17: Int, x18: Int, x19: Int, x20: Int, x21: Int) => 1)

可以注册:

import org.apache.spark.sql.expressions.UserDefinedFunction
val dummyUdf: UserDefinedFunction = spark.udf.register("dummy", dummy)

和直接使用

val df = spark.range(1)
val exprs =  (0 to 21).map(_ => lit(1))
df.select(dummyUdf(exprs: _*))

或通过callUdf命名

import org.apache.spark.sql.functions.callUDF
df.select(
  callUDF("dummy", exprs:  _*).alias("dummy")
)

或SQL表达式:

df.selectExpr(s"""dummy(${Seq.fill(22)(1).mkString(",")})""")

您也可以创建一个UserDefinedFunction对象:

import org.apache.spark.sql.expressions.UserDefinedFunction
Seq(1).toDF.select(UserDefinedFunction(dummy, IntegerType, None)(exprs: _*))

在实践中,有22个参数的函数并不是很有用,除非你想使用Scala反射来生成这些参数,否则维护起来会很麻烦。

我会考虑使用集合(array, map)或struct作为输入或将其分为多个模块。例如:

val aLongArray = array((0 to 256).map(_ => lit(1)): _*)
val udfWitharray = udf((xs: Seq[Int]) => 1)
Seq(1).toDF.select(udfWitharray(aLongArray).alias("dummy"))

只是在0的答案上展开,可以让.withColumn()函数处理具有超过10个参数的UDF。只需要spark.udf.register()函数,然后使用expr作为添加列的参数(而不是udf)。

例如,应该这样做:

def mergeFunction(...) // with 14 input variables
spark.udf.register("mergeFunction", mergeFunction) // make available in expressions
df.groupBy("id").agg(
   collect_list(df(...)) as ...
   ... // too many of these (something like 14 of them)
).withColumn("features_labels",
  expr("mergeFunction(col1, col2, col3, col4, ...)") ) //pass in the 14 column names
.select("id", "feature_labels")

底层表达式解析器似乎处理超过10个参数,所以我不认为你必须通过传递数组来调用函数。此外,如果它们的参数恰好是不同的数据类型,数组将不能很好地工作。

相关内容

  • 没有找到相关文章

最新更新