pytest 是否支持在测试文件中使用函数工厂?

Does pytest support the use of function factories in test files?

示例 test.py 文件:

import torch

def one():
    return torch.tensor(0.0132005215)

def two():
    return torch.tensor(4.4345855713e-05)

def three():
    return torch.tensor(7.1525573730e-07)


def test_method(method, expected_value):
    value = method()
    assert(torch.isclose(value, expected_value))

def test_one():
    test_method(one, torch.tensor(0.0132005215))

def test_two():
    test_method(two, torch.tensor(4.4345855713e-05))

def test_three():
    test_method(three, torch.tensor(7.1525573730e-07))
    # test_method(three, torch.tensor(1.0))

if __name__ == '__main__':
    test_one()
    test_two()
    test_three()

基本上,我有几个函数要测试(这里称为 onetwothree),它们都具有相同的签名但内部结构不同。因此,我没有编写函数 test_one()test_two() 等并因此重复代码,而是编写了一个 "function factory"(这是正确的术语吗?)test_method,它作为一个输入函数、预期结果和 returns assert 命令的结果。

如您所见,现在测试是手动执行的:我 运行 脚本 test.py,查看屏幕,如果没有 Assertion error 被打印,我'我很高兴。当然,我想通过使用 pytest 来改进这一点,因为有人告诉我它是最简单和最常用的 Python 测试框架之一。问题是,通过查看 pytest documentation,我得到的印象是 pytest 将尝试 运行ning 所有名称以 test_ 开头的函数。当然,测试 test_method 本身没有任何意义。你能帮我重构这个测试脚本,这样我就可以 运行 它和 pytest?

在pytest中,可以使用test parametrization来实现。在您的情况下,您必须为测试提供不同的参数:

import pytest

@pytest.mark.parametrize("method, expected_value",
                         [(one, 0.0132005215),
                          (two, 4.4345855713e-05),
                          (three, 7.1525573730e-07)])
def test_method(method, expected_value):
    value = method()
    assert(torch.isclose(value, expected_value))

如果您 运行 python -m pytest -rA(请参阅 the documentation 了解输出选项),​​您将获得三个测试的输出,例如:

======================================================= PASSES ========================================================
=============================================== short test summary info ===============================================
PASSED test.py::test_method[one-0.0132005215]
PASSED test.py::test_method[two-4.4345855713e-05]
PASSED test.py::test_method[three-7.152557373e-07]
================================================== 3 passed in 0.07s ==================================================

如果您不喜欢灯具名称,您可以修改它们:

@pytest.mark.parametrize("method, expected_value",
                          [(one, 0.0132005215),
                          (two, 4.4345855713e-05),
                          (three, 7.1525573730e-07),
                          ],
                         ids=["one", "two", "three"])
...

改为:

=============================================== short test summary info ===============================================
PASSED test.py::test_method[one]
PASSED test.py::test_method[two]
PASSED test.py::test_method[three]
================================================== 3 passed in 0.06s ==================================================