创建热图并根据标签在其顶部绘制三个不同的线

Create heatmap and plot three different lines on top of it ,based on label

我有以下数据框:

val1=np.random.rand(234)
val2=np.random.rand(234)
val3=np.random.rand(234)
wave=np.arange(start=300, stop=1000, step=3)
labels=['label1','label2','label3']

df=pd.DataFrame([val1,val2,val3],columns=wave)
df['labels']=['label1','label2','label3']
df=df.set_index('labels')

以及下一行:

line=np.random.uniform(low=-1.5, high=1.5, size=(234,))
line_df=pd.DataFrame([wave,line]).T
line_df.columns=['wave','val']

我想用标签创建热图,并在每个标签行中绘制线,作为辅助轴,如下所示:

我在这里画了热图顶部的线(理论上应该是同一条线)。

我可以用这种方式创建热图,但不能按标签画线:

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(rc={'figure.figsize':(25,8.7)},font_scale = 2)
ax =sns.heatmap(df,cmap='Reds').set(title='heatmap')

ax2=plt.twinx()
ax2.plot(line_df['wave'], line_df['val'],color="blue",linewidth=3)

正在寻找为每个标签添加相同行的想法。

这是一种创建 3 个堆叠子图的方法。单独的子图允许 3 条曲线的单独 y 轴。

import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
import seaborn as sns
import pandas as pd
import numpy as np

val1 = np.random.rand(234)
val2 = np.random.rand(234)
val3 = np.random.rand(234)
wave = np.arange(start=300, stop=1000, step=3)
labels = ['label1', 'label2', 'label3']

df = pd.DataFrame([val1, val2, val3], columns=wave)
df['labels'] = ['label1', 'label2', 'label3']
df = df.set_index('labels')

line = np.random.uniform(low=-1.5, high=1.5, size=(234,))
line_df = pd.DataFrame([wave, line]).T
line_df.columns = ['wave', 'val']

sns.set(rc={'figure.figsize': (25, 8.7)}, font_scale=2)
fig, axs = plt.subplots(nrows=3, sharex=True, gridspec_kw={'hspace': 0})
norm = plt.Normalize(df.to_numpy().min(), df.to_numpy().max())
cmap = 'Reds'
for i, ax in enumerate(axs):
    ax.imshow(df.iloc[i:i + 1, :].to_numpy(), extent=[wave[0], wave[-1], 0, 1], aspect='auto', cmap=cmap, norm=norm)
    ax.set_yticks([0.5])
    ax.set_yticklabels([df.index[i]])
    if ax == axs[1]:
        ax.set_ylabel('Label')
    ax2 = ax.twinx()
    ax2.plot(line_df['wave'], line_df['val'], color="blue", linewidth=3)
    ax.grid(False)
    ax2.grid(False)
axs[0].set_title('heatmap')
plt.colorbar(ScalarMappable(cmap=cmap, norm=norm), ax=axs)
plt.subplots_adjust(left=0.12, right=0.72)
plt.show()