我是Scala新手,在编写单元测试时遇到了一些问题。
我试图在Scala中比较和检查两个Spark dataframe的相等性以进行单元测试,并意识到没有简单的方法来检查两个Spark dataframe的相等性。
c++的等效代码将是(假设dataframe在c++中表示为双数组):
int expected[10][2];
int result[10][2];
for (int row = 0; row < 10; row++) {
for (int col = 0; col < 2; col++) {
if (expected[row][col] != result[row][col]) return false;
}
}
实际的测试将涉及基于DataFrames列的数据类型的相等性测试(使用浮点数的精度公差测试等)。
似乎没有一种简单的方法可以使用Scala迭代遍历dataframe中的所有元素,而其他用于检查两个dataframe(如df1.except(df2)
)的相等性的解决方案在我的情况下不起作用,因为我需要能够提供支持,以允许浮点数和双精度数测试相等性。
当然,我可以尝试事先四舍五入所有的元素,然后比较结果,但我想看看是否有任何其他的解决方案,将允许我遍历DataFrames来检查是否相等
import org.scalatest.{BeforeAndAfterAll, FeatureSpec, Matchers}
outDf.collect() should contain theSameElementsAs (dfComparable.collect())
# or ( obs order matters ! )
// outDf.except(dfComparable).toDF().count should be(0)
outDf.except(dfComparable).count should be(0)
如果您想要检查两个数据帧是否相等,您可以使用数据帧的subtract()
方法(在1.3及以上版本中支持)
您可以检查两个数据帧的diff是否为空或0。例:df1.subtract(df2).count() == 0
假设您有固定的col和行#,一种解决方案可以通过行索引连接两个Df(如果您没有记录的id),然后在最终Df中直接迭代[与两个Df的所有列]。像这样:
Schemas
DF1
root
|-- col1: double (nullable = true)
|-- col2: double (nullable = true)
|-- col3: double (nullable = true)
DF2
root
|-- col1: double (nullable = true)
|-- col2: double (nullable = true)
|-- col3: double (nullable = true)
df1
+----------+-----------+------+
| col1| col2| col3|
+----------+-----------+------+
|1.20000001| 1.21| 1.2|
| 2.1111| 2.3| 22.2|
| 3.2|2.330000001| 2.333|
| 2.2444| 2.344|2.3331|
+----------+-----------+------+
df2
+------+-----+------+
| col1| col2| col3|
+------+-----+------+
| 1.2| 1.21| 1.2|
|2.1111| 2.3| 22.2|
| 3.2| 2.33| 2.333|
|2.2444|2.344|2.3331|
+------+-----+------+
Added row index
df1
+----------+-----------+------+---+
| col1| col2| col3|row|
+----------+-----------+------+---+
|1.20000001| 1.21| 1.2| 0|
| 2.1111| 2.3| 22.2| 1|
| 3.2|2.330000001| 2.333| 2|
| 2.2444| 2.344|2.3331| 3|
+----------+-----------+------+---+
df2
+------+-----+------+---+
| col1| col2| col3|row|
+------+-----+------+---+
| 1.2| 1.21| 1.2| 0|
|2.1111| 2.3| 22.2| 1|
| 3.2| 2.33| 2.333| 2|
|2.2444|2.344|2.3331| 3|
+------+-----+------+---+
Combined DF
+---+----------+-----------+------+------+-----+------+
|row| col1| col2| col3| col1| col2| col3|
+---+----------+-----------+------+------+-----+------+
| 0|1.20000001| 1.21| 1.2| 1.2| 1.21| 1.2|
| 1| 2.1111| 2.3| 22.2|2.1111| 2.3| 22.2|
| 2| 3.2|2.330000001| 2.333| 3.2| 2.33| 2.333|
| 3| 2.2444| 2.344|2.3331|2.2444|2.344|2.3331|
+---+----------+-----------+------+------+-----+------+
你可以这样做:
println("Schemas")
println("DF1")
df1.printSchema()
println("DF2")
df2.printSchema()
println("df1")
df1.show
println("df2")
df2.show
val finaldf1 = df1.withColumn("row", monotonically_increasing_id())
val finaldf2 = df2.withColumn("row", monotonically_increasing_id())
println("Added row index")
println("df1")
finaldf1.show()
println("df2")
finaldf2.show()
val joinedDfs = finaldf1.join(finaldf2, "row")
println("Combined DF")
joinedDfs.show()
val tolerance = 0.001
def isInValidRange(a: Double, b: Double): Boolean ={
Math.abs(a-b)<=tolerance
}
joinedDfs.take(10).foreach(row => {
assert( isInValidRange(row.getDouble(1), row.getDouble(4)) , "Col1 validation. Row %s".format(row.getLong(0)+1))
assert( isInValidRange(row.getDouble(2), row.getDouble(5)) , "Col2 validation. Row %s".format(row.getLong(0)+1))
assert( isInValidRange(row.getDouble(3), row.getDouble(6)) , "Col3 validation. Row %s".format(row.getLong(0)+1))
})
注意:Assert是不序列化的,一个解决方法是使用take()来避免错误。