Python 单元测试模拟 pyspark 链
Python unittest mock pyspark chain
我想为具有 pyspark 代码的简单方法编写一些单元测试。
def do_stuff(self, df1: DataFrame, df2_path: str, df1_key: str, df2_key: str) -> DataFrame:
df2 = self.spark.read.format('parquet').load(df2_path)
return df1.join(df2, [f.col(df1_key) == f.col(df2_key)], 'left')
如何模拟 spark read 部分?我试过这个:
@patch("class_to_test.SparkSession")
def test_do_stuff(self, mock_spark: MagicMock) -> None:
spark = MagicMock()
spark.read.return_value.format.return_value.load.return_value = \
self.spark.createDataFrame([(1, 2)], ["key2", "c2"])
mock_spark.return_value = spark
input_df = self.spark.createDataFrame([(1, 1)], ["key1", "c1"])
actual_df = ClassToTest().do_stuff(input_df, "df2", "key1", "key2")
expected_df = self.spark.createDataFrame([(1, 1, 1, 2)], ["key1", "c1", "key2", "c2"])
assert_pyspark_df_equal(actual_df, expected_df)
但失败并出现此错误:
py4j.Py4JException: Method join([class java.util.ArrayList, class org.apache.spark.sql.Column, class java.lang.String]) does not exist
看起来模拟没有像我预期的那样工作,我应该如何处理它以便 spark.read.load returns 我指定的测试数据帧?
您可以使用 PropertyMock
来完成。这是一个例子:
# test.py
import unittest
from unittest.mock import patch, PropertyMock, Mock
from pyspark.sql import SparkSession, DataFrame, functions as f
from pyspark_test import assert_pyspark_df_equal
class ClassToTest:
def __init__(self) -> None:
self._spark = SparkSession.builder.getOrCreate()
@property
def spark(self):
return self._spark
def do_stuff(self, df1: DataFrame, df2_path: str, df1_key: str, df2_key: str) -> DataFrame:
df2 = self.spark.read.format('parquet').load(df2_path)
return df1.join(df2, [f.col(df1_key) == f.col(df2_key)], 'left')
class TestClassToTest(unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.getOrCreate()
def test_do_stuff(self) -> None:
# let's say ClassToTest().spark.read.format().load() will return a DataFrame
with patch(
# change __main__ to your module...
'__main__.ClassToTest.spark',
new_callable=PropertyMock,
return_value=Mock(
# read property
read=Mock(
# format() method
format=Mock(
return_value=Mock(
# load() method result:
load=Mock(return_value=self.spark.createDataFrame([(1, 2)], ['key2', 'c2']))))))
):
input_df = self.spark.createDataFrame([(1, 1)], ['key1', 'c1'])
df = ClassToTest().do_stuff(input_df, 'df2_path', 'key1', 'key2')
assert_pyspark_df_equal(
df,
self.spark.createDataFrame([(1, 1, 1, 2)], ['key1', 'c1', 'key2', 'c2'])
)
if __name__ == '__main__':
unittest.main()
让我们检查一下:
python test.py
# result:
----------------------------------------------------------------------
Ran 1 test in 7.460s
OK
我想为具有 pyspark 代码的简单方法编写一些单元测试。
def do_stuff(self, df1: DataFrame, df2_path: str, df1_key: str, df2_key: str) -> DataFrame:
df2 = self.spark.read.format('parquet').load(df2_path)
return df1.join(df2, [f.col(df1_key) == f.col(df2_key)], 'left')
如何模拟 spark read 部分?我试过这个:
@patch("class_to_test.SparkSession")
def test_do_stuff(self, mock_spark: MagicMock) -> None:
spark = MagicMock()
spark.read.return_value.format.return_value.load.return_value = \
self.spark.createDataFrame([(1, 2)], ["key2", "c2"])
mock_spark.return_value = spark
input_df = self.spark.createDataFrame([(1, 1)], ["key1", "c1"])
actual_df = ClassToTest().do_stuff(input_df, "df2", "key1", "key2")
expected_df = self.spark.createDataFrame([(1, 1, 1, 2)], ["key1", "c1", "key2", "c2"])
assert_pyspark_df_equal(actual_df, expected_df)
但失败并出现此错误:
py4j.Py4JException: Method join([class java.util.ArrayList, class org.apache.spark.sql.Column, class java.lang.String]) does not exist
看起来模拟没有像我预期的那样工作,我应该如何处理它以便 spark.read.load returns 我指定的测试数据帧?
您可以使用 PropertyMock
来完成。这是一个例子:
# test.py
import unittest
from unittest.mock import patch, PropertyMock, Mock
from pyspark.sql import SparkSession, DataFrame, functions as f
from pyspark_test import assert_pyspark_df_equal
class ClassToTest:
def __init__(self) -> None:
self._spark = SparkSession.builder.getOrCreate()
@property
def spark(self):
return self._spark
def do_stuff(self, df1: DataFrame, df2_path: str, df1_key: str, df2_key: str) -> DataFrame:
df2 = self.spark.read.format('parquet').load(df2_path)
return df1.join(df2, [f.col(df1_key) == f.col(df2_key)], 'left')
class TestClassToTest(unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.getOrCreate()
def test_do_stuff(self) -> None:
# let's say ClassToTest().spark.read.format().load() will return a DataFrame
with patch(
# change __main__ to your module...
'__main__.ClassToTest.spark',
new_callable=PropertyMock,
return_value=Mock(
# read property
read=Mock(
# format() method
format=Mock(
return_value=Mock(
# load() method result:
load=Mock(return_value=self.spark.createDataFrame([(1, 2)], ['key2', 'c2']))))))
):
input_df = self.spark.createDataFrame([(1, 1)], ['key1', 'c1'])
df = ClassToTest().do_stuff(input_df, 'df2_path', 'key1', 'key2')
assert_pyspark_df_equal(
df,
self.spark.createDataFrame([(1, 1, 1, 2)], ['key1', 'c1', 'key2', 'c2'])
)
if __name__ == '__main__':
unittest.main()
让我们检查一下:
python test.py
# result:
----------------------------------------------------------------------
Ran 1 test in 7.460s
OK