是什么让 pytorch 中的预训练模型对图像进行错误分类
What makes a pre-trained model in pytorch misclassify an image
我在 cifar-10 数据集上成功训练了 Data Efficient Image Transformer (deit),准确率约为 95%。不过并保存起来以备后用。我创建了一个单独的 class 来加载模型并仅对一张图像进行推理。每次我 运行 它都会得到不同的预测值。
import torch
from models.deit import deit_small_patch16_224
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision.transforms import transforms as transforms
class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
model = deit_small_patch16_224(pretrained=True, use_top_n_heads=8, use_patch_outputs=False)
checkpoint = torch.load("./checkpoint/deit224.t7")
model.load_state_dict(checkpoint, strict=False)
model.head = torch.nn.Linear(in_features=model.head.in_features, out_features=10)
model.eval()
img = Image.open("cats.jpeg")
img_tensor = torch.tensor(np.array(img))/255.0
img_tensor = img_tensor.unsqueeze(0).permute(0, 3, 1, 2)
# print(img_tensor.shape)
with torch.no_grad():
output = model(img_tensor)
predicted_class = np.argmax(output)
print(predicted_class)
是的,找出错误。更新了下面的代码
import torch
from models.deit import deit_small_patch16_224
from torch.utils.data import dataset
import torchvision.datasets
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision.transforms import transforms as transforms
class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
model = deit_small_patch16_224(pretrained=True, use_top_n_heads=8, use_patch_outputs=False)
checkpoint = torch.load("./checkpoint/deit224.t7")
state_dict = checkpoint["model"]
new_state_dict = {}
for key in state_dict:
new_key = '.'.join(key.split('.')[1:])
new_state_dict[new_key] = state_dict[key]
model.head = torch.nn.Linear(in_features=model.head.in_features, out_features=10)
model.load_state_dict(new_state_dict)
model.eval()
img = Image.open("cats.jpeg")
trans = transforms.ToTensor()
# img_tensor = torch.tensor(np.array(img, dtype=np.float64))/255.0
img_tensor = torch.tensor(np.array(img))/255.0
# img_tensor = torch.tensor(np.array(img))
img_tensor = img_tensor.unsqueeze(0).permute(0, 3, 1, 2)
# print(img_tensor.shape)
with torch.no_grad():
output = model(img_tensor)
predicted_class = np.argmax(output)
print(predicted_class)
我在 cifar-10 数据集上成功训练了 Data Efficient Image Transformer (deit),准确率约为 95%。不过并保存起来以备后用。我创建了一个单独的 class 来加载模型并仅对一张图像进行推理。每次我 运行 它都会得到不同的预测值。
import torch
from models.deit import deit_small_patch16_224
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision.transforms import transforms as transforms
class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
model = deit_small_patch16_224(pretrained=True, use_top_n_heads=8, use_patch_outputs=False)
checkpoint = torch.load("./checkpoint/deit224.t7")
model.load_state_dict(checkpoint, strict=False)
model.head = torch.nn.Linear(in_features=model.head.in_features, out_features=10)
model.eval()
img = Image.open("cats.jpeg")
img_tensor = torch.tensor(np.array(img))/255.0
img_tensor = img_tensor.unsqueeze(0).permute(0, 3, 1, 2)
# print(img_tensor.shape)
with torch.no_grad():
output = model(img_tensor)
predicted_class = np.argmax(output)
print(predicted_class)
是的,找出错误。更新了下面的代码
import torch
from models.deit import deit_small_patch16_224
from torch.utils.data import dataset
import torchvision.datasets
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision.transforms import transforms as transforms
class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
model = deit_small_patch16_224(pretrained=True, use_top_n_heads=8, use_patch_outputs=False)
checkpoint = torch.load("./checkpoint/deit224.t7")
state_dict = checkpoint["model"]
new_state_dict = {}
for key in state_dict:
new_key = '.'.join(key.split('.')[1:])
new_state_dict[new_key] = state_dict[key]
model.head = torch.nn.Linear(in_features=model.head.in_features, out_features=10)
model.load_state_dict(new_state_dict)
model.eval()
img = Image.open("cats.jpeg")
trans = transforms.ToTensor()
# img_tensor = torch.tensor(np.array(img, dtype=np.float64))/255.0
img_tensor = torch.tensor(np.array(img))/255.0
# img_tensor = torch.tensor(np.array(img))
img_tensor = img_tensor.unsqueeze(0).permute(0, 3, 1, 2)
# print(img_tensor.shape)
with torch.no_grad():
output = model(img_tensor)
predicted_class = np.argmax(output)
print(predicted_class)