创建热图并根据标签在其顶部绘制三个不同的线
Create heatmap and plot three different lines on top of it ,based on label
我有以下数据框:
val1=np.random.rand(234)
val2=np.random.rand(234)
val3=np.random.rand(234)
wave=np.arange(start=300, stop=1000, step=3)
labels=['label1','label2','label3']
df=pd.DataFrame([val1,val2,val3],columns=wave)
df['labels']=['label1','label2','label3']
df=df.set_index('labels')
以及下一行:
line=np.random.uniform(low=-1.5, high=1.5, size=(234,))
line_df=pd.DataFrame([wave,line]).T
line_df.columns=['wave','val']
我想用标签创建热图,并在每个标签行中绘制线,作为辅助轴,如下所示:
我在这里画了热图顶部的线(理论上应该是同一条线)。
我可以用这种方式创建热图,但不能按标签画线:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(rc={'figure.figsize':(25,8.7)},font_scale = 2)
ax =sns.heatmap(df,cmap='Reds').set(title='heatmap')
ax2=plt.twinx()
ax2.plot(line_df['wave'], line_df['val'],color="blue",linewidth=3)
正在寻找为每个标签添加相同行的想法。
这是一种创建 3 个堆叠子图的方法。单独的子图允许 3 条曲线的单独 y 轴。
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
import seaborn as sns
import pandas as pd
import numpy as np
val1 = np.random.rand(234)
val2 = np.random.rand(234)
val3 = np.random.rand(234)
wave = np.arange(start=300, stop=1000, step=3)
labels = ['label1', 'label2', 'label3']
df = pd.DataFrame([val1, val2, val3], columns=wave)
df['labels'] = ['label1', 'label2', 'label3']
df = df.set_index('labels')
line = np.random.uniform(low=-1.5, high=1.5, size=(234,))
line_df = pd.DataFrame([wave, line]).T
line_df.columns = ['wave', 'val']
sns.set(rc={'figure.figsize': (25, 8.7)}, font_scale=2)
fig, axs = plt.subplots(nrows=3, sharex=True, gridspec_kw={'hspace': 0})
norm = plt.Normalize(df.to_numpy().min(), df.to_numpy().max())
cmap = 'Reds'
for i, ax in enumerate(axs):
ax.imshow(df.iloc[i:i + 1, :].to_numpy(), extent=[wave[0], wave[-1], 0, 1], aspect='auto', cmap=cmap, norm=norm)
ax.set_yticks([0.5])
ax.set_yticklabels([df.index[i]])
if ax == axs[1]:
ax.set_ylabel('Label')
ax2 = ax.twinx()
ax2.plot(line_df['wave'], line_df['val'], color="blue", linewidth=3)
ax.grid(False)
ax2.grid(False)
axs[0].set_title('heatmap')
plt.colorbar(ScalarMappable(cmap=cmap, norm=norm), ax=axs)
plt.subplots_adjust(left=0.12, right=0.72)
plt.show()
我有以下数据框:
val1=np.random.rand(234)
val2=np.random.rand(234)
val3=np.random.rand(234)
wave=np.arange(start=300, stop=1000, step=3)
labels=['label1','label2','label3']
df=pd.DataFrame([val1,val2,val3],columns=wave)
df['labels']=['label1','label2','label3']
df=df.set_index('labels')
以及下一行:
line=np.random.uniform(low=-1.5, high=1.5, size=(234,))
line_df=pd.DataFrame([wave,line]).T
line_df.columns=['wave','val']
我想用标签创建热图,并在每个标签行中绘制线,作为辅助轴,如下所示:
我在这里画了热图顶部的线(理论上应该是同一条线)。
我可以用这种方式创建热图,但不能按标签画线:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(rc={'figure.figsize':(25,8.7)},font_scale = 2)
ax =sns.heatmap(df,cmap='Reds').set(title='heatmap')
ax2=plt.twinx()
ax2.plot(line_df['wave'], line_df['val'],color="blue",linewidth=3)
正在寻找为每个标签添加相同行的想法。
这是一种创建 3 个堆叠子图的方法。单独的子图允许 3 条曲线的单独 y 轴。
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
import seaborn as sns
import pandas as pd
import numpy as np
val1 = np.random.rand(234)
val2 = np.random.rand(234)
val3 = np.random.rand(234)
wave = np.arange(start=300, stop=1000, step=3)
labels = ['label1', 'label2', 'label3']
df = pd.DataFrame([val1, val2, val3], columns=wave)
df['labels'] = ['label1', 'label2', 'label3']
df = df.set_index('labels')
line = np.random.uniform(low=-1.5, high=1.5, size=(234,))
line_df = pd.DataFrame([wave, line]).T
line_df.columns = ['wave', 'val']
sns.set(rc={'figure.figsize': (25, 8.7)}, font_scale=2)
fig, axs = plt.subplots(nrows=3, sharex=True, gridspec_kw={'hspace': 0})
norm = plt.Normalize(df.to_numpy().min(), df.to_numpy().max())
cmap = 'Reds'
for i, ax in enumerate(axs):
ax.imshow(df.iloc[i:i + 1, :].to_numpy(), extent=[wave[0], wave[-1], 0, 1], aspect='auto', cmap=cmap, norm=norm)
ax.set_yticks([0.5])
ax.set_yticklabels([df.index[i]])
if ax == axs[1]:
ax.set_ylabel('Label')
ax2 = ax.twinx()
ax2.plot(line_df['wave'], line_df['val'], color="blue", linewidth=3)
ax.grid(False)
ax2.grid(False)
axs[0].set_title('heatmap')
plt.colorbar(ScalarMappable(cmap=cmap, norm=norm), ax=axs)
plt.subplots_adjust(left=0.12, right=0.72)
plt.show()