预训练闪电 VAE 未对训练数据集进行正确推理
Pretrained lightning-bolts VAE not doing proper inference on training dataset
我正在使用来自 lightning-bolts. It should be able to regenerate images with the quality shown on this picture taken from the docs 的 CIFAR-10 预训练 VAE(LHS 是真实图像,RHS 是生成图像)
但是,当我编写一个简单的脚本来加载模型、权重并在 training 集上对其进行测试时,我得到了更糟糕的重建(顶行是真实的图片,底行是生成的图片):
这是一个独立的 colab 笔记本 link,它再现了我制作图片所遵循的步骤。
我是不是在推理过程中做错了什么?会不会是权重没有文档说的那么“好”?
谢谢!
首先,您显示的文档中的图像是针对 AE,而不是 VAE。 VAE 的结果看起来更糟:
https://pl-bolts-weights.s3.us-east-2.amazonaws.com/vae/vae-cifar10/vae_output.png
其次,文档指出“输入和生成的图像都是标准化版本,因为训练是使用此类图像完成的。”因此,当您加载数据时,您应该指定 normalize=True。绘制数据时,您还需要 'unnormalize' 数据:
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.autoencoders import VAE
from pytorch_lightning import Trainer
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
torch.manual_seed(17)
np.random.seed(17)
vae = VAE(32, lr=0.00001)
vae = vae.from_pretrained("cifar10-resnet18")
dm = CIFAR10DataModule(".", normalize=True)
dm.prepare_data()
dm.setup("fit")
dataloader = dm.train_dataloader()
print(dm.default_transforms())
mean = torch.tensor(dm.default_transforms().transforms[1].mean)
std = torch.tensor(dm.default_transforms().transforms[1].std)
unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
X, _ = next(iter(dataloader))
vae.eval()
X_hat = vae(X)
fig, axes = plt.subplots(2, 10, figsize=(10, 2))
for i in range(10):
ax_real = axes[0][i]
ax_real.imshow(np.transpose(unnormalize(X[i]), (1, 2, 0)))
ax_real.get_xaxis().set_visible(False)
ax_real.get_yaxis().set_visible(False)
ax_gen = axes[1][i]
ax_gen.imshow(np.transpose(unnormalize(X_hat[i]).detach().numpy(), (1, 2, 0)))
ax_gen.get_xaxis().set_visible(False)
ax_gen.get_yaxis().set_visible(False)
这给出了这样的东西:
没有规范化它看起来像:
我正在使用来自 lightning-bolts. It should be able to regenerate images with the quality shown on this picture taken from the docs 的 CIFAR-10 预训练 VAE(LHS 是真实图像,RHS 是生成图像)
但是,当我编写一个简单的脚本来加载模型、权重并在 training 集上对其进行测试时,我得到了更糟糕的重建(顶行是真实的图片,底行是生成的图片):
这是一个独立的 colab 笔记本 link,它再现了我制作图片所遵循的步骤。
我是不是在推理过程中做错了什么?会不会是权重没有文档说的那么“好”?
谢谢!
首先,您显示的文档中的图像是针对 AE,而不是 VAE。 VAE 的结果看起来更糟:
https://pl-bolts-weights.s3.us-east-2.amazonaws.com/vae/vae-cifar10/vae_output.png
其次,文档指出“输入和生成的图像都是标准化版本,因为训练是使用此类图像完成的。”因此,当您加载数据时,您应该指定 normalize=True。绘制数据时,您还需要 'unnormalize' 数据:
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.autoencoders import VAE
from pytorch_lightning import Trainer
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
torch.manual_seed(17)
np.random.seed(17)
vae = VAE(32, lr=0.00001)
vae = vae.from_pretrained("cifar10-resnet18")
dm = CIFAR10DataModule(".", normalize=True)
dm.prepare_data()
dm.setup("fit")
dataloader = dm.train_dataloader()
print(dm.default_transforms())
mean = torch.tensor(dm.default_transforms().transforms[1].mean)
std = torch.tensor(dm.default_transforms().transforms[1].std)
unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
X, _ = next(iter(dataloader))
vae.eval()
X_hat = vae(X)
fig, axes = plt.subplots(2, 10, figsize=(10, 2))
for i in range(10):
ax_real = axes[0][i]
ax_real.imshow(np.transpose(unnormalize(X[i]), (1, 2, 0)))
ax_real.get_xaxis().set_visible(False)
ax_real.get_yaxis().set_visible(False)
ax_gen = axes[1][i]
ax_gen.imshow(np.transpose(unnormalize(X_hat[i]).detach().numpy(), (1, 2, 0)))
ax_gen.get_xaxis().set_visible(False)
ax_gen.get_yaxis().set_visible(False)
这给出了这样的东西:
没有规范化它看起来像: