未在 figure/plot 中获得图例

Not getting legend in figure/plot

我在两个子图中共享 Y 轴,代码如下,但两个共享图都缺少图例。

    projectDir = r'/media/DATA/banikr_D_drive/model/2021-04-28-01-18-15_5_fold_114sub'
    logPath = os.path.join(projectDir, '2021-04-28-01-18-15_fold_1_Mixed_loss_dice.bin')
    with open(logPath, 'rb') as pfile:
        h = pickle.load(pfile)
    print(h.keys())

    fig, ax = plt.subplots(2, figsize=(20, 20), dpi=100)
    ax[0].plot(h['dice_sub_train'], color='tab:cyan', linewidth=2.0, label="Train")
    ax[0].plot(smooth_curve(h['dice_sub_train']), color='tab:purple')
    ax[0].set_xlabel('Epoch/iterations', fontsize=20)
    ax[0].set_ylabel('Dice Score', fontsize=20)
    ax[0].legend(loc='lower right', fontsize=20)#, frameon=False)
    ax1 = ax[0].twiny()
    ax1.plot(h['dice_sub_valid'], color='tab:orange', linewidth=2.0, alpha=0.9, label="Validation" )
    ax1.plot(smooth_curve(h['dice_sub_valid']), color='tab:red')
    # , bbox_to_anchor = (0.816, 0.85)
    ax[1].plot(h['loss_sub_train'], color='tab:cyan', linewidth=2.0, label="Train")
    ax[1].plot(smooth_curve(h['loss_sub_train']), color='tab:purple')
    ax2 = ax[1].twiny()
    ax2.plot(h['loss_sub_valid'], color='tab:orange', linewidth=2.0, label="Validation", alpha=0.6)
    ax2.plot(smooth_curve(h['loss_sub_valid']), color='tab:red')
    ax[1].set_xlabel('Epoch/iterations', fontsize=20)
    ax[1].set_ylabel('loss(a.u.)', fontsize=20)
    ax[1].legend(loc='upper right', fontsize=20)
    # ,bbox_to_anchor = (0.8, 0.9)
    plt.suptitle('Subject wise dice score and loss', fontsize=30)
    plt.setp(ax[0].get_xticklabels(), fontsize=20, fontweight="normal", horizontalalignment="center") #fontweight="bold"
    plt.setp(ax[0].get_yticklabels(), fontsize=20, fontweight='normal', horizontalalignment="right")
    plt.setp(ax[1].get_xticklabels(), fontsize=20, fontweight="normal", horizontalalignment="center")
    plt.setp(ax[1].get_yticklabels(), fontsize=20, fontweight="normal", horizontalalignment="right")
    plt.show()

知道如何解决这个问题吗? [1]: https://i.stack.imgur.com/kg7PY.png

ax1ax[0] 有一个双 y 轴,但它们是两个独立的轴。这就是为什么 ax[0].legend() 不知道 ax1.

Validation

要在同一个图例上设置 TrainValidation,请在主轴 ax[0]ax[1] 上绘制空线,并使用所需的 colorlabel。这将在主要图例上生成虚拟 Validation 条目:

...
ax[0].plot([], [], color='tab:orange', label="Validation")
ax[0].legend(loc='lower right', fontsize=20)
...
ax[1].plot([], [], color='tab:orange', label="Validation")
ax[1].legend(loc='upper right', fontsize=20)
...