matplotlib 一直在写同一个数字。我需要回归结果作为单独的 png。我应该如何修改代码?
matplotlib keeps writing over the same figure. I need regression results as separate png's. How should I modify the code?
这是代码。在实际代码中还有另外两个回归,它们的结果也写在同一张图上,如下图所示
import pandas as pd
import os
import statsmodels.api as sm
import matplotlib.pyplot as plt
IN_PATH = os.path.join("data", "clean", "imdb_clean.csv")
OUTPUT_DIR = "quantitative analysis"
REVENUE_IMDB_OLS_PATH = os.path.join(OUTPUT_DIR, "revenue_imdb_ols_regression.png")
IMDB_OLS_PATH = os.path.join(OUTPUT_DIR, "imdb_ols_regression.png")
df = pd.read_csv(IN_PATH)
dummy_cols = df.columns[10:-1]
def revenue_imdb_ols_regression(out_path):
'''Perform OLS regression of movie Revenue on IMBD Rating, Release Year, and genre dummies and create csv'''
x_cols = ["IMDBRating", "ReleaseYear"]
for col in dummy_cols:
x_cols.append(col)
x = df[x_cols]
y = df["GrossRevenue"]
model = sm.OLS(y, sm.add_constant(x)).fit()
model_summary = model.summary()
plt.rc("figure", figsize=(12, 7))
plt.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
plt.axis("off")
plt.tight_layout()
plt.savefig(out_path)
def imdb_ols_regression(out_path):
'''Perform OLS regression of IMBD Rating on genre dummies and create csv'''
x = df[dummy_cols]
y = df["IMDBRating"]
model = sm.OLS(y, sm.add_constant(x)).fit()
model_summary = model.summary()
plt.rc("figure", figsize=(12, 7))
plt.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
plt.axis("off")
plt.tight_layout()
plt.savefig(out_path)
if __name__ == "__main__":
os.makedirs(OUTPUT_DIR, exist_ok=True)
revenue_imdb_ols_regression(REVENUE_IMDB_OLS_PATH)
imdb_ols_regression(IMDB_OLS_PATH)
def revenue_imdb_ols_regression(out_path):
'''Perform OLS regression of movie Revenue on IMBD Rating, Release Year, and genre dummies and create csv'''
x_cols = ["IMDBRating", "ReleaseYear"]
for col in dummy_cols:
x_cols.append(col)
x = df[x_cols]
y = df["GrossRevenue"]
model = sm.OLS(y, sm.add_constant(x)).fit()
model_summary = model.summary()
fig, ax = plt.subplots(figsize=(12, 7))
ax.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
ax.axis("off")
plt.tight_layout()
fig.savefig(out_path)
fig.set_tight_layout(True)
而不是 plt.tight_layout()
可能效果更好 - 试试看
这是代码。在实际代码中还有另外两个回归,它们的结果也写在同一张图上,如下图所示
import pandas as pd
import os
import statsmodels.api as sm
import matplotlib.pyplot as plt
IN_PATH = os.path.join("data", "clean", "imdb_clean.csv")
OUTPUT_DIR = "quantitative analysis"
REVENUE_IMDB_OLS_PATH = os.path.join(OUTPUT_DIR, "revenue_imdb_ols_regression.png")
IMDB_OLS_PATH = os.path.join(OUTPUT_DIR, "imdb_ols_regression.png")
df = pd.read_csv(IN_PATH)
dummy_cols = df.columns[10:-1]
def revenue_imdb_ols_regression(out_path):
'''Perform OLS regression of movie Revenue on IMBD Rating, Release Year, and genre dummies and create csv'''
x_cols = ["IMDBRating", "ReleaseYear"]
for col in dummy_cols:
x_cols.append(col)
x = df[x_cols]
y = df["GrossRevenue"]
model = sm.OLS(y, sm.add_constant(x)).fit()
model_summary = model.summary()
plt.rc("figure", figsize=(12, 7))
plt.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
plt.axis("off")
plt.tight_layout()
plt.savefig(out_path)
def imdb_ols_regression(out_path):
'''Perform OLS regression of IMBD Rating on genre dummies and create csv'''
x = df[dummy_cols]
y = df["IMDBRating"]
model = sm.OLS(y, sm.add_constant(x)).fit()
model_summary = model.summary()
plt.rc("figure", figsize=(12, 7))
plt.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
plt.axis("off")
plt.tight_layout()
plt.savefig(out_path)
if __name__ == "__main__":
os.makedirs(OUTPUT_DIR, exist_ok=True)
revenue_imdb_ols_regression(REVENUE_IMDB_OLS_PATH)
imdb_ols_regression(IMDB_OLS_PATH)
def revenue_imdb_ols_regression(out_path):
'''Perform OLS regression of movie Revenue on IMBD Rating, Release Year, and genre dummies and create csv'''
x_cols = ["IMDBRating", "ReleaseYear"]
for col in dummy_cols:
x_cols.append(col)
x = df[x_cols]
y = df["GrossRevenue"]
model = sm.OLS(y, sm.add_constant(x)).fit()
model_summary = model.summary()
fig, ax = plt.subplots(figsize=(12, 7))
ax.text(0.01, 0.05, str(model_summary), {"fontsize": 10}, fontproperties = "monospace")
ax.axis("off")
plt.tight_layout()
fig.savefig(out_path)
fig.set_tight_layout(True)
而不是 plt.tight_layout()
可能效果更好 - 试试看