斯卡拉火花 UDF 过滤器结构数组



我有一个带有架构的数据帧

root
|-- x: Long (nullable = false)
|-- y: Long (nullable = false)
|-- features: array (nullable = true)
|    |-- element: struct (containsNull = true)
|    |    |-- name: string (nullable = true)
|    |    |-- score: double (nullable = true)

例如,我有数据

+--------------------+--------------------+------------------------------------------+
|                x   |              y     |       features                           |
+--------------------+--------------------+------------------------------------------+
|10                  |          9         |[["f1", 5.9], ["ft2", 6.0], ["ft3", 10.9]]|
|11                  |          0         |[["f4", 0.9], ["ft1", 4.0], ["ft2", 0.9] ]|
|20                  |          9         |[["f5", 5.9], ["ft2", 6.4], ["ft3", 1.9] ]|
|18                  |          8         |[["f1", 5.9], ["ft4", 8.1], ["ft2", 18.9]]|
+--------------------+--------------------+------------------------------------------+

我想过滤带有特定前缀的功能,比如"ft",所以最终我想要结果:

+--------------------+--------------------+-----------------------------+
|                x   |              y     |       features              |
+--------------------+--------------------+-----------------------------+
|10                  |          9         |[["ft2", 6.0], ["ft3", 10.9]]|
|11                  |          0         |[["ft1", 4.0], ["ft2", 0.9] ]|
|20                  |          9         |[["ft2", 6.4], ["ft3", 1.9] ]|
|18                  |          8         |[["ft4", 8.1], ["ft2", 18.9]]|
+--------------------+--------------------+-----------------------------+

我没有使用 Spark 2.4+,所以我无法使用这里提供的解决方案:Spark (Scala( 过滤器结构数组而不会爆炸

我尝试使用 UDF,但仍然不起作用。这是我的尝试。我定义了一个 UDF:

def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
features.filter{
x.getString(0).startsWith("ft")
}
)

但是如果我应用这个 UDF

df.withColumn("filtered", filterFeature($"features"))

我收到错误Schema for type org.apache.spark.sql.Row is not supported.我发现我无法从 UDF 返回Row。然后我试了

def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
features.filter{
x.getString(0).startsWith("ft")
}, (StringType, DoubleType)
)

然后我得到一个错误:

error: type mismatch;
found   : (org.apache.spark.sql.types.StringType.type, org.apache.spark.sql.types.DoubleType.type)
required: org.apache.spark.sql.types.DataType
}, (StringType, DoubleType)
^

我还尝试了一些答案建议的案例类:

case class FilteredFeature(featureName: String, featureScore: Double)
def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
features.filter{
x.getString(0).startsWith("ft")
}, FilteredFeature
)

但我得到了:

error: type mismatch;
found   : FilteredFeature.type
required: org.apache.spark.sql.types.DataType
}, FilteredFeature
^

我试过了:

case class FilteredFeature(featureName: String, featureScore: Double)
def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
features.filter{
x.getString(0).startsWith("ft")
}, Seq[FilteredFeature]
)

我得到了:

<console>:192: error: missing argument list for method apply in class GenericCompanion
Unapplied methods are only converted to functions when a function type is expected.
You can make this conversion explicit by writing `apply _` or `apply(_)` instead of `apply`.
}, Seq[FilteredFeature]
^

我试过了:

case class FilteredFeature(featureName: String, featureScore: Double)
def filterFeature: UserDefinedFunction = 
udf((features: Seq[Row]) =>
features.filter{
x.getString(0).startsWith("ft")
}, Seq[FilteredFeature](_)
)

我得到了:

<console>:201: error: type mismatch;
found   : Seq[FilteredFeature]
required: FilteredFeature
}, Seq[FilteredFeature](_)
^

在这种情况下我该怎么办?

您有两个选择:

a( 向 UDF 提供架构,这让您返回Seq[Row]

b( 将Seq[Row]转换为Tuple2或 case 类的Seq,则无需提供模式(但如果使用元组,结构字段名称会丢失!

对于您的情况,我更喜欢选项 a((适用于具有许多字段的结构(:

val schema = df.schema("features").dataType
val filterFeature = udf((features:Seq[Row]) => features.filter(_.getAs[String]("name").startsWith("ft")),schema)

试试这个:

def filterFeature: UserDefinedFunction =
udf((features: Row) => {
features.getAs[Array[Array[Any]]]("features").filter(in => in(0).asInstanceOf[String].startsWith("ft"))
})

如果您不使用Spark 2.4,那么这应该适用于您的情况

case class FilteredFeature(featureName: String, featureScore: Double)
import org.apache.spark.sql.functions._  
def filterFeature: UserDefinedFunction = udf((feature: Seq[Row]) => {
feature.filter(x => {
x.getString(0).startsWith("ft")
}).map(r => FilteredFeature(r.getString(0), r.getDouble(1)))
})
df.select($"x", $"y", filterFeature($"feature") as "filter").show(false)

输出:

+---+---+-----------------------+
|x  |y  |filter                 |
+---+---+-----------------------+
|10 |9  |[[ft2,6.0], [ft3,10.9]]|
|11 |0  |[[ft1,4.0], [ft2,0.9]] |
|20 |9  |[[ft2,6.4], [ft3,1.9]] |
|18 |8  |[[ft4,8.1], [ft2,18.9]]|
+---+---+-----------------------+

最新更新