Module 5. Machine Learning

Slides

Section slides here. (To convert html to pdf, press E \(\to\) Print \(\to\) Destination: Save to pdf)

R code

Show the code
# -------------------------------------------
# Survival Analysis: Module 5 Code
# -------------------------------------------

# ---------------------------
# Load Packages
# ---------------------------

library(tidymodels) # load tidymodels
library(censored)

# ---------------------------
# Prepare GBC Data
# ---------------------------

gbc <- read.table("data/gbc.txt", header = TRUE) # Load GBC dataset

df <- gbc |>  # calculate time to first event (relapse or death)
  group_by(id) |> # group by id
  arrange(time) |> # sort rows by time
  slice(1) |>      # get the first row within each id
  ungroup() |> 
  mutate(
    surv_obj = Surv(time, status), # create the Surv object as response variable
    .after = id, # keep id column after surv_obj
    .keep = "unused" # discard original time and status columns
  )

# ---------------------------
# Data Splitting
# ---------------------------

set.seed(123) # set seed for reproducibility
gbc_split <- initial_split(df) # split data into training and testing sets
gbc_train <- training(gbc_split) # obtain training set

# ---------------------------
# Preprocessing Recipe
# ---------------------------

gbc_recipe <- recipe(surv_obj ~ ., data = gbc_train) |> # specify formula
  step_mutate(
    grade = factor(grade),
    age40 = as.numeric(age >= 40), # create a binary variable for age >= 40
    prog = prog / 100, # rescale prog
    estrg = estrg / 100 # rescale estrg
  ) |> 
  step_dummy(grade) |> 
  step_rm(id) # remove id

# ---------------------------
# Regularized Cox Model
# ---------------------------

cox_spec <- proportional_hazards(penalty = tune()) |>  # tune lambda
  set_engine("glmnet") |>  # set engine to glmnet
  set_mode("censored regression") # set mode to censored regression

cox_wflow <- workflow() |> 
  add_model(cox_spec) |>   # add model specification
  add_recipe(gbc_recipe)   # add recipe

# ---------------------------
# Cross-Validation Setup
# ---------------------------

set.seed(123) # set seed for reproducibility
gbc_folds <- vfold_cv(gbc_train, v = 10) # 10-fold cross-validation

# Set evaluation metrics
gbc_metrics <- metric_set(brier_survival, brier_survival_integrated, 
                          roc_auc_survival, concordance_survival)

# Set evaluation time points
time_points <- seq(0, 84, by = 12) # evaluation time points

# ---------------------------
# Cox model CV
# ---------------------------

set.seed(123) # set seed for reproducibility
# Tune the regularized Cox model (this will take some time)
cox_res <- tune_grid(
  cox_wflow, 
  resamples = gbc_folds, 
  grid = 10, # number of hyperparameter combinations to try
  metrics = gbc_metrics, # evaluation metrics
  eval_time = time_points, # evaluation time points
  control = control_grid(save_workflow = TRUE) # save workflow
)

# ---------------------------
# Plot Cox Model Brier Score
# ---------------------------

collect_metrics(cox_res) |>  # collect metrics from tuning results
  filter(.metric == "brier_survival_integrated") |>  # filter for Brier score
  ggplot(aes(log(penalty), mean)) + # plot log-lambda vs Brier score
  geom_line() +  # plot line
  labs(x = "Log-lambda", y = "Overall Brier Score") + # labels
  theme_classic() # classic theme

# ---------------------------
# Best Cox Models
# ---------------------------

show_best(cox_res, metric = "brier_survival_integrated", n = 5) # top 5 models

# ---------------------------
# Random Forest Model
# ---------------------------

rf_spec <- rand_forest(mtry = tune(), min_n = tune()) |>  # tune mtry and min_n
  set_engine("aorsf") |>  # set engine to aorsf
  set_mode("censored regression") # set mode to censored regression

rf_wflow <- workflow() |> 
  add_model(rf_spec) |>   # add model specification
  add_recipe(gbc_recipe)   # add recipe

# ---------------------------
# RF CV
# ---------------------------

set.seed(123) # set seed for reproducibility
# Tune the random forest model (this will take some time)
rf_res <- tune_grid(
  rf_wflow, 
  resamples = gbc_folds, 
  grid = 10, # number of hyperparameter combinations to try
  metrics = gbc_metrics, # evaluation metrics
  eval_time = time_points # evaluation time points
)

# ---------------------------
# View RF Validation Metrics
# ---------------------------

collect_metrics(rf_res) |> head()   # collect metrics from tuning results

# ---------------------------
# Best RF Models
# ---------------------------

show_best(rf_res, metric = "brier_survival_integrated", n = 5) # top 5 models

# ---------------------------
# Finalize and Fit Best RF Model
# ---------------------------

param_best <- select_best(rf_res, metric = "brier_survival_integrated") 
rf_final_wflow <- finalize_workflow(rf_wflow, param_best) # finalize workflow

set.seed(123) # set seed for reproducibility
final_rf_fit <- last_fit(
  rf_final_wflow, 
  split = gbc_split, # use the original split
  metrics = gbc_metrics, # evaluation metrics
  eval_time = time_points # evaluation time points
)

# ---------------------------
# Test Performance - RF Model
# ---------------------------

collect_metrics(final_rf_fit) |> # collect overall performance metrics
  filter(.metric %in% c("concordance_survival", "brier_survival_integrated")) 

roc_test <- collect_metrics(final_rf_fit) |>  
  filter(.metric == "roc_auc_survival") |>  
  rename(mean = .estimate)

roc_test |>  
  ggplot(aes(.eval_time, mean)) +  
  geom_line() + 
  labs(x = "Time (months)", y = "ROC AUC") + 
  theme_classic()

# ---------------------------
# Predict with Final RF Model
# ---------------------------

gbc_rf <- extract_workflow(final_rf_fit) # extract the fitted workflow
gbc_5 <- testing(gbc_split) |>  slice(1:5) # take first 5 rows of test data
predict(gbc_rf, new_data = gbc_5, type = "time") # predict survival times