seaborn heatmap annotation ValueError: Unknown format code 'g' for object of type 'numpy.str_'

seaborn heatmap annotation ValueError: Unknown format code 'g' for object of type 'numpy.str_'

我想画一个seaborn.heatmap,只标注一些rows/columns。
所有单元格都有注释的示例:

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np


n1 = 5
n2 = 10
M = np.random.random((n1, n2))   

fig, ax = plt.subplots()

sns.heatmap(ax = ax, data = M, annot = True)

plt.show()

按照这些例子(段落Adding Value Annotations),可以通过seaborn.heatmap 一个数组,每个单元格的注释作为 annot 参数:

annot : bool or rectangular dataset, optional
If True, write the data value in each cell. If an array-like with the same shape as data, then use this to annotate the heatmap instead of the data. Note that DataFrames will match on position, not index.

如果我尝试生成一个 str 数组并将其作为 annot 参数传递给 seaborn.heatmap,我会收到以下错误:

Traceback (most recent call last):
  File "C:/.../myfile.py", line 16, in <module>
    sns.heatmap(ax = ax, data = M, annot = A)
  File "C:\venv\lib\site-packages\seaborn\_decorators.py", line 46, in inner_f
    return f(**kwargs)
  File "C:\venv\lib\site-packages\seaborn\matrix.py", line 558, in heatmap
    plotter.plot(ax, cbar_ax, kwargs)
  File "C:\venv\lib\site-packages\seaborn\matrix.py", line 353, in plot
    self._annotate_heatmap(ax, mesh)
  File "C:\venv\lib\site-packages\seaborn\matrix.py", line 262, in _annotate_heatmap
    annotation = ("{:" + self.fmt + "}").format(val)
ValueError: Unknown format code 'g' for object of type 'numpy.str_'

生成 ValueError 的代码(在本例中,我尝试删除第 4th 列的注释作为示例):

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np


n1 = 5
n2 = 10
M = np.random.random((n1, n2))

A = np.array([[f'{M[i, j]:.2f}' for j in range(n2)] for i in range(n1)])
A[:, 3] = ''


fig, ax = plt.subplots(figsize = (6, 3))

sns.heatmap(ax = ax, data = M, annot = A)

plt.show()

这个错误的原因是什么?
如何生成 seaborn.heatmap 并仅注释选定的 rows/columns?

这是格式问题。如果您使用 non-numeric 标签(默认为:fmt='.2g'),则此处需要 fmt = '',它仅考虑数值并为文本格式的标签抛出错误。

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np


n1 = 5
n2 = 10
M = np.random.random((n1, n2))

A = np.array([[f'{M[i, j]:.2f}' for j in range(n2)] for i in range(n1)])
A[:, 3] = ''


fig, ax = plt.subplots(figsize = (6, 3))

sns.heatmap(ax = ax, data = M, annot = A, fmt='')

plt.show()