如何在多个子图中绘制

How to plot in multiple subplots

我对这段代码的工作原理有点困惑:

fig, axes = plt.subplots(nrows=2, ncols=2)
plt.show()

在这种情况下,无花果、轴如何工作?它有什么作用?

还有为什么这不能做同样的事情:

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

阅读文档:matplotlib.pyplot.subplots

pyplot.subplots() returns 元组 fig, ax 使用符号

在两个变量中解包
fig, axes = plt.subplots(nrows=2, ncols=2)

代码:

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

不起作用,因为 subplots()pyplot 中的函数,不是对象 Figure.

的成员

有几种方法可以做到这一点。 subplots 方法创建图形以及随后存储在 ax 数组中的子图。例如:

import matplotlib.pyplot as plt

x = range(10)
y = range(10)

fig, ax = plt.subplots(nrows=2, ncols=2)

for row in ax:
    for col in row:
        col.plot(x, y)

plt.show()

但是,类似这样的方法也可以,但事实并非如此"clean",因为您正在创建一个带有子图的图形,然后在它们之上添加:

fig = plt.figure()

plt.subplot(2, 2, 1)
plt.plot(x, y)

plt.subplot(2, 2, 2)
plt.plot(x, y)

plt.subplot(2, 2, 3)
plt.plot(x, y)

plt.subplot(2, 2, 4)
plt.plot(x, y)

plt.show()

您可能对以下事实感兴趣:从 matplotlib 2.1 版开始,问题中的第二个代码也可以正常工作。

来自change log

Figure class now has subplots method The Figure class now has a subplots() method which behaves the same as pyplot.subplots() but on an existing figure.

示例:

import matplotlib.pyplot as plt

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

plt.show()
import matplotlib.pyplot as plt

fig, ax = plt.subplots(2, 2)

ax[0, 0].plot(range(10), 'r') #row=0, col=0
ax[1, 0].plot(range(10), 'b') #row=1, col=0
ax[0, 1].plot(range(10), 'g') #row=0, col=1
ax[1, 1].plot(range(10), 'k') #row=1, col=1
plt.show()

  • 您也可以在子图中解压轴调用

  • 并设置子图之间是否共享x轴和y轴

像这样:

import matplotlib.pyplot as plt
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
ax1.plot(range(10), 'r')
ax2.plot(range(10), 'b')
ax3.plot(range(10), 'g')
ax4.plot(range(10), 'k')
plt.show()

如果你真的想使用循环,请执行以下操作:

def plot(data):
    fig = plt.figure(figsize=(100, 100))
    for idx, k in enumerate(data.keys(), 1):
        x, y = data[k].keys(), data[k].values
        plt.subplot(63, 10, idx)
        plt.bar(x, y)  
    plt.show()

按顺序遍历所有子图:

fig, axes = plt.subplots(nrows, ncols)

for ax in axes.flatten():
    ax.plot(x,y)

访问特定索引:

for row in range(nrows):
    for col in range(ncols):
        axes[row,col].plot(x[row], y[col])

其他答案都很好,这个答案是一个组合,可能会有用。

import numpy as np
import matplotlib.pyplot as plt

# Optional: define x for all the sub-plots
x = np.linspace(0,2*np.pi,100)

# (1) Prepare the figure infrastructure 
fig, ax_array = plt.subplots(nrows=2, ncols=2)

# flatten the array of axes, which makes them easier to iterate through and assign
ax_array = ax_array.flatten()

# (2) Plot loop
for i, ax in enumerate(ax_array):
  ax.plot(x , np.sin(x + np.pi/2*i))
  #ax.set_title(f'plot {i}')

# Optional: main title
plt.suptitle('Plots')

总结

  1. 准备图形基础设施
    • 获取ax_array,一个子图数组
    • 展平数组以便在一个数组中使用它'for loop'
  2. 情节循环
    • 遍历扁平化的 ax_array 以更新子图
    • 可选:使用枚举来跟踪子图编号
  3. 一旦展平,每个 ax_array 都可以从 0nrows x ncols -1 单独索引(例如 ax_array[0]ax_array[1]ax_array[2]ax_array[3]).

带有 pandas

的子图
  • 此答案适用于以 pandas, which, uses matplotlib 作为默认绘图后端的子图。
  • 这是创建以 pandas.DataFrame 开头的子图的四个选项
    • 实施 1. 和 2. 适用于宽格式的数据,为每一列创建子图。
    • 实施 3. 和 4. 适用于长格式数据,为列中的每个唯一值创建子图。
  • 测试于 python 3.8.11pandas 1.3.2matplotlib 3.4.3seaborn 0.11.2

进口和数据

import seaborn as sns  # data only
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# wide dataframe
df = sns.load_dataset('planets').iloc[:, 2:5]

   orbital_period   mass  distance
0         269.300   7.10     77.40
1         874.774   2.21     56.95
2         763.000   2.60     19.84
3         326.030  19.40    110.62
4         516.220  10.50    119.47

# long dataframe
dfm = sns.load_dataset('planets').iloc[:, 2:5].melt()

         variable    value
0  orbital_period  269.300
1  orbital_period  874.774
2  orbital_period  763.000
3  orbital_period  326.030
4  orbital_period  516.220

1。 subplots=Truelayout,每列

  • pandas.DataFrame.plot
  • 中使用参数subplots=Truelayout=(rows, cols)
  • 此示例使用 kind='density',但 kind 有不同的选项,这适用于所有选项。不指定 kind,默认为线图。
  • axpandas.DataFrame.plot
  • 返回的 AxesSubplot 数组
  • 如果需要,请参阅 How to get a Figure object
    • How to save pandas subplots
axes = df.plot(kind='density', subplots=True, layout=(2, 2), sharex=False, figsize=(10, 6))

# extract the figure object; only used for tight_layout in this example
fig = axes[0][0].get_figure() 

# set the individual titles
for ax, title in zip(axes.ravel(), df.columns):
    ax.set_title(title)
fig.tight_layout()
plt.show()

2。 plt.subplots,每列

  • 创建一个Axes with matplotlib.pyplot.subplots的数组,然后将axes[i, j]axes[n]传递给ax参数。
    • 此选项使用pandas.DataFrame.plot,但可以使用其他axes级绘图调用作为替代(例如sns.kdeplotplt.plot等)
    • .ravel or .flatten. See Axes的子图数组折叠成一维最简单。
    • 任何应用于每个 axes 的需要迭代的变量都与 .zip 组合(例如 colsaxescolors , palette, 等等)。每个对象的长度必须相同。
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6))  # define the figure and subplots
axes = axes.ravel()  # array to 1D
cols = df.columns  # create a list of dataframe columns to use
colors = ['tab:blue', 'tab:orange', 'tab:green']  # list of colors for each subplot, otherwise all subplots will be one color

for col, color, ax in zip(cols, colors, axes):
    df[col].plot(kind='density', ax=ax, color=color, label=col, title=col)
    ax.legend()
    
fig.delaxes(axes[3])  # delete the empty subplot
fig.tight_layout()
plt.show()

1 和 2 的结果

3。 plt.subplots,对于 .groupby

中的每个组
  • 这类似于 2.,除了它将 coloraxes 压缩到 .groupby 对象。
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6))  # define the figure and subplots
axes = axes.ravel()  # array to 1D
dfg = dfm.groupby('variable')  # get data for each unique value in the first column
colors = ['tab:blue', 'tab:orange', 'tab:green']  # list of colors for each subplot, otherwise all subplots will be one color

for (group, data), color, ax in zip(dfg, colors, axes):
    data.plot(kind='density', ax=ax, color=color, title=group, legend=False)

fig.delaxes(axes[3])  # delete the empty subplot
fig.tight_layout()
plt.show()

4。 seaborn 图形级情节

  • 使用 seaborn 图级图,并使用 colrow 参数。 seabornmatplotlib 的高级 API。参见 seaborn: API reference
p = sns.displot(data=dfm, kind='kde', col='variable', col_wrap=2, x='value', hue='variable',
                facet_kws={'sharey': False, 'sharex': False}, height=3.5, aspect=1.75)
sns.move_legend(p, "upper left", bbox_to_anchor=(.55, .45))

axes数组转换为一维

  • 使用 plt.subplots(nrows, ncols) 生成子图,其中 nrows 和 ncols 都大于 1,returns <AxesSubplot:> 对象的嵌套数组。
    • nrows=1ncols=1 的情况下,没有必要展平 axes,因为 axes 已经是一维的,这是默认参数 squeeze=True
  • 访问对象的最简单方法是使用 .ravel(), .flatten(), or .flat 将数组转换为一维。
      • flatten总是returns一份。
      • ravel returns 尽可能查看原始数组。
  • 一旦axes的数组转换为一维,就有多种绘图方式。
import matplotlib.pyplot as plt
import numpy as np  # sample data only

# example of data
rads = np.arange(0, 2*np.pi, 0.01)
y_data = np.array([np.sin(t*rads) for t in range(1, 5)])
x_data = [rads, rads, rads, rads]

# Generate figure and its subplots
fig, axes = plt.subplots(nrows=2, ncols=2)

# axes before
array([[<AxesSubplot:>, <AxesSubplot:>],
       [<AxesSubplot:>, <AxesSubplot:>]], dtype=object)

# convert the array to 1 dimension
axes = axes.ravel()

# axes after
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
      dtype=object)
  1. 遍历展平数组
    • 如果子图比数据多,这将导致 IndexError: list index out of range
      • 改为尝试选项 3,或 select 轴的子集(例如 axes[:-2]
for i, ax in enumerate(axes):
    ax.plot(x_data[i], y_data[i])
  1. 通过索引访问每个轴
axes[0].plot(x_data[0], y_data[0])
axes[1].plot(x_data[1], y_data[1])
axes[2].plot(x_data[2], y_data[2])
axes[3].plot(x_data[3], y_data[3])
  1. 索引数据和轴
for i in range(len(x_data)):
    axes[i].plot(x_data[i], y_data[i])
  1. zip 将轴和数据放在一起,然后遍历元组列表
for ax, x, y in zip(axes, x_data, y_data):
    ax.plot(x, y)

输出

这是一个简单的解决方案

fig, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=False)
for sp in fig.axes:
    sp.plot(range(10))