如何动态更新现有的 seaborn 热图数据

How to dynamicly update existing seaborn heatmap data

我正在用 PyQt5 构建一个应用程序。我想使用 seaborn 可视化矩阵并在矩阵数据更改时更新 seaborn 创建的热图。我这样创建原始图:

from matplotlib.figure import Figure
import numpy as np
import seaborn as sns
class MplWidget_pcolormesh(QWidget):

    def __init__(self, parent=None):
        QWidget.__init__(self, parent)

        self.canvas = FigureCanvas(Figure())
        self.clb = []
        self.plot = []


        vertical_layout = QVBoxLayout()
        vertical_layout.addWidget(self.canvas)

        self.axes = self.canvas.figure.add_subplot(111)
        self.setLayout(vertical_layout)
        self.canvas.toolbar = NavigationToolbar(self.canvas, self)
        self.layout().addWidget(self.canvas.toolbar)
        self.layout().addWidget(self.canvas)
        X = np.array([[0, 1], [1, 0]])
        self.plot = sns.heatmap(X, cmap='PuBu', square=True, linewidth=0.1, linecolor=(0.1, 0.2, 0.2),
                                ax=self.axes, vmin=0, vmax=1)

这只会创建一个稍后更新的简单热图:
简单的热图开始于:

然后,在另一个 py 文件中,我想像这样更新它:

def matrixPlot(self, selected):
    X = ... # generating new data
    self.view.MplWidget.axes.clear()
    self.view.MplWidget.plot = sns.heatmap(X, cmap='PuBu', square=True, linewidth=0.1, linecolor=(0.1, 0.2, 0.2),
                            ax=self.view.MplWidget.axes, vmin=0, vmax=1)
    self.view.MplWidget.canvas.draw()

它确实更新了情节,但它总是重新生成颜色条:

几次更新后的热图:

现在我尝试保存轴对象,这样我就可以更新它们,但是当我尝试像这样创建子图时:

fig, (ax, cbarax) = self.axes = self.canvas.figure.add_subplot(111)

我收到错误:

TypeError: cannot unpack non-iterable AxesSubplot object

我如何创建热图,以便稍后更新内容和颜色条值,而无需创建多个颜色条?


编辑 (@eyllanesc):
您可以通过构建此项目获得相同的结果:

simpleMatrixGUI.py

# -*- coding: utf-8 -*-

# Form implementation generated from reading ui file 'simpleMatrixGUI.ui'
#
# Created by: PyQt5 UI code generator 5.15.4
#
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again.  Do not edit this file unless you know what you are doing.


from PyQt5 import QtCore, QtGui, QtWidgets


class Ui_MainWindow(object):
    def setupUi(self, MainWindow):
        MainWindow.setObjectName("MainWindow")
        MainWindow.resize(800, 600)
        self.centralwidget = QtWidgets.QWidget(MainWindow)
        self.centralwidget.setObjectName("centralwidget")
        self.verticalLayout = QtWidgets.QVBoxLayout(self.centralwidget)
        self.verticalLayout.setObjectName("verticalLayout")
        self.MplWidget = MplWidget(self.centralwidget)
        self.MplWidget.setObjectName("MplWidget")
        self.verticalLayout.addWidget(self.MplWidget)
        self.randMatrixGenerator = QtWidgets.QPushButton(self.centralwidget)
        self.randMatrixGenerator.setObjectName("randMatrixGenerator")
        self.verticalLayout.addWidget(self.randMatrixGenerator)
        MainWindow.setCentralWidget(self.centralwidget)
        self.menubar = QtWidgets.QMenuBar(MainWindow)
        self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 26))
        self.menubar.setObjectName("menubar")
        MainWindow.setMenuBar(self.menubar)
        self.statusbar = QtWidgets.QStatusBar(MainWindow)
        self.statusbar.setObjectName("statusbar")
        MainWindow.setStatusBar(self.statusbar)

        self.retranslateUi(MainWindow)
        QtCore.QMetaObject.connectSlotsByName(MainWindow)

    def retranslateUi(self, MainWindow):
        _translate = QtCore.QCoreApplication.translate
        MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
        self.randMatrixGenerator.setText(_translate("MainWindow", "Generate new"))
from mplwidget import MplWidget

Windows.py

from PyQt5.QtWidgets import QMainWindow

from simpleMatrixGUI import Ui_MainWindow


class MainWindow(QMainWindow, Ui_MainWindow):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setupUi(self)

Main.py

import sys
from PyQt5.QtWidgets import QApplication

from MainPresenter import MainPresenter
from Windows import MainWindow


if __name__ == '__main__':
    # Create the application (with optional system arguments)
    app = QApplication(sys.argv)

    # Create model, view and presenter objects
    model = None
    view = MainWindow()
    presenter = MainPresenter(view=view, model=model)

    # Start the main app event loop
    exit_code = app.exec_()

    # Perform system exit to safely quit (relay the app exit code)
    sys.exit(exit_code)

mplwidget.py

from PyQt5.QtWidgets import *
from matplotlib.backends.backend_qt5agg import (FigureCanvasQTAgg as
        FigureCanvas, NavigationToolbar2QT as NavigationToolbar)
from matplotlib.figure import Figure
import numpy as np
import seaborn as sns


class MplWidget(QWidget):

    def __init__(self, parent=None):
        QWidget.__init__(self, parent)

        self.canvas = FigureCanvas(Figure())
        self.clb = []
        self.plot = []


        vertical_layout = QVBoxLayout()
        vertical_layout.addWidget(self.canvas)

        self.axes = self.canvas.figure.add_subplot(111)
        self.setLayout(vertical_layout)
        self.canvas.toolbar = NavigationToolbar(self.canvas, self)
        self.layout().addWidget(self.canvas.toolbar)
        self.layout().addWidget(self.canvas)
        X = np.random.randn(10, 8)
        self.plot = sns.heatmap(X, cmap='PuBu', square=True, linewidth=0.1, linecolor=(0.1, 0.2, 0.2),
                                ax=self.axes, vmin=np.min(X), vmax=np.max(X))

MainPresenter.py

from typing import Optional
import numpy as np
from Windows import MainWindow
import seaborn as sns


class MainPresenter:
    def __init__(self, view: MainWindow, model: Optional[int] = None):
        self.view = view
        self.model = model

        self.view.show()

        view.randMatrixGenerator.clicked.connect(self.matrixPlot)

    def matrixPlot(self, selected):
        X = np.random.randn(10,8)
        self.view.MplWidget.axes.clear()
        self.view.MplWidget.plot = sns.heatmap(X, cmap='PuBu', square=True, linewidth=0.1, linecolor=(0.1, 0.2, 0.2),
                                ax=self.view.MplWidget.axes, vmin=np.min(X), vmax=np.max(X))
        self.view.MplWidget.canvas.draw()

您可以这样创建颜色条轴:

grid_kws = {'width_ratios': (0.9, 0.05), 'wspace': 0.2}
fig, (ax, cbar_ax) = plt.subplots(1, 2, gridspec_kw = grid_kws, figsize = (10, 8))

然后,您可以将 cbar_ax 传递给 sns.heatmap:

sns.heatmap(ax = ax, cbar_ax = cbar_ax, ...)

查看 作为参考。


编辑

如果您将上述概念应用到您的文件中,则需要以这种方式编辑它们:

mplwidget.py

from PyQt5.QtWidgets import *
from matplotlib.backends.backend_qt5agg import (FigureCanvasQTAgg as
        FigureCanvas, NavigationToolbar2QT as NavigationToolbar)
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt


class MplWidget(QWidget):

    def __init__(self, parent=None):
        QWidget.__init__(self, parent)

        grid_kws = {'width_ratios': (0.9, 0.05), 'wspace': 0.2}
        self.figure, (self.axes, self.cbar_ax) = plt.subplots(1, 2, gridspec_kw = grid_kws, figsize = (10, 8))

        self.canvas = FigureCanvas(self.figure)
        self.clb = []
        self.plot = []


        vertical_layout = QVBoxLayout()
        vertical_layout.addWidget(self.canvas)

        self.setLayout(vertical_layout)
        self.canvas.toolbar = NavigationToolbar(self.canvas, self)
        self.layout().addWidget(self.canvas.toolbar)
        self.layout().addWidget(self.canvas)
        X = np.random.randn(10, 8)
        self.plot = sns.heatmap(X, cmap='PuBu', square=True, linewidth=0.1, linecolor=(0.1, 0.2, 0.2),
                                ax=self.axes, vmin=np.min(X), vmax=np.max(X), cbar_ax = self.cbar_ax)

MainPresenter.py

from typing import Optional
import numpy as np
from Windows import MainWindow
import seaborn as sns


class MainPresenter:
    def __init__(self, view: MainWindow, model: Optional[int] = None):
        self.view = view
        self.model = model


        self.view.show()

        view.randMatrixGenerator.clicked.connect(self.matrixPlot)

    def matrixPlot(self, selected):
        X = np.random.randn(10,8)
        self.view.MplWidget.axes.clear()
        self.view.MplWidget.plot = sns.heatmap(X, cmap='PuBu', square=True, linewidth=0.1, linecolor=(0.1, 0.2, 0.2),
                                ax=self.view.MplWidget.axes, vmin=np.min(X), vmax=np.max(X), cbar_ax = self.view.MplWidget.cbar_ax)
        self.view.MplWidget.canvas.draw()