通过替换 sklearn.cross_validation 从 sklearn.model_selection 导入 StratifiedShuffleSplit 时应该对参数进行哪些更改

What changes in parameters should make when import StratifiedShuffleSplit from sklearn.model_selection by replacing sklearn.cross_validation

我正在尝试 运行 用于隔离语音识别的 python3 代码,但我收到了使用的 DeprecationWarning:

from sklearn.cross_validation import StratifiedShuffleSplit

为了消除此警告,我只是从 sklearn.model_selection 而不是 sklearn.cross_validation 导入了 StratifiedShuffleSplit,在 运行 代码之后,我得到:

TypeError: 'StratifiedShuffleSplit' object is not iterable

可能是因为在

class sklearn.cross_validation.StratifiedShuffleSplit(y, n_iter=10, test_size=0.1, train_size=None, random_state=None)

y是一个数组。

同时在:

class sklearn.cross_validation.StratifiedShuffleSplit(y, n_iter=10, test_size=0.1, train_size=None, random_state=None)

没有数组:

from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(all_labels, test_size=0.1, random_state=0)

for n,i in enumerate(all_obs):
    all_obs[n] /= all_obs[n].sum(axis=0)

for train_index, test_index in sss:
    X_train, X_test = all_obs[train_index, ...], all_obs[test_index, ...]
    y_train, y_test = all_labels[train_index], all_labels[test_index]
ys = set(all_labels)
ms = [gmmhmm(7) for y in ys]

如何替换 all_labels 因为它是根据 sklearn.cross_validation 的数组,但 sklearn.model_selection 不接受数组参数。

两者的区别在于

  • sklearn.model_selection.StratifiedShuffleSplit 是交叉验证器

  • sklearn.cross_validation.StratifiedShuffleSplit是交叉验证器迭代器

因此,您示例中的正确用法是

from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(test_size=0.1, random_state=0)

for n,i in enumerate(all_obs):
    all_obs[n] /= all_obs[n].sum(axis=0)

for train_index, test_index in sss.split(all_obs, all_labels):
     print(train_index, test_index)

阅读 sklearn.model_selection.StratifiedShuffleSplit

文档中的示例可能会有所帮助