pytorch从视频中读取帧进行图像分类,都预测同一个类别

pytorch reads frames from video for image classification, all predicting the same category

我有3类鱼图片,每类1000张图片,我训练模型的时候,分类准确率是97%,用文件夹图片做预测,准确率没问题。

但是当我把图片换成视频,从视频中切出每一帧图像进行分类时,无论是哪个类别的视频,所有帧都被预测为类别1:"stingray"。为什么?

#!/usr/bin/env python
# coding: utf-8

import torch
from torchvision import transforms
import torchvision.models as models
import cv2
import torch.nn.functional as F


CLASSES = {0:"goldfish", 1:"stingray", 2:"tench"}
BATCH_SIZE = 4
IMG_SIZE = (400, 400)
TRANSFORM_IMG = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(IMG_SIZE),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225] )
    ])

# model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.vgg19_bn(pretrained=False, num_classes=3)
model.to(device)
model.load_state_dict(torch.load('checkpoint.pt'))
model.eval()


videoCapture = cv2.VideoCapture(r'D:/video/Goldfish.mp4')
fps = videoCapture.get(cv2.CAP_PROP_FPS)
size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))

ps = 25
fourcc = cv2.VideoWriter_fourcc(*'DIVX')
videoWriter = cv2.VideoWriter("D:/goldfish.mp4", fourcc, fps, size)

with torch.no_grad():
    success, frame = videoCapture.read()
    while success:
        # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image_tensor = TRANSFORM_IMG(frame)
        image_tensor = image_tensor.unsqueeze(0) 
        test_input = image_tensor.to(device)
        outputs = model(test_input)
        _, predicted = torch.max(outputs, 1)
        probability =  F.softmax(outputs, dim=1)
        top_probability, top_class = probability.topk(1, dim=1)
        predicted = predicted.cpu().detach().numpy()
        predicted = predicted.tolist()[0]
        label = CLASSES[predicted]
        top_probability = top_probability.cpu().detach().numpy()
        top_probability = top_probability.tolist()[0][0]
        top_probability = '%.2f%%' % (top_probability * 100)
        print(top_probability)
        print(label) #all the label is stingray############################################
        frame = cv2.putText(frame, label+': '+top_probability, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 255), 2)
        videoWriter.write(frame)
        success, frame = videoCapture.read()
    videoWriter.release()

当我使用此代码时,没有 problem.I 将图像通道从 BGR 转换为 RGB for pytorch。

import torch
from torchvision import transforms
import torchvision.models as models
import cv2
import torch.nn.functional as F
import copy

CLASSES = {0:"goldfish", 1:"stingray", 2:"tench"}
BATCH_SIZE = 4
IMG_SIZE = (400, 400)
TRANSFORM_IMG = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(IMG_SIZE),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225] )
    ])

# model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.vgg19_bn(pretrained=False, num_classes=3)
model.to(device)
model.load_state_dict(torch.load('checkpoint.pt'))
model.eval()


videoCapture = cv2.VideoCapture(r'video/Goldfish.mp4')
fps = videoCapture.get(cv2.CAP_PROP_FPS)
size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))

ps = 25
fourcc = cv2.VideoWriter_fourcc(*'DIVX')
videoWriter = cv2.VideoWriter(r"D:/goldfish.mp4", fourcc, fps, size)

with torch.no_grad():
    success, frame = videoCapture.read()
    while success:
        frame_copy = copy.deepcopy(frame) 
        frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
        image_tensor = TRANSFORM_IMG(frame_copy)
        image_tensor = image_tensor.unsqueeze(0) 
        test_input = image_tensor.to(device)
        outputs = model(test_input)
        _, predicted = torch.max(outputs, 1)
        probability =  F.softmax(outputs, dim=1)
        top_probability, top_class = probability.topk(1, dim=1)
        predicted = predicted.cpu().detach().numpy()
        predicted = predicted.tolist()[0]
        label = CLASSES[predicted]
        top_probability = top_probability.cpu().detach().numpy()
        top_probability = top_probability.tolist()[0][0]
        top_probability = '%.2f%%' % (top_probability * 100)
        print(top_probability)
        print(label)
        frame = cv2.putText(frame, label+': '+top_probability, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 255), 2)
        videoWriter.write(frame)
        success, frame = videoCapture.read()
    videoWriter.release()