使用 pandas 对多标签 DataFrame 进行欠采样

Undersampling a multi-label DataFrame using pandas

我有这样一个 DataFrame:

file_name                                                label

../input/image-classification-screening/train/...         1
../input/image-classification-screening/train/...         7
../input/image-classification-screening/train/...         9
../input/image-classification-screening/train/...         9
../input/image-classification-screening/train/...         6

它有 11 个 classes(0 到 10)并且有很高的 class 不平衡。以下是 train['label'].value_counts():

的输出
6     6285
3     4139
9     3933
7     3664
2     2778
5     2433
8     2338
0     2166
4     2052
10    1039
1      922

如何在 pandas 中对这些数据进行欠采样,以便每个 class 的样本数少于 2500 个?我想从大多数 class 中随机删除数据点,例如 6、3、9、7 和 2。

您可以在 groupby.apply 中使用 sample。这是一个具有 4 个不平衡标签的可重现示例。

np.random.seed(1)
df = pd.DataFrame({
    'a':range(100), 
    'label':np.random.choice(range(4), size=100, p=[0.5,0.3,0.18,0.02])})
print(df['label'].value_counts())
# 0    51
# 1    30
# 2    18
# 3     1
# Name: label, dtype: int64

现在每个标签 select 最多 25 个(为您替换为 2500 个),您需要:

nMax = 25 #change to 2500

res = df.groupby('label').apply(lambda x: x.sample(n=min(nMax, len(x))))

print(res['label'].value_counts())
# 0    25 # see how label 0 and 1 are now 25
# 1    25
# 2    18 # and the smaller groups stay the same
# 3     1
# Name: label, dtype: int64

您可以创建一个掩码来标识哪些“标签”具有超过 2500 个项目,然后使用 groupby+sample(通过设置 n=n 对所需数量的项目进行采样项目)超过 2500 项的标签和 select 所有少于 2500 项的标签。这将创建两个数据帧,一个采样到 2500,另一个 select 整体。然后使用 pd.concat:

连接两组
n = 2500
msk = df.groupby('label')['label'].transform('size') >= n
df = pd.concat((df[msk].groupby('label').sample(n=n), df[~msk]), ignore_index=True)

例如,如果您有一个像这样的 DataFrame:

df = pd.DataFrame({'ID': range(30),
                   'label': ['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
                             'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B', 'B', 'B', 
                             'B', 'B', 'B', 'B', 'C', 'C', 'D', 'F', 'F', 'G']})

>>> df['label'].value_counts()

A    13
B    11
C     2
F     2
D     1
G     1
Name: label, dtype: int64            
        

然后上面的代码加上 n=3,得到:

    ID label
0    7     A
1    0     A
2   10     A
3   20     B
4   18     B
5   21     B
6   24     C
7   25     C
8   26     D
9   27     F
10  28     F
11  29     G

>>> df['label'].value_counts()

A    3
B    3
C    2
F    2
D    1
G    1
Name: label, dtype: int64