我有使用pyspark
库的代码,我想用pytest
测试它
然而,当运行测试时,我想在数据帧上模拟.repartition()
方法
- 假设我要测试的代码是pyspark链式函数,如下所示
def transform(df: pyspark.sql.DataFrame):
return (
df
.repartition("id")
.groupby("id")
.sum("quantity")
)
- 当前我的测试功能如下
@pytest.mark.parametrize("df, expected_df", [(..., ...)]) # my input args
def test_transform(df, expected_df):
df_output = transform(df)
assert df_output == expected_df
- 现在,我如何为我的测试模拟
.repartition()
方法?类似于此伪代码(当前不工作(
from unittest import mock
@pytest.mark.parametrize("df, expected_df", [(..., ...)]) # my input args
@mock.patch("pyspark.sql.DataFrame.repartition")
def test_transform(df, expected_df):
df_output = transform(df)
assert df_output == expected_df
请像下面这样连锁调用。参见此处类似的
@mock.patch("pyspark.sql.DataFrame")
def test_transform(df: Mock):
expected_df = "expected value"
df.repartition.return_value.groupby.return_value.sum.return_value = expected_df
df_output = transform(df)
assert df_output == expected_df
df.repartition.assert_called_with("id")
df.repartition().groupby.assert_called_with("id")
df.repartition().groupby().sum.assert_called_with("quantity")