Matplotlib 后端笔记本定制

Matplotlib backend notebook customizations

我正在使用 matplotlib 后端 'notebook',因为我正在制作一些交互式图形,这与笔记本后端配合得很好(特别是,我通过 Jupyter Notebooks 为它们提供服务)。我使用 ipywidgets 来设计 GUI 和交互性。

然而,使用这个后端,有各种各样的按钮可以干扰我的交互图形。特别是,调整大小、缩放、平移或电源按钮,会让我的学生感到困惑...

我想禁用它们。有关我要禁用的内容,请参见此插图。

谁能告诉我相关的 API 页面,或者谁知道如何 disable/remove 这些按钮?我尝试了一些其他后端,但这些后端通常不能很好地用于 Jupyter notebook 中的交互式图形,因此我想尽可能坚持使用 notebook 后端。

这是svm_helper的内容:

from matplotlib import pyplot as plt
from matplotlib.backend_bases import MouseButton as mb
import ipywidgets as widgets
import sklearn.linear_model
import sklearn.metrics
import sklearn.svm
import numpy as np

def plot_decision_boundary_margin(X, y, model):
    Xmin = np.min(X[:,:],axis=0)
    Xmax = np.max(X[:,:],axis=0)
    Xmin = np.array([-3, -3])
    Xmax = np.array([3, 3])

    x0, x1 = np.meshgrid(
            np.linspace(Xmin[0], Xmax[0], 500).reshape(-1, 1),
            np.linspace(Xmin[1], Xmax[1], 200).reshape(-1, 1),
        )
    X_new = np.c_[x0.ravel(), x1.ravel()]

    y_new = model.decision_function(X_new)

    #plot_dataset(X,y)

    zz = y_new.reshape(x0.shape)
    C1 = plt.contour(x0, x1, zz, levels=np.array([0]),colors='k')
    C2 = plt.contour(x0, x1, zz, levels=np.array([-1,1]),colors='k',linestyles='dashed')
    
    return (C1, C2)

class LineBuilder2:
    def __init__(self, lineR, lineB, widgetcolor, widgetC, my_out, need_seperable):
        self.lineR = lineR
        self.xsR = list(lineR.get_xdata())
        self.ysR = list(lineR.get_ydata())

        self.lineB = lineB
        self.xsB = list(lineB.get_xdata())
        self.ysB = list(lineB.get_ydata())

        self.mywidgetcolor = widgetcolor
        self.cid = lineR.figure.canvas.mpl_connect('button_press_event', self)
        self.cid = lineR.figure.canvas.mpl_connect('motion_notify_event', self)
        
        self.widgetC = widgetC
        self.my_out = my_out

        self.dragging_timer = 0
        self.trained = False
        
        self.model = None
        self.C1 = None
        self.C2 = None
        
        self.need_seperable = need_seperable
    
    def remove_decision_boundary(self):
        
        if (self.C1 == None) or (self.C2 == None):
            return
        
        for coll in self.C1.collections: 
            plt.gca().collections.remove(coll) 
            
        for coll in self.C2.collections: 
            plt.gca().collections.remove(coll) 

    def __call__(self, event):
        #print('click', event)

        currently_dragging = False
        if event.name == 'motion_notify_event':
            currently_dragging = True
            self.dragging_timer = self.dragging_timer+1
            if self.dragging_timer > 5:
                self.dragging_timer = 0

        if not (event.button == mb.LEFT or event.button == mb.MIDDLE or event.button == mb.RIGHT):
            return

        if event.inaxes != self.lineB.axes:
            return

        #print(widgetcolor.value)
        if self.mywidgetcolor.value == 'green':
            self.xsR.append(event.xdata)
            self.ysR.append(event.ydata)
            if (not currently_dragging) or (currently_dragging and self.dragging_timer == 0):
                self.lineR.set_data(self.xsR, self.ysR)
            #self.lineR.figure.canvas.draw()

        if self.mywidgetcolor.value == 'blue':
            self.xsB.append(event.xdata)
            self.ysB.append(event.ydata)
            if (not currently_dragging) or (currently_dragging and self.dragging_timer == 0):
                self.lineB.set_data(self.xsB, self.ysB)
            #self.lineB.figure.canvas.draw()

        #if self.dragging_timer == 0:
        #    self.lineR.figure.canvas.draw()
        
    def clear(self, button):
        
        if self.trained == False:
            with self.my_out:
                print('can only reset if trained')
            return
        
        with self.my_out:
            print('resetted the widget')
            
        self.trained = False
        
        self.remove_decision_boundary()
        self.C1 = None
        self.C2 = None
        self.model = None
        self.xsR = []
        self.ysR = []
        self.xsB = []
        self.ysB = []
        self.lineR.set_data(self.xsR, self.ysR)
        self.lineB.set_data(self.xsB, self.ysB)
        self.lineB.figure.canvas.draw()
        self.lineR.figure.canvas.draw()
        
        
    def export(self):
        
        dataR = np.array([self.xsR,self.ysR]).transpose()
        dataB = np.array([self.xsB,self.ysB]).transpose()
        yR = np.ones((dataR.shape[0], 1))
        yB = -np.ones((dataB.shape[0], 1))
        X = np.concatenate((dataR,dataB))
        y = np.concatenate((yR,yB))
        y = np.reshape(y,y.shape[0])
        return (X,y)
    
    def train(self, button):
        
        self.my_out.clear_output()
        
        if len(self.xsR) < 1 or len(self.xsB) < 1:
            with self.my_out:
                print('need at least one object in both classes to train')
            return
        
        (X,y) = self.export()
        
        if self.need_seperable:
            C = float('inf')
        else:
            C = self.widgetC.value
        
        model = sklearn.svm.LinearSVC(loss='hinge',C=C)
        model.fit(X,y)
        
        if self.need_seperable:
            acc = model.score(X,y)
            if acc < 0.99999:
                with self.my_out:
                    print('this dataset is not seperable')
                return
                
        self.remove_decision_boundary()
        
        train_error = model.score(X,y)
        
        (C1, C2) = plot_decision_boundary_margin(X,y,model)
        self.C1 = C1
        self.C2 = C2
        
        self.model = model
        
        self.trained = True
        
        with self.my_out:
            if self.need_seperable:
                print('trained hard margin SVM')
            else:
                print('trained soft margin SVM with C %f' % C)
        

def init(need_seperable = True):

    # Turn off interactivity, for now
    plt.ioff()

    fig = plt.figure(figsize = (4,4))
    ax = fig.add_subplot(111)

    # Make some nice axes
    ax.set_xlim(-3, 3)
    ax.set_ylim(-3, 3)
    ax.set_title('click to add points')
    ax.set_xlabel('Feature 1')
    ax.set_ylabel('Feature 2')

    # Remove some stuff from the backend
    #fig.canvas.toolbar_visible = False # Hide toolbar
    #fig.canvas.header_visible = False # Hide the Figure name at the top of the figure
    #fig.canvas.footer_visible = False
    #fig.canvas.resizable = False

    # These items will contain the objects
    lineR, = ax.plot([], [], linestyle="none", marker="s", color="g", markersize=10)
    lineB, = ax.plot([], [], linestyle="none", marker="^", color="b", markersize=10)

    # Make the GUI
    w_clear = widgets.Button(
        description='Clear all',
        disabled=False,
        button_style='danger', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Remove all data and start from scratch',
        icon='check' # (FontAwesome names without the `fa-` prefix)
    )

    w_color = widgets.ToggleButtons(
        options=['green', 'blue'],
        description='Class:',
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
        tooltips=['Description of slow', 'Description of regular'],
    #     icons=['check'] * 3
    )

    if not need_seperable:
        w_C = widgets.FloatLogSlider(
            value=1,
            base=10,
            min=-10, # max exponent of base
            max=10, # min exponent of base
            step=0.2, # exponent step
            #description='Log Slider',
            description='C:',
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            #readout_format='.2f',
        )
    else:
        w_C = None

    w_train = widgets.Button(
        description='Train SVM',
        disabled=False,
        button_style='warning', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='...',
        icon='check' # (FontAwesome names without the `fa-` prefix)
    )

    out = widgets.Output(layout={'border': '1px solid black'})
    out.layout.height = '40px'
    out.layout.width = '600px'


    if need_seperable:
        b1 = widgets.HBox([w_color,w_train])
        bar = widgets.VBox([b1, out])
    else:
        b1 = widgets.HBox([w_color,w_C,w_train])
        #b2 = widgets.HBox([w_train,w_C])
        bar = widgets.VBox([b1, out])

    linebuilder = LineBuilder2(lineR, lineB, w_color, w_C, out, need_seperable)
    w_clear.on_click(linebuilder.clear)
    w_train.on_click(linebuilder.train)

    # Turn interactivity back on
    plt.ion()

    out = fig
    ui = bar
    return display(ui, out)

为了开始交互,我在 Jupyter notebook 中使用了以下内容:

%matplotlib notebook
from svm_helper import init
init()

到目前为止,我发现在以 %matplotlib notebook 开头的单元格上方的单元格中添加以下代码(来自 here)有效:

%%html
<style>
.output_wrapper button.btn.btn-default,
.output_wrapper .ui-dialog-titlebar {
  display: none;
}
</style>

可能不理想,因为你必须解释为什么他们必须 运行 这样做,而不是向你的学生解释忽略按钮,但这很重要。