带子图的 Pyplot 动画返回空白 canvas

Pyplot animation with subplot returnn a blank canvas

我试图从 Fashion MNIST 数据集中随机生成 25 张图像的动画。

然而,当我试图这样做时,我的代码只有 returns 一个空白 canvas。

import tensorflow as tf
from matplotlib.animation import FuncAnimation
from matplotlib import pyplot as plt
import numpy as np

# these code are from the first TensorFlow official tutorial.
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images,
                               test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# create a 10-inch canvas
fig = plt.figure(figsize=(10, 10))

def draw25Random():
    # show 25 random images from the 60k training set
    ris = np.random.randint(60000, size=25)
    ris.sort()
    for i in range(25):
        ri = ris[i]
        plt.subplot(5, 5, i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(train_images[ri], cmap=plt.cm.binary)
        plt.xlabel(f'{ri} // {class_names[train_labels[ri]]}')

anim = FuncAnimation(fig,
                     draw25Random,
                     interval=1000)

(渲染25张图部分不错)

有什么想法吗?

对于 matplotlib.animation.FuncAnimation 的工作方式,您需要向它传递一个参数 i,这是一个每帧增加 1 的计数器。所以你不需要在 draw25Random 函数中循环 for i in range(25)

def draw25Random(i):
    # show 25 random images from the 60k training set
    ri = ris[i]
    plt.subplot(5, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[ri], cmap=plt.cm.binary)
    plt.xlabel(f'{ri} // {class_names[train_labels[ri]]}')

完整代码

import tensorflow as tf
from matplotlib.animation import FuncAnimation
from matplotlib import pyplot as plt
import numpy as np

# these code are from the first TensorFlow official tutorial.
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images,
                               test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

ris = np.random.randint(60000, size = 25)
ris.sort()

# create a 10-inch canvas
fig = plt.figure(figsize=(10, 10))

def draw25Random(i):
    # show 25 random images from the 60k training set
    ri = ris[i]
    plt.subplot(5, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[ri], cmap=plt.cm.binary)
    plt.xlabel(f'{ri} // {class_names[train_labels[ri]]}')

anim = FuncAnimation(fig,
                     draw25Random,
                     interval=100,
                     frames = 25)

plt.show()