如何在 Spark 数据帧上应用自定义筛选函数



我有一个数据帧的形式:

A_DF = |id_A: Int|concatCSV: String|

另一个:

B_DF = |id_B: Int|triplet: List[String]|

concatCSV的示例可能如下所示:

"StringD, StringB, StringF, StringE, StringZ"
"StringA, StringB, StringX, StringY, StringZ"
...

triplet是这样的:

("StringA", "StringF", "StringZ")
("StringB", "StringU", "StringR")
...

我想生成A_DFB_DF笛卡尔集合,例如;

| id_A: Int | concatCSV: String                             | id_B: Int | triplet: List[String]            |
|     14    | "StringD, StringB, StringF, StringE, StringZ" |     21    | ("StringA", "StringF", "StringZ")|
|     14    | "StringD, StringB, StringF, StringE, StringZ" |     45    | ("StringB", "StringU", "StringR")|
|     18    | "StringA, StringB, StringX, StringY, StringG" |     21    | ("StringA", "StringF", "StringZ")|
|     18    | "StringA, StringB, StringX, StringY, StringG" |     45    | ("StringB", "StringU", "StringR")|
|    ...    |                                               |           |                                  |

然后只保留至少有两个子字符串(例如StringA, StringB)的记录,从B_DF("triplet")中出现的A_DF("concatCSV"),即使用filter排除那些不满足此条件的记录。

第一个问题是:我可以在不将DF转换为RDD的情况下做到这一点吗?

第二个问题是:我能否理想地在join步骤中完成整个事情 - 作为一个where条件?

我尝试过尝试如下内容:

val cartesianRDD = A_DF
.join(B_DF,"right")
.where($"triplet".exists($"concatCSV".contains(_)))

where无法解决。我用filter而不是where尝试了它,但仍然没有运气。另外,出于某种奇怪的原因,cartesianRDD的类型注释是SchemaRDD而不是DataFrame。我怎么会这样?最后,我上面尝试的(我写的短代码)是不完整的,因为它会保留triplet中找到concatCSV的只有一个子字符串的记录。

所以,第三个问题是:我应该只更改为RDD并使用自定义过滤功能解决它吗?

最后,最后一个问题:是否可以对数据帧使用自定义筛选函数?

感谢您的帮助。

函数CROSS JOINHive中实现,所以你可以先使用Hive SQL进行交叉连接:

A_DF.registerTempTable("a")
B_DF.registerTempTable("b")
// sqlContext should be really a HiveContext
val result = sqlContext.sql("SELECT * FROM a CROSS JOIN b") 

然后,您可以使用两个udf过滤到预期的输出。一个将您的字符串转换为单词数组,另一个为我们提供生成的数组列和现有列的交集长度"triplet"

import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.functions.col
val splitArr = udf { (s: String) => s.split(",").map(_.trim) }
val commonLen = udf { (a: WrappedArray[String], 
b: WrappedArray[String]) => a.intersect(b).length }
val temp = (result.withColumn("concatArr",
splitArr(col("concatCSV"))).select(col("*"),
commonLen(col("triplet"), col("concatArr")).alias("comm"))
.filter(col("comm") >= 2)
.drop("comm")
.drop("concatArr"))
temp.show
+----+--------------------+----+--------------------+
|id_A|           concatCSV|id_B|             triplet|
+----+--------------------+----+--------------------+
|  14|StringD, StringB,...|  21|[StringA, StringF...|
|  18|StringA, StringB,...|  21|[StringA, StringF...|
+----+--------------------+----+--------------------+

相关内容

  • 没有找到相关文章