使用自定义(外部)函数在子图 matplotlib 中绘制

Plot in subplot matplotlib with custom (external) function

我有一个与之前回答的问题类似的问题(). However, I want to have more advanced plots. I'm using this function for plotting (taken from https://towardsdatascience.com/an-introduction-to-bayesian-inference-in-pystan-c27078e58d53):

def plot_trace(param, param_name='parameter', ax=None, **kwargs):
  """Plot the trace and posterior of a parameter."""

  # Summary statistics
    mean = np.mean(param)
    median = np.median(param)
    cred_min, cred_max = np.percentile(param, 2.5), np.percentile(param, 97.5)

  # Plotting
    #ax = ax or plt.gca()
    plt.subplot(2,1,1)
    plt.plot(param)
    plt.xlabel('samples')
    plt.ylabel(param_name)
    plt.axhline(mean, color='r', lw=2, linestyle='--')
    plt.axhline(median, color='c', lw=2, linestyle='--')
    plt.axhline(cred_min, linestyle=':', color='k', alpha=0.2)
    plt.axhline(cred_max, linestyle=':', color='k', alpha=0.2)
    plt.title('Trace and Posterior Distribution for {}'.format(param_name))

    plt.subplot(2,1,2)
    plt.hist(param, 30, density=True); sns.kdeplot(param, shade=True)
    plt.xlabel(param_name)
    plt.ylabel('density')
    plt.axvline(mean, color='r', lw=2, linestyle='--',label='mean')
    plt.axvline(median, color='c', lw=2, linestyle='--',label='median')
    plt.axvline(cred_min, linestyle=':', color='k', alpha=0.2, label='95% CI')
    plt.axvline(cred_max, linestyle=':', color='k', alpha=0.2)

    plt.gcf().tight_layout()
    plt.legend()

而且我想要这两个子图用于不同的参数。如果我使用这段代码,它只会覆盖绘图并忽略 ax 参数。你能帮我看看如何使它不仅适用于 2 个参数吗?

params = [mu_a,mu_tau]
param_names = ['A mean','tau mean']

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16,16))
fig.subplots_adjust(hspace=0.5)
fig.suptitle('Convergence and distribution of parameters')

plot_trace(params[0], param_name=param_names[0], ax = ax1)
plot_trace(params[1], param_name=param_names[1], ax = ax2)

预期的结果是将这些地块并排放置 How one subplot should look like (what the function makes)。谢谢。

正如@Sebastian-R 和@ImportanceOfBeingErnest 所建议的那样,问题在于您是在函数内部创建子图,而不是使用传递给 ax= 关键字的内容。此外,如果你想在两个不同的子图上绘图,你需要将两个轴实例传递给你的函数

有两种方法可以解决问题:

  • 第一个解决方案是重新编写您的函数以使用 matplotlib 的 the object-oriented API(这是我推荐的解决方案,尽管它需要更多工作)

代码:

def plot_trace(param, axs, param_name='parameter', **kwargs):
    """Plot the trace and posterior of a parameter."""
    # Summary statistics
    (...)
    # Plotting
    ax1 = axs[0]
    ax1.plot(param)
    ax1.set_xlabel('samples')  # notice the difference in notation here
    (...)

    ax2 = axs[1]
    ax2.hist(...); sns.kdeplot(..., ax=ax2)
    (...)
  • 第二个解决方案保留您当前的代码,但由于 plt.xxx() 函数在当前轴上工作,因此您需要将作为参数传递的轴设为当前轴:

代码:

def plot_trace(param, axs, param_name='parameter', **kwargs):
    """Plot the trace and posterior of a parameter."""
    # Summary statistics
    (...)
    # Plotting
    plt.sca(axs[0])
    plt.plot(param)
    plt.xlabel('samples')
    (...)

    plt.sca(axs[1])
    plt.hist(...)
    (...)

然后你需要在最后创建你需要的子图数量(所以如果你想绘制两个参数,你必须创建 4 个子图)并按照你的需要调用你的函数:

params = [mu_a,mu_tau]
param_names = ['A mean','tau mean']

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(16,16))
fig.subplots_adjust(hspace=0.5)
fig.suptitle('Convergence and distribution of parameters')

plot_trace(params[0], axs=[ax1, ax3], param_name=param_names[0])
plot_trace(params[1], axs=[ax2, ax4], param_name=param_names[1])