使用 matplotlib 在循环内擦除和重新创建(部分,如果可能)子图的有效方法?

Efficient way to erase and re create a (part of, if possible) subplot inside loop using matplotlib?

下面的代码从 X 创建散点图并基于 w,b 的值,在 X 上创建线。

我尝试了几种组合,例如:

fig.canvas.draw()
fig.canvas.flush_events()

plt.clf
plt.cla

但他们要么似乎在绘图上绘制了多条线,要么删除了图形/轴。

是否可以只绘制一次散点图,但线条会根据 w,b 不断变化?.

下面是我用过的代码:

from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
import time
from IPython.display import display, clear_output

def get_hyperplane_value(x, w, b, offset):
    '''
    Generate Hyperplane for the plot
    '''
    return (-w[0] * x + b + offset) / w[1]


def plot_now(ax, W,b):
    '''
    Visualise the results
    '''
    x0_1 = np.amin(X[:, 0])
    x0_2 = np.amax(X[:, 0])

    x1_1 = get_hyperplane_value(x0_1, W, b, 0)
    x1_2 = get_hyperplane_value(x0_2, W, b, 0)

    x1_1_m = get_hyperplane_value(x0_1, W, b, -1)
    x1_2_m = get_hyperplane_value(x0_2, W, b, -1)

    x1_1_p = get_hyperplane_value(x0_1, W, b, 1)
    x1_2_p = get_hyperplane_value(x0_2, W, b, 1)

    ax.plot([x0_1, x0_2], [x1_1, x1_2], "y--")
    ax.plot([x0_1, x0_2], [x1_1_m, x1_2_m], "k")
    ax.plot([x0_1, x0_2], [x1_1_p, x1_2_p], "k")

    x1_min = np.amin(X[:, 1])
    x1_max = np.amax(X[:, 1])
    ax.set_ylim([x1_min - 3, x1_max + 3])
    
    ax.scatter(X[:, 0], X[:, 1], marker="o", c = y)
    return ax



X, y = datasets.make_blobs(n_samples=50, n_features=2, centers=2, cluster_std=1.05, random_state=40)
y = np.where(y == 0, -1, 1)


fig = plt.figure(figsize = (7,7))
ax = fig.add_subplot(1, 1, 1)

    
for i in range(50):
    
    W = np.random.randn(2)
    b = np.random.randn()
    
    ax.cla()
    ax = plot_now(ax, W, b)
    
    display(fig)    
    clear_output(wait = True)
    plt.pause(0.25) 

在我看来,您正在尝试为人物制作动画,因此您应该使用 FuncAnimation。动画的基本原则是初始化线条,然后更新值。

from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation

def get_hyperplane_value(x, w, b, offset):
    '''
    Generate Hyperplane for the plot
    '''
    return (-w[0] * x + b + offset) / w[1]

def get_weights_bias(i):
    W = np.random.randn(2)
    b = np.random.randn()
    return W, b

def plot_now(i):
    # retrieve weights and bias at iteration i
    W, b = get_weights_bias(i)
    
    x0_1 = np.amin(X[:, 0])
    x0_2 = np.amax(X[:, 0])

    x1_1 = get_hyperplane_value(x0_1, W, b, 0)
    x1_2 = get_hyperplane_value(x0_2, W, b, 0)

    x1_1_m = get_hyperplane_value(x0_1, W, b, -1)
    x1_2_m = get_hyperplane_value(x0_2, W, b, -1)

    x1_1_p = get_hyperplane_value(x0_1, W, b, 1)
    x1_2_p = get_hyperplane_value(x0_2, W, b, 1)

    line1.set_data([x0_1, x0_2], [x1_1, x1_2])
    line2.set_data([x0_1, x0_2], [x1_1_m, x1_2_m])
    line3.set_data([x0_1, x0_2], [x1_1_p, x1_2_p])

    x1_min = np.amin(X[:, 1])
    x1_max = np.amax(X[:, 1])
    ax.set_ylim([x1_min - 3, x1_max + 3])

X, y = datasets.make_blobs(n_samples=50, n_features=2, centers=2, cluster_std=1.05, random_state=40)
y = np.where(y == 0, -1, 1)

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
plt.scatter(X[:, 0], X[:, 1], marker="o", c = y) # ax.scatter

# initialize empty lines
line1, = ax.plot([], [], "y--")
line2, = ax.plot([], [], "k")
line3, = ax.plot([], [], "k")

# create an animation with 10 frames
anim = FuncAnimation(fig, plot_now, frames=range(10), repeat=False)
plt.show()