########################################
# 1. Regularized Cox model
########################################
library(glmnet)
<- glmnet(Z, Surv(time, status), family = "cox", alpha = 1)
obj ::plot_glmnet(obj) # Coefficient paths
plotmo
# Cross-validation to select lambda
<- cv.glmnet(Z, Surv(time, status), family = "cox", alpha = 1)
obj.cv summary(obj.cv)
########################################
# 2. Survival tree
########################################
library(rpart)
<- rpart(Surv(time, status) ~ covariates,
obj control = rpart.control(xval = 10, minbucket = 2, cp = 0))
# Prune the tree
<- prune(obj, cp = obj$cptable[which.min(obj$cptable[, "xerror"]), "CP"])
fit rpart.plot(fit)
# Predict terminal nodes and Kaplan–Meier within nodes
$where fit
Chapter 15 - Machine Learning in Survival Analysis
Slides
Lecture slides here. (To convert html to pdf, press E \(\to\) Print \(\to\) Destination: Save to pdf)
Chapter Summary
Machine learning methods for survival analysis offer flexible and scalable alternatives to traditional models, especially in the presence of many covariates, nonlinear effects, or complex interactions. Two main approaches are considered: regularized Cox models for high-dimensional variable selection, and survival trees for nonparametric prediction. Ensemble methods such as bagging and random forests further enhance stability and prediction accuracy.
Regularized Cox models
When there are many candidate covariates, regularization can improve generalizability and interpretability.
Penalized partial likelihood (lasso-penalized Cox model): \[ \hat\beta(\lambda) = \arg\min_\beta \left\{ -pl_n(\beta) + \lambda \sum_{j=1}^p |\beta_j| \right\}, \] where \(pl_n(\beta)\) is the partial log-likelihood and \(\lambda\) controls the degree of shrinkage.
Variable selection:
The \(L_1\)-penalty (lasso) produces sparse solutions by shrinking some coefficients exactly to zero.Cross-validation:
The penalty parameter \(\lambda\) is selected by minimizing the partial-likelihood deviance across validation folds.Prediction error:
Assessed using inverse probability censored weighting (IPCW) estimators of the Brier score.Computation:
Coordinate descent algorithm solves the penalized likelihood via iterative weighted least squares with soft-thresholding.
Survival trees
Survival trees offer a nonparametric alternative, capable of automatically capturing nonlinearities and interactions without prespecification.
- Tree construction:
- Recursive binary splits based on covariates
- Splitting criterion: minimize within-node deviance residuals
- Fully grown tree followed by pruning based on cost-complexity and cross-validation.
- Terminal nodes:
- Within-node Kaplan–Meier curves provide survival predictions for new subjects.
- Advantages:
- No need to specify functional forms
- Flexible to complex data structures
- Naturally captures covariate interactions
Ensemble methods
Stability and prediction performance are improved by aggregating many trees.
- Bagging:
- Grow multiple trees on bootstrapped datasets
- Average survival predictions across trees
- Random forests:
- Further randomize splits by selecting a random subset of covariates at each node
- Reduce correlation among trees for better generalization
Example R code
Conclusion
Machine learning methods provide flexible, data-driven tools for survival prediction and variable selection. Regularized Cox models yield interpretable sparse models and handle high-dimensional data efficiently. Survival trees offer a nonparametric alternative that automatically detects complex nonlinear effects and interactions. Ensemble methods like bagging and random forests further improve stability and predictive performance, making them powerful complements to traditional survival analysis techniques.
R Code
Show the code
###############################################################################
# Chapter 15 R Code
#
# This script reproduces all major numerical results in Chapter 15, including:
# 1. Cox-Lasso analysis on GBC data (train/test split)
# 2. Survival tree modeling (pruning, terminal node curves)
# 3. Brier Score comparisons (Cox-lasso, full Cox, survival tree)
###############################################################################
#==============================================================================
# (A) Cox-Lasso on GBC Data
#==============================================================================
library(survival)
library(glmnet)
# library("glmpath") # alternative approach for lasso, commented out
#------------------------------------------------------------------------------
# 1. Data Reading & Preparation
#------------------------------------------------------------------------------
# The "gbc.txt" file contains the complete German Breast Cancer Study data
<- read.table("Data//German Breast Cancer Study//gbc.txt")
gbc
# Sort the data by (id, time) so that each subject’s rows appear chronologically
<- order(gbc$id, gbc$time)
o <- gbc[o, ]
gbc
# Keep only the first row per subject => “first event” data
# i.e., we assume each subject either has 1 event or is censored.
<- gbc[!duplicated(gbc$id), ]
data.CE
# Convert status so that status=1 if status is in {1,2}, else 0
# i.e., 0 = censored, 1 = relapse or death
$status <- as.integer(data.CE$status > 0)
data.CE
# Create a binary variable age40 = 1 if age <= 40, else 0
# This categorizes younger patients distinctly
$age40 <- as.integer(data.CE$age <= 40)
data.CE
# Store total sample size
<- nrow(data.CE)
n
# Quick peek at first few rows
head(data.CE)
#------------------------------------------------------------------------------
# 2. Train/Test Split
#------------------------------------------------------------------------------
# We'll do a random sample of 400 for training, the rest are test
set.seed(1234)
<- sample(1:n, size = 400)
ind <- data.CE[ind, ]
train <- data.CE[-ind, ]
test
#------------------------------------------------------------------------------
# 3. Fitting Cox-Lasso with glmnet
#------------------------------------------------------------------------------
# Define the set of predictors we want to consider
<- c("hormone", "age40", "meno", "size", "grade",
pred_list "nodes", "prog", "estrg")
# Construct the design matrix Z from train data
<- as.matrix(train[, pred_list])
Z
# Extract time and status vectors
<- train$time
time <- train$status
status
# Fit a Cox model with L1 penalty (alpha=1 => lasso), over a sequence of lambda
<- glmnet(Z, Surv(time, status), family = "cox", alpha = 1)
obj
# Summarize the glmnet fit => mostly high-level info about the path
summary(obj)
# Perform 10-fold cross-validation to select optimal lambda
<- cv.glmnet(Z, Surv(time, status), family = "cox", alpha = 1)
obj.cv
#------------------------------------------------------------------------------
# 4. Visualizing Coefficient Paths & CV Error
#------------------------------------------------------------------------------
library(plotmo) # Provides plot_glmnet() for convenient path plots
par(mfrow = c(1, 2))
# plot_glmnet() => coefficient paths vs. log(lambda)
plot_glmnet(obj, lwd = 2)
# plot() on the cv.glmnet object => partial-likelihood deviance vs. lambda
plot(obj.cv)
par(mfrow = c(1, 1)) # reset plotting layout
# Identify the optimal lambda minimizing the CV error
<- obj.cv$lambda.min
lambda.opt
# Extract the coefficients at lambda.min
<- coef(obj.cv, s = "lambda.min")
beta.opt
# Identify which coefficients are non-zero
<- beta.opt[abs(beta.opt[, 1]) > 0, ]
beta.selected # show the non-zero variables
beta.selected length(beta.selected) # how many are non-zero?
#------------------------------------------------------------------------------
# 5. Refit Cox Model Using Selected Variables
#------------------------------------------------------------------------------
# We'll pull the variable names and re-fit a coxph with these in the training set
<- names(beta.selected)
selected_vars <- coxph(Surv(time, status) ~ as.matrix(train[, selected_vars]))
obj_lasso
#------------------------------------------------------------------------------
# 6. A Utility Function to Get Predicted Survival Curves from a Cox Model
#------------------------------------------------------------------------------
<- function(obj, Z) {
predsurv_cox # obj: coxph object
# Z: (n x p) matrix of covariates for n subjects
# Returns a list with $St (n x m matrix of S_i(t_j)) and $t (length m of times)
<- obj$coefficients
beta <- basehaz(obj, centered = FALSE) # baseline hazard info
bhaz <- bhaz$hazard # cumulative baseline hazard
L <- c(0, bhaz$time) # prepend 0 to times
tt
# We'll store S_i(t_j) in a matrix of dimension (n x length(tt))
<- matrix(NA, nrow = nrow(Z), ncol = length(tt))
Smat
for (i in seq_len(nrow(Z))) {
<- exp(Z[i, ] %*% beta) # linear predictor for subject i
eXb # S(0) = 1, S(t_j)=exp(- eXb * L_j)
<- c(1, exp(-eXb * L))
Smat[i, ]
}
list(St = Smat, t = tt)
}
#==============================================================================
# (B) Brier Score Calculation for Lasso & Full Cox
#==============================================================================
#------------------------------------------------------------------------------
# 1. Lasso-Fitted Model's Predictions on Test Data
#------------------------------------------------------------------------------
<- predsurv_cox(
pred_lasso
obj_lasso,as.matrix(test[, selected_vars])
)<- pred_lasso$St
St_lasso <- pred_lasso$t
t_lasso
# KM for censoring distribution in the test set
<- survfit(Surv(time, 1 - status) ~ 1, data = test)
G_obj
<- c(0, G_obj$time) # times in the censoring KM
tc <- c(1, G_obj$surv) # survival prob => G(t)
Gtc
#------------------------------------------------------------------------------
# 2. Brier Score Function
#------------------------------------------------------------------------------
<- function(tau, time, status, St, t, Gtc, tc) {
BSfun # tau: time at which we evaluate the Brier Score
# time, status: observed times & event indicators
# St: (n x m) matrix of predicted survival rates from the model
# t: vector of times associated with columns of St
# Gtc, tc: censoring distribution estimates (KM) at times in tc
<- length(time)
n <- sum(t <= tau) # columns in St up to tau
pos <- numeric(n) # store Brier Score components per subject
BSvec
for (i in seq_len(n)) {
<- time[i]
X_i <- St[i, pos] # predicted survival at tau
S_i # Evaluate G(t_i) or G(tau) for weighting
if (X_i <= tau && status[i] == 1) {
# subject i had the event by time X_i <= tau
# denominator => G(X_i)
<- Gtc[sum(tc <= X_i)]
G_i # contributes (S_i)^2 / G(X_i)
<- (S_i^2 / G_i)
BSvec[i] else if (X_i > tau) {
} # subject i is still at risk at tau
# => contributes (1-S_i)^2 / G(tau)
<- Gtc[sum(tc <= tau)]
G_i <- ((1 - S_i)^2 / G_i)
BSvec[i] else {
} # X_i <= tau but no event => subject was censored, no direct event
<- 0
BSvec[i]
}
}
mean(BSvec, na.rm = TRUE)
}
#------------------------------------------------------------------------------
# 3. Evaluate Brier Score for Cox-Lasso at times 12..60
#------------------------------------------------------------------------------
<- 12:60
tau_list <- numeric(length(tau_list))
BS_tau_lasso
for (i in seq_along(tau_list)) {
<- BSfun(
BS_tau_lasso[i] tau = tau_list[i],
time = test$time,
status = test$status,
St = St_lasso,
t = t_lasso,
Gtc = Gtc,
tc = tc
)
}
#------------------------------------------------------------------------------
# 4. Full Cox Model
#------------------------------------------------------------------------------
<- coxph(Surv(train$time, train$status) ~ as.matrix(train[, pred_list]))
obj_full
# Predicted survival in test set
<- predsurv_cox(obj_full, as.matrix(test[, pred_list]))
pred_full <- pred_full$St
St_full <- pred_full$t
t_full
# Brier Score for the full Cox
<- numeric(length(tau_list))
BS_tau_full
for (i in seq_along(tau_list)) {
<- BSfun(
BS_tau_full[i] tau = tau_list[i],
time = test$time,
status = test$status,
St = St_full,
t = t_full,
Gtc = Gtc,
tc = tc
)
}
#==============================================================================
# (C) Survival Tree Analysis
#==============================================================================
library(rpart)
library(rpart.plot)
#------------------------------------------------------------------------------
# 1. Building a Survival Tree (Train Set)
#------------------------------------------------------------------------------
set.seed(12345) # for reproducibility
<- rpart(
obj_tree Surv(time, status) ~ hormone + meno + size + grade + nodes + prog + estrg + age,
control = rpart.control(xval = 10, minbucket = 2, cp = 0),
data = train
)
printcp(obj_tree) # shows the cross-validation results
# Identify the complexity parameter that yields minimal xerror
<- obj_tree$cptable
cptable <- cptable[which.min(cptable[, "xerror"]), "CP"]
cp.opt
# Prune the tree using cp.opt
<- prune(obj_tree, cp = cp.opt)
fit_tree
#------------------------------------------------------------------------------
# 2. Visualize the Pruned Tree and KM Curves
#------------------------------------------------------------------------------
par(mfrow = c(1, 2))
# Plot the tree structure
rpart.plot(fit_tree)
# Fit a KM in each terminal node => helpful to see the survival in each leaf
<- survfit(Surv(time, status) ~ fit_tree$where, data = train)
km_fit plot(
km_fit,lty = 1:4,
mark.time= FALSE,
xlab = "Years",
ylab = "Progression",
lwd = 2,
cex.lab = 1.2,
cex.axis = 1.2
)legend(
"bottomleft",
paste("Node", sort(unique(fit_tree$where))),
lty = 1:4,
lwd = 2,
cex = 1.2
)
par(mfrow = c(1, 1))
#------------------------------------------------------------------------------
# 3. Extracting Leaf-Specific Survival Functions
#------------------------------------------------------------------------------
<- summary(km_fit)
tmp <- as.integer(sub(".*=", "", tmp$strata)) # node labels
tmp.strata<- tmp$time
tmp.t <- tmp$surv
tmp.surv
# Terminal node IDs
<- sort(unique(tmp.strata))
TN <- length(TN)
N
# Sort the unique times from tmp.t
<- sort(unique(tmp.t))
t_unique <- length(t_unique)
m
# fitted_surv[j,k] => survival at time t_unique[j] for node k
<- matrix(NA, nrow = m, ncol = N)
fitted_surv
for (j in seq_len(m)) {
<- t_unique[j]
tj for (k in seq_len(N)) {
# times within that node
<- c(0, tmp.t[tmp.strata == TN[k]])
node_times <- c(1, tmp.surv[tmp.strata == TN[k]])
node_survs
<- sum(node_times <= tj)
idx <- node_survs[idx]
fitted_surv[j, k]
}
}
#------------------------------------------------------------------------------
# 4. Apply the Tree to the Test Set
#------------------------------------------------------------------------------
library(treeClust)
# rpart.predict.leaves() => which leaf each test subject lands in
<- rpart.predict.leaves(fit_tree, test)
test_term <- nrow(test)
n_test
# Construct an (n_test x m) matrix of survival probabilities
<- matrix(NA, nrow = n_test, ncol = m)
St_tree
for (k in seq_len(N)) {
# Index test subjects in node k
<- which(test_term == TN[k])
ind # replicate the node-k survival curve for these subjects
if (length(ind) > 0) {
<- matrix(fitted_surv[, k], nrow = length(ind), ncol = m, byrow = TRUE)
St_tree[ind, ]
}
}
#------------------------------------------------------------------------------
# 5. Brier Score for the Survival Tree
#------------------------------------------------------------------------------
<- numeric(length(tau_list))
BS_tau_tree
for (i in seq_along(tau_list)) {
<- BSfun(
BS_tau_tree[i] tau = tau_list[i],
time = test$time,
status= test$status,
St = St_tree,
t = t_unique,
Gtc = Gtc,
tc = tc
)
}
#==============================================================================
# (D) Comparing Brier Score Curves
#==============================================================================
par(mfrow = c(1, 1))
# Plot Brier Score for the survival tree (red), Cox-lasso (blue), full Cox (black)
plot(
/ 12, BS_tau_tree,
tau_list type = "l",
lwd = 2,
col = "red",
cex.axis= 1.2,
cex.lab = 1.2,
xlab = "t (years)",
ylab = "BS(t)",
ylim = c(0, 0.25)
)lines(tau_list / 12, BS_tau_lasso, lwd = 2, col = "blue")
lines(tau_list / 12, BS_tau_full, lwd = 2, col = "black")
legend(
"bottomright",
lty = 1,
col = c("red", "blue", "black"),
lwd = 2,
legend = c("Survival Tree", "Cox-lasso", "Cox-full"),
cex = 1.2
)