我如何使用 Pytorch 为每个 class 保存图像? (无网格)
How do i save images for each class using Pytorch? (no grid)
我正在使用 acgan 进行图像增强。目前,示例图像以网格格式生成。但我想分别为每个 class 保存图像。 (例如 1.png; 2.png ...)我应该如何修改这段代码?或者有没有我想参考的答案?
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)
self.init_size = opt.img_size // 4 # Initial size before upsampling
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
..
generator = Generator()
..
def sample_image(n_row, batches_done):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
在 def sample_image
中,您有一行定义生成器的目标标签:
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
.
不要使用由于从范围内采样而改变的 num,而是使用您传递的常量作为参数(下面的 class_id
):
def sample_image(n_row, batches_done, class_id):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([class_id for _ in range(n_row) for __ in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
这样您将得到一个矩形数组,其中包含您请求的 class 图片。
此外,如果只有一张图片,您可以将 n_row
设置为 1。请注意,您没有提供 save_image
函数的代码,可能会有一些技巧。
我正在使用 acgan 进行图像增强。目前,示例图像以网格格式生成。但我想分别为每个 class 保存图像。 (例如 1.png; 2.png ...)我应该如何修改这段代码?或者有没有我想参考的答案?
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)
self.init_size = opt.img_size // 4 # Initial size before upsampling
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
..
generator = Generator()
..
def sample_image(n_row, batches_done):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
在 def sample_image
中,您有一行定义生成器的目标标签:
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
.
不要使用由于从范围内采样而改变的 num,而是使用您传递的常量作为参数(下面的 class_id
):
def sample_image(n_row, batches_done, class_id):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([class_id for _ in range(n_row) for __ in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
这样您将得到一个矩形数组,其中包含您请求的 class 图片。
此外,如果只有一张图片,您可以将 n_row
设置为 1。请注意,您没有提供 save_image
函数的代码,可能会有一些技巧。