如何理解拆分数据的函数

How to understand a function which splits the data

谁能帮我理解这个函数的作用?

我理解了行打印,但在那之后我有点迷路了。从 train_data.

开始
def stratifiedShuffleSplit_data(X, y):
    sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)
    for train_index, test_index in sss.split(X, y):
        print("len(TRAIN):", len(train_index), "len(TEST):", len(test_index))
        print("TRAIN:", train_index, "TEST:", test_index)

        train_data = [df.loc[ind] for ind in train_index]
        test_data = [df.loc[ind] for ind in test_index]
        save_datarows(train_data, datafile+".train")
        save_datarows(test_data, datafile+".test")

假设您使用的是熊猫包,

 pd.DataFrame.loc 

是一种基于位置的索引器 - 这是一个过于简化的版本。我将 post 一些可以帮助您更好地理解它的资源。

train_data = [df.loc[ind] for ind in train_index]

在这里,您基本上遍历了列表 ind 并存储了各自的值 train_data test_data

的情况类似

我假设 save_datarows 是一个自定义函数,用于将 train_data 存储到扩展名为 .train

的文件中

希望对您有所帮助。

这是一个非常好的参考资料,可以提供更多说明:

https://www.geeksforgeeks.org/python-pandas-dataframe-loc/