如何针对使用 matplotlib 的代码编写单元测试?

How can I write unit tests against code that uses matplotlib?

我正在开发一个 python (2.7) 程序,它可以生成很多不同的 matplotlib 图(数据不是随机的)。我愿意实施一些测试(使用 unittest)以确保生成的数字是正确的。例如,我将预期的图形(数据或图像)存储在某个地方,我 运行 我的函数并将结果与​​参考进行比较。有办法吗?

Matplotlib 有一个 testing infrastructure。例如:

import numpy as np
import matplotlib
from matplotlib.testing.decorators import image_comparison
import matplotlib.pyplot as plt

@image_comparison(baseline_images=['spines_axes_positions'])
def test_spines_axes_positions():
    # SF bug 2852168
    fig = plt.figure()
    x = np.linspace(0,2*np.pi,100)
    y = 2*np.sin(x)
    ax = fig.add_subplot(1,1,1)
    ax.set_title('centered spines')
    ax.plot(x,y)
    ax.spines['right'].set_position(('axes',0.1))
    ax.yaxis.set_ticks_position('right')
    ax.spines['top'].set_position(('axes',0.25))
    ax.xaxis.set_ticks_position('top')
    ax.spines['left'].set_color('none')
    ax.spines['bottom'].set_color('none')

来自docs

The first time this test is run, there will be no baseline image to compare against, so the test will fail. Copy the output images (in this case result_images/test_category/spines_axes_positions.*) to the correct subdirectory of baseline_images tree in the source directory (in this case lib/matplotlib/tests/baseline_images/test_category). When rerunning the tests, they should now pass.

在我的 experience 中,图像比较测试最终带来的麻烦多于它们的价值。如果您想 运行 跨多个系统(如 TravisCI)持续集成,这些系统可能具有略微不同的字体或可用的绘图后端,则尤其如此。即使功能完美无缺地工作,要保持测试通过也可能需要大量工作。此外,以这种方式进行测试需要将图像保存在您的 git 存储库中,如果您经常更改代码,这会很快导致存储库膨胀。

我认为更好的方法是 (1) 假设 matplotlib 实际上会正确绘制图形,以及 (2) 运行 针对绘图函数返回的数据进行数值测试。 (如果您知道在哪里查找,您也可以随时在 Axes 对象中找到此数据。)

例如,假设您想测试这样一个简单的函数:

import numpy as np
import matplotlib.pyplot as plt
def plot_square(x, y):
    y_squared = np.square(y)
    return plt.plot(x, y_squared)

你的单元测试可能看起来像

def test_plot_square1():
    x, y = [0, 1, 2], [0, 1, 2]
    line, = plot_square(x, y)
    x_plot, y_plot = line.get_xydata().T
    np.testing.assert_array_equal(y_plot, np.square(y))

或者,等价地,

def test_plot_square2():
    f, ax = plt.subplots()
    x, y = [0, 1, 2], [0, 1, 2]
    plot_square(x, y)
    x_plot, y_plot = ax.lines[0].get_xydata().T
    np.testing.assert_array_equal(y_plot, np.square(y))

您还可以使用 unittest.mock 模拟 matplotlib.pyplot 并检查是否使用适当的参数对其进行了适当的调用。假设您在 module.py 中有一个 plot_data(data) 函数(假设它位于 package/src/ 中)您想要测试它,它看起来像这样:

import matplotlib.pyplot as plt

def plot_data(x, y, title):
    plt.figure()
    plt.title(title)
    plt.plot(x, y)
    plt.show()

为了在您的 test_module.py 文件中测试此功能,您需要:

import numpy as np

from unittest import mock
import package.src.module as my_module  # Specify path to your module.py


@mock.patch("%s.my_module.plt" % __name__)
def test_module(mock_plt):
    x = np.arange(0, 5, 0.1)
    y = np.sin(x)
    my_module.plot_data(x, y, "my title")

    # Assert plt.title has been called with expected arg
    mock_plt.title.assert_called_once_with("my title")

    # Assert plt.figure got called
    assert mock_plt.figure.called

这会检查是否使用参数 my title 调用了 title 方法,并且是否在 plt 对象的 plot_data 内部调用了 figure 方法.

更详细的解释:

@mock.patch("module.plt") 装饰器“修补”module.py 中导入的 plt 模块,并将其作为 mock 对象 (mock_plt) 注入到 test_module作为参数。这个模拟对象(作为 mock_plt 传递)现在可以在我们的测试中使用来记录 plot_data(我们正在测试的函数)对 plt 所做的一切 - 这是因为所有调用plt by plot_data 现在将改为在我们的模拟对象上制作。

另外,除了assert_called_once_with you might want to use other, similar methods such as assert_not_called, assert_called_once