单元测试pyspark和累加器

Unit test pyspark and accumulator

我正在尝试在 Python 中测试我的 Spark 代码,但每当我的测试代码 运行 时,我的所有累加器都是空的。但是,当我 运行 没有模拟的本地代码时,代码工作正常并且累加器有值。这是代码的精简版:

代码:

from typing import Any
from pyspark.accumulators import AccumulatorParam
from pyspark.sql import DataFrame, SparkSession

columns: Any = []

class SetAccumulator(AccumulatorParam):
    def zero(self, value):
        return value.copy()

    def addInPlace(self, value1, value2):
        return value1.union(value2)

def read_columns(obj: dict) -> None:
    global columns

    for key in obj.keys():
        columns += {key}

def run(spark: SparkSession, df: DataFrame) -> list:
    global columns
    columns = spark.sparkContext.accumulator(set(), SetAccumulator())
    df.rdd.foreach(lambda row: read_columns(row.asDict()))
    return list(columns.value)

模拟 Spark 测试代码:

import pydeequ
from unittest import TestCase
from pyspark.sql import SparkSession

class SparkTestCase(TestCase):
    spark: SparkSession

    @classmethod
    def setUpClass(cls) -> None:
        cls.spark = (
            SparkSession.builder.appName("testspark")
              .master("local")
              .enableHiveSupport()
              .config("spark.jars.packages", pydeequ.deequ_maven_coord)
              .config("spark.jars.excludes", pydeequ.f2j_maven_coord)
              .config("spark.sql.shuffle.partitions", 8)
              .getOrCreate()
        )

测试代码:

from tests.spark.testcase import SparkTestCase
from foo.bar import run

class TestFoo(SparkTestCase):
    def test_foo(self):
        columns = [
            "test",
            "bar",
            "name"
        ]
        data = [
            (
                "Hello!",
                100,
                "Foobar"
            )
        ]

        df = self.spark.createDataFrame(data, columns)
        response = run(self.spark, df)
        print(response)

测试打印出一个空列表。但如前所述,当我 运行 在测试框架之外(本地,在我的计算机上)时,它会打印出 ["test", "bar", "name"].

我做错了什么或者我需要添加什么才能使其在测试用例中工作?

我找到了一种使单元测试工作的方法。我创建了一个累加器字典并将其传递给每个任务,测试能够正确更新值。我假设 global 不能正常使用 Spark 单元测试。

更新后的代码如下所示。上面问题的测试代码保持不变。

from typing import Any
from pyspark.accumulators import AccumulatorParam
from pyspark.sql import DataFrame, SparkSession

class SetAccumulator(AccumulatorParam):
    def zero(self, value):
        return value.copy()

    def addInPlace(self, value1, value2):
        return value1.union(value2)

def read_columns(obj: dict, accumulators: dict) -> None:
    for key in obj.keys():
        accumulators["columns"] += {key}

def run(spark: SparkSession, df: DataFrame) -> list:
    columns = spark.sparkContext.accumulator(set(), SetAccumulator())
    accumulators = {"columns": columns}
    df.rdd.foreach(lambda row: read_columns(row.asDict(), accumulators))
    return list(columns.value)

我删除了顶部的 columns 变量并删除了对它的所有全局引用。相反,我创建了一个字典 accumulators = {"columns": columns},我将其传递给 read_columns 函数并按键获取累加器。

测试现在可以正确打印 ["test", "bar", "name"] 并且它在测试环境之外仍然有效。