在 matplotlib 3d 散点图中更改数据点的颜色并通过按键将其删除

Change colour of data points on selection and remove them with key press in matplotlib 3d scatter plot

我在 matplotlib 中有一个 3d 散点图,并设置了注释,灵感来自答案 here, particularly that by Don Cristobal

我已经设置了一些基本的事件捕获代码,但经过几天的尝试,我仍然没有实现我的目标。它们是:

(i) 使用鼠标左键选择时,将点(点)的颜色从蓝色更改为例如黑暗 blue/green.

(ii) 在按下 'delete' 键后删除在 (i) 中选择的任何选定点,包括任何注释

(iii) Select (i) 中的多个点使用 selection rectangle 并使用 'delete' 键删除

我尝试了很多方法,包括动画图表以根据数据变化进行更新、操纵艺术家参数、通过例如更改数据点xs, ys, zs = graph._offsets3d(似乎没有记录),但无济于事。

我已经尝试在 onpick(event) 函数中:

(i) 通过 event.ind 与点交互以使用 event.artist.set_face_colour()

改变颜色

(ii) 使用 artist.remove()

删除点

(iii) Remove points using xs, ys, zs = graph._offsets3d ,从xs, ys, zs中按索引(event.ind[0])移除相关点,然后重置图点通过 graph._offsets3d = xs_new, ys_new, zs_new

(iv) 重绘图表或仅重绘图表的相关部分(blitting?)

没有成功!

我现在的代码大致如下。事实上,我有几百个点,而不是下面简化示例中的 3 个。如果可能的话,我希望图表能够顺利更新,尽管只是得到可用的东西会很棒。执行此操作的大部分代码可能应该驻留在 'onpick' 中,因为这是处理拾取事件的函数(请参阅 event handler)。我保留了一些我的代码尝试,注释掉了,我希望这可能会有一些用处。 'forceUpdate' 函数旨在在事件触发时更新图形对象,但我不相信它目前有任何作用。功能 on_key(event) 目前似乎也不起作用:大概必须有一个设置才能确定要删除的点,例如所有具有已从默认更改的面色的艺术家(例如,删除所有具有深色 blue/green 而不是浅蓝色的点)。

非常感谢任何帮助。

调用代码(下方):

visualize3DData (Y, ids, subindustry)

一些示例数据点如下:

#Datapoints
Y = np.array([[ 4.82250000e+01,  1.20276889e-03,  9.14501289e-01], [ 6.17564688e+01,  5.95020883e-02, -1.56770827e+00], [ 4.55139000e+01,  9.13454423e-02, -8.12277299e+00]])

#Annotations
ids = ['a', 'b', 'c']

subindustry =  'example'

我当前的代码在这里:

import matplotlib.pyplot as plt, numpy as np
from mpl_toolkits.mplot3d import proj3d

def visualize3DData (X, ids, subindus):
    """Visualize data in 3d plot with popover next to mouse position.

    Args:
        X (np.array) - array of points, of shape (numPoints, 3)
    Returns:
        None
    """
    fig = plt.figure(figsize = (16,10))
    ax = fig.add_subplot(111, projection = '3d')
    graph  = ax.scatter(X[:, 0], X[:, 1], X[:, 2], depthshade = False, picker = True)  

    def distance(point, event):
        """Return distance between mouse position and given data point

        Args:
            point (np.array): np.array of shape (3,), with x,y,z in data coords
            event (MouseEvent): mouse event (which contains mouse position in .x and .xdata)
        Returns:
            distance (np.float64): distance (in screen coords) between mouse pos and data point
        """
        assert point.shape == (3,), "distance: point.shape is wrong: %s, must be (3,)" % point.shape

        # Project 3d data space to 2d data space
        x2, y2, _ = proj3d.proj_transform(point[0], point[1], point[2], plt.gca().get_proj())
        # Convert 2d data space to 2d screen space
        x3, y3 = ax.transData.transform((x2, y2))

        return np.sqrt ((x3 - event.x)**2 + (y3 - event.y)**2)


    def calcClosestDatapoint(X, event):
        """"Calculate which data point is closest to the mouse position.

        Args:
            X (np.array) - array of points, of shape (numPoints, 3)
            event (MouseEvent) - mouse event (containing mouse position)
        Returns:
            smallestIndex (int) - the index (into the array of points X) of the element closest to the mouse position
        """
        distances = [distance (X[i, 0:3], event) for i in range(X.shape[0])]
        return np.argmin(distances)


    def annotatePlot(X, index, ids):
        """Create popover label in 3d chart

        Args:
            X (np.array) - array of points, of shape (numPoints, 3)
            index (int) - index (into points array X) of item which should be printed
        Returns:
            None
        """
        # If we have previously displayed another label, remove it first
        if hasattr(annotatePlot, 'label'):
            annotatePlot.label.remove()
        # Get data point from array of points X, at position index
        x2, y2, _ = proj3d.proj_transform(X[index, 0], X[index, 1], X[index, 2], ax.get_proj())
        annotatePlot.label = plt.annotate( ids[index],
            xy = (x2, y2), xytext = (-20, 20), textcoords = 'offset points', ha = 'right', va = 'bottom',
            bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
            arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))
        fig.canvas.draw()


    def onMouseMotion(event):
        """Event that is triggered when mouse is moved. Shows text annotation over data point closest to mouse."""
        closestIndex = calcClosestDatapoint(X, event)
        annotatePlot (X, closestIndex, ids) 


    def onclick(event):
        print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
              ('double' if event.dblclick else 'single', event.button,
               event.x, event.y, event.xdata, event.ydata))

    def on_key(event):
        """
        Function to be bound to the key press event
        If the key pressed is delete and there is a picked object,
        remove that object from the canvas
        """
        if event.key == u'delete':
            ax = plt.gca()
            if ax.picked_object:
                ax.picked_object.remove()
                ax.picked_object = None
                ax.figure.canvas.draw()

    def onpick(event):

        xmouse, ymouse = event.mouseevent.xdata, event.mouseevent.ydata
        artist = event.artist
        # print(dir(event.mouseevent))
        ind = event.ind
        # print('Artist picked:', event.artist)
        # # print('{} vertices picked'.format(len(ind)))
        print('ind', ind)
        # # print('Pick between vertices {} and {}'.format(min(ind), max(ind) + 1))
        # print('x, y of mouse: {:.2f},{:.2f}'.format(xmouse, ymouse))
        # # print('Data point:', x[ind[0]], y[ind[0]])
        #
        # # remove = [artist for artist in pickable_artists if     artist.contains(event)[0]]
        # remove = [artist for artist in X if artist.contains(event)[0]]
        #
        # if not remove:
        #     # add a pt
        #     x, y = ax.transData.inverted().transform_point([event.x,     event.y])
        #     pt, = ax.plot(x, y, 'o', picker=5)
        #     pickable_artists.append(pt)
        # else:
        #     for artist in remove:
        #         artist.remove()
        # plt.draw()
        # plt.draw_idle()

        xs, ys, zs = graph._offsets3d
        print(xs[ind[0]])
        print(ys[ind[0]])
        print(zs[ind[0]])
        print(dir(artist))

        # xs[ind[0]] = 0.5
        # ys[ind[0]] = 0.5
        # zs[ind[0]] = 0.5   
        # graph._offsets3d = (xs, ys, zs)

        # print(artist.get_facecolor())
        # artist.set_facecolor('red')
        graph._facecolors[ind, :] = (1, 0, 0, 1)

        plt.draw()

    def forceUpdate(event):
        global graph
        graph.changed()

    fig.canvas.mpl_connect('motion_notify_event', onMouseMotion)  # on mouse motion    
    fig.canvas.mpl_connect('button_press_event', onclick)
    fig.canvas.mpl_connect('pick_event', onpick)
    fig.canvas.mpl_connect('draw_event', forceUpdate)

    plt.tight_layout()

    plt.show()

好的,我至少为您提供了部分解决方案,没有矩形 selection,但您可以 select 多个点并用一个 key_event 删除它们。

要更改颜色,您需要更改 graph._facecolor3d,提示在 this 关于 set_facecolor 未设置 _facecolor3d 的错误报告中。

将您的函数重写为 class 以摆脱任何需要的 global 变量可能也是个好主意。

我的解决方案有一些不太漂亮的部分,我需要在删除数据点后重新绘制图形,我无法删除和更新工作。 还有 (见下面的编辑 2)。 我还没有实现如果删除最后一个数据点会发生什么。

您的功能 on_key(event) 不起作用的原因很简单:您忘记连接它了。

所以这是一个应该满足目标 (i) 和 (ii) 的解决方案:

import matplotlib.pyplot as plt, numpy as np
from mpl_toolkits.mplot3d import proj3d

class Class3DDataVisualizer:    
    def __init__(self, X, ids, subindus, drawNew = True):

        self.X = X;
        self.ids = ids
        self.subindus = subindus

        self.disconnect = False
        self.ind = []
        self.label = None

        if drawNew:        
            self.fig = plt.figure(figsize = (7,5))
        else:
            self.fig.delaxes(self.ax)
        self.ax = self.fig.add_subplot(111, projection = '3d')
        self.graph  = self.ax.scatter(self.X[:, 0], self.X[:, 1], self.X[:, 2], depthshade = False, picker = True, facecolors=np.repeat([[0,0,1,1]],X.shape[0], axis=0) )         
        if drawNew and not self.disconnect:
            self.fig.canvas.mpl_connect('motion_notify_event', lambda event: self.onMouseMotion(event))  # on mouse motion    
            self.fig.canvas.mpl_connect('pick_event', lambda event: self.onpick(event))
            self.fig.canvas.mpl_connect('key_press_event', lambda event: self.on_key(event))

        self.fig.tight_layout()
        self.fig.show()


    def distance(self, point, event):
        """Return distance between mouse position and given data point

        Args:
            point (np.array): np.array of shape (3,), with x,y,z in data coords
            event (MouseEvent): mouse event (which contains mouse position in .x and .xdata)
        Returns:
            distance (np.float64): distance (in screen coords) between mouse pos and data point
        """
        assert point.shape == (3,), "distance: point.shape is wrong: %s, must be (3,)" % point.shape

        # Project 3d data space to 2d data space
        x2, y2, _ = proj3d.proj_transform(point[0], point[1], point[2], plt.gca().get_proj())
        # Convert 2d data space to 2d screen space
        x3, y3 = self.ax.transData.transform((x2, y2))

        return np.sqrt ((x3 - event.x)**2 + (y3 - event.y)**2)


    def calcClosestDatapoint(self, event):
        """"Calculate which data point is closest to the mouse position.

        Args:
            X (np.array) - array of points, of shape (numPoints, 3)
            event (MouseEvent) - mouse event (containing mouse position)
        Returns:
            smallestIndex (int) - the index (into the array of points X) of the element closest to the mouse position
        """
        distances = [self.distance (self.X[i, 0:3], event) for i in range(self.X.shape[0])]
        return np.argmin(distances)


    def annotatePlot(self, index):
        """Create popover label in 3d chart

        Args:
            X (np.array) - array of points, of shape (numPoints, 3)
            index (int) - index (into points array X) of item which should be printed
        Returns:
            None
        """
        # If we have previously displayed another label, remove it first
        if self.label is not None:
            self.label.remove()
        # Get data point from array of points X, at position index
        x2, y2, _ = proj3d.proj_transform(self.X[index, 0], self.X[index, 1], self.X[index, 2], self.ax.get_proj())
        self.label = plt.annotate( self.ids[index],
            xy = (x2, y2), xytext = (-20, 20), textcoords = 'offset points', ha = 'right', va = 'bottom',
            bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
            arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))
        self.fig.canvas.draw()


    def onMouseMotion(self, event):
        """Event that is triggered when mouse is moved. Shows text annotation over data point closest to mouse."""
        closestIndex = self.calcClosestDatapoint(event)
        self.annotatePlot (closestIndex) 


    def on_key(self, event):
        """
        Function to be bound to the key press event
        If the key pressed is delete and there is a picked object,
        remove that object from the canvas
        """
        if event.key == u'delete':
            if self.ind:
                self.X = np.delete(self.X, self.ind, axis=0)
                self.ids = np.delete(ids, self.ind, axis=0)
                self.__init__(self.X, self.ids, self.subindus, False)
            else:
                print('nothing selected')

    def onpick(self, event):
        self.ind.append(event.ind)
        self.graph._facecolor3d[event.ind] = [1,0,0,1]



#Datapoints
Y = np.array([[ 4.82250000e+01,  1.20276889e-03,  9.14501289e-01], [ 6.17564688e+01,  5.95020883e-02, -1.56770827e+00], [ 4.55139000e+01,  9.13454423e-02, -8.12277299e+00], [3,  8, -8.12277299e+00]])
#Annotations
ids = ['a', 'b', 'c', 'd']

subindustries =  'example'

Class3DDataVisualizer(Y, ids, subindustries)

要实现矩形 selection,您必须覆盖当前在拖动(旋转 3D 图)期间发生的情况,或者更简单的解决方案是通过连续两次单击来定义矩形。

然后使用 proj3d.proj_transform 查找该矩形内的数据,找到所述数据的索引并使用 self.graph._facecolor3d[idx] 函数重新着色并用这些索引填充 self.ind,之后点击删除将负责删除由 self.ind.

指定的所有数据

编辑: 我在 __init__ 中添加了两行,删除了 ax/subplot,然后在删除数据点后添加新行。我注意到在删除了几个数据点后绘图交互变得缓慢,因为该图只是在每个子图中绘制。

编辑 2: 我发现了如何修改数据而不是重新绘制整个图,如 中所述,您必须修改 _offsets3d,这很奇怪 return xy 的元组,但是 z.

的数组

您可以使用

修改它
(x,y,z) = self.graph._offsets3d # or event.artist._offsets3d
xNew = x[:int(idx)] + x[int(idx)+1:]
yNew = y[:int(idx)] + y[int(idx)+1:]
z = np.delete(z, int(idx))
self.graph._offsets3d = (xNew,yNew,z) # or event.artist._offsets3d

但是你会 运行 在循环中删除几个数据点时遇到问题,因为你之前存储的索引在第一个循环后将不再适用,你必须更新 _facecolor3d, 标签列表...所以我选择保留答案原样,因为用新数据重新绘制图表似乎比那更容易和更清晰。