如何使用 ggplot 绘制 3d 超平面来说明决策边界?

How can I draw 3d hyperplane to illustrate decision boundary using ggplot?

我有数据框 df,它有 3d 输入数据:x1、x2、x3 和目标 t。我使用逻辑回归来创建决策边界

a0 + a1 * x1 + a2 * x2 + a3 * x3 = 0

我想知道是否有一种方法可以使用 ggplot 绘制 3d 超平面(以及 3d 输入数据)来说明逻辑回归创建的决策边界。

谢谢

您无法在 ggplot2 中获得真正的 3D 绘图,但可以使用等高线或颜色填充来表示 3d 平面。下面是一个使用彩色栅格图层表示平面的示例。

我从问题中假设您希望决策边界位于概率为 0.5 的位置(即对数赔率 = 0)

首先我们需要一个逻辑回归模型,因此在问题中没有任何数据的情况下,让我们创建一些模型来为我们提供一个很好的例子:

# Create dummy data for logistic regression
set.seed(69)

x1       <- sample(100, 1000, TRUE)
x2       <- sample(100, 1000, TRUE)
x3       <- sample(100, 1000, TRUE)
log_odds <- -1 + 0.02 * x1 + 0.005 * x2 - 0.03 * x3 + rnorm(1000, 0, 2)
odds     <- exp(log_odds)
probs    <- odds/(1 + odds)
y        <- rbinom(1000, 1, probs)
df       <- data.frame(y, x1, x2, x3)

现在我们有一个二进制结果,y,其值取决于三个自变量的值 x1x2x3,所以我们可以 运行 逻辑回归并获取其系数:

# Run logistic regression and extract coefficients
logistic_model <- glm(y ~ x1 + x2 + x3, data = df, family = binomial)
summary(logistic_model)
#> 
#> Call:
#> glm(formula = y ~ x1 + x2 + x3, family = binomial, data = df)
#> 
#> Deviance Residuals: 
#>     Min       1Q   Median       3Q      Max  
#> -1.5058  -0.8689  -0.6296   1.1264   2.3669  
#> 
#> Coefficients:
#>              Estimate Std. Error z value Pr(>|z|)    
#> (Intercept) -0.888782   0.232728  -3.819 0.000134 ***
#> x1           0.012369   0.002562   4.828 1.38e-06 ***
#> x2           0.008031   0.002478   3.241 0.001191 ** 
#> x3          -0.020676   0.002560  -8.076 6.67e-16 ***
#> ---
#> Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#> 
#> (Dispersion parameter for binomial family taken to be 1)
#> 
#>     Null deviance: 1235.0  on 999  degrees of freedom
#> Residual deviance: 1129.9  on 996  degrees of freedom
#> AIC: 1137.9
#> 
#> Number of Fisher Scoring iterations: 4

coefs <- coef(logistic_model)

我们的绘图将在 x 轴上显示 x1,在 y 轴上显示 x2。每个点 (x1, x2) 的颜色将是 x3 的值,它产生的对数赔率为 0。我们可以通过重新排列您在问题中显示的公式 a0 + a1 * x1 + a2 * x2 + a3 * x3 = 0 来获得:

# Create a function that returns the value of x3 at p = 0.5, given x1 and x2
find_x3 <- function(x1, x2) (-coefs[1] -coefs[2] * x1 - coefs[3] * x2)/coefs[4]

现在我们可以创建一个数据框,其中包含 1 到 100 之间的所有 x1x2 值,并找到 x3 的适当值,使对数几率为 0对于此网格上的每个点:

# Create a data frame to plot the 3d plane where p = 0.5
plot_df    <- expand.grid(x2 = 1:100, x1 = 1:100)
plot_df$x3 <- find_x3(plot_df$x1, plot_df$x2)
head(plot_df)
#>   x2 x1        x3
#> 1  1  1 -41.99975
#> 2  2  1 -41.61133
#> 3  3  1 -41.22291
#> 4  4  1 -40.83450
#> 5  5  1 -40.44608
#> 6  6  1 -40.05766

我们可以通过 运行ning predict 将此数据框作为 newdata 来确认这给了我们决策边界的值。值应全部为 0(或非常接近 0):

head(predict(logistic_model, newdata = plot_df))
#>             1             2             3             4             5            
#>  0.000000e+00  0.000000e+00 -1.110223e-16  0.000000e+00  0.000000e+00

很好。

最后,我们可以用彩色发散标尺绘制结果,以显示 x1x2x3 的值,它们共同构成了您的决策边界:

library(ggplot2)

ggplot(plot_df, aes(x1, x2, fill = x3)) + 
  geom_raster() +
  scale_fill_gradientn(colours = c("deepskyblue4", "forestgreen", "gold", "red")) +
  coord_equal() +
  theme_classic()

如果您正在寻找真正的 3d 透视图,您可以尝试 base R 的 persp 函数:

persp(x = 1:100, y = 1:100, z = matrix(plot_df$x3, ncol = 100), 
      xlab = "x1", ylab = "x2", zlab = "x3", 
      theta = -45, , phi = 25, d = 5,
      col = "gold", border = "orange", 
      ticktype = "detailed")

reprex package (v0.3.0)

于 2020-08-16 创建