如何向 seaborn.heatmap 中的单元格添加影线

how to add hatches to cells in seaborn.heatmap

我尝试使用 seaborn.heatmap 可视化我的数据。

但是,我遇到的问题是,当我用 grayscle 打印出来时,图像很难看清。

我关注了很多类似的问题,但都没有用。

是否可以在 seaborn.heatmap 中的单元格上添加影线?

我的代码如下:

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns

df = pd.read_csv("file.csv")

sns.heatmap(df, annot=False, fmt='.0f', square=True,
    cmap="coolwarm", linewidths=1, cbar=False)

plt.show()

您可以创建一个循环,将值分成例如4 组并通过应用于子集的 pcolor 为每个组分配一个填充图案。

这是一个从随机测试数据开始的例子:

from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

column_names = [f'{c:.2f}' for c in np.arange(0, 1.5001, 0.05)]
row_names = ['Alkaid', 'Mizar', 'Alioth', 'Megrez', 'Phecda', 'Merak', 'Dubhe']
df = pd.DataFrame(np.random.normal(0.3, 1, (len(row_names), len(column_names))).cumsum(axis=1) + 5,
                  columns=column_names, index=row_names)

values = df.values
vmin = values.min()
vmax = values.max()
patterns = ['', 'oo', '////', 'XXX']
bounds = np.linspace(vmin, vmax, len(patterns) + 1)
bounds[-1] += 1
sns.set_style('white')
fig, ax = plt.subplots(figsize=(12, 5))
sns.heatmap(data=df, linewidths=1, square=True, cmap='coolwarm', linecolor='white', cbar=False, ax=ax)
x = np.arange(df.shape[1] + 1)
y = np.arange(df.shape[0] + 1)
handles = []
norm = plt.Normalize(vmin, vmax)
cmap = plt.get_cmap('coolwarm')
for pattern, b0, b1 in zip(patterns, bounds[:-1], bounds[1:]):
    ax.pcolor(x, y, np.where((values >= b0) & (values < b1), values, np.nan), cmap=cmap, norm=norm,
              hatch=pattern, ec='black', lw=1)
    handles.append(plt.Rectangle((0, 0), 0, 0, color=cmap(norm((b0 + b1) / 2)), ec='black',
                                 hatch=pattern, label=f'{b0:5.2f}-{b1:5.2f}'))
ax.hlines(y, 0, x.max(), color='w', lw=2)
ax.vlines(x, 0, y.max(), color='w', lw=2)
ax.legend(handles=handles, bbox_to_anchor=(1.01, 1.02), loc='upper left',
          handlelength=2, handleheight=2, frameon=False)
plt.tight_layout()
plt.show()