无法在 seaborn 热图动画中正确标记刻度线

Unable to correctly label tickmarks in seaborn heatplot animation

我有一个由 100 个 4×4 相干矩阵组成的序列,我希望使用 seaborn 对其热图进行动画处理。四个rows/columns分别对应通道'a''b''c''d',动画序列会嵌入一个tkinter图形canvas出现在 Toplevel window.

我正在使用方便的 Player class 定义 ,它创建带有播放按钮的动画。

我希望将所有这些组合成一个函数,它将主 tkinter window 名称、通道标签和矩阵列表作为其输入,其输出是所需的动画。相关代码:

M_list=[np.random.rand(4,4) for i in range(50)]
channels=['a','b','c','d']

root=Tk()
root.geometry('1000x1000')

def animate_coherence_matrices(root,channels,M_list):
    num_times=len(M_list)-1
    fig=Figure()
    plot_window = Toplevel(bg="lightgray")
    plot_window.geometry('700x700')

    canvas = FigureCanvasTkAgg(fig, master=plot_window)
    canvas.draw()
    canvas.get_tk_widget().pack(side=TOP,fill=BOTH,expand=1)


    def update_matrix(i):
        # clear current axes
        ax.cla()
        sns.heatmap(ax = ax, data = M_list[i], cmap = "coolwarm", cbar_ax = 
        cbar_ax,vmin=0,vmax=1)


    ax=fig.add_subplot(111)
    # set up axes labels OUTSIDE update function so we don't need to re-create for 
    # each frame. 
    ax.set_xticks(range(len(channels)))
    ax.set_xticklabels(channels,fontsize=10)
    ax.set_yticks(range(len(channels)))
    ax.set_yticklabels(channels,fontsize=10)

    # create heatmap color bar
    divider = make_axes_locatable(ax)
    cbar_ax = divider.append_axes("right", size="5%", pad=0.05)

    # Create animation with buttons as described at URL above.
    ani = Player(fig, update_matrix, maxi=num_times)


animate_coherence_matrices(root,channels,M_list)

root.mainloop()

除了每个帧中的轴标签是 0123 而不是通道标签 'a', 'b', 'c', 'd'.

显然我忽略了一些很容易纠正的东西。

只需移动行:

ax.set_xticks(range(len(channels)))
ax.set_xticklabels(channels, fontsize = 10)
ax.set_yticks(range(len(channels)))
ax.set_yticklabels(channels, fontsize = 10)

函数内 update_matrix:

def update_matrix(i):
    # clear current axes
    ax.cla()
    sns.heatmap(ax = ax, data = M_list[i], cmap = "coolwarm", cbar_ax = 
    cbar_ax,vmin=0,vmax=1)
    ax.set_xticks(range(len(channels)))
    ax.set_xticklabels(channels, fontsize = 10)
    ax.set_yticks(range(len(channels)))
    ax.set_yticklabels(channels, fontsize = 10)


如果您将热图单元格的标签居中,只需注释 ax.set_xticksax.set_yticks 行:

def update_matrix(i):
    # clear current axes
    ax.cla()
    sns.heatmap(ax = ax, data = M_list[i], cmap = "coolwarm", cbar_ax = 
    cbar_ax,vmin=0,vmax=1)
    # ax.set_xticks(range(len(channels)))
    ax.set_xticklabels(channels, fontsize = 10)
    # ax.set_yticks(range(len(channels)))
    ax.set_yticklabels(channels, fontsize = 10)