跳过使用插入符号拟合最终模型
Skip fitting the final model with caret
有时当我用插入符拟合模型时,我真的只是想看看它如何使用我选择的重采样方法(例如交叉验证)执行。
当我对建立在完整训练数据上的"final model"不感兴趣时,我想避免拟合它。这真的只是在开发过程中多次节省宝贵的时间。
有没有办法在使用插入符时跳过拟合最终模型?我没有在 caret::trainControl
或 caret::train
.
中看到任何相关参数
似乎确实没有直接实现这一点的论点。不过,有几个候选解决方案。
selectionFunction
作为 trainControl
的参数根据候选模型的性能(没有参数调整时只有一个候选)选择最终模型准确性、RMSE 等。将 selectionFunction
设置为 function(x, ...) NA
或 function(x, ...) NULL
之类的设置失败。但是,function(x, ...) -1
之类的东西确实部分起作用:没有警告或错误返回,并且最终模型 尝试 适合。最终结果似乎与模型有关。
另一个 trainControl
感兴趣的参数是 indexFinal
:
an optional vector of integers indicating which samples are used to
fit the final model after resampling. If NULL, then entire data set is
used.
将其设置为 NA
似乎对大多数模型都失败,除了 kNN。将其设置为 1:10
适合最终模型,如果参数很少 ,仅使用这十个观察值。因此,将其设置为 1:100
之类的设置在很多情况下都应该有效并且花费的时间很少。
您当然可以更改 train
函数本身。下面我只加了一个参数fitFinal
,默认是TRUE
,在拟合最终模型的时候检查是不是TRUE
。如果fitFinal == FALSE
,那么
finalModel <- list(fit = NULL, preProc = NULL)
finalTime <- 0
其他一切似乎 运行 顺利。至于覆盖实际的 train.default
函数,你应该在之后 运行
environment(myTrain) <- environment(caret:::train.default)
assignInNamespace("train.default", myTrain, ns = "caret")
所以,我们有
myTrain <- function (x, y, method = "rf", preProcess = NULL, ..., weights = NULL, fitFinal = TRUE,
metric = ifelse(is.factor(y), "Accuracy", "RMSE"), maximize = ifelse(metric %in%
c("RMSE", "logLoss", "MAE"), FALSE, TRUE), trControl = trainControl(),
tuneGrid = NULL, tuneLength = ifelse(trControl$method ==
"none", 1, 3))
{
startTime <- proc.time()
rs_seed <- sample.int(.Machine$integer.max, 1L)
if (is.null(colnames(x)))
stop("Please use column names for `x`", call. = FALSE)
if (is.character(y))
y <- as.factor(y)
if (!is.numeric(y) & !is.factor(y)) {
stop("Please make sure `y` is a factor or numeric value.",
call. = FALSE)
}
if (is.list(method)) {
minNames <- c("library", "type", "parameters", "grid",
"fit", "predict", "prob")
nameCheck <- minNames %in% names(method)
if (!all(nameCheck))
stop(paste("some required components are missing:",
paste(minNames[!nameCheck], collapse = ", ")),
call. = FALSE)
models <- method
method <- "custom"
}
else {
models <- getModelInfo(method, regex = FALSE)[[1]]
if (length(models) == 0)
stop(paste("Model", method, "is not in caret's built-in library"),
call. = FALSE)
}
checkInstall(models$library)
for (i in seq(along = models$library)) do.call("requireNamespaceQuietStop",
list(package = models$library[i]))
if (any(names(models) == "check") && is.function(models$check)) {
software_check <- models$check(models$library)
}
paramNames <- as.character(models$parameters$parameter)
funcCall <- match.call(expand.dots = TRUE)
modelType <- get_model_type(y)
if (!(modelType %in% models$type))
stop(paste("wrong model type for", tolower(modelType)),
call. = FALSE)
if (grepl("^svm", method) & grepl("String$", method)) {
if (is.vector(x) && is.character(x)) {
stop("'x' should be a character matrix with a single column for string kernel methods",
call. = FALSE)
}
if (is.matrix(x) && is.numeric(x)) {
stop("'x' should be a character matrix with a single column for string kernel methods",
call. = FALSE)
}
if (is.data.frame(x)) {
stop("'x' should be a character matrix with a single column for string kernel methods",
call. = FALSE)
}
}
if (modelType == "Regression" & length(unique(y)) == 2)
warning(paste("You are trying to do regression and your outcome only has",
"two possible values Are you trying to do classification?",
"If so, use a 2 level factor as your outcome column."))
if (modelType != "Classification" & !is.null(trControl$sampling))
stop("sampling methods are only implemented for classification problems",
call. = FALSE)
if (!is.null(trControl$sampling)) {
trControl$sampling <- parse_sampling(trControl$sampling)
}
if (any(class(x) == "data.table"))
x <- as.data.frame(x)
check_dims(x = x, y = y)
n <- if (class(y)[1] == "Surv")
nrow(y)
else length(y)
parallel_check("RWeka", models)
parallel_check("keras", models)
if (!is.null(preProcess) && !(all(names(preProcess) %in%
ppMethods)))
stop(paste("pre-processing methods are limited to:",
paste(ppMethods, collapse = ", ")), call. = FALSE)
if (modelType == "Classification") {
classLevels <- levels(y)
attributes(classLevels) <- list(ordered = is.ordered(y))
xtab <- table(y)
if (any(xtab == 0)) {
xtab_msg <- paste("'", names(xtab)[xtab == 0], "'",
collapse = ", ", sep = "")
stop(paste("One or more factor levels in the outcome has no data:",
xtab_msg), call. = FALSE)
}
if (trControl$classProbs && any(classLevels != make.names(classLevels))) {
stop(paste("At least one of the class levels is not a valid R variable name;",
"This will cause errors when class probabilities are generated because",
"the variables names will be converted to ",
paste(make.names(classLevels), collapse = ", "),
". Please use factor levels that can be used as valid R variable names",
" (see ?make.names for help)."), call. = FALSE)
}
if (metric %in% c("RMSE", "Rsquared"))
stop(paste("Metric", metric, "not applicable for classification models"),
call. = FALSE)
if (!trControl$classProbs && metric == "ROC")
stop(paste("Class probabilities are needed to score models using the",
"area under the ROC curve. Set `classProbs = TRUE`",
"in the trainControl() function."), call. = FALSE)
if (trControl$classProbs) {
if (!is.function(models$prob)) {
warning("Class probabilities were requested for a model that does not implement them")
trControl$classProbs <- FALSE
}
}
}
else {
if (metric %in% c("Accuracy", "Kappa"))
stop(paste("Metric", metric, "not applicable for regression models"),
call. = FALSE)
classLevels <- NA
if (trControl$classProbs) {
warning("cannnot compute class probabilities for regression")
trControl$classProbs <- FALSE
}
}
if (trControl$method == "oob" & is.null(models$oob))
stop("Out of bag estimates are not implemented for this model",
call. = FALSE)
trControl <- withr::with_seed(rs_seed, make_resamples(trControl,
outcome = y))
if (is.logical(trControl$savePredictions)) {
trControl$savePredictions <- if (trControl$savePredictions)
"all"
else "none"
}
else {
if (!(trControl$savePredictions %in% c("all", "final",
"none")))
stop("`savePredictions` should be either logical or \"all\", \"final\" or \"none\"",
call. = FALSE)
}
if (!is.null(preProcess)) {
ppOpt <- list(options = preProcess)
if (length(trControl$preProcOptions) > 0)
ppOpt <- c(ppOpt, trControl$preProcOptions)
}
else ppOpt <- NULL
if (is.null(tuneGrid)) {
if (!is.null(ppOpt) && length(models$parameters$parameter) >
1 && as.character(models$parameters$parameter) !=
"parameter") {
pp <- list(method = ppOpt$options)
if ("ica" %in% pp$method)
pp$n.comp <- ppOpt$ICAcomp
if ("pca" %in% pp$method)
pp$thresh <- ppOpt$thresh
if ("knnImpute" %in% pp$method)
pp$k <- ppOpt$k
pp$x <- x
ppObj <- do.call("preProcess", pp)
tuneGrid <- models$grid(x = predict(ppObj, x), y = y,
len = tuneLength, search = trControl$search)
rm(ppObj, pp)
}
else {
tuneGrid <- models$grid(x = x, y = y, len = tuneLength,
search = trControl$search)
if (trControl$search != "grid" && tuneLength < nrow(tuneGrid))
tuneGrid <- tuneGrid[1:tuneLength, , drop = FALSE]
}
}
if (grepl("adaptive", trControl$method) & nrow(tuneGrid) ==
1) {
stop(paste("For adaptive resampling, there needs to be more than one",
"tuning parameter for evaluation"), call. = FALSE)
}
dotNames <- hasDots(tuneGrid, models)
if (dotNames)
colnames(tuneGrid) <- gsub("^\.", "", colnames(tuneGrid))
tuneNames <- as.character(models$parameters$parameter)
goodNames <- all.equal(sort(tuneNames), sort(names(tuneGrid)))
if (!is.logical(goodNames) || !goodNames) {
stop(paste("The tuning parameter grid should have columns",
paste(tuneNames, collapse = ", ", sep = "")), call. = FALSE)
}
if (trControl$method == "none" && nrow(tuneGrid) != 1)
stop("Only one model should be specified in tuneGrid with no resampling",
call. = FALSE)
trControl$yLimits <- if (is.numeric(y))
get_range(y)
else NULL
if (trControl$method != "none") {
if (is.function(models$loop) && nrow(tuneGrid) > 1) {
trainInfo <- models$loop(tuneGrid)
if (!all(c("loop", "submodels") %in% names(trainInfo)))
stop("The 'loop' function should produce a list with elements 'loop' and 'submodels'",
call. = FALSE)
lengths <- unlist(lapply(trainInfo$submodels, nrow))
if (all(lengths == 0))
trainInfo$submodels <- NULL
}
else trainInfo <- list(loop = tuneGrid)
num_rs <- if (trControl$method != "oob")
length(trControl$index)
else 1L
if (trControl$method %in% c("boot632", "optimism_boot",
"boot_all"))
num_rs <- num_rs + 1L
if (is.null(trControl$seeds) || all(is.na(trControl$seeds))) {
seeds <- sample.int(n = 1000000L, size = num_rs *
nrow(trainInfo$loop) + 1L)
seeds <- lapply(seq(from = 1L, to = length(seeds),
by = nrow(trainInfo$loop)), function(x) {
seeds[x:(x + nrow(trainInfo$loop) - 1L)]
})
seeds[[num_rs + 1L]] <- seeds[[num_rs + 1L]][1L]
trControl$seeds <- seeds
}
else {
if (!(length(trControl$seeds) == 1 && is.na(trControl$seeds))) {
numSeeds <- unlist(lapply(trControl$seeds, length))
badSeed <- (length(trControl$seeds) < num_rs +
1L) || (any(numSeeds[-length(numSeeds)] < nrow(trainInfo$loop))) ||
(numSeeds[length(numSeeds)] < 1L)
if (badSeed)
stop(paste("Bad seeds: the seed object should be a list of length",
num_rs + 1, "with", num_rs, "integer vectors of size",
nrow(trainInfo$loop), "and the last list element having at least a",
"single integer"), call. = FALSE)
if (any(is.na(unlist(trControl$seeds))))
stop("At least one seed is missing (NA)", call. = FALSE)
}
}
if (trControl$method == "oob") {
perfNames <- metric
}
else {
testSummary <- evalSummaryFunction(y, wts = weights,
ctrl = trControl, lev = classLevels, metric = metric,
method = method)
perfNames <- names(testSummary)
}
if (!(metric %in% perfNames)) {
oldMetric <- metric
metric <- perfNames[1]
warning(paste("The metric \"", oldMetric, "\" was not in ",
"the result set. ", metric, " will be used instead.",
sep = ""))
}
if (trControl$method == "oob") {
tmp <- oobTrainWorkflow(x = x, y = y, wts = weights,
info = trainInfo, method = models, ppOpts = preProcess,
ctrl = trControl, lev = classLevels, ...)
performance <- tmp
perfNames <- colnames(performance)
perfNames <- perfNames[!(perfNames %in% as.character(models$parameters$parameter))]
if (!(metric %in% perfNames)) {
oldMetric <- metric
metric <- perfNames[1]
warning(paste("The metric \"", oldMetric, "\" was not in ",
"the result set. ", metric, " will be used instead.",
sep = ""))
}
}
else {
if (trControl$method == "LOOCV") {
tmp <- looTrainWorkflow(x = x, y = y, wts = weights,
info = trainInfo, method = models, ppOpts = preProcess,
ctrl = trControl, lev = classLevels, ...)
performance <- tmp$performance
}
else {
if (!grepl("adapt", trControl$method)) {
tmp <- nominalTrainWorkflow(x = x, y = y, wts = weights,
info = trainInfo, method = models, ppOpts = preProcess,
ctrl = trControl, lev = classLevels, ...)
performance <- tmp$performance
resampleResults <- tmp$resample
}
else {
tmp <- adaptiveWorkflow(x = x, y = y, wts = weights,
info = trainInfo, method = models, ppOpts = preProcess,
ctrl = trControl, lev = classLevels, metric = metric,
maximize = maximize, ...)
performance <- tmp$performance
resampleResults <- tmp$resample
}
}
}
trControl$indexExtra <- NULL
if (!(trControl$method %in% c("LOOCV", "oob"))) {
if (modelType == "Classification" && length(grep("^\cell",
colnames(resampleResults))) > 0) {
resampledCM <- resampleResults[, !(names(resampleResults) %in%
perfNames)]
resampleResults <- resampleResults[, -grep("^\cell",
colnames(resampleResults))]
}
else resampledCM <- NULL
}
else resampledCM <- NULL
if (trControl$verboseIter) {
cat("Aggregating results\n")
flush.console()
}
perfCols <- names(performance)
perfCols <- perfCols[!(perfCols %in% paramNames)]
if (all(is.na(performance[, metric]))) {
cat(paste("Something is wrong; all the", metric,
"metric values are missing:\n"))
print(summary(performance[, perfCols[!grepl("SD$",
perfCols)], drop = FALSE]))
stop("Stopping", call. = FALSE)
}
if (!is.null(models$sort))
performance <- models$sort(performance)
if (any(is.na(performance[, metric])))
warning("missing values found in aggregated results")
if (trControl$verboseIter && nrow(performance) > 1) {
cat("Selecting tuning parameters\n")
flush.console()
}
selectClass <- class(trControl$selectionFunction)[1]
if (grepl("adapt", trControl$method)) {
perf_check <- subset(performance, .B == max(performance$.B))
}
else perf_check <- performance
if (selectClass == "function") {
bestIter <- trControl$selectionFunction(x = perf_check,
metric = metric, maximize = maximize)
}
else {
if (trControl$selectionFunction == "oneSE") {
bestIter <- oneSE(perf_check, metric, length(trControl$index),
maximize)
}
else {
bestIter <- do.call(trControl$selectionFunction,
list(x = perf_check, metric = metric, maximize = maximize))
}
}
if (is.na(bestIter) || length(bestIter) != 1)
stop("final tuning parameters could not be determined",
call. = FALSE)
if (grepl("adapt", trControl$method)) {
best_perf <- perf_check[bestIter, as.character(models$parameters$parameter),
drop = FALSE]
performance$order <- 1:nrow(performance)
bestIter <- merge(performance, best_perf)$order
performance$order <- NULL
}
bestTune <- performance[bestIter, paramNames, drop = FALSE]
}
else {
bestTune <- tuneGrid
performance <- evalSummaryFunction(y, wts = weights,
ctrl = trControl, lev = classLevels, metric = metric,
method = method)
perfNames <- names(performance)
performance <- as.data.frame(t(performance))
performance <- cbind(performance, tuneGrid)
performance <- performance[-1, , drop = FALSE]
tmp <- resampledCM <- NULL
}
if (!(trControl$method %in% c("LOOCV", "oob", "none"))) {
byResample <- switch(trControl$returnResamp, none = NULL,
all = {
out <- resampleResults
colnames(out) <- gsub("^\.", "", colnames(out))
out
}, final = {
out <- merge(bestTune, resampleResults)
out <- out[, !(names(out) %in% names(tuneGrid)),
drop = FALSE]
out
})
}
else {
byResample <- NULL
}
orderList <- list()
for (i in seq(along = paramNames)) orderList[[i]] <- performance[,
paramNames[i]]
performance <- performance[do.call("order", orderList), ]
if (trControl$verboseIter) {
bestText <- paste(paste(names(bestTune), "=", format(bestTune,
digits = 3)), collapse = ", ")
if (nrow(performance) == 1)
bestText <- "final model"
cat("Fitting", bestText, "on full training set\n")
flush.console()
}
indexFinal <- if (is.null(trControl$indexFinal))
seq(along = y)
else trControl$indexFinal
if (!(length(trControl$seeds) == 1 && is.na(trControl$seeds)))
set.seed(trControl$seeds[[length(trControl$seeds)]][1])
if (fitFinal) {
finalTime <- system.time(finalModel <- createModel(x = subset_x(x,
indexFinal), y = y[indexFinal], wts = weights[indexFinal],
method = models, tuneValue = bestTune, obsLevels = classLevels,
pp = ppOpt, last = TRUE, classProbs = trControl$classProbs,
sampling = trControl$sampling, ...))
} else {
finalModel <- list(fit = NULL, preProc = NULL)
finalTime <- 0
}
if (trControl$trim && !is.null(models$trim)) {
if (trControl$verboseIter)
old_size <- object.size(finalModel$fit)
finalModel$fit <- models$trim(finalModel$fit)
if (trControl$verboseIter) {
new_size <- object.size(finalModel$fit)
reduction <- format(old_size - new_size, units = "Mb")
if (reduction == "0 Mb")
reduction <- "< 0 Mb"
p_reduction <- (unclass(old_size) - unclass(new_size))/unclass(old_size) *
100
p_reduction <- if (p_reduction < 1)
"< 1%"
else paste0(round(p_reduction, 0), "%")
cat("Final model footprint reduced by", reduction,
"or", p_reduction, "\n")
}
}
pp <- finalModel$preProc
finalModel <- finalModel$fit
if (method == "pls")
finalModel$bestIter <- bestTune
if (method == "glmnet")
finalModel$lambdaOpt <- bestTune$lambda
if (trControl$returnData) {
outData <- if (!is.data.frame(x))
try(as.data.frame(x), silent = TRUE)
else x
if (inherits(outData, "try-error")) {
warning("The training data could not be converted to a data frame for saving")
outData <- NULL
}
else {
outData$.outcome <- y
if (!is.null(weights))
outData$.weights <- weights
}
}
else outData <- NULL
if (trControl$savePredictions == "final")
tmp$predictions <- merge(bestTune, tmp$predictions)
endTime <- proc.time()
times <- list(everything = endTime - startTime, final = finalTime)
out <- structure(list(method = method, modelInfo = models,
modelType = modelType, results = performance, pred = tmp$predictions,
bestTune = bestTune, call = funcCall, dots = list(...),
metric = metric, control = trControl, finalModel = finalModel,
preProcess = pp, trainingData = outData, resample = byResample,
resampledCM = resampledCM, perfNames = perfNames, maximize = maximize,
yLimits = trControl$yLimits, times = times, levels = classLevels),
class = "train")
trControl$yLimits <- NULL
if (trControl$timingSamps > 0) {
pData <- x[sample(1:nrow(x), trControl$timingSamps, replace = TRUE),
, drop = FALSE]
out$times$prediction <- system.time(predict(out, pData))
}
else out$times$prediction <- rep(NA, 3)
out
}
这给了
data(iris)
TrainData <- iris[,1:4]
TrainClasses <- iris[,5]
knnFit1 <- train(TrainData, TrainClasses,
method = "knn",
preProcess = c("center", "scale"),
tuneLength = 10,
trControl = trainControl(method = "cv"), fitFinal = FALSE)
knnFit1$finalModel
# NULL
有时当我用插入符拟合模型时,我真的只是想看看它如何使用我选择的重采样方法(例如交叉验证)执行。
当我对建立在完整训练数据上的"final model"不感兴趣时,我想避免拟合它。这真的只是在开发过程中多次节省宝贵的时间。
有没有办法在使用插入符时跳过拟合最终模型?我没有在 caret::trainControl
或 caret::train
.
似乎确实没有直接实现这一点的论点。不过,有几个候选解决方案。
selectionFunction
作为trainControl
的参数根据候选模型的性能(没有参数调整时只有一个候选)选择最终模型准确性、RMSE 等。将selectionFunction
设置为function(x, ...) NA
或function(x, ...) NULL
之类的设置失败。但是,function(x, ...) -1
之类的东西确实部分起作用:没有警告或错误返回,并且最终模型 尝试 适合。最终结果似乎与模型有关。另一个
trainControl
感兴趣的参数是indexFinal
:an optional vector of integers indicating which samples are used to fit the final model after resampling. If NULL, then entire data set is used.
将其设置为
NA
似乎对大多数模型都失败,除了 kNN。将其设置为1:10
适合最终模型,如果参数很少 ,仅使用这十个观察值。因此,将其设置为1:100
之类的设置在很多情况下都应该有效并且花费的时间很少。您当然可以更改
train
函数本身。下面我只加了一个参数fitFinal
,默认是TRUE
,在拟合最终模型的时候检查是不是TRUE
。如果fitFinal == FALSE
,那么finalModel <- list(fit = NULL, preProc = NULL) finalTime <- 0
其他一切似乎 运行 顺利。至于覆盖实际的
train.default
函数,你应该在之后 运行environment(myTrain) <- environment(caret:::train.default) assignInNamespace("train.default", myTrain, ns = "caret")
所以,我们有
myTrain <- function (x, y, method = "rf", preProcess = NULL, ..., weights = NULL, fitFinal = TRUE, metric = ifelse(is.factor(y), "Accuracy", "RMSE"), maximize = ifelse(metric %in% c("RMSE", "logLoss", "MAE"), FALSE, TRUE), trControl = trainControl(), tuneGrid = NULL, tuneLength = ifelse(trControl$method == "none", 1, 3)) { startTime <- proc.time() rs_seed <- sample.int(.Machine$integer.max, 1L) if (is.null(colnames(x))) stop("Please use column names for `x`", call. = FALSE) if (is.character(y)) y <- as.factor(y) if (!is.numeric(y) & !is.factor(y)) { stop("Please make sure `y` is a factor or numeric value.", call. = FALSE) } if (is.list(method)) { minNames <- c("library", "type", "parameters", "grid", "fit", "predict", "prob") nameCheck <- minNames %in% names(method) if (!all(nameCheck)) stop(paste("some required components are missing:", paste(minNames[!nameCheck], collapse = ", ")), call. = FALSE) models <- method method <- "custom" } else { models <- getModelInfo(method, regex = FALSE)[[1]] if (length(models) == 0) stop(paste("Model", method, "is not in caret's built-in library"), call. = FALSE) } checkInstall(models$library) for (i in seq(along = models$library)) do.call("requireNamespaceQuietStop", list(package = models$library[i])) if (any(names(models) == "check") && is.function(models$check)) { software_check <- models$check(models$library) } paramNames <- as.character(models$parameters$parameter) funcCall <- match.call(expand.dots = TRUE) modelType <- get_model_type(y) if (!(modelType %in% models$type)) stop(paste("wrong model type for", tolower(modelType)), call. = FALSE) if (grepl("^svm", method) & grepl("String$", method)) { if (is.vector(x) && is.character(x)) { stop("'x' should be a character matrix with a single column for string kernel methods", call. = FALSE) } if (is.matrix(x) && is.numeric(x)) { stop("'x' should be a character matrix with a single column for string kernel methods", call. = FALSE) } if (is.data.frame(x)) { stop("'x' should be a character matrix with a single column for string kernel methods", call. = FALSE) } } if (modelType == "Regression" & length(unique(y)) == 2) warning(paste("You are trying to do regression and your outcome only has", "two possible values Are you trying to do classification?", "If so, use a 2 level factor as your outcome column.")) if (modelType != "Classification" & !is.null(trControl$sampling)) stop("sampling methods are only implemented for classification problems", call. = FALSE) if (!is.null(trControl$sampling)) { trControl$sampling <- parse_sampling(trControl$sampling) } if (any(class(x) == "data.table")) x <- as.data.frame(x) check_dims(x = x, y = y) n <- if (class(y)[1] == "Surv") nrow(y) else length(y) parallel_check("RWeka", models) parallel_check("keras", models) if (!is.null(preProcess) && !(all(names(preProcess) %in% ppMethods))) stop(paste("pre-processing methods are limited to:", paste(ppMethods, collapse = ", ")), call. = FALSE) if (modelType == "Classification") { classLevels <- levels(y) attributes(classLevels) <- list(ordered = is.ordered(y)) xtab <- table(y) if (any(xtab == 0)) { xtab_msg <- paste("'", names(xtab)[xtab == 0], "'", collapse = ", ", sep = "") stop(paste("One or more factor levels in the outcome has no data:", xtab_msg), call. = FALSE) } if (trControl$classProbs && any(classLevels != make.names(classLevels))) { stop(paste("At least one of the class levels is not a valid R variable name;", "This will cause errors when class probabilities are generated because", "the variables names will be converted to ", paste(make.names(classLevels), collapse = ", "), ". Please use factor levels that can be used as valid R variable names", " (see ?make.names for help)."), call. = FALSE) } if (metric %in% c("RMSE", "Rsquared")) stop(paste("Metric", metric, "not applicable for classification models"), call. = FALSE) if (!trControl$classProbs && metric == "ROC") stop(paste("Class probabilities are needed to score models using the", "area under the ROC curve. Set `classProbs = TRUE`", "in the trainControl() function."), call. = FALSE) if (trControl$classProbs) { if (!is.function(models$prob)) { warning("Class probabilities were requested for a model that does not implement them") trControl$classProbs <- FALSE } } } else { if (metric %in% c("Accuracy", "Kappa")) stop(paste("Metric", metric, "not applicable for regression models"), call. = FALSE) classLevels <- NA if (trControl$classProbs) { warning("cannnot compute class probabilities for regression") trControl$classProbs <- FALSE } } if (trControl$method == "oob" & is.null(models$oob)) stop("Out of bag estimates are not implemented for this model", call. = FALSE) trControl <- withr::with_seed(rs_seed, make_resamples(trControl, outcome = y)) if (is.logical(trControl$savePredictions)) { trControl$savePredictions <- if (trControl$savePredictions) "all" else "none" } else { if (!(trControl$savePredictions %in% c("all", "final", "none"))) stop("`savePredictions` should be either logical or \"all\", \"final\" or \"none\"", call. = FALSE) } if (!is.null(preProcess)) { ppOpt <- list(options = preProcess) if (length(trControl$preProcOptions) > 0) ppOpt <- c(ppOpt, trControl$preProcOptions) } else ppOpt <- NULL if (is.null(tuneGrid)) { if (!is.null(ppOpt) && length(models$parameters$parameter) > 1 && as.character(models$parameters$parameter) != "parameter") { pp <- list(method = ppOpt$options) if ("ica" %in% pp$method) pp$n.comp <- ppOpt$ICAcomp if ("pca" %in% pp$method) pp$thresh <- ppOpt$thresh if ("knnImpute" %in% pp$method) pp$k <- ppOpt$k pp$x <- x ppObj <- do.call("preProcess", pp) tuneGrid <- models$grid(x = predict(ppObj, x), y = y, len = tuneLength, search = trControl$search) rm(ppObj, pp) } else { tuneGrid <- models$grid(x = x, y = y, len = tuneLength, search = trControl$search) if (trControl$search != "grid" && tuneLength < nrow(tuneGrid)) tuneGrid <- tuneGrid[1:tuneLength, , drop = FALSE] } } if (grepl("adaptive", trControl$method) & nrow(tuneGrid) == 1) { stop(paste("For adaptive resampling, there needs to be more than one", "tuning parameter for evaluation"), call. = FALSE) } dotNames <- hasDots(tuneGrid, models) if (dotNames) colnames(tuneGrid) <- gsub("^\.", "", colnames(tuneGrid)) tuneNames <- as.character(models$parameters$parameter) goodNames <- all.equal(sort(tuneNames), sort(names(tuneGrid))) if (!is.logical(goodNames) || !goodNames) { stop(paste("The tuning parameter grid should have columns", paste(tuneNames, collapse = ", ", sep = "")), call. = FALSE) } if (trControl$method == "none" && nrow(tuneGrid) != 1) stop("Only one model should be specified in tuneGrid with no resampling", call. = FALSE) trControl$yLimits <- if (is.numeric(y)) get_range(y) else NULL if (trControl$method != "none") { if (is.function(models$loop) && nrow(tuneGrid) > 1) { trainInfo <- models$loop(tuneGrid) if (!all(c("loop", "submodels") %in% names(trainInfo))) stop("The 'loop' function should produce a list with elements 'loop' and 'submodels'", call. = FALSE) lengths <- unlist(lapply(trainInfo$submodels, nrow)) if (all(lengths == 0)) trainInfo$submodels <- NULL } else trainInfo <- list(loop = tuneGrid) num_rs <- if (trControl$method != "oob") length(trControl$index) else 1L if (trControl$method %in% c("boot632", "optimism_boot", "boot_all")) num_rs <- num_rs + 1L if (is.null(trControl$seeds) || all(is.na(trControl$seeds))) { seeds <- sample.int(n = 1000000L, size = num_rs * nrow(trainInfo$loop) + 1L) seeds <- lapply(seq(from = 1L, to = length(seeds), by = nrow(trainInfo$loop)), function(x) { seeds[x:(x + nrow(trainInfo$loop) - 1L)] }) seeds[[num_rs + 1L]] <- seeds[[num_rs + 1L]][1L] trControl$seeds <- seeds } else { if (!(length(trControl$seeds) == 1 && is.na(trControl$seeds))) { numSeeds <- unlist(lapply(trControl$seeds, length)) badSeed <- (length(trControl$seeds) < num_rs + 1L) || (any(numSeeds[-length(numSeeds)] < nrow(trainInfo$loop))) || (numSeeds[length(numSeeds)] < 1L) if (badSeed) stop(paste("Bad seeds: the seed object should be a list of length", num_rs + 1, "with", num_rs, "integer vectors of size", nrow(trainInfo$loop), "and the last list element having at least a", "single integer"), call. = FALSE) if (any(is.na(unlist(trControl$seeds)))) stop("At least one seed is missing (NA)", call. = FALSE) } } if (trControl$method == "oob") { perfNames <- metric } else { testSummary <- evalSummaryFunction(y, wts = weights, ctrl = trControl, lev = classLevels, metric = metric, method = method) perfNames <- names(testSummary) } if (!(metric %in% perfNames)) { oldMetric <- metric metric <- perfNames[1] warning(paste("The metric \"", oldMetric, "\" was not in ", "the result set. ", metric, " will be used instead.", sep = "")) } if (trControl$method == "oob") { tmp <- oobTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, method = models, ppOpts = preProcess, ctrl = trControl, lev = classLevels, ...) performance <- tmp perfNames <- colnames(performance) perfNames <- perfNames[!(perfNames %in% as.character(models$parameters$parameter))] if (!(metric %in% perfNames)) { oldMetric <- metric metric <- perfNames[1] warning(paste("The metric \"", oldMetric, "\" was not in ", "the result set. ", metric, " will be used instead.", sep = "")) } } else { if (trControl$method == "LOOCV") { tmp <- looTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, method = models, ppOpts = preProcess, ctrl = trControl, lev = classLevels, ...) performance <- tmp$performance } else { if (!grepl("adapt", trControl$method)) { tmp <- nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, method = models, ppOpts = preProcess, ctrl = trControl, lev = classLevels, ...) performance <- tmp$performance resampleResults <- tmp$resample } else { tmp <- adaptiveWorkflow(x = x, y = y, wts = weights, info = trainInfo, method = models, ppOpts = preProcess, ctrl = trControl, lev = classLevels, metric = metric, maximize = maximize, ...) performance <- tmp$performance resampleResults <- tmp$resample } } } trControl$indexExtra <- NULL if (!(trControl$method %in% c("LOOCV", "oob"))) { if (modelType == "Classification" && length(grep("^\cell", colnames(resampleResults))) > 0) { resampledCM <- resampleResults[, !(names(resampleResults) %in% perfNames)] resampleResults <- resampleResults[, -grep("^\cell", colnames(resampleResults))] } else resampledCM <- NULL } else resampledCM <- NULL if (trControl$verboseIter) { cat("Aggregating results\n") flush.console() } perfCols <- names(performance) perfCols <- perfCols[!(perfCols %in% paramNames)] if (all(is.na(performance[, metric]))) { cat(paste("Something is wrong; all the", metric, "metric values are missing:\n")) print(summary(performance[, perfCols[!grepl("SD$", perfCols)], drop = FALSE])) stop("Stopping", call. = FALSE) } if (!is.null(models$sort)) performance <- models$sort(performance) if (any(is.na(performance[, metric]))) warning("missing values found in aggregated results") if (trControl$verboseIter && nrow(performance) > 1) { cat("Selecting tuning parameters\n") flush.console() } selectClass <- class(trControl$selectionFunction)[1] if (grepl("adapt", trControl$method)) { perf_check <- subset(performance, .B == max(performance$.B)) } else perf_check <- performance if (selectClass == "function") { bestIter <- trControl$selectionFunction(x = perf_check, metric = metric, maximize = maximize) } else { if (trControl$selectionFunction == "oneSE") { bestIter <- oneSE(perf_check, metric, length(trControl$index), maximize) } else { bestIter <- do.call(trControl$selectionFunction, list(x = perf_check, metric = metric, maximize = maximize)) } } if (is.na(bestIter) || length(bestIter) != 1) stop("final tuning parameters could not be determined", call. = FALSE) if (grepl("adapt", trControl$method)) { best_perf <- perf_check[bestIter, as.character(models$parameters$parameter), drop = FALSE] performance$order <- 1:nrow(performance) bestIter <- merge(performance, best_perf)$order performance$order <- NULL } bestTune <- performance[bestIter, paramNames, drop = FALSE] } else { bestTune <- tuneGrid performance <- evalSummaryFunction(y, wts = weights, ctrl = trControl, lev = classLevels, metric = metric, method = method) perfNames <- names(performance) performance <- as.data.frame(t(performance)) performance <- cbind(performance, tuneGrid) performance <- performance[-1, , drop = FALSE] tmp <- resampledCM <- NULL } if (!(trControl$method %in% c("LOOCV", "oob", "none"))) { byResample <- switch(trControl$returnResamp, none = NULL, all = { out <- resampleResults colnames(out) <- gsub("^\.", "", colnames(out)) out }, final = { out <- merge(bestTune, resampleResults) out <- out[, !(names(out) %in% names(tuneGrid)), drop = FALSE] out }) } else { byResample <- NULL } orderList <- list() for (i in seq(along = paramNames)) orderList[[i]] <- performance[, paramNames[i]] performance <- performance[do.call("order", orderList), ] if (trControl$verboseIter) { bestText <- paste(paste(names(bestTune), "=", format(bestTune, digits = 3)), collapse = ", ") if (nrow(performance) == 1) bestText <- "final model" cat("Fitting", bestText, "on full training set\n") flush.console() } indexFinal <- if (is.null(trControl$indexFinal)) seq(along = y) else trControl$indexFinal if (!(length(trControl$seeds) == 1 && is.na(trControl$seeds))) set.seed(trControl$seeds[[length(trControl$seeds)]][1]) if (fitFinal) { finalTime <- system.time(finalModel <- createModel(x = subset_x(x, indexFinal), y = y[indexFinal], wts = weights[indexFinal], method = models, tuneValue = bestTune, obsLevels = classLevels, pp = ppOpt, last = TRUE, classProbs = trControl$classProbs, sampling = trControl$sampling, ...)) } else { finalModel <- list(fit = NULL, preProc = NULL) finalTime <- 0 } if (trControl$trim && !is.null(models$trim)) { if (trControl$verboseIter) old_size <- object.size(finalModel$fit) finalModel$fit <- models$trim(finalModel$fit) if (trControl$verboseIter) { new_size <- object.size(finalModel$fit) reduction <- format(old_size - new_size, units = "Mb") if (reduction == "0 Mb") reduction <- "< 0 Mb" p_reduction <- (unclass(old_size) - unclass(new_size))/unclass(old_size) * 100 p_reduction <- if (p_reduction < 1) "< 1%" else paste0(round(p_reduction, 0), "%") cat("Final model footprint reduced by", reduction, "or", p_reduction, "\n") } } pp <- finalModel$preProc finalModel <- finalModel$fit if (method == "pls") finalModel$bestIter <- bestTune if (method == "glmnet") finalModel$lambdaOpt <- bestTune$lambda if (trControl$returnData) { outData <- if (!is.data.frame(x)) try(as.data.frame(x), silent = TRUE) else x if (inherits(outData, "try-error")) { warning("The training data could not be converted to a data frame for saving") outData <- NULL } else { outData$.outcome <- y if (!is.null(weights)) outData$.weights <- weights } } else outData <- NULL if (trControl$savePredictions == "final") tmp$predictions <- merge(bestTune, tmp$predictions) endTime <- proc.time() times <- list(everything = endTime - startTime, final = finalTime) out <- structure(list(method = method, modelInfo = models, modelType = modelType, results = performance, pred = tmp$predictions, bestTune = bestTune, call = funcCall, dots = list(...), metric = metric, control = trControl, finalModel = finalModel, preProcess = pp, trainingData = outData, resample = byResample, resampledCM = resampledCM, perfNames = perfNames, maximize = maximize, yLimits = trControl$yLimits, times = times, levels = classLevels), class = "train") trControl$yLimits <- NULL if (trControl$timingSamps > 0) { pData <- x[sample(1:nrow(x), trControl$timingSamps, replace = TRUE), , drop = FALSE] out$times$prediction <- system.time(predict(out, pData)) } else out$times$prediction <- rep(NA, 3) out }
这给了
data(iris) TrainData <- iris[,1:4] TrainClasses <- iris[,5] knnFit1 <- train(TrainData, TrainClasses, method = "knn", preProcess = c("center", "scale"), tuneLength = 10, trControl = trainControl(method = "cv"), fitFinal = FALSE) knnFit1$finalModel # NULL