我正在学习pyspark,我有一个函数:
import re
def function_1(string):
new_string = re.sub(r"!", " ", string)
return new_string
udf_function_1 = udf(lambda s: function_1(s), StringType())
def function_2(data):
new_data = data
.withColumn("column_1", udf_function_1("column_1"))
return new_data
我的问题是如何在Python中为function_2()
编写unittest。
您到底想在function_2
中测试什么?
下面是一个保存在名为sample_test.py
的文件中的简单测试。我使用了pytest
,但是你可以在unittest中修改非常相似的代码。
# sample_test.py
from pyspark import sql
spark = sql.SparkSession.builder
.appName("local-spark-session")
.getOrCreate()
def test_create_session():
assert isinstance(spark, sql.SparkSession) == True
assert spark.sparkContext.appName == 'local-spark-session'
def test_spark_version():
assert spark.version == '3.1.2'
running the test…
C:UsersuserDesktop>pytest -v sample_test.py
============================================= test session starts =============================================
platform win32 -- Python 3.6.7, pytest-6.2.5, py-1.10.0, pluggy-1.0.0 -- c:usersuserappdatalocalprogramspythonpython36python.exe
cachedir: .pytest_cache
rootdir: C:UsersuserDesktop
collected 2 items
sample_test.py::test_create_session PASSED [ 50%]
sample_test.py::test_spark_version PASSED [100%]
============================================== 2 passed in 4.81s ==============================================