检查Scala中两个Spark dataframe是否相等



我是Scala新手,在编写单元测试时遇到了一些问题。

我试图在Scala中比较和检查两个Spark dataframe的相等性以进行单元测试,并意识到没有简单的方法来检查两个Spark dataframe的相等性。

c++的等效代码将是(假设dataframe在c++中表示为双数组):

    int expected[10][2];
    int result[10][2];
    for (int row = 0; row < 10; row++) {
        for (int col = 0; col < 2; col++) {
            if (expected[row][col] != result[row][col]) return false;
        }
    }

实际的测试将涉及基于DataFrames列的数据类型的相等性测试(使用浮点数的精度公差测试等)。

似乎没有一种简单的方法可以使用Scala迭代遍历dataframe中的所有元素,而其他用于检查两个dataframe(如df1.except(df2))的相等性的解决方案在我的情况下不起作用,因为我需要能够提供支持,以允许浮点数和双精度数测试相等性。

当然,我可以尝试事先四舍五入所有的元素,然后比较结果,但我想看看是否有任何其他的解决方案,将允许我遍历DataFrames来检查是否相等

import org.scalatest.{BeforeAndAfterAll, FeatureSpec, Matchers}
outDf.collect() should contain theSameElementsAs (dfComparable.collect())
# or ( obs order matters ! )
// outDf.except(dfComparable).toDF().count should be(0)
outDf.except(dfComparable).count should be(0)   

如果您想要检查两个数据帧是否相等,您可以使用数据帧的subtract()方法(在1.3及以上版本中支持)

您可以检查两个数据帧的diff是否为空或0。例:df1.subtract(df2).count() == 0

假设您有固定的col和行#,一种解决方案可以通过行索引连接两个Df(如果您没有记录的id),然后在最终Df中直接迭代[与两个Df的所有列]。像这样:

Schemas
DF1
root
 |-- col1: double (nullable = true)
 |-- col2: double (nullable = true)
 |-- col3: double (nullable = true)
DF2
root
 |-- col1: double (nullable = true)
 |-- col2: double (nullable = true)
 |-- col3: double (nullable = true)
df1
+----------+-----------+------+
|      col1|       col2|  col3|
+----------+-----------+------+
|1.20000001|       1.21|   1.2|
|    2.1111|        2.3|  22.2|
|       3.2|2.330000001| 2.333|
|    2.2444|      2.344|2.3331|
+----------+-----------+------+
df2
+------+-----+------+
|  col1| col2|  col3|
+------+-----+------+
|   1.2| 1.21|   1.2|
|2.1111|  2.3|  22.2|
|   3.2| 2.33| 2.333|
|2.2444|2.344|2.3331|
+------+-----+------+
Added row index
df1
+----------+-----------+------+---+
|      col1|       col2|  col3|row|
+----------+-----------+------+---+
|1.20000001|       1.21|   1.2|  0|
|    2.1111|        2.3|  22.2|  1|
|       3.2|2.330000001| 2.333|  2|
|    2.2444|      2.344|2.3331|  3|
+----------+-----------+------+---+
df2
+------+-----+------+---+
|  col1| col2|  col3|row|
+------+-----+------+---+
|   1.2| 1.21|   1.2|  0|
|2.1111|  2.3|  22.2|  1|
|   3.2| 2.33| 2.333|  2|
|2.2444|2.344|2.3331|  3|
+------+-----+------+---+
Combined DF
+---+----------+-----------+------+------+-----+------+
|row|      col1|       col2|  col3|  col1| col2|  col3|
+---+----------+-----------+------+------+-----+------+
|  0|1.20000001|       1.21|   1.2|   1.2| 1.21|   1.2|
|  1|    2.1111|        2.3|  22.2|2.1111|  2.3|  22.2|
|  2|       3.2|2.330000001| 2.333|   3.2| 2.33| 2.333|
|  3|    2.2444|      2.344|2.3331|2.2444|2.344|2.3331|
+---+----------+-----------+------+------+-----+------+

你可以这样做:

println("Schemas")
    println("DF1")
    df1.printSchema()
    println("DF2")
    df2.printSchema()
    println("df1")
    df1.show
    println("df2")
    df2.show
    val finaldf1 = df1.withColumn("row", monotonically_increasing_id())
    val finaldf2 = df2.withColumn("row", monotonically_increasing_id())
    println("Added row index")
    println("df1")
    finaldf1.show()
    println("df2")
    finaldf2.show()
    val joinedDfs = finaldf1.join(finaldf2, "row")
    println("Combined DF")
    joinedDfs.show()
    val tolerance = 0.001
    def isInValidRange(a: Double, b: Double): Boolean ={
      Math.abs(a-b)<=tolerance
    }
    joinedDfs.take(10).foreach(row => {
      assert( isInValidRange(row.getDouble(1), row.getDouble(4)) , "Col1 validation. Row %s".format(row.getLong(0)+1))
      assert( isInValidRange(row.getDouble(2), row.getDouble(5)) , "Col2 validation. Row %s".format(row.getLong(0)+1))
      assert( isInValidRange(row.getDouble(3), row.getDouble(6)) , "Col3 validation. Row %s".format(row.getLong(0)+1))
    })

注意:Assert是不序列化的,一个解决方法是使用take()来避免错误。

相关内容

  • 没有找到相关文章