如何更改 seaborn 散点图中大小类别的数量

How to change the number of size categories in seaborn scatterplot

我努力浏览所有 documentation and examples 但我无法弄明白。如何更改类别数 = size 气泡的数量,以及它们在 seaborn 散点图中的边界? sizes 参数在这里没有帮助。

无论我尝试什么,它总是给我 6 个(这里是 8、16、...、48):

import seaborn as sns

tips = sns.load_dataset("tips")

sns.scatterplot(data=tips, x="total_bill", y="tip", size="total_bill")

penguins = sns.load_dataset("penguins")

sns.scatterplot(data=penguins, x="bill_length_mm", y="bill_depth_mm", size="body_mass_g")

我该如何改变他们的界限? IE。如果我想在第一种情况下使用 10、20、30、40、50,或者在第二种情况下使用 3000、4000、5000、6000?

我知道四处走动并在数据框中创建另一个列是可行的,但这不是我们想要的(添加了不必要的列,即使我即时进行,这也不是我想要的)。

解决方法:

def myfunc(mass):
    if mass <3500:
        return 3000
    elif mass <4500:
        return 4000
    elif mass <5500:
        return 5000
    return 6000

penguins["mass"] = penguins.apply(lambda x: myfunc(x['body_mass_g']), axis=1)

sns.scatterplot(data=penguins, x="bill_length_mm", y="bill_depth_mm", size="mass")

我认为 seaborn 没有 fine-grained 控件,它只是试图想出一些在许多情况下都可以直观地工作的东西,但并非适用于所有情况。 legend='full' 参数显示 size 列的 所有 值,但这可能过于庞大。

建议创建一个具有合并大小的新列的缺点是,这也会更改散点图中使用的大小。

一种方法是创建您自己的自定义图例。请注意,当图例还包含其他元素时,此方法需要稍微调整一下。

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

tips = sns.load_dataset("tips")

ax = sns.scatterplot(data=tips, x="total_bill", y="tip", size="total_bill", legend='full')
handles, labels = ax.get_legend_handles_labels()
labels = np.array([float(l) for l in labels])
desired_labels = [10, 20, 30, 40, 50]
desired_handles = [handles[np.argmin(np.abs(labels - d))] for d in desired_labels]
ax.legend(handles=desired_handles, labels=desired_labels, title=ax.legend_.get_title().get_text())
plt.show()

代码可以包装成一个函数,例如应用于企鹅:

from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np

def sizes_legend(desired_sizes, ax=None):
    ax = ax or plt.gca()
    handles, labels = ax.get_legend_handles_labels()
    labels = np.array([float(l) for l in labels])
    desired_handles = [handles[np.argmin(np.abs(labels - d))] for d in desired_sizes]
    ax.legend(handles=desired_handles, labels=desired_sizes, title=ax.legend_.get_title().get_text())

penguins = sns.load_dataset("penguins")
ax = sns.scatterplot(data=penguins, x="bill_length_mm", y="bill_depth_mm", size="body_mass_g", legend='full')
sizes_legend([3000, 4000, 5000, 6000], ax)
plt.show()