如何在 R 中直接绘制 h2o 模型对象的 ROC

How to directly plot ROC of h2o model object in R

如果我遗漏了一些明显的东西,我深表歉意。在过去的几天里,我一直非常喜欢使用 R 界面与 h2o 一起工作。我想通过绘制 ROC 来评估我的模型,比如随机森林。文档似乎表明有一种直接的方法可以做到这一点:

Interpreting a DRF Model

  • By default, the following output displays:
  • Model parameters (hidden)
  • A graph of the scoring history (number of trees vs. training MSE)
  • A graph of the ROC curve (TPR vs. FPR)
  • A graph of the variable importances ...

我还看到在 python 中您可以应用 roc 函数 here。但我似乎无法找到在 R 界面中执行相同操作的方法。目前我正在使用 h2o.cross_validation_holdout_predictions 从模型中提取预测,然后使用 R 中的 pROC 包来绘制 ROC。但我希望能够直接从 H2O 模型对象或 H2OModelMetrics 对象来完成。

非常感谢!

H2O R 或 Python 客户端目前没有直接绘制 ROC 曲线的功能。 roc method in Python returns the data neccessary to plot the ROC curve, but does not plot the curve itself. ROC curve plotting directly from R and Python seems like a useful thing to add, so I've created a JIRA ticket for it here: https://0xdata.atlassian.net/browse/PUBDEV-4449

文档中对 ROC 曲线的引用是指 H2O Flow GUI,它将自动为 H2O 集群中的任何二元分类模型绘制 ROC 曲线。然而,该列表中的所有其他项目实际上都可以直接在 R 和 Python 中使用。

如果您在 R 中训练模型,可以访问 Flow 界面(例如 localhost:54321)并单击二项式模型以查看其 ROC 曲线(训练、验证和交叉验证版本)。它看起来像这样:

一个天真的解决方案是使用 plot() 通用函数来绘制 H2OMetrics 对象:

logit_fit <- h2o.glm(colnames(training)[-1],'y',training_frame =
    training.hex,validation_frame=validation.hex,family = 'binomial')
plot(h2o.performance(logit_fit),valid=T),type='roc')

这会给我们一个情节:

但是很难定制,尤其是改变线型,因为type参数已经被当作'roc'了。此外,我还没有找到一种方法可以在一个图上绘制多个模型的 ROC 曲线。我想出了一种从 H2OMetrics 对象中提取真阳性率和假阳性率的方法,并使用 ggplot2 自己在一个图上绘制 ROC 曲线。这是示例代码(使用了很多 tidyverse 语法):

# for example I have 4 H2OModels
list(logit_fit,dt_fit,rf_fit,xgb_fit) %>% 
  # map a function to each element in the list
  map(function(x) x %>% h2o.performance(valid=T) %>% 
        # from all these 'paths' in the object
        .@metrics %>% .$thresholds_and_metric_scores %>% 
        # extracting true positive rate and false positive rate
        .[c('tpr','fpr')] %>% 
        # add (0,0) and (1,1) for the start and end point of ROC curve
        add_row(tpr=0,fpr=0,.before=T) %>% 
        add_row(tpr=0,fpr=0,.before=F)) %>% 
  # add a column of model name for future grouping in ggplot2
  map2(c('Logistic Regression','Decision Tree','Random Forest','Gradient Boosting'),
        function(x,y) x %>% add_column(model=y)) %>% 
  # reduce four data.frame to one
  reduce(rbind) %>% 
  # plot fpr and tpr, map model to color as grouping
  ggplot(aes(fpr,tpr,col=model))+
  geom_line()+
  geom_segment(aes(x=0,y=0,xend = 1, yend = 1),linetype = 2,col='grey')+
  xlab('False Positive Rate')+
  ylab('True Positive Rate')+
  ggtitle('ROC Curve for Four Models')

则ROC曲线为:

您可以通过将模型性能指标传递给 H2O 的 plot 函数来获得 roc 曲线。

缩短的代码片段假设您创建了一个模型,将其命名为 glm,并将您的数据集拆分为训练集和验证集:

perf <- h2o.performance(glm, newdata = validation)
h2o.plot(perf)

下面是完整的代码片段:

h2o.init()

# Run GLM of CAPSULE ~ AGE + RACE + PSA + DCAPS
prostatePath = system.file("extdata", "prostate.csv", package = "h2o")
prostate.hex = h2o.importFile(path = prostatePath, destination_frame = "prostate.hex")
glm = h2o.glm(y = "CAPSULE", x = c("AGE","RACE","PSA","DCAPS"), training_frame = prostate.hex, family = "binomial", nfolds = 0, alpha = 0.5, lambda_search = FALSE)

perf <- h2o.performance(glm, newdata = prostate.hex)
h2o.plot(perf)

这将产生以下结果:

建立@Lauren 的示例,在您运行 model.performance 之后,您可以从perf@metrics$thresholds_and_metric_scores 中提取ggplot 的所有必要信息。此代码生成 ROC 曲线,但您也可以将 precision, recall 添加到所选变量以绘制 PR 曲线。

下面是一些使用与上述相同模型的示例代码。

library(h2o)
library(dplyr)
library(ggplot2)

h2o.init()

# Run GLM of CAPSULE ~ AGE + RACE + PSA + DCAPS
prostatePath <- system.file("extdata", "prostate.csv", package = "h2o")
prostate.hex <- h2o.importFile(
    path = prostatePath, 
    destination_frame = "prostate.hex"
    )
glm <- h2o.glm(
    y = "CAPSULE",
    x = c("AGE", "RACE", "PSA", "DCAPS"), 
    training_frame = prostate.hex, 
    family = "binomial", 
    nfolds = 0, 
    alpha = 0.5, 
    lambda_search = FALSE
)

# Model performance
perf <- h2o.performance(glm, newdata = prostate.hex)

# Extract info for ROC curve
curve_dat <- data.frame(perf@metrics$thresholds_and_metric_scores) %>%
    select(c(tpr, fpr))

# Plot ROC curve
ggplot(curve_dat, aes(x = fpr, y = tpr)) +
    geom_point() +
    geom_line() +
    geom_segment(
        aes(x = 0, y = 0, xend = 1, yend = 1),
        linetype = "dotted",
        color = "grey50"
        ) +
    xlab("False Positive Rate") +
    ylab("True Positive Rate") +
    ggtitle("ROC Curve") +
    theme_bw()

产生这个情节的是:

roc_plot