matplotlib 一直在写同一个数字。我需要回归结果作为单独的 png。我应该如何修改代码?

matplotlib keeps writing over the same figure. I need regression results as separate png's. How should I modify the code?

这是代码。在实际代码中还有另外两个回归,它们的结果也写在同一张图上,如下图所示

import pandas as pd
import os
import statsmodels.api as sm
import matplotlib.pyplot as plt

IN_PATH = os.path.join("data", "clean", "imdb_clean.csv")
OUTPUT_DIR = "quantitative analysis"
REVENUE_IMDB_OLS_PATH = os.path.join(OUTPUT_DIR, "revenue_imdb_ols_regression.png")
IMDB_OLS_PATH = os.path.join(OUTPUT_DIR, "imdb_ols_regression.png")

df = pd.read_csv(IN_PATH)
dummy_cols = df.columns[10:-1]


def revenue_imdb_ols_regression(out_path):
    '''Perform OLS regression of movie Revenue on IMBD Rating, Release Year, and genre dummies and create csv'''
    
    x_cols = ["IMDBRating", "ReleaseYear"]
    for col in dummy_cols:
        x_cols.append(col)

    x = df[x_cols]
    y = df["GrossRevenue"]
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    model_summary = model.summary()
    
    
    plt.rc("figure", figsize=(12, 7))
    plt.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(out_path)

def imdb_ols_regression(out_path):
    '''Perform OLS regression of IMBD Rating on genre dummies and create csv'''
    
    x = df[dummy_cols]
    y = df["IMDBRating"]

    model = sm.OLS(y, sm.add_constant(x)).fit()
    model_summary = model.summary()
    
    
    plt.rc("figure", figsize=(12, 7))
    plt.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(out_path)

if __name__ == "__main__":
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    revenue_imdb_ols_regression(REVENUE_IMDB_OLS_PATH)
    imdb_ols_regression(IMDB_OLS_PATH)

def revenue_imdb_ols_regression(out_path):
    '''Perform OLS regression of movie Revenue on IMBD Rating, Release Year, and genre dummies and create csv'''
    
    x_cols = ["IMDBRating", "ReleaseYear"]
    for col in dummy_cols:
        x_cols.append(col)

    x = df[x_cols]
    y = df["GrossRevenue"]
    
    model = sm.OLS(y, sm.add_constant(x)).fit()
    model_summary = model.summary()
    
    
    fig, ax = plt.subplots(figsize=(12, 7))
    
    ax.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
    ax.axis("off")
    plt.tight_layout()
    fig.savefig(out_path)

fig.set_tight_layout(True) 而不是 plt.tight_layout() 可能效果更好 - 试试看