CART: Classification and Regression Trees

Simulate Data

library(data.table)

set.seed(123456)
n <- 5000
dt <- data.table(
    p0 = rep(0.2, n)
  , or1 = rep(1, n)
  , var1 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.3, 0.7))
  , var1n = rnorm(n, 0, 1)
  , or2 = rep(1.1, n)
  , var2 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.4, 0.6))
  , var2n = rnorm(n, 0, 2)
  , or3 = rep(1.2, n)
  , var3 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.2, 0.8))
  , var3n = rnorm(n, 0, 2)
  , or4 = rep(1.5, n)
  , var4 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.3, 0.7))
  , var4n = rnorm(n, 0, 2)
  , or5 = rep(1.7, n)
  , var5 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5))
  , var5n = rnorm(n, 0, 2)
  , or6 = rep(2, n)
  , var6 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.4, 0.6))
  , var6n = rnorm(n, 0, 2)
  , or7 = rep(5, n)
  , var7 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.1, 0.9))
  , var7n = rnorm(n, 0, 2)
)


dt <- dt[, odds0 := p0 / (1 - p0)
         ][, log_odds := log(odds0) +
                 var1 * log(or1) + var1n * log(or1) + 
                 var2 * log(or2) + var2n * log(or2) + 
                 var3 * log(or3) + var3n * log(or3) + 
                 var4 * log(or4) + var4n * log(or4) + 
                 var5 * log(or5) + var5n * log(or5) + 
                 var6 * log(or6) + var6n * log(or6) + 
                 var7 * log(or7) + var7n * log(or7)
           ][, p := exp(log_odds)/ (1 + exp(log_odds))]

vsample <- function(p){
    sample(c(1, 0), size = 1, replace = TRUE, prob = c(p, 1 - p))
}

vsample <- Vectorize(vsample)

dt <- dt[, outcome := vsample(p)]

unique(dt[, .(or1, or2, or3, or4, or5, or6, or7)]) %>% prt(caption = "Variables with Odds Ratios")
Variables with Odds Ratios
or1 or2 or3 or4 or5 or6 or7
1 1.1 1.2 1.5 1.7 2 5

GLM

  outcome
Predictors Odds Ratios CI p
(Intercept) 0.14 0.09 – 0.21 <0.001
var1 1.15 0.94 – 1.41 0.160
var2 1.11 0.92 – 1.34 0.271
var3 1.41 1.12 – 1.78 0.004
var4 1.76 1.44 – 2.16 <0.001
var5 1.97 1.64 – 2.38 <0.001
var6 2.34 1.93 – 2.85 <0.001
var7 5.44 4.01 – 7.42 <0.001
var1n 1.05 0.95 – 1.15 0.344
var2n 1.11 1.06 – 1.17 <0.001
var3n 1.20 1.14 – 1.26 <0.001
var4n 1.50 1.42 – 1.58 <0.001
var5n 1.69 1.61 – 1.79 <0.001
var6n 2.10 1.97 – 2.23 <0.001
var7n 5.15 4.67 – 5.71 <0.001
Observations 5000
R2 Tjur 0.620

Prune the CART Model

CP Table
CP nsplit rel error xerror xstd
0.4398767 0 1.0000000 1.0000000 0.0177165
0.0308325 1 0.5601233 0.5899281 0.0152822
0.0142172 3 0.4984584 0.5123330 0.0145182
0.0113052 6 0.4558068 0.4794450 0.0141563
0.0097636 7 0.4445015 0.4712230 0.0140620
0.0077081 8 0.4347379 0.4681398 0.0140262
0.0071942 10 0.4193217 0.4676259 0.0140202
0.0038541 13 0.3977390 0.4614594 0.0139479
0.0035971 15 0.3900308 0.4558068 0.0138808
0.0030832 17 0.3828366 0.4522097 0.0138376
0.0025694 18 0.3797533 0.4480987 0.0137880
0.0017986 21 0.3720452 0.4491264 0.0138004
0.0017129 27 0.3612539 0.4511819 0.0138253
0.0015416 31 0.3540596 0.4501542 0.0138129
0.0014560 46 0.3237410 0.4496403 0.0138066
0.0013703 53 0.3134635 0.4516958 0.0138315
0.0012847 58 0.3062693 0.4511819 0.0138253
0.0011990 70 0.2908530 0.4537513 0.0138562
0.0010277 73 0.2872559 0.4547790 0.0138685
0.0009250 92 0.2677287 0.4609455 0.0139418
0.0008993 108 0.2471737 0.4599178 0.0139297
0.0008565 116 0.2384378 0.4635149 0.0139721
0.0007708 127 0.2276465 0.4640288 0.0139781
0.0007194 143 0.2153135 0.4691675 0.0140381
0.0006852 148 0.2117163 0.4712230 0.0140620
0.0005139 156 0.2060637 0.4737924 0.0140916
0.0004111 189 0.1875642 0.4866393 0.0142376
0.0003426 198 0.1824255 0.4886948 0.0142606
0.0002936 211 0.1778006 0.4943474 0.0143233
0.0002569 218 0.1757451 0.4943474 0.0143233
0.0001713 240 0.1700925 0.5092497 0.0144853
0.0001285 252 0.1680370 0.5159301 0.0145563
0.0001028 260 0.1670092 0.5190134 0.0145888
0.0000856 270 0.1659815 0.5210689 0.0146103
0.0000000 276 0.1654676 0.5220966 0.0146211

AUC on the Training Data

R sessionInfo

R version 4.2.0 (2022-04-22) Platform: x86_64-pc-linux-gnu (64-bit) Running under: Ubuntu 20.04.3 LTS

Matrix products: default BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0 LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0

locale: [1] LC_CTYPE=C.UTF-8 LC_NUMERIC=C LC_TIME=C.UTF-8
[4] LC_COLLATE=C.UTF-8 LC_MONETARY=C.UTF-8 LC_MESSAGES=C.UTF-8
[7] LC_PAPER=C.UTF-8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C

attached base packages: [1] stats graphics grDevices utils datasets methods base

other attached packages: [1] pROC_1.18.0 rattle_5.5.1 bitops_1.0-7
[4] tibble_3.1.7 RColorBrewer_1.1-3 rpart.plot_3.1.1
[7] rpart_4.1-15 sjPlot_2.8.10 Wu_0.0.0.9000
[10] flexdashboard_0.5.2 lme4_1.1-29 Matrix_1.4-0
[13] mgcv_1.8-38 nlme_3.1-152 png_0.1-7
[16] scales_1.2.0 nnet_7.3-16 labelled_2.9.1
[19] kableExtra_1.3.4 plotly_4.10.0 gridExtra_2.3
[22] ggplot2_3.3.6 DT_0.23 tableone_0.13.2
[25] magrittr_2.0.3 lubridate_1.8.0 dplyr_1.0.9
[28] plyr_1.8.7 data.table_1.14.2 rmdformats_1.0.4
[31] knitr_1.39

loaded via a namespace (and not attached): [1] TH.data_1.1-1 minqa_1.2.4 colorspace_2.0-3 ellipsis_0.3.2
[5] sjlabelled_1.2.0 estimability_1.4 parameters_0.18.1 rstudioapi_0.13
[9] fansi_1.0.3 mvtnorm_1.1-3 xml2_1.3.3 codetools_0.2-18 [13] splines_4.2.0 sjmisc_2.8.9 jsonlite_1.8.0 nloptr_2.0.3
[17] ggeffects_1.1.2 broom_0.8.0 effectsize_0.7.0 compiler_4.2.0
[21] httr_1.4.3 sjstats_0.18.1 emmeans_1.7.5 backports_1.4.1
[25] assertthat_0.2.1 fastmap_1.1.0 lazyeval_0.2.2 survey_4.1-1
[29] cli_3.3.0 htmltools_0.5.3 tools_4.2.0 gtable_0.3.0
[33] glue_1.6.2 Rcpp_1.0.8.3 jquerylib_0.1.4 vctrs_0.4.1
[37] svglite_2.1.0 crosstalk_1.2.0 insight_0.18.0 xfun_0.31
[41] stringr_1.4.0 rvest_1.0.2 lifecycle_1.0.1 klippy_0.0.0.9500 [45] MASS_7.3-54 zoo_1.8-10 hms_1.1.1 sandwich_3.0-2
[49] yaml_2.3.5 sass_0.4.1 stringi_1.7.8 highr_0.9
[53] bayestestR_0.12.1 boot_1.3-28 rlang_1.0.4 pkgconfig_2.0.3
[57] systemfonts_1.0.4 evaluate_0.15 lattice_0.20-45 purrr_0.3.4
[61] htmlwidgets_1.5.4 tidyselect_1.1.2 bookdown_0.27 R6_2.5.1
[65] generics_0.1.2 multcomp_1.4-19 DBI_1.1.2 pillar_1.7.0
[69] haven_2.5.0 withr_2.5.0 survival_3.2-13 datawizard_0.4.1 [73] performance_0.9.1 modelr_0.1.8 crayon_1.5.1 utf8_1.2.2
[77] rmarkdown_2.14 grid_4.2.0 forcats_0.5.1 digest_0.6.29
[81] webshot_0.5.3 xtable_1.8-4 tidyr_1.2.0 munsell_0.5.0
[85] viridisLite_0.4.0 bslib_0.3.1 mitools_2.4