使用 matplotlib 的 imshow 绘制具有相同颜色分​​配的多个图像

Plot multiple images with identical color assignments using matplotlib's imshow

我有多个图像(numpy 数组),其数据值对应于 N 个不同的 classes。每个图像不一定包含每个 class 的示例。例如,可能总共有 12 个不同的 classes (0:11),但是,一张图像可能只包含 classes 1:9.

我想绘制每个图像,以便分配给每个 class 的颜色在所有图像中都相同。

我研究了几个答案: the accepted and popular answers didn't work across multiple images. here 似乎可行,但我真的很想使用颜色图 (from matplotlib import cm) 以免手动设置颜色。我还想要一种创建包含所有 classes.

的适当颜色条的方法

我试过的代码如下:

import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt

t1 = np.arange(9).reshape(3,3)
t2 = t1.copy()
t2[1,1] = 10
t3 = t2.copy()
t3[1,1] = 11

cmap = cm.get_cmap('tab20', 11)

fig, axs = plt.subplots(1,3)

axs[0].imshow(t1, cmap = cmap, vmin = 0, vmax = 11)
axs[1].imshow(t2, cmap = cmap, vmin = 0, vmax = 11)
axs[2].imshow(t3, cmap = cmap, vmin = 0, vmax = 11)

看起来 cm.get_cmap 需要调整以处理图像中所有可能的 categories/classes。以下代码有效:

import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt

t1 = np.arange(9).reshape(3,3)
t2 = t1.copy()
t2[1,1] = 10
t3 = t2.copy()
t3[1,1] = 11

cmap = cm.get_cmap('tab20', 12)

fig, axs = plt.subplots(1,3)

axs[0].imshow(t1, cmap = cmap, vmin = 0, vmax = 11)
axs[1].imshow(t2, cmap = cmap, vmin = 0, vmax = 11)
axs[2].imshow(t3, cmap = cmap, vmin = 0, vmax = 11)

为了将来参考,如果您想定义自己的颜色而不是预定义的 cmap,我前段时间专门为此创建了以下代码。

import matplotlib as mpl
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np

C_p = 11 # Classes

colour_names = [ # Your predefined colours
    "blue",
    "red",
    "yellow",
    "orange",
    "black",
    "purple",
    "green",
    "turquoise",
    "grey",
    "maroon",
    "silver",
    "white"
]

colour_dict = { # Color mapping (class -> colour)
    i: mpl.colors.to_rgb(colour_names[i])
    for i in range(C_p + 1)
}

# Create a colormap (optional)
colours_rgb = [colour_dict[i] for i in range(C_p)]
colours = mpl.colors.ListedColormap(colours_rgb)

norm = mpl.colors.BoundaryNorm(np.arange(C_p + 1) - 0.5, C_p)

plt.figure() # If you only want to plot one
plt.imshow(t2, cmap=colours, norm=norm)
cb = plt.colorbar(ticks=np.arange(C_p))
plt.axis("off")

以您的 t1t2t3 为例:

fig, axs = plt.subplots(1,3)
axs[0].imshow(t1, cmap = colours, norm=norm)
axs[0].set_title("t1")
axs[0].axis('off')
axs[1].imshow(t2, cmap = colours, norm=norm)
axs[1].set_title("t2")
axs[1].axis('off')
im = axs[2].imshow(t3, cmap = colours, norm=norm)
axs[2].set_title("t3")
axs[2].axis('off')
p0 = axs[0].get_position().get_points().flatten()
p1 = axs[1].get_position().get_points().flatten()
p2 = axs[2].get_position().get_points().flatten()
ax_cbar = fig.add_axes([p0[0], 0.08, p2[0], 0.05])
plt.colorbar(im, cax=ax_cbar, ticks=np.arange(C_p), orientation='horizontal')
fig.tight_layout()