由于调色板导致 matplotlib 散点图出错

Error with matplotlib scatter plot due to color palette

我使用这种方法为 mnist 数据集的另一个模型创建了一个散点图,它对另一个模型工作正常,但我无法弄清楚我在另一个模型上做错了什么。

方法是

def scatter(x, labels, subtitle=None):
    # Create a scatter plot of all the 
    # the embeddings of the model.
    # We choose a color palette with seaborn.
    palette = np.array(sns.color_palette("hls", 10))
    # We create a scatter plot.
    f = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(x[:,0], x[:,1], lw=0,alpha = 0.5, s=40,
                c=palette[labels.astype(np.int)])
    plt.xlim(-25, 25)
    plt.ylim(-25, 25)
    ax.axis('off')
    ax.axis('tight')

我使用它来使用来自 keras 的 mnist 数据集为绘图创建数据

# Using the newly trained model compute the embeddings 
# for a number images
sample_size = 5000
X_train_trm = model.predict(X_train[:sample_size].reshape(-1,28,28,1))
X_test_trm = model.predict(X_test[:sample_size].reshape(-1,28,28,1))
# TSNE to use dimensionality reduction to visulaise the resultant embeddings
tsne = TSNE()
train_tsne_embeds = tsne.fit_transform(X_train_trm)
scatter(train_tsne_embeds, y_train[:sample_size])

这会给出这个错误,当我检查调色板和 c 的大小时我不明白它应该是 5000 而不是 150000。 错误是这样的

ValueError: 'c' argument has 150000 elements, which is inconsistent with 'x' and 'y' with size 5000.

经过一些谷歌搜索和 运行 进入一些死机后,我发现了问题所在。 我发布的代码工作正常。我使用的标签使用

转换为分类标签
y_train = keras.utils.to_categorical(y_train, 10)

这就是错误包含 150000 个元素的原因。这都是单热编码。 为了解决这个问题,我在开始时复制了标签,然后对它们进行了 onehot 编码。

# convert class vectors to binary class matrices
Y_train_raw = y_train

y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

然后在散点图中使用原始标签

scatter(train_tsne_embeds, Y_train_raw[:sample_size])