使用多索引数据帧时如何在 seaborn clustermap 中排列 y 标签?

How to arrange y-labels in seaborn clustermap when using a multiindex dataframe?

我正在尝试使用多索引数据框自定义 seaborn 的 clustermap 的 y 标签。所以我有一个看起来像这样的数据框:

                    Col1    Col2    ...
Idx1.A    Idx2.a    1.05    1.51    ...
          Idx2.b    0.94    0.88    ...
Idx1.B    Idx2.c    1.09    1.20    ...
          Idx2.d    0.90    0.79    ...
   ...       ...     ...     ...    ...

目标是拥有相同的 y 标签,在我的示例中,Idx1 是季节,Idx2 是月份,Cols 是年份(除了它是聚类图,而不是热图- 所以我认为 seaborn 类 的函数在自定义刻度时是不同的 - 尽管 clustermap 只是在行或列上的热图上“添加”层次聚类): 我的代码:

def do_clustermap():
    with open('/home/Documents/myfile.csv', 'r') as f:
        df = pd.read_csv(f, index_col=[0, 1], sep='\t')

        g = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004)

        plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), fontsize=4)
        plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), fontsize=4)

我试图按照此 的答案进行操作,但它给出了以下消息:

UserWarning: Clustering large matrix with scipy. Installing `fastcluster` may give better performance.
Traceback (most recent call last):
  File "/home/ju/PycharmProjects/stage/figures.py", line 24, in <module>
  File "/home/ju/PycharmProjects/stage/figures.py", line 13, in do_heatmap
    ax = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/_decorators.py", line 46, in inner_f
    return f(**kwargs)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 1412, in clustermap
    tree_kws=tree_kws, **kwargs)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 1223, in plot
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 1079, in plot_dendrograms
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/_decorators.py", line 46, in inner_f
    return f(**kwargs)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 776, in dendrogram
    label=label, rotate=rotate)
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 584, in __init__
    self.linkage = self.calculated_linkage
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 651, in calculated_linkage
    return self._calculate_linkage_scipy()
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/seaborn/matrix.py", line 620, in _calculate_linkage_scipy
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/scipy/cluster/hierarchy.py", line 1038, in linkage
    y = _convert_to_double(np.asarray(y, order='c'))
  File "/home/ju/PycharmProjects/stage/venv/lib/python3.6/site-packages/scipy/cluster/hierarchy.py", line 1560, in _convert_to_double
    X = X.astype(np.double)
ValueError: could not convert string to float: 'Col1'

有人有想法吗? 这是我正在使用的文件的一个小例子:

        Robert  Jean    Lulu
Bar a   1.05    1.52    1.16
Bar b   0.94    0.49    0.83
Foo c   1.09    1.22    1.44
Foo d   0.92    0.79    0.55
Hop e   0.62    0.82    0.68
Hop f   0.52    0.18    0.31
Hop g   0.93    1.15    1.11


import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

df = pd.DataFrame({'Idx1': ['Bar', 'Bar', 'Foo', 'Foo', 'Hop', 'Hop', 'Hop'],
                   'Idx2': ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
                   'Col1': np.random.rand(7),
                   'Col2': np.random.rand(7)})
df = df.set_index(['Idx1', 'Idx2'])

g = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004)

plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), fontsize=10)
plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), fontsize=10)


               Col1      Col2
Idx1 Idx2                    
Bar  a     0.366961  0.253956
     b     0.320457  0.807694
Foo  c     0.293184  0.337154
     d     0.868155  0.661968
Hop  e     0.908930  0.406291
     f     0.670220  0.668903
     g     0.683821  0.476246

使用 seaborn 0.11.1、matplotlib 3.4.2、pandas 1.2.4 和 scipy 1.6.3 生成以下图:


import matplotlib.pyplot as plt
from itertools import groupby
import seaborn as sns
import pandas as pd
import numpy as np

def add_line(ax, xpos, ypos):
    line = plt.Line2D([ypos, ypos+ .2], [xpos, xpos], color='black', transform=ax.transAxes)

def label_len(my_index,level):
    labels = my_index.get_level_values(level)
    return [(k, sum(1 for i in g)) for k,g in groupby(labels)]

def label_group_bar_table(ax, df):
    xpos = -.2
    scale = 1./df.index.size
    for level in range(df.index.nlevels):
        pos = df.index.size
        for label, rpos in label_len(df.index,level):
            add_line(ax, pos*scale, xpos)
            pos -= rpos
            lypos = (pos + .5 * rpos)*scale
            ax.text(xpos+.1, lypos, label, ha='center', transform=ax.transAxes)
        add_line(ax, pos*scale , xpos)
        xpos -= .2

df = pd.DataFrame({'Idx1': ['Bar', 'Bar', 'Foo', 'Foo', 'Hop', 'Hop', 'Hop'],
                   'Idx2': ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
                   'Col1': np.random.rand(7),
                   'Col2': np.random.rand(7)})
df = df.set_index(['Idx2', 'Idx1'])

g = sns.clustermap(df, center=1, row_cluster=False, cmap="YlGnBu", yticklabels=True, xticklabels=True, linewidths=0.004, figsize=(10,5))

plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), fontsize=10)
label_group_bar_table(g.ax_heatmap, df)