Module 5. Machine Learning
Department of Biostatistics & Medical Informatics
University of Wisconsin-Madison
Aug 3, 2025
glmnet packagerpart packageaorsf packagetidymodels Workflowstidymodels and censoredtidymodels: a collection of packages for modeling and machine learning in R
parsnipcensored: a parsnip extension package for survival data
Surv object as response
Surv(time, event)Data splitting
initial_split(): splits data into training and testing setssurvival_reg(): parametric AFT modelsproportional_hazards(penalty = tune()): (regularized) Cox PH modelsdecision_tree(cost_complexity = tune()): decision treesrand_forest(mtry = tune()): random forestsset_engine("survival"): for AFT modelsset_engine("glmnet"): for Cox PH modelsset_engine("aorsf"): for random forestsset_mode("censored regression"): for survival modelsrecipe(response ~ ., data = df): specify response and predictorsstep_mutate(): standardize numeric predictors  step_dummy(): convert categorical variables to dummy variables workflow() |> add_model(model_spec) |> add_recipe(recipe)# Create a recipe
model_recipe <- recipe(surv_obj ~ ., data = df_train) |> # specify formula
  step_mutate(z1 = z1 / 1000) |>  # standardize z1  
  step_other(z2, z3, threshold = 0.02) |> # group levels with prop < .02 into "other"
  step_dummy(all_nominal_predictors())  # convert categorical variables to dummy variables
# Create a workflow by combining model and recipe
model_wflow <- workflow() |> 
  add_model(model_spec) |>   # add model specification
  add_recipe(model_recipe)   # add recipedf_train_folds <- vfold_cv(df_train, v = k): create k-folds on training data (default 10)tune_grid(model_wflow, resamples = df_train_folds): tune hyperparameters using cross-validation# k-fold cross-validation
df_train_folds <- vfold_cv(df_train, v = 10) # 10-fold cross-validation
# Tune hyperparameters
model_res <- tune_grid(
  model_wflow, 
  resamples = df_train_folds, 
  grid = 10, # number of hyperparameter combinations to try
  metrics = metric_set(brier_survival, brier_survival_integrated,  # specify metrics
                       roc_auc_survival, concordance_survival), 
  eval_time = seq(0, 84, by = 12) # evaluation time points
)collect_metrics(model_res): collect metrics from tuning resultsshow_best(model_res, metric = "brier_survival_integrated", n = 5): show top 5 models based on Brier scoreparam_best <- select_best(model_res, metric = "brier_survival_integrated"): select best hyperparameters based on Brier scorefinal_wl <- finalize_workflow(model_wflow, param_best): finalize workflow with best hyperparametersfinal_mod <- last_fit(final_wl, split = df_split): fit the finalized workflow on the testing setcollect_metrics(final_mod): collect metrics of final model on test datapredict(final_mod, new_data = new_data, type = "time"): predict survival times on new data# Fit the finalized workflow on the testing set
final_mod <- last_fit(final_wl, split = df_split)
# Collect metrics of final model on test data
collect_metrics(final_mod) %>% 
  filter(.metric == "brier_survival_integrated")
# Make predictions on new data
new_data <- testing(df_split) |>  slice(1:5) # take first 5 rows of test data
predict(final_mod, new_data = new_data, type = "time")library(tidymodels) # load tidymodels
library(censored)
gbc <- read.table("data/gbc.txt", header = TRUE) # Load GBC dataset
df <- gbc |>  # calculate time to first event (relapse or death)
  group_by(id) |> # group by id
  arrange(time) |> # sort rows by time
  slice(1) |>      # get the first row within each id
  ungroup() |> 
  mutate(
    surv_obj = Surv(time, status), # create the Surv object as response variable
    .after = id, # keep id column after surv_obj
    .keep = "unused" # discard original time and status columns
  )# A tibble: 6 × 10
     id   surv_obj hormone   age  meno  size grade nodes  prog estrg
  <int>     <Surv>   <int> <int> <int> <int> <int> <int> <int> <int>
1     1 43.836066        1    38     1    18     3     5   141   105
2     2 46.557377        1    52     1    20     1     1    78    14
3     3 41.934426        1    47     1    30     2     1   422    89
4     4  4.852459+       1    40     1    24     1     3    25    11
5     5 61.081967+       2    64     2    19     2     1    19     9
6     6 63.377049+       2    49     2    56     1     3   356    64
proportional_hazards(penalty = tune())glmnet engine for fittingrand_forest(mtry = tune(), min_n = tune())aorsf engine for fittinggbc_recipe <- recipe(surv_obj ~ ., data = gbc_train) |> # specify formula
  step_mutate(
    grade = factor(grade),
    age40 = as.numeric(age >= 40), # create a binary variable for age >= 40
    prog = prog / 100, # rescale prog
    estrg = estrg / 100 # rescale estrg
  ) |> 
  step_dummy(grade) |> 
  step_rm(id) # remove id
# gbc_recipe # print recipe information# Regularized Cox model specification
cox_spec <- proportional_hazards(penalty = tune()) |>  # tune lambda
  set_engine("glmnet") |>  # set engine to glmnet
  set_mode("censored regression") # set mode to censored regression
cox_spec # print model specificationProportional Hazards Model Specification (censored regression)
Main Arguments:
  penalty = tune()
Computational engine: glmnet 
set.seed(123) # set seed for reproducibility
gbc_folds <- vfold_cv(gbc_train, v = 10) # 10-fold cross-validation
# Set evaulation metrics
gbc_metrics <- metric_set(brier_survival, brier_survival_integrated, 
                          roc_auc_survival, concordance_survival)
gbc_metrics # evaluation metrics infoA metric set, consisting of:
- `brier_survival()`, a dynamic survival metric               | direction:
minimize
- `brier_survival_integrated()`, a integrated survival metric | direction:
minimize
- `roc_auc_survival()`, a dynamic survival metric             | direction:
maximize
- `concordance_survival()`, a static survival metric          | direction:
maximize
tune_grid() to perform hyperparameter tuningset.seed(123) # set seed for reproducibility
# Tune the regularized Cox model (this will take some time)
cox_res <- tune_grid(
  cox_wflow, 
  resamples = gbc_folds, 
  grid = 10, # number of hyperparameter combinations to try
  metrics = gbc_metrics, # evaluation metrics
  eval_time = time_points, # evaluation time points
  control = control_grid(save_workflow = TRUE) # save workflow
)collect_metrics(cox_res) |>  # collect metrics from tuning results
  filter(.metric == "brier_survival_integrated") |>  # filter for Brier score
  ggplot(aes(log(penalty), mean)) + # plot log-lambda vs Brier score
  geom_line() +  # plot line
  labs(x = "Log-lambda", y = "Overall Brier Score") + # labels
  theme_classic() # classic themeShow best models
# A tibble: 5 × 8
       penalty .metric         .estimator .eval_time  mean     n std_err .config
         <dbl> <chr>           <chr>           <dbl> <dbl> <int>   <dbl> <chr>  
1 0.00147      brier_survival… standard           NA 0.160    10 0.00775 Prepro…
2 0.0000127    brier_survival… standard           NA 0.160    10 0.00775 Prepro…
3 0.0000000105 brier_survival… standard           NA 0.160    10 0.00775 Prepro…
4 0.000754     brier_survival… standard           NA 0.160    10 0.00775 Prepro…
5 0.00000409   brier_survival… standard           NA 0.160    10 0.00775 Prepro…
# Random forest model specification
rf_spec <- rand_forest(mtry = tune(), min_n = tune()) |>  # tune mtry and min_n
  set_engine("aorsf") |>  # set engine to aorsf
  set_mode("censored regression") # set mode to censored regression
rf_spec # print model specificationRandom Forest Model Specification (censored regression)
Main Arguments:
  mtry = tune()
  min_n = tune()
Computational engine: aorsf 
set.seed(123) # set seed for reproducibility
# Tune the random forest model (this will take some time)
rf_res <- tune_grid(
  rf_wflow, 
  resamples = gbc_folds, 
  grid = 10, # number of hyperparameter combinations to try
  metrics = gbc_metrics, # evaluation metrics
  eval_time = time_points # evaluation time points
)# A tibble: 6 × 9
   mtry min_n .metric         .estimator .eval_time   mean     n std_err .config
  <int> <int> <chr>           <chr>           <dbl>  <dbl> <int>   <dbl> <chr>  
1     3    30 brier_survival  standard            0 0         10 0       Prepro…
2     3    30 roc_auc_surviv… standard            0 0.5       10 0       Prepro…
3     3    30 brier_survival  standard           12 0.0635    10 0.00706 Prepro…
4     3    30 roc_auc_surviv… standard           12 0.827     10 0.0314  Prepro…
5     3    30 brier_survival  standard           24 0.163     10 0.0114  Prepro…
6     3    30 roc_auc_surviv… standard           24 0.747     10 0.0475  Prepro…
# A tibble: 5 × 9
   mtry min_n .metric          .estimator .eval_time  mean     n std_err .config
  <int> <int> <chr>            <chr>           <dbl> <dbl> <int>   <dbl> <chr>  
1     6    24 brier_survival_… standard           NA 0.155    10 0.00765 Prepro…
2     5    27 brier_survival_… standard           NA 0.155    10 0.00782 Prepro…
3     4    20 brier_survival_… standard           NA 0.156    10 0.00777 Prepro…
4     9    36 brier_survival_… standard           NA 0.156    10 0.00743 Prepro…
5     2     7 brier_survival_… standard           NA 0.156    10 0.00829 Prepro…
# Select best RF hyperparameters (mtry, min_n) based on Brier score
param_best <- select_best(rf_res, metric = "brier_survival_integrated") 
param_best # view results# A tibble: 1 × 3
   mtry min_n .config              
  <int> <int> <chr>                
1     6    24 Preprocessor1_Model07
# Finalize the workflow with the best hyperparameters
rf_final_wflow <- finalize_workflow(rf_wflow, param_best) # finalize workflow
# Fit the finalized workflow on the testing set
set.seed(123) # set seed for reproducibility
final_rf_fit <- last_fit(
  rf_final_wflow, 
  split = gbc_split, # use the original split
  metrics = gbc_metrics, # evaluation metrics
  eval_time = time_points # evaluation time points
)collect_metrics(final_rf_fit) |> # collect overall performance metrics
  filter(.metric %in% c("concordance_survival", "brier_survival_integrated")) # A tibble: 2 × 5
  .metric                   .estimator .eval_time .estimate .config             
  <chr>                     <chr>           <dbl>     <dbl> <chr>               
1 brier_survival_integrated standard           NA     0.237 Preprocessor1_Model1
2 concordance_survival      standard           NA     0.655 Preprocessor1_Model1
extract_workflow() to get the final modelgbc_rf <- extract_workflow(final_rf_fit) # extract the fitted workflow
# Predict on new data
gbc_5 <- testing(gbc_split) |>  slice(1:5) # take first 5 rows of test data
predict(gbc_rf, new_data = gbc_5, type = "time") # predict survival times# A tibble: 5 × 1
  .pred_time
       <dbl>
1       47.6
2       67.1
3       67.2
4       36.1
5       49.7
cox_res and fit it to test data# Select best Cox hyperparameters (penalty) based on Brier score
param_best_cox <- select_best(cox_res, metric = "brier_survival_integrated")
# Finalize the workflow with the best hyperparameters
cox_final_wflow <- finalize_workflow(cox_wflow, param_best_cox) # finalize workflow
# Fit finalized workflow on the testing set
final_cox_fit <- last_fit(
  cox_final_wflow, 
  split = gbc_split, # use the original split
  metrics = gbc_metrics, # evaluation metrics
  eval_time = time_points # evaluation time points
)
# Collect metrics on test data
collect_metrics(final_cox_fit) |> # collect overall performance metrics
  filter(.metric %in% c("concordance_survival", "brier_survival_integrated"))# Extract test ROC AUC over time
roc_test_cox <- collect_metrics(final_cox_fit) |>  
  filter(.metric == "roc_auc_survival") |>  # filter for ROC AUC
  rename(mean = .estimate) # rename mean column
roc_test_cox |>  # pass the test ROC AUC data
  ggplot(aes(.eval_time, mean)) +  # plot evaluation time vs mean ROC AUC
  geom_line() + # plot line
  labs(x = "Time (months)", y = "ROC AUC") + # labels
  theme_classic()
# Predict on new data
# Extract fitted workflow
gbc_cox <- extract_workflow(final_cox_fit) # extract the fitted workflow
gbc_5 <- testing(gbc_split) |>  slice(1:5) # take first 5 rows of test data
predict(gbc_cox, new_data = gbc_5, type = "time") # predict survival timestidy() function from broom packagetidy(gbc_cox) # tidy the coefficients
# # A tibble: 10 × 3
#    term     estimate penalty
#    <chr>       <dbl>   <dbl>
#  1 hormone  -0.227   0.00147
#  2 age       0.00908 0.00147
#  3 meno     -0.00736 0.00147
#  4 size      0.0127  0.00147
#  5 nodes     0.0411  0.00147
#  6 prog     -0.270   0.00147
#  7 estrg     0.00771 0.00147
#  8 age40    -0.540   0.00147
#  9 grade_X2  0.573   0.00147
# 10 grade_X3  0.624   0.00147decision_tree() with set_engine("rpart")cost_complexity using tune()tidymodels: a consistent interface for modeling and machine learning
parsnip for model specification and tuningcensored packages for survival data