修复由赛璐珞制作的动画中的图例

Fix a legend in an animation created by celluloid

我想动画化通过不同的梯度下降优化方法寻找函数最小点的过程。为此,我使用了 matplotlib 和 celluloid 包。问题是无法修复动画中情节的图例,并且在每个循环中都会在先前的图例下方添加一个新图例,如下图所示。有什么方法可以修复图例并避免这个问题吗?

from celluloid import Camera
fig,ax = plt.subplots(1, 1,figsize=(10, 10))
camera = Camera(fig)
for i in range(path1.shape[1])
  ax.contour(x_mesh, y_mesh, z, levels=np.logspace(0, 5, 35), norm=LogNorm(), cmap=plt.cm.jet)
  ax.plot(*minima_, 'r*', markersize=18)

  line, = ax.plot([], [], 'k', label='Simple SGD', lw=2)
  point, = ax.plot([], [], 'ko')
  line.set_data(path1[::,:i])
  point.set_data(path1[::,i-1:i])

  line, = ax.plot([], [], 'r', label='SGD with momentum', lw=2)
  point, = ax.plot([], [], 'ro')
  line.set_data(*path2[::,:i])
  point.set_data(*path2[::,i-1:i])

  line, = ax.plot([], [], 'g', label='SGD with Nesterov', lw=2)
  point, = ax.plot([], [], 'go')
  line.set_data(*path3[::,:i])
  point.set_data(*path3[::,i-1:i])

  line, = ax.plot([], [], 'b', label='SGD with Adagrad', lw=2)
  point, = ax.plot([], [], 'bo')
  line.set_data(*path4[::,:i])
  point.set_data(*path4[::,i-1:i])

  line, = ax.plot([], [], 'c', label='SGD with Adadelta', lw=2)
  point, = ax.plot([], [], 'co')
  line.set_data(*path5[::,:i])
  point.set_data(*path5[::,i-1:i]) 

  line, = ax.plot([], [], 'm', label='SGD with RMSprob', lw=2)
  point, = ax.plot([], [], 'mo')
  line.set_data(*path6[::,:i])
  point.set_data(*path6[::,i-1:i])

  line, = ax.plot([], [], 'y', label='SGD with Adam', lw=2)
  point, = ax.plot([], [], 'yo')
  line.set_data(*path7[::,:i])
  point.set_data(*path7[::,i-1:i])

  line, = ax.plot([], [], 'y', label='SGD with Adamax', lw=2)
  point, = ax.plot([], [], 'y*')
  line.set_data(*path8[::,:i])
  point.set_data(*path8[::,i-1:i])

  line, = ax.plot([], [], 'k', label='SGD with Nadam', lw=2)
  point, = ax.plot([], [], 'kp')
  line.set_data(*path9[::,:i])
  point.set_data(*path9[::,i-1:i])

  line, = ax.plot([], [], 'r', label='SGD with AMSGrad', lw=2)
  point, = ax.plot([], [], 'rD')
  line.set_data(*path10[::,:i])
  point.set_data(*path10[::,i-1:i])

  ax.legend(loc='upper left') 
  camera.snap()
animation = camera.animate()
animation.save('2D_animation_overlap.gif', writer='imagemagick')

此处的最佳做法是创建自定义图例,而不是自动生成图例,在这种情况下,可以通过

完成
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

labels = ['Single SGD', 'SGD with momentum', 'SGD with Nesterov', 
          'SGD with Adagrad', 'SGD with Adadelta', 'SGD with RMSprob', 'SGD with Adam', 
          'SGD with Adamax', 'SGD with Nadam', 'SGD with AMSgrad']
colors = ['k', 'r', 'g', 'b', 'c', 'm', 'y', 'y', 'k', 'r']
handles = []
for c, l in zip(colors, labels):
    handles.append(Line2D([0], [0], color = c, label = l))

plt.legend(handles = handles, loc = 'upper left')

这会给你一个这样的图例:

你不需要在循环中有任何这些,你可以在之前或之后做,它仍然有效。它也可以在循环中工作,但不需要每次都重新绘制图例。

用 if 语句简单地保护图例创建而不是手动创建图例也足够了。即

    # ...
    if i == 0:
        ax.legend(loc = 'upper left')

但我建议不要鼓励自动生成图例,而建议直接创建图例。