使用pyspark数据帧测试代码时:如何模拟.repartition()链接函数



我有使用pyspark库的代码,我想用pytest测试它

然而,当运行测试时,我想在数据帧上模拟.repartition()方法

  1. 假设我要测试的代码是pyspark链式函数,如下所示
def transform(df: pyspark.sql.DataFrame):
return (
df
.repartition("id")
.groupby("id")
.sum("quantity")
)
  1. 当前我的测试功能如下
@pytest.mark.parametrize("df, expected_df", [(..., ...)])  # my input args
def test_transform(df, expected_df):
df_output = transform(df)
assert df_output == expected_df
  1. 现在,我如何为我的测试模拟.repartition()方法?类似于此伪代码(当前不工作(
from unittest import mock
@pytest.mark.parametrize("df, expected_df", [(..., ...)])  # my input args
@mock.patch("pyspark.sql.DataFrame.repartition")
def test_transform(df, expected_df):
df_output = transform(df)
assert df_output == expected_df

请像下面这样连锁调用。参见此处类似的

@mock.patch("pyspark.sql.DataFrame")
def test_transform(df: Mock):
expected_df = "expected value"
df.repartition.return_value.groupby.return_value.sum.return_value = expected_df
df_output = transform(df)
assert df_output == expected_df
df.repartition.assert_called_with("id")
df.repartition().groupby.assert_called_with("id")
df.repartition().groupby().sum.assert_called_with("quantity")

最新更新