CART: Classification and Regression Trees
Simulate Data
library(data.table)
set.seed(123456)
n <- 5000
dt <- data.table(
p0 = rep(0.2, n)
, or1 = rep(1, n)
, var1 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.3, 0.7))
, var1n = rnorm(n, 0, 1)
, or2 = rep(1.1, n)
, var2 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.4, 0.6))
, var2n = rnorm(n, 0, 2)
, or3 = rep(1.2, n)
, var3 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.2, 0.8))
, var3n = rnorm(n, 0, 2)
, or4 = rep(1.5, n)
, var4 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.3, 0.7))
, var4n = rnorm(n, 0, 2)
, or5 = rep(1.7, n)
, var5 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5))
, var5n = rnorm(n, 0, 2)
, or6 = rep(2, n)
, var6 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.4, 0.6))
, var6n = rnorm(n, 0, 2)
, or7 = rep(5, n)
, var7 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.1, 0.9))
, var7n = rnorm(n, 0, 2)
)
dt <- dt[, odds0 := p0 / (1 - p0)
][, log_odds := log(odds0) +
var1 * log(or1) + var1n * log(or1) +
var2 * log(or2) + var2n * log(or2) +
var3 * log(or3) + var3n * log(or3) +
var4 * log(or4) + var4n * log(or4) +
var5 * log(or5) + var5n * log(or5) +
var6 * log(or6) + var6n * log(or6) +
var7 * log(or7) + var7n * log(or7)
][, p := exp(log_odds)/ (1 + exp(log_odds))]
vsample <- function(p){
sample(c(1, 0), size = 1, replace = TRUE, prob = c(p, 1 - p))
}
vsample <- Vectorize(vsample)
dt <- dt[, outcome := vsample(p)]
unique(dt[, .(or1, or2, or3, or4, or5, or6, or7)]) %>% prt(caption = "Variables with Odds Ratios")
or1 | or2 | or3 | or4 | or5 | or6 | or7 |
---|---|---|---|---|---|---|
1 | 1.1 | 1.2 | 1.5 | 1.7 | 2 | 5 |
GLM
m <- glm(outcome ~ var1 + var2 + var3 + var4 + var5 + var6 + var7 +
var1n + var2n + var3n + var4n + var5n + var6n + var7n
, data = dt
, family = binomial
)
library(sjPlot)
tab_model(m)
 | outcome | ||
---|---|---|---|
Predictors | Odds Ratios | CI | p |
(Intercept) | 0.14 | 0.09 – 0.21 | <0.001 |
var1 | 1.15 | 0.94 – 1.41 | 0.160 |
var2 | 1.11 | 0.92 – 1.34 | 0.271 |
var3 | 1.41 | 1.12 – 1.78 | 0.004 |
var4 | 1.76 | 1.44 – 2.16 | <0.001 |
var5 | 1.97 | 1.64 – 2.38 | <0.001 |
var6 | 2.34 | 1.93 – 2.85 | <0.001 |
var7 | 5.44 | 4.01 – 7.42 | <0.001 |
var1n | 1.05 | 0.95 – 1.15 | 0.344 |
var2n | 1.11 | 1.06 – 1.17 | <0.001 |
var3n | 1.20 | 1.14 – 1.26 | <0.001 |
var4n | 1.50 | 1.42 – 1.58 | <0.001 |
var5n | 1.69 | 1.61 – 1.79 | <0.001 |
var6n | 2.10 | 1.97 – 2.23 | <0.001 |
var7n | 5.15 | 4.67 – 5.71 | <0.001 |
Observations | 5000 | ||
R2 Tjur | 0.620 |
CART: the Full Model
library(rpart)
library(rpart.plot)
library(RColorBrewer)
library(rattle)
predictors <- c(
"var1"
, "var2"
, "var3"
, "var4"
, "var5"
, "var6"
, "var7"
, "var1n"
, "var2n"
, "var3n"
, "var4n"
, "var5n"
, "var6n"
, "var7n"
)
frml <- Wu::wu_formula(outcome = "outcome", predictors = predictors)
set.seed(123456)
tr <- rpart(
frml
, data = dt
, method = "class"
, model = TRUE
, x = TRUE
, y = TRUE
, parms = list(split = "information")
, control = rpart.control(cp = 0
, xval = 20
, maxdepth = 30
, minsplit = 10
, minbucket = 5
)
)
rpart.plot(tr, type = 2, extra = 106, tweak = 2, under = TRUE)
Prune the CART Model
CP | nsplit | rel error | xerror | xstd |
---|---|---|---|---|
0.4398767 | 0 | 1.0000000 | 1.0000000 | 0.0177165 |
0.0308325 | 1 | 0.5601233 | 0.5899281 | 0.0152822 |
0.0142172 | 3 | 0.4984584 | 0.5123330 | 0.0145182 |
0.0113052 | 6 | 0.4558068 | 0.4794450 | 0.0141563 |
0.0097636 | 7 | 0.4445015 | 0.4712230 | 0.0140620 |
0.0077081 | 8 | 0.4347379 | 0.4681398 | 0.0140262 |
0.0071942 | 10 | 0.4193217 | 0.4676259 | 0.0140202 |
0.0038541 | 13 | 0.3977390 | 0.4614594 | 0.0139479 |
0.0035971 | 15 | 0.3900308 | 0.4558068 | 0.0138808 |
0.0030832 | 17 | 0.3828366 | 0.4522097 | 0.0138376 |
0.0025694 | 18 | 0.3797533 | 0.4480987 | 0.0137880 |
0.0017986 | 21 | 0.3720452 | 0.4491264 | 0.0138004 |
0.0017129 | 27 | 0.3612539 | 0.4511819 | 0.0138253 |
0.0015416 | 31 | 0.3540596 | 0.4501542 | 0.0138129 |
0.0014560 | 46 | 0.3237410 | 0.4496403 | 0.0138066 |
0.0013703 | 53 | 0.3134635 | 0.4516958 | 0.0138315 |
0.0012847 | 58 | 0.3062693 | 0.4511819 | 0.0138253 |
0.0011990 | 70 | 0.2908530 | 0.4537513 | 0.0138562 |
0.0010277 | 73 | 0.2872559 | 0.4547790 | 0.0138685 |
0.0009250 | 92 | 0.2677287 | 0.4609455 | 0.0139418 |
0.0008993 | 108 | 0.2471737 | 0.4599178 | 0.0139297 |
0.0008565 | 116 | 0.2384378 | 0.4635149 | 0.0139721 |
0.0007708 | 127 | 0.2276465 | 0.4640288 | 0.0139781 |
0.0007194 | 143 | 0.2153135 | 0.4691675 | 0.0140381 |
0.0006852 | 148 | 0.2117163 | 0.4712230 | 0.0140620 |
0.0005139 | 156 | 0.2060637 | 0.4737924 | 0.0140916 |
0.0004111 | 189 | 0.1875642 | 0.4866393 | 0.0142376 |
0.0003426 | 198 | 0.1824255 | 0.4886948 | 0.0142606 |
0.0002936 | 211 | 0.1778006 | 0.4943474 | 0.0143233 |
0.0002569 | 218 | 0.1757451 | 0.4943474 | 0.0143233 |
0.0001713 | 240 | 0.1700925 | 0.5092497 | 0.0144853 |
0.0001285 | 252 | 0.1680370 | 0.5159301 | 0.0145563 |
0.0001028 | 260 | 0.1670092 | 0.5190134 | 0.0145888 |
0.0000856 | 270 | 0.1659815 | 0.5210689 | 0.0146103 |
0.0000000 | 276 | 0.1654676 | 0.5220966 | 0.0146211 |
AUC on the Training Data
R sessionInfo
R version 4.2.0 (2022-04-22) Platform: x86_64-pc-linux-gnu (64-bit) Running under: Ubuntu 20.04.3 LTS
Matrix products: default BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0 LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0
locale: [1] LC_CTYPE=C.UTF-8 LC_NUMERIC=C LC_TIME=C.UTF-8
[4] LC_COLLATE=C.UTF-8 LC_MONETARY=C.UTF-8 LC_MESSAGES=C.UTF-8
[7] LC_PAPER=C.UTF-8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C
attached base packages: [1] stats graphics grDevices utils datasets methods base
other attached packages: [1] pROC_1.18.0 rattle_5.5.1 bitops_1.0-7
[4] tibble_3.1.7 RColorBrewer_1.1-3 rpart.plot_3.1.1
[7] rpart_4.1-15 sjPlot_2.8.10 Wu_0.0.0.9000
[10] flexdashboard_0.5.2 lme4_1.1-29 Matrix_1.4-0
[13] mgcv_1.8-38 nlme_3.1-152 png_0.1-7
[16] scales_1.2.0 nnet_7.3-16 labelled_2.9.1
[19] kableExtra_1.3.4 plotly_4.10.0 gridExtra_2.3
[22] ggplot2_3.3.6 DT_0.23 tableone_0.13.2
[25] magrittr_2.0.3 lubridate_1.8.0 dplyr_1.0.9
[28] plyr_1.8.7 data.table_1.14.2 rmdformats_1.0.4
[31] knitr_1.39
loaded via a namespace (and not attached): [1] TH.data_1.1-1 minqa_1.2.4 colorspace_2.0-3 ellipsis_0.3.2
[5] sjlabelled_1.2.0 estimability_1.4 parameters_0.18.1 rstudioapi_0.13
[9] fansi_1.0.3 mvtnorm_1.1-3 xml2_1.3.3 codetools_0.2-18 [13] splines_4.2.0 sjmisc_2.8.9 jsonlite_1.8.0 nloptr_2.0.3
[17] ggeffects_1.1.2 broom_0.8.0 effectsize_0.7.0 compiler_4.2.0
[21] httr_1.4.3 sjstats_0.18.1 emmeans_1.7.5 backports_1.4.1
[25] assertthat_0.2.1 fastmap_1.1.0 lazyeval_0.2.2 survey_4.1-1
[29] cli_3.3.0 htmltools_0.5.3 tools_4.2.0 gtable_0.3.0
[33] glue_1.6.2 Rcpp_1.0.8.3 jquerylib_0.1.4 vctrs_0.4.1
[37] svglite_2.1.0 crosstalk_1.2.0 insight_0.18.0 xfun_0.31
[41] stringr_1.4.0 rvest_1.0.2 lifecycle_1.0.1 klippy_0.0.0.9500 [45] MASS_7.3-54 zoo_1.8-10 hms_1.1.1 sandwich_3.0-2
[49] yaml_2.3.5 sass_0.4.1 stringi_1.7.8 highr_0.9
[53] bayestestR_0.12.1 boot_1.3-28 rlang_1.0.4 pkgconfig_2.0.3
[57] systemfonts_1.0.4 evaluate_0.15 lattice_0.20-45 purrr_0.3.4
[61] htmlwidgets_1.5.4 tidyselect_1.1.2 bookdown_0.27 R6_2.5.1
[65] generics_0.1.2 multcomp_1.4-19 DBI_1.1.2 pillar_1.7.0
[69] haven_2.5.0 withr_2.5.0 survival_3.2-13 datawizard_0.4.1 [73] performance_0.9.1 modelr_0.1.8 crayon_1.5.1 utf8_1.2.2
[77] rmarkdown_2.14 grid_4.2.0 forcats_0.5.1 digest_0.6.29
[81] webshot_0.5.3 xtable_1.8-4 tidyr_1.2.0 munsell_0.5.0
[85] viridisLite_0.4.0 bslib_0.3.1 mitools_2.4