基于分组列的多个列之间具有相关性的DataFrame

DataFrame with correlation between several columns based on grouping column

我有一个具有以下结构的数据框:

|file_id|metric_a|metric_b|metric_c|
|'1.xml'| 1      | 0.5    | 50     |
|'1.xml'| 1.5    | 0.55   | 65     | 
|'2.xml'| 2      | 0.7    | 75     |
|'2.xml'| 2.5    | 0.75   | 80     | 

因此,我想获得 'metric_c' 与其他列之间的 table 相关性:

|file_id|correlation_a_c|correlation_b_c|
|'1.xml'| 0.7           |  0.8          |
|'2.xml'| 0.75          |  0.85         | 

我用下面的代码来做,但它看起来很糟糕:

metric_a_vs_metric_c_df = source_df.groupby('file_id')[
                                  ['metric_a', 'metric_c']].corr(method='spearman').iloc[0::2,-1].reset_index().rename(
    columns={'metric_a': 'correlation_a_c'}
)
metric_b_vs_metric_c_df = source_df.groupby('file_id')[
                                 ['metric_b', 'metric_c']].corr(method='spearman').iloc[0::2,-1].reset_index().rename(
    columns={'metric_b': 'correlation_b_c'}
)
joined_df = metric_a_vs_metric_c_df.set_index('file_id').join(metric_b_vs_metric_c_df.set_index('file_id'), lsuffix='_caller', rsuffix='_other')
print(joined_df)

是否存在使它更具可读性的方法?

您可以使用crosstab它returns一个DataFrame,然后在其上应用correlation

metric_a_vs_metric_c_df = pd.crosstab(df['metric_a'],df['metric_c'])

这是一个解决方案。由于示例数据是天真的,结果也是如此 - 但它也适用于真实数据。

df = df.groupby("file_id").corr().reset_index().melt(id_vars = ["file_id", "level_1"])
ac = df[(df.level_1 == "metric_a") & (df.variable == "metric_c")]
bc = df[(df.level_1 == "metric_b") & (df.variable == "metric_c")]
df = pd.concat([ac, bc])
df["metrics"] = df.level_1 + "_" + df.variable

df = pd.pivot_table(df, index="file_id", columns="metrics")
df.columns = [c[1] for c in df.columns]

结果是:

         metric_a_metric_c  metric_b_metric_c
file_id                                      
'1.xml'                1.0                1.0
'2.xml'                1.0                1.0

您想分别计算 'a'-'c'、'b'-'c' 列之间的 (Spearman) 相关性 crosstab。 这是带有 crosstab 的单行代码,它允许您传递自定义聚合函数。类似于:

df[['a']].apply(lambda s: df['c'].corr('spearman',s.values), axis=1)

# (this is nearly working, you get the idea)

这里是从 PSV(管道分隔值,即分隔符为“|”)读取数据帧的样板,并在您的列名中替换 'metric_' -> ''。

import pandas as pd
from io import StringIO

df = """|file_id|metric_a|metric_b|metric_c|
|'1.xml'| 1      | 0.5    | 50     |
|'2.xml'| 2      | 0.7    | 75     |"""

df = pd.read_csv(StringIO(df), sep='|', index_col=[0], usecols=[1,2,3,4])

df.columns = [s.replace('metric_', '') for s in df.columns]

顺便说一下,pandas 数据帧也有一个 corr 函数,它计算所有列相关:

>>> df.corr(method='spearman')
     a    b    c
a  1.0  1.0  1.0
b  1.0  1.0  1.0
c  1.0  1.0  1.0