Matplotlib ValueError: num must be 1 <= num <= 20, not 0

我正在学习 Building Autoencoders in Keras 关于 MNIST 手写数字的教程。 下面是代码:

input_img = Input(shape=(28, 28, 1))  # adapt this if using `channels_first` image data format

x = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)

# at this point the representation is (4, 4, 8) i.e. 128-dimensional

x = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

加载 Mnist 数据集并训练我们的模型后,我们将在此处绘制原始图像和重建图像

decoded_imgs = autoencoder.predict(x_test)

n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i)
    plt.imshow(x_test[i].reshape(28, 28))

    # display reconstruction
    ax = plt.subplot(2, n, i + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))



ValueError                                Traceback (most recent call last)

<ipython-input-35-d0a536786436> in <module>()
      5 for i in range(n):
      6     # display original
----> 7     ax = plt.subplot(2, n, i)
      8     plt.imshow(x_test[i].reshape(28, 28))
      9     plt.gray()

2 frames

/usr/local/lib/python3.6/dist-packages/matplotlib/axes/ in __init__(self, fig, *args, **kwargs)
     64                 if num < 1 or num > rows*cols:
     65                     raise ValueError(
---> 66                         f"num must be 1 <= num <= {rows*cols}, not {num}")
     67                 self._subplotspec = GridSpec(
     68                         rows, cols, figure=self.figure)[int(num) - 1]

ValueError: num must be 1 <= num <= 20, not 0

在第一个循环中,i==0 因为 range(10)[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]。您不能将 0 用作子图的索引,这会导致该错误。您应该在 plt.subplot() 中使用 i+1 以获得正确的轴。