Chapter 15 - Machine Learning in Survival Analysis

Slides

Lecture slides here. (To convert html to pdf, press E \(\to\) Print \(\to\) Destination: Save to pdf)

Chapter Summary

Machine learning methods for survival analysis offer flexible and scalable alternatives to traditional models, especially in the presence of many covariates, nonlinear effects, or complex interactions. Two main approaches are considered: regularized Cox models for high-dimensional variable selection, and survival trees for nonparametric prediction. Ensemble methods such as bagging and random forests further enhance stability and prediction accuracy.

Regularized Cox models

When there are many candidate covariates, regularization can improve generalizability and interpretability.

  • Penalized partial likelihood (lasso-penalized Cox model): \[ \hat\beta(\lambda) = \arg\min_\beta \left\{ -pl_n(\beta) + \lambda \sum_{j=1}^p |\beta_j| \right\}, \] where \(pl_n(\beta)\) is the partial log-likelihood and \(\lambda\) controls the degree of shrinkage.

  • Variable selection:
    The \(L_1\)-penalty (lasso) produces sparse solutions by shrinking some coefficients exactly to zero.

  • Cross-validation:
    The penalty parameter \(\lambda\) is selected by minimizing the partial-likelihood deviance across validation folds.

  • Prediction error:
    Assessed using inverse probability censored weighting (IPCW) estimators of the Brier score.

  • Computation:
    Coordinate descent algorithm solves the penalized likelihood via iterative weighted least squares with soft-thresholding.

Survival trees

Survival trees offer a nonparametric alternative, capable of automatically capturing nonlinearities and interactions without prespecification.

  • Tree construction:
    • Recursive binary splits based on covariates
    • Splitting criterion: minimize within-node deviance residuals
    • Fully grown tree followed by pruning based on cost-complexity and cross-validation.
  • Terminal nodes:
    • Within-node Kaplan–Meier curves provide survival predictions for new subjects.
  • Advantages:
    • No need to specify functional forms
    • Flexible to complex data structures
    • Naturally captures covariate interactions

Ensemble methods

Stability and prediction performance are improved by aggregating many trees.

  • Bagging:
    • Grow multiple trees on bootstrapped datasets
    • Average survival predictions across trees
  • Random forests:
    • Further randomize splits by selecting a random subset of covariates at each node
    • Reduce correlation among trees for better generalization

Example R code

########################################
# 1. Regularized Cox model
########################################
library(glmnet)
obj <- glmnet(Z, Surv(time, status), family = "cox", alpha = 1)
plotmo::plot_glmnet(obj) # Coefficient paths

# Cross-validation to select lambda
obj.cv <- cv.glmnet(Z, Surv(time, status), family = "cox", alpha = 1)
summary(obj.cv)

########################################
# 2. Survival tree
########################################
library(rpart)
obj <- rpart(Surv(time, status) ~ covariates,
             control = rpart.control(xval = 10, minbucket = 2, cp = 0))

# Prune the tree
fit <- prune(obj, cp = obj$cptable[which.min(obj$cptable[, "xerror"]), "CP"])
rpart.plot(fit)

# Predict terminal nodes and Kaplan–Meier within nodes
fit$where

Conclusion

Machine learning methods provide flexible, data-driven tools for survival prediction and variable selection. Regularized Cox models yield interpretable sparse models and handle high-dimensional data efficiently. Survival trees offer a nonparametric alternative that automatically detects complex nonlinear effects and interactions. Ensemble methods like bagging and random forests further improve stability and predictive performance, making them powerful complements to traditional survival analysis techniques.

R Code

Show the code
###############################################################################
# Chapter 15 R Code
#
# This script reproduces all major numerical results in Chapter 15, including:
#   1. Cox-Lasso analysis on GBC data (train/test split)
#   2. Survival tree modeling (pruning, terminal node curves)
#   3. Brier Score comparisons (Cox-lasso, full Cox, survival tree)
###############################################################################


#==============================================================================
# (A) Cox-Lasso on GBC Data
#==============================================================================
library(survival)
library(glmnet)
# library("glmpath") # alternative approach for lasso, commented out

#------------------------------------------------------------------------------
# 1. Data Reading & Preparation
#------------------------------------------------------------------------------
# The "gbc.txt" file contains the complete German Breast Cancer Study data
gbc <- read.table("Data//German Breast Cancer Study//gbc.txt")

# Sort the data by (id, time) so that each subject’s rows appear chronologically
o    <- order(gbc$id, gbc$time)
gbc  <- gbc[o, ]

# Keep only the first row per subject => “first event” data
#    i.e., we assume each subject either has 1 event or is censored.
data.CE <- gbc[!duplicated(gbc$id), ]

# Convert status so that status=1 if status is in {1,2}, else 0
#    i.e., 0 = censored, 1 = relapse or death
data.CE$status <- as.integer(data.CE$status > 0)

# Create a binary variable age40 = 1 if age <= 40, else 0
#    This categorizes younger patients distinctly
data.CE$age40 <- as.integer(data.CE$age <= 40)

# Store total sample size
n <- nrow(data.CE)

# Quick peek at first few rows
head(data.CE)


#------------------------------------------------------------------------------
# 2. Train/Test Split
#------------------------------------------------------------------------------
# We'll do a random sample of 400 for training, the rest are test
set.seed(1234)
ind   <- sample(1:n, size = 400)
train <- data.CE[ind, ]
test  <- data.CE[-ind, ]


#------------------------------------------------------------------------------
# 3. Fitting Cox-Lasso with glmnet
#------------------------------------------------------------------------------
# Define the set of predictors we want to consider
pred_list <- c("hormone", "age40", "meno", "size", "grade",
               "nodes", "prog", "estrg")

# Construct the design matrix Z from train data
Z <- as.matrix(train[, pred_list])

# Extract time and status vectors
time   <- train$time
status <- train$status

# Fit a Cox model with L1 penalty (alpha=1 => lasso), over a sequence of lambda
obj <- glmnet(Z, Surv(time, status), family = "cox", alpha = 1)

# Summarize the glmnet fit => mostly high-level info about the path
summary(obj)

# Perform 10-fold cross-validation to select optimal lambda
obj.cv <- cv.glmnet(Z, Surv(time, status), family = "cox", alpha = 1)


#------------------------------------------------------------------------------
# 4. Visualizing Coefficient Paths & CV Error
#------------------------------------------------------------------------------
library(plotmo)  # Provides plot_glmnet() for convenient path plots

par(mfrow = c(1, 2))

# plot_glmnet() => coefficient paths vs. log(lambda)
plot_glmnet(obj, lwd = 2)

# plot() on the cv.glmnet object => partial-likelihood deviance vs. lambda
plot(obj.cv)

par(mfrow = c(1, 1))  # reset plotting layout

# Identify the optimal lambda minimizing the CV error
lambda.opt <- obj.cv$lambda.min

# Extract the coefficients at lambda.min
beta.opt   <- coef(obj.cv, s = "lambda.min")

# Identify which coefficients are non-zero
beta.selected  <- beta.opt[abs(beta.opt[, 1]) > 0, ]
beta.selected  # show the non-zero variables
length(beta.selected)  # how many are non-zero?


#------------------------------------------------------------------------------
# 5. Refit Cox Model Using Selected Variables
#------------------------------------------------------------------------------
# We'll pull the variable names and re-fit a coxph with these in the training set
selected_vars <- names(beta.selected)
obj_lasso     <- coxph(Surv(time, status) ~ as.matrix(train[, selected_vars]))


#------------------------------------------------------------------------------
# 6. A Utility Function to Get Predicted Survival Curves from a Cox Model
#------------------------------------------------------------------------------
predsurv_cox <- function(obj, Z) {
  # obj: coxph object
  # Z: (n x p) matrix of covariates for n subjects
  # Returns a list with $St (n x m matrix of S_i(t_j)) and $t (length m of times)
  
  beta <- obj$coefficients
  bhaz <- basehaz(obj, centered = FALSE)  # baseline hazard info
  L    <- bhaz$hazard                     # cumulative baseline hazard
  tt   <- c(0, bhaz$time)                 # prepend 0 to times
  
  # We'll store S_i(t_j) in a matrix of dimension (n x length(tt))
  Smat <- matrix(NA, nrow = nrow(Z), ncol = length(tt))
  
  for (i in seq_len(nrow(Z))) {
    eXb <- exp(Z[i, ] %*% beta)  # linear predictor for subject i
    # S(0) = 1, S(t_j)=exp(- eXb * L_j)
    Smat[i, ] <- c(1, exp(-eXb * L))
  }
  
  list(St = Smat, t = tt)
}


#==============================================================================
# (B) Brier Score Calculation for Lasso & Full Cox
#==============================================================================
#------------------------------------------------------------------------------
# 1. Lasso-Fitted Model's Predictions on Test Data
#------------------------------------------------------------------------------
pred_lasso <- predsurv_cox(
  obj_lasso,
  as.matrix(test[, selected_vars])
)
St_lasso <- pred_lasso$St
t_lasso  <- pred_lasso$t

# KM for censoring distribution in the test set
G_obj <- survfit(Surv(time, 1 - status) ~ 1, data = test)

tc  <- c(0, G_obj$time)   # times in the censoring KM
Gtc <- c(1, G_obj$surv)   # survival prob => G(t)

#------------------------------------------------------------------------------
# 2. Brier Score Function
#------------------------------------------------------------------------------
BSfun <- function(tau, time, status, St, t, Gtc, tc) {
  # tau: time at which we evaluate the Brier Score
  # time, status: observed times & event indicators
  # St: (n x m) matrix of predicted survival rates from the model
  # t: vector of times associated with columns of St
  # Gtc, tc: censoring distribution estimates (KM) at times in tc
  
  n    <- length(time)
  pos  <- sum(t <= tau)  # columns in St up to tau
  BSvec<- numeric(n)     # store Brier Score components per subject
  
  for (i in seq_len(n)) {
    X_i <- time[i]
    S_i <- St[i, pos]    # predicted survival at tau
    # Evaluate G(t_i) or G(tau) for weighting
    if (X_i <= tau && status[i] == 1) {
      # subject i had the event by time X_i <= tau
      # denominator => G(X_i)
      G_i <- Gtc[sum(tc <= X_i)]
      # contributes (S_i)^2 / G(X_i)
      BSvec[i] <- (S_i^2 / G_i)
    } else if (X_i > tau) {
      # subject i is still at risk at tau
      # => contributes (1-S_i)^2 / G(tau)
      G_i <- Gtc[sum(tc <= tau)]
      BSvec[i] <- ((1 - S_i)^2 / G_i)
    } else {
      # X_i <= tau but no event => subject was censored, no direct event
      BSvec[i] <- 0
    }
  }
  
  mean(BSvec, na.rm = TRUE)
}

#------------------------------------------------------------------------------
# 3. Evaluate Brier Score for Cox-Lasso at times 12..60
#------------------------------------------------------------------------------
tau_list     <- 12:60
BS_tau_lasso <- numeric(length(tau_list))

for (i in seq_along(tau_list)) {
  BS_tau_lasso[i] <- BSfun(
    tau    = tau_list[i],
    time   = test$time,
    status = test$status,
    St     = St_lasso,
    t      = t_lasso,
    Gtc    = Gtc,
    tc     = tc
  )
}

#------------------------------------------------------------------------------
# 4. Full Cox Model
#------------------------------------------------------------------------------
obj_full <- coxph(Surv(train$time, train$status) ~ as.matrix(train[, pred_list]))

# Predicted survival in test set
pred_full <- predsurv_cox(obj_full, as.matrix(test[, pred_list]))
St_full   <- pred_full$St
t_full    <- pred_full$t

# Brier Score for the full Cox
BS_tau_full <- numeric(length(tau_list))

for (i in seq_along(tau_list)) {
  BS_tau_full[i] <- BSfun(
    tau    = tau_list[i],
    time   = test$time,
    status = test$status,
    St     = St_full,
    t      = t_full,
    Gtc    = Gtc,
    tc     = tc
  )
}


#==============================================================================
# (C) Survival Tree Analysis
#==============================================================================
library(rpart)
library(rpart.plot)

#------------------------------------------------------------------------------
# 1. Building a Survival Tree (Train Set)
#------------------------------------------------------------------------------
set.seed(12345)  # for reproducibility
obj_tree <- rpart(
  Surv(time, status) ~ hormone + meno + size + grade + nodes + prog + estrg + age,
  control = rpart.control(xval = 10, minbucket = 2, cp = 0),
  data    = train
)

printcp(obj_tree)  # shows the cross-validation results

# Identify the complexity parameter that yields minimal xerror
cptable  <- obj_tree$cptable
cp.opt   <- cptable[which.min(cptable[, "xerror"]), "CP"]

# Prune the tree using cp.opt
fit_tree <- prune(obj_tree, cp = cp.opt)

#------------------------------------------------------------------------------
# 2. Visualize the Pruned Tree and KM Curves
#------------------------------------------------------------------------------
par(mfrow = c(1, 2))

# Plot the tree structure
rpart.plot(fit_tree)

# Fit a KM in each terminal node => helpful to see the survival in each leaf
km_fit <- survfit(Surv(time, status) ~ fit_tree$where, data = train)
plot(
  km_fit,
  lty      = 1:4,
  mark.time= FALSE,
  xlab     = "Years",
  ylab     = "Progression",
  lwd      = 2,
  cex.lab  = 1.2,
  cex.axis = 1.2
)
legend(
  "bottomleft",
  paste("Node", sort(unique(fit_tree$where))),
  lty = 1:4,
  lwd = 2,
  cex = 1.2
)

par(mfrow = c(1, 1))

#------------------------------------------------------------------------------
# 3. Extracting Leaf-Specific Survival Functions
#------------------------------------------------------------------------------
tmp       <- summary(km_fit)
tmp.strata<- as.integer(sub(".*=", "", tmp$strata)) # node labels
tmp.t     <- tmp$time
tmp.surv  <- tmp$surv

# Terminal node IDs
TN <- sort(unique(tmp.strata))
N  <- length(TN)

# Sort the unique times from tmp.t
t_unique <- sort(unique(tmp.t))
m        <- length(t_unique)

# fitted_surv[j,k] => survival at time t_unique[j] for node k
fitted_surv <- matrix(NA, nrow = m, ncol = N)

for (j in seq_len(m)) {
  tj <- t_unique[j]
  for (k in seq_len(N)) {
    # times within that node
    node_times <- c(0, tmp.t[tmp.strata == TN[k]])
    node_survs <- c(1, tmp.surv[tmp.strata == TN[k]])
    
    idx <- sum(node_times <= tj)
    fitted_surv[j, k] <- node_survs[idx]
  }
}

#------------------------------------------------------------------------------
# 4. Apply the Tree to the Test Set
#------------------------------------------------------------------------------
library(treeClust)
# rpart.predict.leaves() => which leaf each test subject lands in
test_term <- rpart.predict.leaves(fit_tree, test)
n_test    <- nrow(test)

# Construct an (n_test x m) matrix of survival probabilities
St_tree <- matrix(NA, nrow = n_test, ncol = m)

for (k in seq_len(N)) {
  # Index test subjects in node k
  ind <- which(test_term == TN[k])
  # replicate the node-k survival curve for these subjects
  if (length(ind) > 0) {
    St_tree[ind, ] <- matrix(fitted_surv[, k], nrow = length(ind), ncol = m, byrow = TRUE)
  }
}

#------------------------------------------------------------------------------
# 5. Brier Score for the Survival Tree
#------------------------------------------------------------------------------
BS_tau_tree <- numeric(length(tau_list))

for (i in seq_along(tau_list)) {
  BS_tau_tree[i] <- BSfun(
    tau   = tau_list[i],
    time  = test$time,
    status= test$status,
    St    = St_tree,
    t     = t_unique,
    Gtc   = Gtc,
    tc    = tc
  )
}


#==============================================================================
# (D) Comparing Brier Score Curves
#==============================================================================
par(mfrow = c(1, 1))

# Plot Brier Score for the survival tree (red), Cox-lasso (blue), full Cox (black)
plot(
  tau_list / 12, BS_tau_tree,
  type  = "l",
  lwd   = 2,
  col   = "red",
  cex.axis= 1.2,
  cex.lab = 1.2,
  xlab  = "t (years)",
  ylab  = "BS(t)",
  ylim  = c(0, 0.25)
)
lines(tau_list / 12, BS_tau_lasso, lwd = 2, col = "blue")
lines(tau_list / 12, BS_tau_full,  lwd = 2, col = "black")

legend(
  "bottomright",
  lty  = 1,
  col  = c("red", "blue", "black"),
  lwd  = 2,
  legend = c("Survival Tree", "Cox-lasso", "Cox-full"),
  cex = 1.2
)