使用循环内不同函数的子图创建图形

Create figure with subplots from different functions within loop

here类似,我尝试创建一个循环,该循环生成一个带有预定义函数子图的图形。这些函数创建不同类型的图形(如线图或 tables)并且已经使用 plt.subplots。最后,我想通过循环为数据集中的每个国家/地区创建一个包含多个子图的图形。然后,特定国家/地区的数据将保存在 pdf 文件的各个页面上。

import pandas as pd
from pandas import DataFrame
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.gridspec as gridspec


dataset = pd.DataFrame({'country':['USA','USA','USA','UK','UK','UK'],
                        'year': [2006,2007,2008,2006,2007,2008],
                        'gdp':   [10,13,7,8,2,10],
                        'empowerment':   [0.2,0.13,0.7,0.8,0.2,0.10],
                        'solidarity':   [0.4,0.63,0.3,0.66,0.85,0.9],
                        'envir':   [55,34,79,65,59,88]})

函数构造如下:

def prepare_line(countries):
    plt.close()
    select_country = dataset[dataset.country == countries]
    select_country = select_country.round(4)
    # create figure and axis objects with subplots()
    fig, ax = plt.subplots(figsize=(12, 5))

    # Line plots
    ind1 = ax.plot(select_country.year, select_country.empowerment, color="blue",  
                   label="Empowerment Index")
    ind2 = ax.plot(select_country.year, select_country.solidarity, color="red", 
                   label="Solidarity Index")

    # set x-axis label
    ax.set_xlabel("year", fontsize=14)
    # set y-axis label
    ax.set_ylabel("Solidarity & Agency Scores", fontsize=14)

    ax2 = ax.twinx()
    axes = plt.gca()
    axes.yaxis.grid()
    # make a plot with different y-axis using second axis object
    ind3 = ax2.plot(select_country.year, select_country["gdp"], color="green", 
                    label="GDP per Capita (const. US 2010)")
    ax2.set_ylabel("GDP per Capita", fontsize=14)
    plt.title(countries, fontsize=18)
    plt.xticks(np.arange(min(visual.index), max(visual.index)+1, 1.0))
    # add figures to get labels
    ind = ind1 + ind2 + ind3
    labs = [l.get_label() for l in ind]
    # define location of legend
    ax.legend(ind, labs, loc=2)

    return fig

def prepare_table(countries):
    select_country = dataset[dataset.country == countries]
    data_table = DataFrame(select_country, columns=['year', 'empowerment', 'solidarity', 'gdp', 'envir'])

    decimals = pd.Series([2, 2, 0, 0], index=['empowerment', 'solidarity', 'gdp', 'envir'])
    data_round = data_table.round(decimals)
    data_round['gdp'] = data_round['gdp'].astype('Int64')
    data_round['envir'] = data_round['envir'].astype('Int64')
    
    data_round = pd.DataFrame(data_round)

    data_round = data_round.fillna(0)

    plt.figure()

    # table
    
    fig_table = plt.table(cellText=data_round.values, 
                          colLabels=['Year', 'Emp. Ind.',
                         'Sol. Ind.', 'GDP p.C.', 'Env. Ind.'], 
                          loc='center')
    fig_table.auto_set_font_size(False)
    fig_table.set_fontsize(10)
    fig_table.scale(1.8, 1.5)
    plt.axis('off')
    
    
    return fig_table

要为不同的国家生成一个包含单独页面的 pdf,其中包含从上述功能生成的数字作为子图,我使用以下内容:

!pip import simplejson as json
import webbrowser

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.gridspec as gridspec

country = dataset.country.unique()


if __name__ == "__main__":
  
  with PdfPages('Summary.pdf') as pdf:
    
    for i in country:
      gs = gridspec.GridSpec(2, 2)
      ax1 = prepare_line(i)
      ax2 = prepare_table(i) 
    pdf.savefig(gs)

不幸的是,线图和 table 都没有保存在 pdf 文件中,而是只为所有国家迭代生成。

我尝试了其他几种配置,包括在各个函数中包含 'ax' 参数的结构。任何帮助是极大的赞赏。功能乱七八糟的,抱歉。

############################################# ###############################

编辑: @gepcel 的解决方案适用于上述情况。但是,我 运行 在尝试将雷达图作为子图嵌入时遇到了问题。我用的雷达图函数如下:

def prepare_spider(ax, countries):
    #plt.close()
    select_country = spider_data.loc[(spider_data['country'] == countries)]
    base = select_country.replace({'year': {3000:"Baseline 2009"}})
    year = base.year
    data_legend = list(year)

    data_red = DataFrame(select_country, columns=['spidersolidarity',  'spidergdp', 'spiderempowerment', 'spiderenvir'])
    data = data_red.values.tolist()

    N = 4
    theta = radar_factory(N, frame='polygon')

    spoke_labels = ["Solidarity", "GDP", "Agency", "EPI"]
    title = countries

    plt.sca(ax)

    fig_spider, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='radar'))
    fig_spider.subplots_adjust(top=0.85, bottom=0.05)

    ax.set_ylim(-3, 3)
    ax.set_rgrids([])
    ax.set_title(title, position=(0.5, 1.1), ha='center')

    for d in data:
        line = ax.plot(theta, d)
    ax.set_varlabels(spoke_labels)

    labels = (data_legend)
    legend = ax.legend(labels, loc=(0.8, .08),
                       labelspacing=0.1, fontsize='small')
    
    return fig 

具体来说,我看不到在哪里嵌入 plt.sca(ax) 也包含 fig_spider, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='radar'))

该函数按定义 here 调用 radar_factory(num_vars, frame='circle') 并使用如下规范化数据:

spider_data = pd.DataFrame({'country':['USA','USA','USA','UK','UK','UK'],
                            'year': [2009,2019,3000,2009,2019,3000],
                            'spiderempowerment':   [0.5,0.6,0.7,0.8,0.2,0.10],
                            'spidersolidarity':   [0.4,0.63,0.3,0.66,0.85,0.9],
                            'spidergdp':   [0.10,0.13,0.7,0.8,0.2,0.10],
                            'spiderenvir':   [0.55,0.34,0.79,0.65,0.59,0.88]})

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D


def radar_factory(num_vars, frame='circle'):
    """Create a radar chart with `num_vars` axes.

    This function creates a RadarAxes projection and registers it.

    Parameters
    ----------
    num_vars : int
        Number of variables for radar chart.
    frame : {'circle' | 'polygon'}
        Shape of frame surrounding axes.

    """
    # calculate evenly-spaced axis angles
    theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)

    class RadarAxes(PolarAxes):

        name = 'radar'

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # rotate plot such that the first axis is at the top
            self.set_theta_zero_location('N')

        def fill(self, *args, closed=True, **kwargs):
            """Override fill so that line is closed by default"""
            return super().fill(closed=closed, *args, **kwargs)

        def plot(self, *args, **kwargs):
            """Override plot so that line is closed by default"""
            lines = super().plot(*args, **kwargs)
            for line in lines:
                self._close_line(line)

        def _close_line(self, line):
            x, y = line.get_data()
            # FIXME: markers at x[0], y[0] get doubled-up
            if x[0] != x[-1]:
                x = np.concatenate((x, [x[0]]))
                y = np.concatenate((y, [y[0]]))
                line.set_data(x, y)

        def set_varlabels(self, labels):
            self.set_thetagrids(np.degrees(theta), labels)

        def _gen_axes_patch(self):
            # The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
            # in axes coordinates.
            if frame == 'circle':
                return Circle((0.5, 0.5), 0.5)
            elif frame == 'polygon':
                return RegularPolygon((0.5, 0.5), num_vars,
                                      radius=.5, edgecolor="k")
            else:
                raise ValueError("unknown value for 'frame': %s" % frame)

        def draw(self, renderer):
            """ Draw. If frame is polygon, make gridlines polygon-shaped """
            if frame == 'polygon':
                gridlines = self.yaxis.get_gridlines()
                for gl in gridlines:
                    gl.get_path()._interpolation_steps = num_vars
            super().draw(renderer)


        def _gen_axes_spines(self):
            if frame == 'circle':
                return super()._gen_axes_spines()
            elif frame == 'polygon':
                # spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
                spine = Spine(axes=self,
                              spine_type='circle',
                              path=Path.unit_regular_polygon(num_vars))
                # unit_regular_polygon gives a polygon of radius 1 centered at
                # (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
                # 0.5) in axes coordinates.
                spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
                                    + self.transAxes)


                return {'polar': spine}
            else:
                raise ValueError("unknown value for 'frame': %s" % frame)

    register_projection(RadarAxes)
    return theta

代码中有很多错误。 PdfPages 每页保存一张图。所以你应该为每个国家生成一个数字,有两个轴(线图和 table)。完整代码如下:

import pandas as pd
from pandas import DataFrame
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.gridspec as gridspec


dataset = pd.DataFrame({'country':['USA','USA','USA','UK','UK','UK'],
                        'year': [2006,2007,2008,2006,2007,2008],
                        'gdp':   [10,13,7,8,2,10],
                        'empowerment':   [0.2,0.13,0.7,0.8,0.2,0.10],
                        'solidarity':   [0.4,0.63,0.3,0.66,0.85,0.9],
                        'envir':   [55,34,79,65,59,88]})

def prepare_line(ax, countries):
    # plt.close()
    select_country = dataset[dataset.country == countries]
    select_country = select_country.round(4)
    # create figure and axis objects with subplots()
    # fig, ax = plt.subplots(figsize=(12, 5))
    plt.sca(ax)
    # Line plots
    ind1 = ax.plot(select_country.year, select_country.empowerment, color="blue",  
                   label="Empowerment Index")
    ind2 = ax.plot(select_country.year, select_country.solidarity, color="red", 
                   label="Solidarity Index")

    # set x-axis label
    ax.set_xlabel("year", fontsize=14)
    # set y-axis label
    ax.set_ylabel("Solidarity & Agency Scores", fontsize=14)

    ax2 = ax.twinx()
    axes = plt.gca()
    axes.yaxis.grid()
    # make a plot with different y-axis using second axis object
    ind3 = ax2.plot(select_country.year, select_country["gdp"], color="green", 
                    label="GDP per Capita (const. US 2010)")
    ax2.set_ylabel("GDP per Capita", fontsize=14)
    plt.title(countries, fontsize=18)
    # plt.xticks(np.arange(min(visual.index), max(visual.index)+1, 1.0))
    # add figures to get labels
    ind = ind1 + ind2 + ind3
    labs = [l.get_label() for l in ind]
    # define location of legend
    ax.legend(ind, labs, loc=2)

    return fig


def prepare_table(ax, countries):
    select_country = dataset[dataset.country == countries]
    data_table = DataFrame(select_country, columns=['year', 'empowerment', 'solidarity', 'gdp', 'envir'])

    decimals = pd.Series([2, 2, 0, 0], index=['empowerment', 'solidarity', 'gdp', 'envir'])
    data_round = data_table.round(decimals)
    data_round['gdp'] = data_round['gdp'].astype('Int64')
    data_round['envir'] = data_round['envir'].astype('Int64')
    
    data_round = pd.DataFrame(data_round)

    data_round = data_round.fillna(0)
    plt.sca(ax)
    # plt.figure()

    # table
    
    fig_table = plt.table(cellText=data_round.values, 
                          colLabels=['Year', 'Emp. Ind.',
                         'Sol. Ind.', 'GDP p.C.', 'Env. Ind.'], 
                          loc='center')
    fig_table.auto_set_font_size(False)
    fig_table.set_fontsize(10)
    fig_table.scale(1.8, 1.5)
    plt.axis('off')
    
    
    return fig_table


country = dataset.country.unique()

with PdfPages('Summary.pdf') as pdf:

    for i in country:
        fig, axs = plt.subplots(2, 1, figsize=(12, 9))
        ax1 = prepare_line(axs[0], i)
        ax2 = prepare_table(axs[1], i) 
        fig.tight_layout()
        pdf.savefig(fig)
        plt.close()

编辑:添加雷达图。请注意,我给出的示例代码只是根据您的代码组合进行的最小修改。应该优化一下。

import pandas as pd
from pandas import DataFrame
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.gridspec as gridspec


dataset = pd.DataFrame({'country':['USA','USA','USA','UK','UK','UK'],
                        'year': [2006,2007,2008,2006,2007,2008],
                        'gdp':   [10,13,7,8,2,10],
                        'empowerment':   [0.2,0.13,0.7,0.8,0.2,0.10],
                        'solidarity':   [0.4,0.63,0.3,0.66,0.85,0.9],
                        'envir':   [55,34,79,65,59,88]})

def prepare_line(ax, countries):
    # plt.close()
    select_country = dataset[dataset.country == countries]
    select_country = select_country.round(4)
    # create figure and axis objects with subplots()
    # fig, ax = plt.subplots(figsize=(12, 5))
    plt.sca(ax)
    # Line plots
    ind1 = ax.plot(select_country.year, select_country.empowerment, color="blue",  
                   label="Empowerment Index")
    ind2 = ax.plot(select_country.year, select_country.solidarity, color="red", 
                   label="Solidarity Index")

    # set x-axis label
    ax.set_xlabel("year", fontsize=14)
    # set y-axis label
    ax.set_ylabel("Solidarity & Agency Scores", fontsize=14)

    ax2 = ax.twinx()
    axes = plt.gca()
    axes.yaxis.grid()
    # make a plot with different y-axis using second axis object
    ind3 = ax2.plot(select_country.year, select_country["gdp"], color="green", 
                    label="GDP per Capita (const. US 2010)")
    ax2.set_ylabel("GDP per Capita", fontsize=14)
    plt.title(countries, fontsize=18)
    # plt.xticks(np.arange(min(visual.index), max(visual.index)+1, 1.0))
    # add figures to get labels
    ind = ind1 + ind2 + ind3
    labs = [l.get_label() for l in ind]
    # define location of legend
    ax.legend(ind, labs, loc=2)

    return fig


def prepare_table(ax, countries):
    select_country = dataset[dataset.country == countries]
    data_table = DataFrame(select_country, columns=['year', 'empowerment', 'solidarity', 'gdp', 'envir'])

    decimals = pd.Series([2, 2, 0, 0], index=['empowerment', 'solidarity', 'gdp', 'envir'])
    data_round = data_table.round(decimals)
    data_round['gdp'] = data_round['gdp'].astype('Int64')
    data_round['envir'] = data_round['envir'].astype('Int64')
    
    data_round = pd.DataFrame(data_round)

    data_round = data_round.fillna(0)
    plt.sca(ax)
    # plt.figure()

    # table
    
    fig_table = plt.table(cellText=data_round.values, 
                          colLabels=['Year', 'Emp. Ind.',
                         'Sol. Ind.', 'GDP p.C.', 'Env. Ind.'], 
                          loc='center')
    fig_table.auto_set_font_size(False)
    fig_table.set_fontsize(10)
    fig_table.scale(1.8, 1.5)
    plt.axis('off')
    
    
    return fig_table


spider_data = pd.DataFrame({'country':['USA','USA','USA','UK','UK','UK'],
                            'year': [2009,2019,3000,2009,2019,3000],
                            'spiderempowerment':   [0.5,0.6,0.7,0.8,0.2,0.10],
                            'spidersolidarity':   [0.4,0.63,0.3,0.66,0.85,0.9],
                            'spidergdp':   [0.10,0.13,0.7,0.8,0.2,0.10],
                            'spiderenvir':   [0.55,0.34,0.79,0.65,0.59,0.88]})

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D


def radar_factory(num_vars, frame='circle'):
    """Create a radar chart with `num_vars` axes.

    This function creates a RadarAxes projection and registers it.

    Parameters
    ----------
    num_vars : int
        Number of variables for radar chart.
    frame : {'circle' | 'polygon'}
        Shape of frame surrounding axes.

    """
    # calculate evenly-spaced axis angles
    theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)

    class RadarAxes(PolarAxes):

        name = 'radar'

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # rotate plot such that the first axis is at the top
            self.set_theta_zero_location('N')

        def fill(self, *args, closed=True, **kwargs):
            """Override fill so that line is closed by default"""
            return super().fill(closed=closed, *args, **kwargs)

        def plot(self, *args, **kwargs):
            """Override plot so that line is closed by default"""
            lines = super().plot(*args, **kwargs)
            for line in lines:
                self._close_line(line)

        def _close_line(self, line):
            x, y = line.get_data()
            # FIXME: markers at x[0], y[0] get doubled-up
            if x[0] != x[-1]:
                x = np.concatenate((x, [x[0]]))
                y = np.concatenate((y, [y[0]]))
                line.set_data(x, y)

        def set_varlabels(self, labels):
            self.set_thetagrids(np.degrees(theta), labels)

        def _gen_axes_patch(self):
            # The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
            # in axes coordinates.
            if frame == 'circle':
                return Circle((0.5, 0.5), 0.5)
            elif frame == 'polygon':
                return RegularPolygon((0.5, 0.5), num_vars,
                                      radius=.5, edgecolor="k")
            else:
                raise ValueError("unknown value for 'frame': %s" % frame)

        def draw(self, renderer):
            """ Draw. If frame is polygon, make gridlines polygon-shaped """
            if frame == 'polygon':
                gridlines = self.yaxis.get_gridlines()
                for gl in gridlines:
                    gl.get_path()._interpolation_steps = num_vars
            super().draw(renderer)


        def _gen_axes_spines(self):
            if frame == 'circle':
                return super()._gen_axes_spines()
            elif frame == 'polygon':
                # spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
                spine = Spine(axes=self,
                              spine_type='circle',
                              path=Path.unit_regular_polygon(num_vars))
                # unit_regular_polygon gives a polygon of radius 1 centered at
                # (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
                # 0.5) in axes coordinates.
                spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
                                    + self.transAxes)


                return {'polar': spine}
            else:
                raise ValueError("unknown value for 'frame': %s" % frame)

    register_projection(RadarAxes)
    return theta


def prepare_spider(ax, countries):
    #plt.close()
    select_country = spider_data.loc[(spider_data['country'] == countries)]
    base = select_country.replace({'year': {3000:"Baseline 2009"}})
    year = base.year
    data_legend = list(year)

    data_red = DataFrame(select_country, columns=['spidersolidarity',  'spidergdp', 'spiderempowerment', 'spiderenvir'])
    data = data_red.values.tolist()

    N = 4
    theta = radar_factory(N, frame='polygon')

    spoke_labels = ["Solidarity", "GDP", "Agency", "EPI"]
    title = countries

    # plt.sca(ax)
    ax.set_ylim(-3, 3)
    ax.set_rgrids([])
    ax.set_title(title, position=(0.5, 1.1), ha='center')

    for d in data:
        line = ax.plot(theta, d)
    ax.set_varlabels(spoke_labels)

    labels = (data_legend)
    legend = ax.legend(labels, loc=(0.8, .08),
                       labelspacing=0.1, fontsize='small')
    
    return ax

country = dataset.country.unique()

with PdfPages('Summary.pdf') as pdf:

    for i in country:
        fig = plt.figure(figsize=(12, 12))
        ax1 = plt.subplot(311)
        ax1 = prepare_line(ax1, i)
        ax2 = plt.subplot(312)
        ax2 = prepare_table(ax2, i) 
        ax3 = plt.subplot(313, projection='radar')
        prepare_spider(ax3, i)
        fig.tight_layout()
        pdf.savefig(fig)
        plt.close()