不能将自定义非线性颜色图与 imshow 结合使用

Cannot use custom non linear colormap in combination with imshow

我正在尝试使用自定义颜色图来显示 ConfusionMatrixDisplay 对象,使其在 0 到 50 之间的范围比使用 this answer 的 50 到 100 之间的范围更精细。

from sklearn.datasets import make_classification
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.pyplot as plt
import numpy as np

plt.rcParams["figure.figsize"] = (15, 15)
font = {'family' : 'DejaVu Sans',
    'weight' : 'bold',
    'size'   : 22}
plt.rc('font', **font)


class nlcmap(LinearSegmentedColormap):
    def __init__(self, cmap, levels):
        self.cmap = cmap
        self.N = cmap.N
        self.monochrome = self.cmap.monochrome
        self.levels = np.asarray(levels, dtype='float64')
        self._x = self.levels
        self.levmax = self.levels.max()
        self.transformed_levels = np.linspace(0.0, self.levmax, len(self.levels))

    def __call__(self, xi, alpha=1.0, **kw):
        yi = np.interp(xi, self._x, self.transformed_levels)
        return self.cmap(yi / self.levmax, alpha)


levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 100]

cmap_nonlin = nlcmap(plt.cm.viridis, levels)
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                            random_state=0)
clf = SVC(random_state=0)
clf.fit(X_train, y_train)
SVC(random_state=0)
predictions = clf.predict(X_test)
cm = confusion_matrix(y_test, predictions, labels=clf.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                             display_labels=clf.classes_)
lin_cmap = plt.cm.viridis
levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 100]
cmap_nonlin = nlcmap(plt.cm.viridis, levels)
fig, ax = plt.subplots()
im = disp.plot(cmap=cmap_nonlin, colorbar=False)
disp.ax_.get_images()[0].set_clim(0, 100)
disp.figure_.colorbar(disp.im_, orientation="horizontal", pad=0.1)
plt.savefig("test.png")

产生以下错误:

Traceback (most recent call last):
  File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/backends/backend_macosx.py", line 61, in _draw
    self.figure.draw(renderer)
  File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/artist.py", line 41, in draw_wrapper
    return draw(artist, renderer, *args, **kwargs)
  File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/figure.py", line 1864, in draw
    renderer, self, artists, self.suppressComposite)
  File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/image.py", line 131, in _draw_list_compositing_images
    a.draw(renderer)
  File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/artist.py", line 41, in draw_wrapper
    return draw(artist, renderer, *args, **kwargs)
  File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py", line 411, in wrapper
    return func(*inner_args, **inner_kwargs)
  File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/axes/_base.py", line 2747, in draw
    mimage._draw_list_compositing_images(renderer, self, artists)
  File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/image.py", line 131, in _draw_list_compositing_images
    a.draw(renderer)
  File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/artist.py", line 41, in draw_wrapper
    return draw(artist, renderer, *args, **kwargs)
  File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/image.py", line 646, in draw
    renderer.draw_image(gc, l, b, im)
TypeError: Cannot cast array data from dtype('float64') to dtype('uint8') according to the rule 'safe'

似乎错误与 imshow 和自定义颜色图有关,因为我可以在没有 sklearn 的情况下重现:

fig, ax = plt.subplots()
ax.imshow(np.array([[10, 15], [20, 30]]), cmap=cmap_nonlin)

有什么想法吗?如果可能的话,我希望修改颜色图而不是数据本身。

根据 LinearSegmentedColormaps 上的 matplotlib's doc,可以执行以下操作来改变快速变化段和慢速变化段之间的对比度。

在这种情况下,为了回答我的问题,让我们在 0 到 50 之间设置一个比 50 到 100 之间更精细的范围,但我的解决方案可以通过更改级别扩展到任意数量的不同节奏段:

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as colors
# A dict with {percentage_of_max_value: percentage_of_variation}. The keys are thus all < 1. and should be in ascending order alongside associated values in the colormap (also ordered and < 1.).
# In this example we have 90% of the variation of the colormap in its first half (until 0.5) and the remaining 10% in its right half
levels = {0.5: 0.9}
# We are not limited to one segment and we can provide for instance the following dict
# levels = {0.4:0.8, 0.5:0.9} to have 80% of variations between 0 and 40% of the colormap max then 10% between 40 and 50% and then the remaining 10% for the rest
cdict = {"red": None, "green": None, "blue": None}
num_values_per_segment = 50
for k, v in cdict.items():
    cdict[k] = []
    # We start the first segment by 0. both for value and cmap_value
    left_val = 0.
    left_cmap_val = 0.
    for val, cmap_val in levels.items():
        values = np.linspace(left_val, val, num_values_per_segment).tolist()
        dynamic_range = np.linspace(left_cmap_val, cmap_val, num_values_per_segment).tolist()
        for i, (v, r) in enumerate(zip(values, dynamic_range)):
            cdict[k].append((v, r, r))
        left_val = val
        left_cmap_val = cmap_val
    # Last segment towards 1.
    values = np.linspace(val, 1., num_values_per_segment).tolist()
    dynamic_range = np.linspace(cmap_val, 1., num_values_per_segment).tolist()
    for i, (v, r) in enumerate(zip(values, dynamic_range)):
        cdict[k].append((v, r, r))
    
# Mapping levels to colormap
cmap = plt.cm.viridis
for k, v in cdict.items():
    if k == "red":
        for i in range(len(v)):
            cdict[k][i] = (v[i][0], cmap(v[i][1])[0], cmap(v[i][2])[0])
    elif k == "green":
        for j in range(len(v)):
            cdict[k][j] = (v[j][0], cmap(v[j][1])[1], cmap(v[j][2])[1])
    elif k == "blue":
        for l in range(len(v)):
            cdict[k][l] = (v[l][0], cmap(v[l][1])[2], cmap(v[l][2])[2])
    else:
        raise ValueError("Color not recognized")
    cdict[k] = tuple(cdict[k])

cmap_nonlin = colors.LinearSegmentedColormap('MyCustomCMap', cdict)
fig, ax = plt.subplots()
my_image = np.array([[30, 45], [25, 10]])
confusion = ax.imshow(my_image, cmap=cmap_nonlin, vmin=0, vmax=100)
plt.colorbar(confusion, ax=ax)
plt.waitforbuttonpress()

并且生成的 cmap_nonlin 对象可以与 imshow 结合使用,没有任何问题: