子图重复同一张图 6 次并生成 6 个数字而不是一个

Subplots repeating the same graph 6 times and producing 6 figures instead of one

所以我有这个代码:

def scatter(df, column_name):
  values = {data: list(df[data]) for data in column_name}

  data = list(values.values())
  labels = list(values.keys())
  
  for i in range(len(data)):
    for j in range(len(data)):
      if i == j:
        continue
      elif (i == 1) & (j == 0):
        continue
      elif (i == 2) & ((j == 0)|(j == 1)):
        continue
      elif (i == 3) & ((j == 0)|(j == 1)|(j == 2)):
        continue
      else:
        for k in range(6):
          ax = plt.subplot(3, 2, k+1)
          plt.scatter(data[i], data[j])
          plt.xlabel(labels[i])
          plt.ylabel(labels[j])
          plt.title('{} vs {}'.format(labels[i], labels[j]))
        plt.show()
        plt.clf()

scatter(roller_coasters, ['speed', 'height', 'length', 'num_inversions'])

但它生成 6 个数字而不是 1 个,并且每个数字都有相同的图形重复 6 次。

请帮我解决这个问题。

每次进入循环的 else 部分时,您都会为给定的 i,j 组合创建 6 个子图。例如。 for i=0; j=1 for k 的循环创建了六个子图,但仅针对特定的 ij。创建后,图形再次关闭 (plt.clf())。以下 i=0; j=2 将创建下一组 6 个子图。

您可以通过让 j 的循环从 i+1 开始来简化事情,因此不需要测试。此外,接下来将为其创建子图的值可以是一个变量 k,每次添加子图时该值都会递增。

这是一些示例代码:

from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

def scatter(df, column_names):
    fig = plt.figure(figsize=(10, 12)) # set a size for the surrounding plot
    n = len(column_names)
    total = n * (n - 1) // 2
    ncols = 2
    nrows = (total + (ncols - 1)) // ncols
    k = 1
    for i in range(n):
        col_i = column_names[i]
        for j in range(i + 1, n):
            col_j = column_names[j]
            ax = plt.subplot(nrows, ncols, k)
            plt.scatter(df[col_i], df[col_j])
            plt.xlabel(col_i)
            plt.ylabel(col_j)
            plt.title(f'{col_i} vs {col_j}')
            k += 1
    plt.tight_layout() # fit labels and ticks nicely together
    plt.show() # only called once, at the end of the function

columns = ['speed', 'height', 'length', 'num_inversions']
roller_coasters = pd.DataFrame(np.random.rand(20, len(columns)), columns=columns)
scatter(roller_coasters, ['speed', 'height', 'length', 'num_inversions'])