在PyQt中根据BB坐标选择特定区域

Selecting certain region based on BB coordinates in PyQt

我在 Pyqt 中创建了一个简单的 GUI,用于使用按钮上传包含图像路径和边界框坐标值的 CSV 文件。它还有另一个按钮,可以转到下一张图像。标签区域显示图像,其中对象周围有边界框,如下所示。

现在我想为带有边界框的对象指定一些名称。为此,我有另一个按钮。但是当图像中有多个对象时,我想单击其中一个边界框然后分配相同的对象。但我正在努力使这个边界框区域可点击。

我看过点击图像时获取像素值或 (x, y) 的例子,但这一个对我来说似乎很难。

下面是相同的代码。

同样的代码如下。

from PyQt5 import QtGui, QtWidgets
from PyQt5.QtWidgets import QFileDialog
from PyQt5.QtWidgets import QApplication
import csv
from pygui import Ui_MainWindow
from collections import namedtuple
import sys
import cv2

Row = namedtuple('Row', ('image_path', 'x', 'y', 'w', 'h'))

class mainProgram(QtWidgets.QMainWindow, Ui_MainWindow):

    def __init__(self, parent=None):

        super(mainProgram, self).__init__(parent)
        self.setupUi(self)
        self.data=None

    def all_callbacks(self):
        # Open directory callback
        self.Upload.clicked.connect(self.on_click_upload)
        # Next button callback
        self.Next.clicked.connect(self.on_click_next)

    def convert_cv_image_to_qt(self, cv_img):
        rgb_image = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)
        h, w, ch = rgb_image.shape
        bytes_per_line = ch * w
        convert_to_Qt_format = QtGui.QImage(rgb_image.data, w, h, bytes_per_line, QtGui.QImage.Format_RGB888)
        return QtGui.QPixmap.fromImage(convert_to_Qt_format)

    def draw_bb_on_image(self, image_data, color=(0, 0, 255), thickness=2):
        self.image_path = image_data.image_path
        self.x, self.y = int(image_data.x), int(image_data.y)
        self.w, self.h = int(image_data.w), int(image_data.h)
        image = cv2.imread(self.image_path)
        output_img = cv2.rectangle(image, (self.x, self.y), (self.x+self.w, self.y+self.h), color, thickness)
        qimage = self.convert_cv_image_to_qt(output_img)
        return qimage

    def on_click_upload(self):

        dialog = QFileDialog()
        csv_file = dialog.getOpenFileName(None, "Import CSV", "", "CSV data files (*.csv)")

        try:
            with open(csv_file[0]) as fp:
                reader = csv.reader(fp, delimiter=',')
                data = [Row(*r) for r in reader]

        except PermissionError:
            print("You don't seem to have the rights to open the file")

        if 0 == len(data):
            print("File is empty, select another file")
            return

        self.count = 0
        self.data = data
        upload_image = self.draw_bb_on_image(data[0])
        self.label.setPixmap(upload_image)
        self.label.show()

    def next_image(self, offset=1):
        if self.data is None:
            return
        self.count = (self.count + offset) % len(self.data)
        next_image = self.draw_bb_on_image(self.data[self.count])
        self.label.setPixmap(next_image)
        self.label.show()

    def on_click_next(self):
        self.next_image(offset=1)

    def on_click_previous(self):
        self.next_image(offset=-1)


def execute_pipeline():
    app = QApplication(sys.argv)

    annotationGui = mainProgram()
    annotationGui.show()
    annotationGui.all_callbacks()

    # Exit the window
    sys.exit(app.exec_())


if __name__ == "__main__":
    execute_pipeline()

我想为对象指定一个名称。为此,我想让这个边界框区域可点击。

由于描述与实现不符,我将对要求进行一些修改。由于有几个边界框,所以 .csv 格式会对其进行限制,所以我将使用 .json.

与其使用 QLabel,不如使用 QGraphicsPixmapItem,因为它允许获取点击位置,绘制边界框我使用 QGraphicsRectItem。

考虑到上述情况,json必须具有以下结构:

[
    {
        "filename": "/path/of/filename1.png",
        "boxes": [
            {
                "width": 100,
                "x": 10,
                "y": 10,
                "name": "foo",
                "height": 100
            },
            {
                "width": 100,
                "x": 110,
                "y": 110,
                "name": "bar",
                "height": 100
            }
        ]
    },
    {
        "filename": "/path/of/filename2.png",
        "boxes": [
            {
                "width": 800,
                "x": 30,
                "y": 50,
                "name": "baz",
                "height": 200
            }
        ]
    }
]
from functools import cached_property
import json
import random
from typing import List

from dataclasses import dataclass
from dataclasses_json import dataclass_json

from PyQt5.QtCore import QRectF, Qt
from PyQt5.QtGui import QBrush, QColor, QPainter, QPalette, QPen, QPixmap
from PyQt5.QtWidgets import (
    QApplication,
    QFileDialog,
    QGraphicsPixmapItem,
    QGraphicsRectItem,
    QGraphicsScene,
    QGraphicsView,
    QGridLayout,
    QInputDialog,
    QLineEdit,
    QMainWindow,
    QPushButton,
    QWidget,
)

KEY_INDEX = 0


@dataclass_json
@dataclass
class Box:
    x: int
    y: int
    width: int
    height: int
    name: str = ""

    def to_rect(self):
        return QRectF(self.x, self.y, self.width, self.height)


@dataclass_json
@dataclass
class ImageItem:
    filename: str
    boxes: List[Box]


def load_items(filename):
    with open(filename, "r") as fp:
        return ImageItem.schema().loads(fp.read(), many=True)


def save_items(items, filename):
    with open(filename, "w") as fp:
        fp.write(ImageItem.schema().dumps(items, many=True))


class BoxItem(QGraphicsRectItem):
    def __init__(self, parent_item):
        super().__init__(parent_item)
        self.setAcceptHoverEvents(True)

    def hoverEnterEvent(self, event):
        self.setBrush(QColor(255, 0, 0, 100))
        super().hoverEnterEvent(event)

    def hoverLeaveEvent(self, event):
        self.setBrush(QBrush(Qt.NoBrush))
        super().hoverLeaveEvent(event)


class ImageViewer(QGraphicsView):
    def __init__(self, parent=None):
        super().__init__(parent)
        self._image_items = list()
        self._current_index = -1

        self.setRenderHints(QPainter.Antialiasing | QPainter.SmoothPixmapTransform)
        self.setAlignment(Qt.AlignCenter)
        self.setBackgroundRole(QPalette.Dark)

        scene = QGraphicsScene()
        self.setScene(scene)

        self._pixmap_item = QGraphicsPixmapItem()
        scene.addItem(self._pixmap_item)

    @property
    def image_items(self):
        return self._image_items

    @image_items.setter
    def image_items(self, items):
        self._image_items.clear()
        self._current_index = -1
        self._image_items.extend(items)
        if items:
            self._current_index = 0
        self._load_image_item()

    @property
    def current_image_item(self):
        if 0 <= self._current_index < len(self.image_items):
            return self.image_items[self._current_index]

    @property
    def current_index(self):
        return self._current_index

    def previous_item(self):
        self._current_index = max(self._current_index - 1, 0)
        self._load_image_item()

    def next_item(self):
        self._current_index = min(self._current_index + 1, len(self.image_items) - 1)
        self._load_image_item()

    def _fit_to_window(self):
        self.setSceneRect(self.scene().itemsBoundingRect())
        self.fitInView(self.sceneRect(), Qt.KeepAspectRatio)

    def _load_image_item(self):
        image_item = self.current_image_item
        if image_item:
            for child_item in self._pixmap_item.childItems():
                if isinstance(child_item, BoxItem):
                    child_item.setParentItem(None)
            self._pixmap_item.setPixmap(QPixmap(image_item.filename))
            for i, box in enumerate(image_item.boxes):
                rect_item = BoxItem(self._pixmap_item)
                rect_item.setRect(box.to_rect())
                rect_item.setPen(QPen(QColor(*random.sample(range(255), 3)), 5))
                rect_item.setData(KEY_INDEX, i)
        else:
            for child_item in self._pixmap_item.childItems():
                if isinstance(child_item, BoxItem):
                    child_item.setParentItem(None)
            self._pixmap_item.setPixmap(QPixmap())
        self._fit_to_window()

    def mousePressEvent(self, event):
        super().mousePressEvent(event)
        sp = self.mapToScene(event.pos())
        items = self.scene().items(sp)
        if not items:
            return
        item = items[0]
        if not isinstance(item, BoxItem):
            return
        i = item.data(KEY_INDEX)
        box_item = self.current_image_item.boxes[i]
        text, ok = QInputDialog.getText(
            self,
            self.tr("Change name"),
            self.tr("Name:"),
            QLineEdit.Normal,
            box_item.name,
        )
        if ok:
            box_item.name = text

    def resizeEvent(self, event):
        super().resizeEvent(event)
        self._fit_to_window()


class MainWindow(QMainWindow):
    def __init__(self, parent=None):
        super().__init__(parent)

        container = QWidget()
        self.setCentralWidget(container)
        layout = QGridLayout(container)
        layout.addWidget(self.load_button, 0, 0, 1, 2)
        layout.addWidget(self.previous_button, 1, 0)
        layout.addWidget(self.next_button, 1, 1)
        layout.addWidget(self.viewer, 2, 0, 1, 2)

        self.previous_button.setEnabled(False)
        self.next_button.setEnabled(False)
        self.load_button.clicked.connect(self.handle_load_button_clicked)
        self.previous_button.clicked.connect(self.handle_previous_button_clicked)
        self.next_button.clicked.connect(self.handle_next_button_clicked)

    @cached_property
    def load_button(self):
        return QPushButton("Load")

    @cached_property
    def previous_button(self):
        return QPushButton("Previous")

    @cached_property
    def next_button(self):
        return QPushButton("Next")

    @cached_property
    def viewer(self):
        return ImageViewer()

    def handle_load_button_clicked(self):
        filename, _ = QFileDialog.getOpenFileName(
            None, "Import JSON", "", "JSON data files (*.json)"
        )
        if filename:
            self.viewer.image_items = load_items(filename)
        self._update_buttons()

    def handle_previous_button_clicked(self):
        self.viewer.previous_item()
        self._update_buttons()

    def handle_next_button_clicked(self):
        self.viewer.next_item()
        self._update_buttons()

    def _update_buttons(self):
        self.previous_button.setEnabled(self.viewer.current_index > 0)
        self.next_button.setEnabled(
            self.viewer.current_index < (len(self.viewer.image_items) - 1)
        )

    def closeEvent(self, event):
        super().closeEvent(event)
        filename, _ = QFileDialog.getSaveFileName(
            None, "Import JSON", "", "JSON data files (*.json)"
        )
        if filename:
            save_items(self.viewer.image_items, filename)


def main():
    import sys

    app = QApplication(sys.argv)

    view = MainWindow()
    view.resize(640, 480)
    view.show()

    ret = app.exec()
    sys.exit(ret)


if __name__ == "__main__":
    main()