使用函数生成的图形填充 matplotlib 子图

Populating matplotlib subplots with figures generated from function

我找到并改编了以下代码片段,用于生成线性回归的诊断图。目前这是使用以下函数完成的:

def residual_plot(some_values):
    plot_lm_1 = plt.figure(1)    
    plot_lm_1 = sns.residplot()
    plot_lm_1.axes[0].set_title('title')
    plot_lm_1.axes[0].set_xlabel('label')
    plot_lm_1.axes[0].set_ylabel('label')
    plt.show()


def qq_plot(residuals):
    QQ = ProbPlot(residuals)
    plot_lm_2 = QQ.qqplot()    
    plot_lm_2.axes[0].set_title('title')
    plot_lm_2.axes[0].set_xlabel('label')
    plot_lm_2.axes[0].set_ylabel('label')
    plt.show()

它们被称为:

plot1 = residual_plot(value_set1)
plot2 = qq_plot(value_set1)
plot3 = residual_plot(value_set2)
plot4 = qq_plot(value_set2)

如何创建 subplots 以便这 4 个图显示在 2x2 网格中?
我试过使用:

fig, axes = plt.subplots(2,2)
    axes[0,0].plot1
    axes[0,1].plot2
    axes[1,0].plot3
    axes[1,1].plot4
    plt.show()

但收到错误:

AttributeError: 'AxesSubplot' object has no attribute 'plot1'

我应该从函数内部还是其他地方设置轴属性?

您应该创建一个具有四个子图轴的单个图形,这些轴将用作自定义绘图函数的输入轴,如下所示

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import probplot


def residual_plot(x, y, axes = None):
    if axes is None:
        fig = plt.figure()
        ax1 = fig.add_subplot(1, 1, 1)
    else:
        ax1 = axes
    p = sns.residplot(x, y, ax = ax1)
    ax1.set_xlabel("Data")
    ax1.set_ylabel("Residual")
    ax1.set_title("Residuals")
    return p


def qq_plot(x, axes = None):
    if axes is None:
        fig = plt.figure()
        ax1 = fig.add_subplot(1, 1, 1)
    else:
        ax1 = axes
    p = probplot(x, plot = ax1)
    ax1.set_xlim(-3, 3)
    return p


if __name__ == "__main__":
    # Generate data
    x = np.arange(100)
    y = 0.5 * x
    y1 = y + np.random.randn(100)
    y2 = y + np.random.randn(100)

    # Initialize figure and axes
    fig = plt.figure(figsize = (8, 8), facecolor = "white")
    ax1 = fig.add_subplot(2, 2, 1)
    ax2 = fig.add_subplot(2, 2, 2)
    ax3 = fig.add_subplot(2, 2, 3)
    ax4 = fig.add_subplot(2, 2, 4)

    # Plot data
    p1 = residual_plot(y, y1, ax1)
    p2 = qq_plot(y1, ax2)
    p3 = residual_plot(y, y2, ax3)
    p4 = qq_plot(y2, ax4)

    fig.tight_layout()
    fig.show()

我不知道你的ProbPlot函数是什么,所以我就拿了SciPy的。