glmnet,多项式预测返回对象

glmnet, multinomial prediction returned object

我正在尝试使用 glmnet 进行分类预测,但是我无法推断出 "glmnet.predict" 的 return 对象应该代表什么。使用代码

mlogit_r<-glmnet(train_x, cbind(cns_label, renal_label,breast_label,nsclc_label,ovarian_label,leuk_label,colon_label, mela_label),
            family="multinomial", alpha=0)
pred <- predict(mlogit_r, train_x, type="class")

其中 train_x 为 57(n) x 6830(p),y 对象为 57(n) x 8 (num 类)。 returned 预测对象是一个带有标签的 57 x 100 矩阵。其中哪些是预测标签?

它没有显示在文档中,因为它只是说

The object returned depends the . . . argument which is passed on to the predict method for glmnet objects.

当您在未指定 lambda 值的情况下拟合 glmnet 模型时,默认情况下拟合包含 100 个 lambda 值的范围。当您在不指定 lambda 的情况下调用此类模型的预测时,将针对所有 lambda 进行预测,因此您会收到来自 100 个不同模型的 100 个不同的预测。

通常 运行s 交叉验证来选择一个最好的 lambda,然后使用它进行预测:

library(glmnet)
data(iris)

让我们使用 120 行进行训练:

z <- sample(1:nrow(iris), 120)

现在 运行 使用未命中分类错误的 5 折交叉验证来选择最佳 lambda:

cv_fit <- cv.glmnet(as.matrix(iris[z,-5]),
                   iris[z,5],
                   nfolds = 5,
                   type.measure = "class",
                   alpha = 0,
                   grouped = FALSE,
                   family = "multinomial")

plot(cv_fit)

这里可以看到左边虚线对应的lambda.min(5折交叉验证中误差最低的lambda)和lambda.1se(误差为1 se的lambda,附近误差最低它稍微靠右。

这些值位于:

cv_fit$lambda.min
#[1] 0.05560455

cv_fit$lambda.1se
#[1] 0.09717054

现在,当您知道最佳 lambda 时,您可以在 100 个 lambda 值上构建模型:

fit <- glmnet(as.matrix(iris[z,-5]),
              iris[z, 5],
              alpha = 0,
              family = "multinomial")

并预测特定的:

predict(fit, as.matrix(iris[-z,-5]), s = cv_fit$lambda.min, type = "class")

或在一个 lambda 上构建模型

fit1 <- glmnet(as.matrix(iris[z,-5]),
              iris[z, 5],
              alpha = 0,
              lambda = cv_fit$lambda.min,
              family = "multinomial")

并在不指定 lambda 的情况下进行预测:

all.equal(as.vector(predict(fit, as.matrix(iris[-z,-5]), s = cv_fit$lambda.min, type = "class")),
          as.vector(predict(fit1, as.matrix(iris[-z,-5]), type = "class")))

#TRUE

要查看系数的约束程度,您可以绘制模型和使用的 lambda:

plot(fit, xvar = "lambda")
abline(v = log(cv_fit$lambda.min), lty = 2)