Torchvision.transforms Flatten() 的实现
Torchvision.transforms implementation of Flatten()
我有灰度图像,但我需要将其转换为一维向量数据集
我怎样才能做到这一点?我在转换中找不到合适的方法:
train_dataset = torchvision.datasets.ImageFolder(root='./data',train=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.ImageFolder(root='./data',train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=4, shuffle=False)
以下是使用 Lambda
的方法
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as T
# without flatten
dataset = MNIST(root='.', download=True, transform=T.ToTensor())
print(dataset[0][0].shape)
# >>> torch.Size([1, 28, 28])
# with flatten (using Lambda, but you can do it in many other ways)
dataset_flatten = MNIST(root='.', download=True, transform=T.Compose([T.ToTensor(), T.Lambda(lambda x: torch.flatten(x))]))
print(dataset_flatten[0][0].shape)
# >>> torch.Size([784])
这个lambda
好像没有必要,会提出一个PyLint unnecessary-lambda / W0108 warning。
这个版本的@Berriel 解决方案因此更加精确:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as T
# without flatten
dataset = MNIST(root='.', download=True, transform=T.ToTensor())
print(dataset[0][0].shape)
# >>> torch.Size([1, 28, 28])
# with flatten (using Lambda, but you can do it in many other ways)
dataset_flatten = MNIST(root='.', download=True,
transform=T.Compose([T.ToTensor(), T.Lambda(torch.flatten)]))
print(dataset_flatten[0][0].shape)
# >>> torch.Size([784])
我有灰度图像,但我需要将其转换为一维向量数据集 我怎样才能做到这一点?我在转换中找不到合适的方法:
train_dataset = torchvision.datasets.ImageFolder(root='./data',train=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.ImageFolder(root='./data',train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=4, shuffle=False)
以下是使用 Lambda
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as T
# without flatten
dataset = MNIST(root='.', download=True, transform=T.ToTensor())
print(dataset[0][0].shape)
# >>> torch.Size([1, 28, 28])
# with flatten (using Lambda, but you can do it in many other ways)
dataset_flatten = MNIST(root='.', download=True, transform=T.Compose([T.ToTensor(), T.Lambda(lambda x: torch.flatten(x))]))
print(dataset_flatten[0][0].shape)
# >>> torch.Size([784])
这个lambda
好像没有必要,会提出一个PyLint unnecessary-lambda / W0108 warning。
这个版本的@Berriel 解决方案因此更加精确:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as T
# without flatten
dataset = MNIST(root='.', download=True, transform=T.ToTensor())
print(dataset[0][0].shape)
# >>> torch.Size([1, 28, 28])
# with flatten (using Lambda, but you can do it in many other ways)
dataset_flatten = MNIST(root='.', download=True,
transform=T.Compose([T.ToTensor(), T.Lambda(torch.flatten)]))
print(dataset_flatten[0][0].shape)
# >>> torch.Size([784])