sklearn train_test_split 在 pandas

sklearn train_test_split on pandas

我是 sklearn 的新用户,对使用 sklearn.model_selection 中的 train_test_split 有疑问。我有一个形状为 (96350, 156) 的大型数据框。在我的数据框中,名为 CountryName 的列包含 160 个国家/地区,每个国家/地区大约有 600 个实例。

输入:

df['CountryName'].unique()

输出:

array(['Aruba', 'Afghanistan', 'Angola', 'Albania', 'Andorra',
       'United Arab Emirates', 'Argentina', 'Australia', 'Austria',
       'Azerbaijan', 'Belgium', 'Benin', 'Burkina Faso', 'Bangladesh',
       'Bulgaria', 'Bahrain', 'Bahamas', 'Bosnia and Herzegovina',
...
       'Slovenia', 'Sweden', 'Eswatini', 'Seychelles', 'Chad', 'Togo',
       'Thailand', 'Trinidad and Tobago', 'Tunisia', 'Turkey', 'Taiwan',
       'Tanzania', 'Uganda', 'Ukraine', 'Uruguay', 'United States',
       'Uzbekistan', 'Venezuela', 'Vietnam', 'South Africa', 'Zambia',
       'Zimbabwe'], dtype=object)

如何在国家层面而非实例层面实施 train_test_split?为了更好地理解我的问题,我快速 table 这是我的数据框。我如何在国家/地区(例如阿鲁巴)执行 train_test_split(因此我们从这个阿鲁巴国家/地区获得 70% 的训练数据和 30% 的测试数据),并为所有国家/地区执行此操作,最后添加这些 trained/testing (X_train、X_test、y_train 和 y_test) 数据一起在另一个数据帧中?

可视化:

(____part of X dataset____)   (y dataset)   
CountryName  value1  value2 ... valueN
   Aruba       1       3    ...   3
   Aruba       2       4    ...   6
   Aruba       3       4    ...   1
    ...       ...     ...   ...  ...
   Sweden      5       3    ...   2
   Sweden      4       7    ...   2
    ...       ...     ...   ...  ...
  Zimbabwe     2       3    ...   9
  Zimbabwe     1       2    ...   8 
  Zimbabwe     5       1    ...   1
  Zimbabwe     5       3    ...   3
    ...       ...     ...   ...  ...

使用stratify作为train_test_split的参数:

类似于:

X_train, X_test = train_test_split(df, test_size=.3, stratify=df['CountryName'])

更新:使用您的数据:

>>> train_test_split(df, test_size=.3, stratify=df['CountryName'])
[  CountryName  value1  value2  valueN
 3      Sweden       5       3       2
 7    Zimbabwe       5       1       1
 0       Aruba       1       3       3
 1       Aruba       2       4       6
 8    Zimbabwe       5       3       3
 5    Zimbabwe       2       3       9,

   CountryName  value1  value2  valueN
 6    Zimbabwe       1       2       8
 2       Aruba       3       4       1
 4      Sweden       4       7       2]