Pytorch Tensorboard SummaryWriter.add_video() 产生不良视频

Pytorch Tensorboard SummaryWriter.add_video() Produces Bad Videos

我正在尝试通过生成一系列 500 个 matplotlib 图来创建视频,将每个图转换为一个 numpy 数组,堆叠它们,然后将它们传递给 SummaryWriter() 的 add_video()。当我这样做时,颜色条从彩色转换为黑白,并且只有少量(~3-4)的 matplotlib 图被重复。我通过使用它们重新创建 matplotlib 图来确认我的 numpy 数组是正确的。

我的输入张量的形状为 (B,C,T,H,W),dtype np.uint8,值在 [0, 255] 之间。

下面的最小工作示例。需要明确的是,代码运行没有任何错误。我的问题是生成的视频错误

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter


tensorboard_writer = SummaryWriter()
print(tensorboard_writer.get_logdir())


def fig2data(fig):

    # draw the renderer
    fig.canvas.draw()

    # Get the RGB buffer from the figure
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    return data


size = 500
x = np.random.uniform(0, 2., size=500)
y = np.random.uniform(0, 2., size=500)
trajectory_len = len(x)
trajectory_indices = np.arange(trajectory_len)
width, height = 3, 2

# tensorboard takes video of shape (B,C,T,H,W)
video_array = np.zeros(
    shape=(1, 3, trajectory_len, height*100, width*100),
    dtype=np.uint8)

for trajectory_idx in trajectory_indices:

    fig, axes = plt.subplots(
        1,
        2,
        figsize=(width, height),
        gridspec_kw={'width_ratios': [1, 0.05]})
    fig.suptitle('Example Trajectory')
    # plot the first trajectory
    sc = axes[0].scatter(
        x=[x[trajectory_idx]],
        y=[y[trajectory_idx]],
        c=[trajectory_indices[trajectory_idx]],
        s=4,
        vmin=0,
        vmax=trajectory_len,
        cmap=plt.cm.jet)

    axes[0].set_xlim(-0.25, 2.25)
    axes[0].set_ylim(-0.25, 2.25)

    colorbar = fig.colorbar(sc, cax=axes[1])
    colorbar.set_label('Trajectory Index Number')

    # extract numpy array of figure
    data = fig2data(fig)

    # UNCOMMENT IF YOU WANT TO VERIFY THAT THE NUMPY ARRAY WAS CORRECTLY EXTRACTED
    # plt.show()
    # fig2 = plt.figure()
    # ax2 = fig2.add_subplot(111, frameon=False)
    # ax2.imshow(data)
    # plt.show()

    # close figure to save memory
    plt.close(fig=fig)

    video_array[0, :, trajectory_idx, :, :] = np.transpose(data, (2, 0, 1))

# tensorboard takes video_array of shape (B,C,T,H,W)
tensorboard_writer.add_video(
    tag='sampled_trajectory',
    vid_tensor=torch.from_numpy(video_array),
    global_step=0,
    fps=4)

print('Added video')

tensorboard_writer.close()

根据pytorch docs,视频张量应该有形状(N,T,C,H,W),我认为它的意思是:batch, time, channels, height and width。你说你的张量有形状(B,C,T,H,W)。所以看起来你的频道和时间轴被交换了。