单元测试pyspark和累加器
Unit test pyspark and accumulator
我正在尝试在 Python 中测试我的 Spark 代码,但每当我的测试代码 运行 时,我的所有累加器都是空的。但是,当我 运行 没有模拟的本地代码时,代码工作正常并且累加器有值。这是代码的精简版:
代码:
from typing import Any
from pyspark.accumulators import AccumulatorParam
from pyspark.sql import DataFrame, SparkSession
columns: Any = []
class SetAccumulator(AccumulatorParam):
def zero(self, value):
return value.copy()
def addInPlace(self, value1, value2):
return value1.union(value2)
def read_columns(obj: dict) -> None:
global columns
for key in obj.keys():
columns += {key}
def run(spark: SparkSession, df: DataFrame) -> list:
global columns
columns = spark.sparkContext.accumulator(set(), SetAccumulator())
df.rdd.foreach(lambda row: read_columns(row.asDict()))
return list(columns.value)
模拟 Spark 测试代码:
import pydeequ
from unittest import TestCase
from pyspark.sql import SparkSession
class SparkTestCase(TestCase):
spark: SparkSession
@classmethod
def setUpClass(cls) -> None:
cls.spark = (
SparkSession.builder.appName("testspark")
.master("local")
.enableHiveSupport()
.config("spark.jars.packages", pydeequ.deequ_maven_coord)
.config("spark.jars.excludes", pydeequ.f2j_maven_coord)
.config("spark.sql.shuffle.partitions", 8)
.getOrCreate()
)
测试代码:
from tests.spark.testcase import SparkTestCase
from foo.bar import run
class TestFoo(SparkTestCase):
def test_foo(self):
columns = [
"test",
"bar",
"name"
]
data = [
(
"Hello!",
100,
"Foobar"
)
]
df = self.spark.createDataFrame(data, columns)
response = run(self.spark, df)
print(response)
测试打印出一个空列表。但如前所述,当我 运行 在测试框架之外(本地,在我的计算机上)时,它会打印出 ["test", "bar", "name"]
.
我做错了什么或者我需要添加什么才能使其在测试用例中工作?
我找到了一种使单元测试工作的方法。我创建了一个累加器字典并将其传递给每个任务,测试能够正确更新值。我假设 global
不能正常使用 Spark 单元测试。
更新后的代码如下所示。上面问题的测试代码保持不变。
from typing import Any
from pyspark.accumulators import AccumulatorParam
from pyspark.sql import DataFrame, SparkSession
class SetAccumulator(AccumulatorParam):
def zero(self, value):
return value.copy()
def addInPlace(self, value1, value2):
return value1.union(value2)
def read_columns(obj: dict, accumulators: dict) -> None:
for key in obj.keys():
accumulators["columns"] += {key}
def run(spark: SparkSession, df: DataFrame) -> list:
columns = spark.sparkContext.accumulator(set(), SetAccumulator())
accumulators = {"columns": columns}
df.rdd.foreach(lambda row: read_columns(row.asDict(), accumulators))
return list(columns.value)
我删除了顶部的 columns
变量并删除了对它的所有全局引用。相反,我创建了一个字典 accumulators = {"columns": columns}
,我将其传递给 read_columns
函数并按键获取累加器。
测试现在可以正确打印 ["test", "bar", "name"]
并且它在测试环境之外仍然有效。
我正在尝试在 Python 中测试我的 Spark 代码,但每当我的测试代码 运行 时,我的所有累加器都是空的。但是,当我 运行 没有模拟的本地代码时,代码工作正常并且累加器有值。这是代码的精简版:
代码:
from typing import Any
from pyspark.accumulators import AccumulatorParam
from pyspark.sql import DataFrame, SparkSession
columns: Any = []
class SetAccumulator(AccumulatorParam):
def zero(self, value):
return value.copy()
def addInPlace(self, value1, value2):
return value1.union(value2)
def read_columns(obj: dict) -> None:
global columns
for key in obj.keys():
columns += {key}
def run(spark: SparkSession, df: DataFrame) -> list:
global columns
columns = spark.sparkContext.accumulator(set(), SetAccumulator())
df.rdd.foreach(lambda row: read_columns(row.asDict()))
return list(columns.value)
模拟 Spark 测试代码:
import pydeequ
from unittest import TestCase
from pyspark.sql import SparkSession
class SparkTestCase(TestCase):
spark: SparkSession
@classmethod
def setUpClass(cls) -> None:
cls.spark = (
SparkSession.builder.appName("testspark")
.master("local")
.enableHiveSupport()
.config("spark.jars.packages", pydeequ.deequ_maven_coord)
.config("spark.jars.excludes", pydeequ.f2j_maven_coord)
.config("spark.sql.shuffle.partitions", 8)
.getOrCreate()
)
测试代码:
from tests.spark.testcase import SparkTestCase
from foo.bar import run
class TestFoo(SparkTestCase):
def test_foo(self):
columns = [
"test",
"bar",
"name"
]
data = [
(
"Hello!",
100,
"Foobar"
)
]
df = self.spark.createDataFrame(data, columns)
response = run(self.spark, df)
print(response)
测试打印出一个空列表。但如前所述,当我 运行 在测试框架之外(本地,在我的计算机上)时,它会打印出 ["test", "bar", "name"]
.
我做错了什么或者我需要添加什么才能使其在测试用例中工作?
我找到了一种使单元测试工作的方法。我创建了一个累加器字典并将其传递给每个任务,测试能够正确更新值。我假设 global
不能正常使用 Spark 单元测试。
更新后的代码如下所示。上面问题的测试代码保持不变。
from typing import Any
from pyspark.accumulators import AccumulatorParam
from pyspark.sql import DataFrame, SparkSession
class SetAccumulator(AccumulatorParam):
def zero(self, value):
return value.copy()
def addInPlace(self, value1, value2):
return value1.union(value2)
def read_columns(obj: dict, accumulators: dict) -> None:
for key in obj.keys():
accumulators["columns"] += {key}
def run(spark: SparkSession, df: DataFrame) -> list:
columns = spark.sparkContext.accumulator(set(), SetAccumulator())
accumulators = {"columns": columns}
df.rdd.foreach(lambda row: read_columns(row.asDict(), accumulators))
return list(columns.value)
我删除了顶部的 columns
变量并删除了对它的所有全局引用。相反,我创建了一个字典 accumulators = {"columns": columns}
,我将其传递给 read_columns
函数并按键获取累加器。
测试现在可以正确打印 ["test", "bar", "name"]
并且它在测试环境之外仍然有效。