如何覆盖 mpl_toolkits.mplot3d.Axes3D.draw() 方法?

How to override mpl_toolkits.mplot3d.Axes3D.draw() method?

我正在做一个小项目,需要解决 matplotlib 中的一个错误,以便修复一些 ax.patchesax.collections 的 zorders。更确切地说,ax.patches 是可在 space 中旋转的符号,而 ax.collectionsax.voxels 的边(因此文本必须放在它们上面)。到目前为止,我知道 mpl_toolkits.mplot3d.Axes3Ddraw 方法中隐藏了一个错误:每次我以不希望的方式移动图表时,都会重新计算 zorder。所以我决定在这些行中更改 draw 方法的定义:

    for i, col in enumerate(
            sorted(self.collections,
                   key=lambda col: col.do_3d_projection(renderer),
                   reverse=True)):
        #col.zorder = zorder_offset + i #comment this line
        col.zorder = col.stable_zorder + i #add this extra line
    for i, patch in enumerate(
            sorted(self.patches,
                   key=lambda patch: patch.do_3d_projection(renderer),
                   reverse=True)):
        #patch.zorder = zorder_offset + i #comment this line
        patch.zorder = patch.stable_zorder + i #add this extra line

假定 ax.collectionax.patch 的每个对象都有一个 stable_attribute,这是在我的项目中手动分配的。所以每次我 运行 我的项目时,我必须确保手动更改 mpl_toolkits.mplot3d.Axes3D.draw 方法(在我的项目之外)。如何避免此更改并在我的项目中以任何方式覆盖此方法?

这是我项目的 MWE:

import matplotlib.pyplot as plt
import numpy as np
#from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d.art3d as art3d
from matplotlib.text import TextPath
from matplotlib.transforms import Affine2D
from matplotlib.patches import PathPatch

class VisualArray:
    def __init__(self, arr, fig=None, ax=None):
        if len(arr.shape) == 1:
            arr = arr[None,None,:]
        elif len(arr.shape) == 2:
            arr = arr[None,:,:]
        elif len(arr.shape) > 3:
            raise NotImplementedError('More than 3 dimensions is not supported')
        self.arr = arr
        if fig is None:
            self.fig = plt.figure()
        else:
            self.fig = fig
        if ax is None:
            self.ax = self.fig.gca(projection='3d')
        else:
            self.ax = ax
        self.ax.azim, self.ax.elev = -120, 30
        self.colors = None

    def text3d(self, xyz, s, zdir="z", zorder=1, size=None, angle=0, usetex=False, **kwargs):
        d = {'-x': np.array([[-1.0, 0.0, 0], [0.0, 1.0, 0.0], [0, 0.0, -1]]),
             '-y': np.array([[0.0, 1.0, 0], [-1.0, 0.0, 0.0], [0, 0.0, 1]]),
             '-z': np.array([[1.0, 0.0, 0], [0.0, -1.0, 0.0], [0, 0.0, -1]])}

        x, y, z = xyz
        if "y" in zdir:
            x, y, z = x, z, y
        elif "x" in zdir:
            x, y, z = y, z, x
        elif "z" in zdir:
            x, y, z = x, y, z

        text_path = TextPath((-0.5, -0.5), s, size=size, usetex=usetex)
        aff = Affine2D()
        trans = aff.rotate(angle)

        # apply additional rotation of text_paths if side is dark
        if '-' in zdir:
            trans._mtx = np.dot(d[zdir], trans._mtx)
        trans = trans.translate(x, y)
        p = PathPatch(trans.transform_path(text_path), **kwargs)
        self.ax.add_patch(p)
        art3d.pathpatch_2d_to_3d(p, z=z, zdir=zdir)
        p.stable_zorder = zorder
        return p

    def on_rotation(self, event):
        vrot_idx = [self.ax.elev > 0, True].index(True)
        v_zorders = 10000 * np.array([(1, -1), (-1, 1)])[vrot_idx]
        for side, zorder in zip((self.side1, self.side4), v_zorders):
            for patch in side:
                patch.stable_zorder = zorder

        hrot_idx = [self.ax.azim < -90, self.ax.azim < 0, self.ax.azim < 90, True].index(True)
        h_zorders = 10000 * np.array([(1, 1, -1, -1), (-1, 1, 1, -1),
                              (-1, -1, 1, 1), (1, -1, -1, 1)])[hrot_idx]
        sides = (self.side3, self.side2, self.side6, self.side5)
        for side, zorder in zip(sides, h_zorders):
            for patch in side:
                patch.stable_zorder = zorder

    def voxelize(self):
        shape = self.arr.shape[::-1]
        x, y, z = np.indices(shape)
        arr = (x < shape[0]) & (y < shape[1]) & (z < shape[2])
        self.ax.voxels(arr, facecolors=self.colors, edgecolor='k')
        for col in self.ax.collections:
            col.stable_zorder = col.zorder

    def labelize(self):
        self.fig.canvas.mpl_connect('motion_notify_event', self.on_rotation)
        s = self.arr.shape
        self.side1, self.side2, self.side3, self.side4, self.side5, self.side6 = [], [], [], [], [], []
        # labelling surfaces of side1 and side4
        surf = np.indices((s[2], s[1])).T[::-1].reshape(-1, 2) + 0.5
        surf_pos1 = np.insert(surf, 2, self.arr.shape[0], axis=1)
        surf_pos2 = np.insert(surf, 2, 0, axis=1)
        labels1 = (self.arr[0]).flatten()
        labels2 = (self.arr[-1]).flatten()
        for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
            t = self.text3d(xyz, label, zdir="z", zorder=10000, size=1, usetex=True, ec="none", fc="k")
            self.side1.append(t)
        for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
            t = self.text3d(xyz, label, zdir="-z", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
            self.side4.append(t)

        # labelling surfaces of side2 and side5
        surf = np.indices((s[2], s[0])).T[::-1].reshape(-1, 2) + 0.5
        surf_pos1 = np.insert(surf, 1, 0, axis=1)
        surf = np.indices((s[0], s[2])).T[::-1].reshape(-1, 2) + 0.5
        surf_pos2 = np.insert(surf, 1, self.arr.shape[1], axis=1)
        labels1 = (self.arr[:, -1]).flatten()
        labels2 = (self.arr[::-1, 0].T[::-1]).flatten()
        for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
            t = self.text3d(xyz, label, zdir="y", zorder=10000, size=1, usetex=True, ec="none", fc="k")
            self.side2.append(t)
        for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
            t = self.text3d(xyz, label, zdir="-y", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
            self.side5.append(t)

        # labelling surfaces of side3 and side6
        surf = np.indices((s[1], s[0])).T[::-1].reshape(-1, 2) + 0.5
        surf_pos1 = np.insert(surf, 0, self.arr.shape[2], axis=1)
        surf_pos2 = np.insert(surf, 0, 0, axis=1)
        labels1 = (self.arr[:, ::-1, -1]).flatten()
        labels2 = (self.arr[:, ::-1, 0]).flatten()
        for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
            t = self.text3d(xyz, label, zdir="x", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
            self.side6.append(t)
        for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
            t = self.text3d(xyz, label, zdir="-x", zorder=10000, size=1, usetex=True, ec="none", fc="k")
            self.side3.append(t)

    def vizualize(self):
        self.voxelize()
        self.labelize()
        plt.axis('off')

arr = np.arange(60).reshape((2,6,5))
va = VisualArray(arr)
va.vizualize()
plt.show()

这是我在外部更改 ...\mpl_toolkits\mplot3d\axes3d.py 文件后得到的输出:

这是一个输出(一个不需要的)如果没有做任何改变我会得到:

你要达到的目标叫做Monkey Patching

它有缺点,必须小心使用(此关键字下有大量可用信息)。但一个选项可能看起来像这样:

from matplotlib import artist
from mpl_toolkits.mplot3d import Axes3D

# Create a new draw function
@artist.allow_rasterization
def draw(self, renderer):
    # Your version
    # ...

    # Add Axes3D explicitly to super() calls
    super(Axes3D, self).draw(renderer)

# Overwrite the old draw function
Axes3D.draw = draw

# The rest of your code
# ...

这里的注意事项是为装饰器导入 artist 并显式调用 super(Axes3D, self).method() 而不是仅使用 super().method().

根据您的用例并与其余代码保持兼容,您还可以保存原始绘制函数并仅暂时使用自定义函数:

def draw_custom():
    ...

draw_org = Axes3D.draw
Axes3D.draw = draw_custom

# Do custom stuff 

Axes3D.draw = draw_org

# Do normal stuff