我在火花(v1.5.0)代码中有两个 DataFrame
:
aDF = [user_id : Int, user_purchases: array<int> ]
bDF = [user_id : Int, user_purchases: array<int> ]
我想做的是加入这两个数据范围,但是我只需要aDF.user_purchases
和bDF.user_purchases
之间的相交的行具有2个以上的元素(交集> 2)。
我是否必须使用RDD API,还是可以使用org.apache.sql.functions的某些函数?
一种可能的解决方案是找到有趣的对并用数组增强它们。首先,让我们导入一些功能:
import org.apache.spark.sql.functions.explode
和重命名列:
val aDF_ = aDF.toDF("a_user_id", "a_user_purchases")
val bDF_ = bDF.toDF("b_user_id", "b_user_purchases")
对匹配谓词的配对可以识别为:
val filtered = aDF_.withColumn("purchase", explode($"a_user_purchases"))
.join(bDF_.withColumn("purchase", explode($"b_user_purchases")), Seq("purchase"))
.groupBy("a_user_id", "b_user_id")
.count()
.where($"count" > 2)
最终可以与输入数据集一起过滤数据以获得完整的结果:
filtered.join(aDF_, Seq("a_user_id")).join(bDF_, Seq("b_user_id")).drop("count")
在Spark 2.4或以后您还可以使用内置功能:
import org.apache.spark.sql.functions.{size, array_intersect}
aDF_
.crossJoin(bDF_)
.where(size(
array_intersect($"a_user_purchases", $"b_user_purchases"
)) > 2)
尽管这可能仍然比有针对性的哈希(Join)慢。
我看不到任何内置功能,但是您可以使用UDF:
import scala.collection.mutable.WrappedArray;
val intersect = udf ((a : WrappedArray[Int], b : WrappedArray[Int]) => {
var count = 0;
a.foreach (x => {
if (b.contains(x)) count = count + 1;
});
count;
});
// test data sets
val one = sc.parallelize(List(
(1, Array(1, 2, 3)),
(2, Array(1,2 ,3, 4)),
(3, Array(1, 2,3)),
(4, Array(1,2))
)).toDF("user", "arr");
val two = sc.parallelize(List(
(1, Array(1, 2, 3)),
(2, Array(1,2 ,3, 4)),
(3, Array(1, 2, 3)),
(4, Array(1))
)).toDF("user", "arr");
// usage
one.join(two, one("user") === two("user"))
.select (one("user"), intersect(one("arr"), two("arr")).as("intersect"))
.where(col("intersect") > 2).show
// version from comment
one.join(two)
.select (one("user"), two("user"), intersect(one("arr"), two("arr")).as("intersect")).
where('intersect > 2).show