从 DataFrame 中删除强相关的列
Remove strongly correlated columns from DataFrame
我有一个像这样的 DataFrame
dict_ = {'Date':['2018-01-01','2018-01-02','2018-01-03','2018-01-04','2018-01-05'],'Col1':[1,2,3,4,5],'Col2':[1.1,1.2,1.3,1.4,1.5],'Col3':[0.33,0.98,1.54,0.01,0.99]}
df = pd.DataFrame(dict_, columns=dict_.keys())
然后我计算列之间的皮尔逊相关性并过滤掉相关性高于我的阈值 0.95
的列
def trimm_correlated(df_in, threshold):
df_corr = df_in.corr(method='pearson', min_periods=1)
df_not_correlated = ~(df_corr.mask(np.eye(len(df_corr), dtype=bool)).abs() > threshold).any()
un_corr_idx = df_not_correlated.loc[df_not_correlated[df_not_correlated.index] == True].index
df_out = df_in[un_corr_idx]
return df_out
产生
uncorrelated_factors = trimm_correlated(df, 0.95)
print uncorrelated_factors
Col3
0 0.33
1 0.98
2 1.54
3 0.01
4 0.99
到目前为止,我对结果很满意,但我想保留每个相关对的一列,因此在上面的示例中,我想包括 Col1 或 Col2。得到s.th。像这样
Col1 Col3
0 1 0.33
1 2 0.98
2 3 1.54
3 4 0.01
4 5 0.99
另外,我是否可以做任何进一步的评估来确定要保留哪些相关列?
感谢
您可以使用 np.tril()
而不是 np.eye()
作为遮罩:
def trimm_correlated(df_in, threshold):
df_corr = df_in.corr(method='pearson', min_periods=1)
df_not_correlated = ~(df_corr.mask(np.tril(np.ones([len(df_corr)]*2, dtype=bool))).abs() > threshold).any()
un_corr_idx = df_not_correlated.loc[df_not_correlated[df_not_correlated.index] == True].index
df_out = df_in[un_corr_idx]
return df_out
输出:
Col1 Col3
0 1 0.33
1 2 0.98
2 3 1.54
3 4 0.01
4 5 0.99
直接在数据框上使用它来排序最高的相关值。
import pandas as pd
import numpy as np
def correl(X_train):
cor = X_train.corr()
corrm = np.corrcoef(X_train.transpose())
corr = corrm - np.diagflat(corrm.diagonal())
print("max corr:",corr.max(), ", min corr: ", corr.min())
c1 = cor.stack().sort_values(ascending=False).drop_duplicates()
high_cor = c1[c1.values!=1]
## change this value to get more correlation results
thresh = 0.9
display(high_cor[high_cor>thresh])
correl(X)
output:
max corr: 0.9821068918331252 , min corr: -0.2993837739125243
object at 0x0000017712D504E0>
count_rech_2g_8 sachet_2g_8 0.982107
count_rech_2g_7 sachet_2g_7 0.979492
count_rech_2g_6 sachet_2g_6 0.975892
arpu_8 total_rech_amt_8 0.946617
arpu_3g_8 arpu_2g_8 0.942428
isd_og_mou_8 isd_og_mou_7 0.938388
arpu_2g_6 arpu_3g_6 0.933158
isd_og_mou_6 isd_og_mou_8 0.931683
arpu_3g_7 arpu_2g_7 0.930460
total_rech_amt_6 arpu_6 0.930103
isd_og_mou_7 isd_og_mou_6 0.926571
arpu_7 total_rech_amt_7 0.926111
dtype: float64
我有一个像这样的 DataFrame
dict_ = {'Date':['2018-01-01','2018-01-02','2018-01-03','2018-01-04','2018-01-05'],'Col1':[1,2,3,4,5],'Col2':[1.1,1.2,1.3,1.4,1.5],'Col3':[0.33,0.98,1.54,0.01,0.99]}
df = pd.DataFrame(dict_, columns=dict_.keys())
然后我计算列之间的皮尔逊相关性并过滤掉相关性高于我的阈值 0.95
的列def trimm_correlated(df_in, threshold):
df_corr = df_in.corr(method='pearson', min_periods=1)
df_not_correlated = ~(df_corr.mask(np.eye(len(df_corr), dtype=bool)).abs() > threshold).any()
un_corr_idx = df_not_correlated.loc[df_not_correlated[df_not_correlated.index] == True].index
df_out = df_in[un_corr_idx]
return df_out
产生
uncorrelated_factors = trimm_correlated(df, 0.95)
print uncorrelated_factors
Col3
0 0.33
1 0.98
2 1.54
3 0.01
4 0.99
到目前为止,我对结果很满意,但我想保留每个相关对的一列,因此在上面的示例中,我想包括 Col1 或 Col2。得到s.th。像这样
Col1 Col3
0 1 0.33
1 2 0.98
2 3 1.54
3 4 0.01
4 5 0.99
另外,我是否可以做任何进一步的评估来确定要保留哪些相关列?
感谢
您可以使用 np.tril()
而不是 np.eye()
作为遮罩:
def trimm_correlated(df_in, threshold):
df_corr = df_in.corr(method='pearson', min_periods=1)
df_not_correlated = ~(df_corr.mask(np.tril(np.ones([len(df_corr)]*2, dtype=bool))).abs() > threshold).any()
un_corr_idx = df_not_correlated.loc[df_not_correlated[df_not_correlated.index] == True].index
df_out = df_in[un_corr_idx]
return df_out
输出:
Col1 Col3
0 1 0.33
1 2 0.98
2 3 1.54
3 4 0.01
4 5 0.99
直接在数据框上使用它来排序最高的相关值。
import pandas as pd
import numpy as np
def correl(X_train):
cor = X_train.corr()
corrm = np.corrcoef(X_train.transpose())
corr = corrm - np.diagflat(corrm.diagonal())
print("max corr:",corr.max(), ", min corr: ", corr.min())
c1 = cor.stack().sort_values(ascending=False).drop_duplicates()
high_cor = c1[c1.values!=1]
## change this value to get more correlation results
thresh = 0.9
display(high_cor[high_cor>thresh])
correl(X)
output:
max corr: 0.9821068918331252 , min corr: -0.2993837739125243
object at 0x0000017712D504E0>
count_rech_2g_8 sachet_2g_8 0.982107
count_rech_2g_7 sachet_2g_7 0.979492
count_rech_2g_6 sachet_2g_6 0.975892
arpu_8 total_rech_amt_8 0.946617
arpu_3g_8 arpu_2g_8 0.942428
isd_og_mou_8 isd_og_mou_7 0.938388
arpu_2g_6 arpu_3g_6 0.933158
isd_og_mou_6 isd_og_mou_8 0.931683
arpu_3g_7 arpu_2g_7 0.930460
total_rech_amt_6 arpu_6 0.930103
isd_og_mou_7 isd_og_mou_6 0.926571
arpu_7 total_rech_amt_7 0.926111
dtype: float64