########################################
# 1. Regularized Cox model
########################################
library(glmnet)
obj <- glmnet(Z, Surv(time, status), family = "cox", alpha = 1)
plotmo::plot_glmnet(obj) # Coefficient paths
# Cross-validation to select lambda
obj.cv <- cv.glmnet(Z, Surv(time, status), family = "cox", alpha = 1)
summary(obj.cv)
########################################
# 2. Survival tree
########################################
library(rpart)
obj <- rpart(Surv(time, status) ~ covariates,
control = rpart.control(xval = 10, minbucket = 2, cp = 0))
# Prune the tree
fit <- prune(obj, cp = obj$cptable[which.min(obj$cptable[, "xerror"]), "CP"])
rpart.plot(fit)
# Predict terminal nodes and Kaplan–Meier within nodes
fit$whereChapter 15 - Machine Learning in Survival Analysis
Slides
Lecture slides here. (To convert html to pdf, press E \(\to\) Print \(\to\) Destination: Save to pdf)
Chapter Summary
Machine learning methods for survival analysis offer flexible and scalable alternatives to traditional models, especially in the presence of many covariates, nonlinear effects, or complex interactions. Two main approaches are considered: regularized Cox models for high-dimensional variable selection, and survival trees for nonparametric prediction. Ensemble methods such as bagging and random forests further enhance stability and prediction accuracy.
Regularized Cox models
When there are many candidate covariates, regularization can improve generalizability and interpretability.
Penalized partial likelihood (lasso-penalized Cox model): \[ \hat\beta(\lambda) = \arg\min_\beta \left\{ -pl_n(\beta) + \lambda \sum_{j=1}^p |\beta_j| \right\}, \] where \(pl_n(\beta)\) is the partial log-likelihood and \(\lambda\) controls the degree of shrinkage.
Variable selection:
The \(L_1\)-penalty (lasso) produces sparse solutions by shrinking some coefficients exactly to zero.Cross-validation:
The penalty parameter \(\lambda\) is selected by minimizing the partial-likelihood deviance across validation folds.Prediction error:
Assessed using inverse probability censored weighting (IPCW) estimators of the Brier score.Computation:
Coordinate descent algorithm solves the penalized likelihood via iterative weighted least squares with soft-thresholding.
Survival trees
Survival trees offer a nonparametric alternative, capable of automatically capturing nonlinearities and interactions without prespecification.
- Tree construction:
- Recursive binary splits based on covariates
- Splitting criterion: minimize within-node deviance residuals
- Fully grown tree followed by pruning based on cost-complexity and cross-validation.
- Terminal nodes:
- Within-node Kaplan–Meier curves provide survival predictions for new subjects.
- Advantages:
- No need to specify functional forms
- Flexible to complex data structures
- Naturally captures covariate interactions
Ensemble methods
Stability and prediction performance are improved by aggregating many trees.
- Bagging:
- Grow multiple trees on bootstrapped datasets
- Average survival predictions across trees
- Random forests:
- Further randomize splits by selecting a random subset of covariates at each node
- Reduce correlation among trees for better generalization
Example R code
Conclusion
Machine learning methods provide flexible, data-driven tools for survival prediction and variable selection. Regularized Cox models yield interpretable sparse models and handle high-dimensional data efficiently. Survival trees offer a nonparametric alternative that automatically detects complex nonlinear effects and interactions. Ensemble methods like bagging and random forests further improve stability and predictive performance, making them powerful complements to traditional survival analysis techniques.
R Code
Show the code
###############################################################################
# Chapter 15 R Code
#
# This script reproduces all major numerical results in Chapter 15, including:
# 1. Cox-Lasso analysis on GBC data (train/test split)
# 2. Survival tree modeling (pruning, terminal node curves)
# 3. Brier Score comparisons (Cox-lasso, full Cox, survival tree)
###############################################################################
#------------------------------------------------------------------------------
# 2. Brier Score Function
#------------------------------------------------------------------------------
BSfun <- function(timepoints, time, status, St, t) {
# timepoints: evaluation times
# time, status: observed time X and event indicator Delta (1=event, 0=censor)
# St: n x m predicted survival, columns correspond to t
# Gtc, tc: censoring survival estimate G(tc) at times tc (KM)
n <- nrow(St)
# censoring survival G(t) estimated via KM on censoring process
# censor indicator: 1 if censored
cens <- 1 - status
fitG <- survival::survfit(survival::Surv(time, cens) ~ 1)
# Right-continuous step function for G, with G(0)=1
G_step <- stats::stepfun(fitG$time, c(1, fitG$surv))
# G(t) for each evaluation time
G_t <- G_step(timepoints)
# Avoid division by 0
G_t[G_t <= 0] <- NA_real_
BSmat <- matrix(0, nrow = n, ncol = length(timepoints))
# small epsilon for left-limit evaluation
eps <- max(1e-8, 1e-8 * max(c(timepoints, time, tc), na.rm = TRUE))
for (i in seq_len(n)) {
S_step_i <- stepfun(t, c(1, St[i, ]))
S_i <- S_step_i(timepoints)
# left-limit of G at observed time
G_xi <- G_step(time[i] - eps)
if (!is.finite(G_xi) || G_xi <= 0) G_xi <- NA_real_
# indicators over timepoints
I_event_by_t <- (time[i] <= timepoints) & (status[i] == 1)
I_at_risk_t <- (time[i] > timepoints)
# event term: S(t)^2 / G(X_i-)
BSmat[i, I_event_by_t] <- (S_i[I_event_by_t]^2) / G_xi
# survival term: (1 - S(t))^2 / G(t)
BSmat[i, I_at_risk_t] <- ((1 - S_i[I_at_risk_t])^2) / G_t[I_at_risk_t]
# censored before t: contributes 0 (already 0)
}
colMeans(BSmat, na.rm = TRUE)
}
# C-index: Harrell's and Uno's
Cindex_fun <- function(timepoints = NULL, riskscore, time, status) {
# timepoints: NULL -> Harrell's C
# numeric vector -> Uno's time-restricted C at each tau in timepoints
# riskscore: higher = higher risk (shorter survival)
# time, status: observed time X and event indicator Delta (1=event, 0=censor)
stopifnot(length(riskscore) == length(time),
length(time) == length(status))
n <- length(time)
if (n < 2L) stop("Need at least two observations.")
# Helper: Harrell's C (pairwise definition)
harrell_c <- function(risk, time, status) {
conc <- 0 # concordant pairs
ties <- 0 # tied pairs
comp <- 0 # comparable pairs
for (i in 1:(n - 1L)) {
for (j in (i + 1L):n) {
# comparable if the shorter observed time is an event
if (time[i] < time[j] && status[i] == 1) {
comp <- comp + 1
if (risk[i] > risk[j]) conc <- conc + 1 # risk higher for shorter time
else if (risk[i] == risk[j]) ties <- ties + 1 # tied risk scores
} else if (time[j] < time[i] && status[j] == 1) { # Same for j < i
comp <- comp + 1
if (risk[j] > risk[i]) conc <- conc + 1
else if (risk[i] == risk[j]) ties <- ties + 1
}
# if equal times, we skip (common choice; avoids ambiguity with ties)
}
}
if (comp == 0) return(NA_real_)
(conc + 0.5 * ties) / comp # Harrell's C
}
# Helper: Uno's C restricted to tau
# Uses IPCW weights 1/G(T_i-) for event times T_i <= tau.
uno_c_tau <- function(tau, risk, time, status) {
n <- length(time)
# censoring survival G(t) estimated via KM on censoring process
# censor indicator: 1 if censored
cens <- 1 - status
fitG <- survival::survfit(survival::Surv(time, cens) ~ 1)
# Right-continuous step function for G, with G(0)=1
G_step <- stats::stepfun(fitG$time, c(1, fitG$surv))
# left-limit at time t via tiny epsilon
eps <- max(1e-8, 1e-8 * max(time, tau, na.rm = TRUE))
num <- 0 # numerator
den <- 0 # denominator
for (i in 1:n) {
# i must be an event observed by tau
if (status[i] == 1 && time[i] <= tau) {
Gi <- G_step(time[i] - eps)
if (!is.finite(Gi) || Gi <= 0) next
wi <- 1 / Gi^2
# compare i against those who are still at risk after time[i]
# (includes those censored after time[i], and those failing later)
for (j in 1:n) {
if (j == i) next
if (time[j] > time[i]) {
den <- den + wi
if (risk[i] > risk[j]) num <- num + wi
else if (risk[i] == risk[j]) num <- num + 0.5 * wi
}
}
}
}
if (den == 0) return(NA_real_)
num / den
}
# Dispatch
if (is.null(timepoints)) {
return(harrell_c(riskscore, time, status))
}
if (!is.numeric(timepoints) || length(timepoints) == 0L) {
stop("timepoints must be NULL or a non-empty numeric vector.")
}
# Uno's C for each tau
out <- vapply(timepoints, function(tau) uno_c_tau(tau, riskscore, time, status), numeric(1))
data.frame(tau = timepoints, C = out)
}
#------------------------------------------------------------------------------
# update glmnet package
# install.packages("glmnet")
#==============================================================================
# (A) Cox-Lasso on GBC Data
#==============================================================================
library(survival)
library(glmnet)
library(tidyverse)
library(patchwork)
library(SurvMetrics)
# library("glmpath") # alternative approach for lasso, commented out
#------------------------------------------------------------------------------
# 1. Data Reading & Preparation
#------------------------------------------------------------------------------
# The "gbc.txt" file contains the complete German Breast Cancer Study data
gbc <- read.table("Data//German Breast Cancer Study//gbc.txt")
# Sort the data by (id, time) so that each subject’s rows appear chronologically
o <- order(gbc$id, gbc$time)
gbc <- gbc[o, ]
# Keep only the first row per subject => “first event” data
# i.e., we assume each subject either has 1 event or is censored.
df <- gbc[!duplicated(gbc$id), ] |>
# Transformations
mutate(
status = as.integer(status > 0), # event: relapse or death
age40 = as.integer(age <= 40) # age <= 40
)
# df$age40 <- as.integer(df$age <= 40)
# df$status <- as.integer(df$status > 0)
# Store total sample size
n <- nrow(df)
# Quick peek at first few rows
head(df)
#------------------------------------------------------------------------------
# 2. Train/Test Split
#------------------------------------------------------------------------------
# We'll do a random sample of 400 for training, the rest are test
set.seed(1234)
ind <- sample(1:n, size = 400)
train <- df[ind, ]
test <- df[-ind, ]
#------------------------------------------------------------------------------
# 3. Fitting Cox-Lasso with glmnet
#------------------------------------------------------------------------------
# Define the set of predictors we want to consider
pred_list <- c("hormone", "age40", "meno", "size", "grade",
"nodes", "prog", "estrg")
# Construct the design matrix Z from train data
Z <- as.matrix(train[, pred_list])
# Extract time and status vectors
time <- train$time
status <- train$status
# Fit a Cox model with L1 penalty (alpha=1 => lasso), over a sequence of lambda
obj <- glmnet(Z, Surv(time, status), family = "cox", alpha = 1)
# Summarize the glmnet fit => mostly high-level info about the path
obj$beta # coefficient matrix (p x m) for m lambda values
obj$lambda # lambda sequence used (m-vector)
plot(obj) # basic plot of coefficient paths vs. log(lambda))
# Perform 10-fold cross-validation to select optimal lambda
set.seed(1234) # for reproducibility
obj.cv <- cv.glmnet(Z, Surv(time, status), family = "cox", alpha = 1)
# ?cv.glmnet
# ?plot.cv.glmnet
obj.cv
plot(obj.cv) # plot CV error vs. log(lambda)
# Alternative CV metrics
# C-index,
# obj.cv_c <- cv.glmnet(Z, Surv(time, status), family = "cox", alpha = 1,
# type.measure = "C")
# plot(obj.cv_c)
#------------------------------------------------------------------------------
# 4. Visualizing Coefficient Paths & CV Error
#------------------------------------------------------------------------------
# Remake the coefficient paths with standardized Z
# and CV error plot side by side for consistent style
# For coefficient paths, standardize Z for better comparison
obj_std <- glmnet(scale(Z), Surv(time, status), family = "cox", alpha = 1)
# plot the pathwise solutions
df_paths <-
tibble(
log_lambda = log(obj_std$lambda),
t(obj_std$beta) |> as.matrix() |> as.tibble()
) |>
pivot_longer(
cols = -log_lambda,
names_to = "variable",
values_to = "coefficient"
)
# Labels and position for variable names at the right end of paths
df_paths_label <- df_paths |>
filter(log_lambda == min(log_lambda)) |>
# estrg and size are too close, adjust y-positions
mutate(
coefficient = case_when(
variable == "nodes" ~ coefficient + 0.01,
variable == "age40" ~ coefficient - 0.01,
variable == "grade" ~ coefficient + 0.01,
variable == "estrg" ~ coefficient - 0.015,
variable == "size" ~ coefficient + 0.015,
variable == "meno" ~ coefficient - 0.012,
TRUE ~ coefficient
),
variable = if_else(variable == "hormone", "horm", variable)
)
# Plot coefficient paths
gbc_path <- ggplot(df_paths, aes(x = log_lambda, y = coefficient, color = variable)) +
geom_line(linewidth = 0.8) +
geom_vline(xintercept = log(obj.cv$lambda.min), linetype = "dashed") +
# label variable names at the end
geom_text(
data = df_paths_label,
aes(label = variable),
hjust = -0.05,
vjust = 0.5,
size = 3.5
) +
scale_x_reverse(expand = expansion(c(0.05, 0.15))) +
labs(
x = expression(log(lambda)),
y = "Coefficients"
) +
theme_minimal(base_size = 12) +
theme(
legend.position = "none",
axis.text.y = element_blank(),
axis.text.x = element_text(size = 11)
)
# Plot CV
df_cv <- data.frame(
log_lambda = log(obj.cv$lambda),
cvm = obj.cv$cvm,
cvup = obj.cv$cvup,
cvlo = obj.cv$cvlo,
nzero = obj.cv$nzero
)
gbc_cv <- ggplot(df_cv, aes(x = log_lambda, y = cvm)) +
geom_line(linewidth = 1) +
geom_point(size = 3, alpha = 0.5) +
# omit any ribbon / error bars to suppress SE
geom_vline(xintercept = log(obj.cv$lambda.min), linetype = "dashed") +
scale_x_reverse() +
labs(x = expression(log(lambda)), y = obj.cv$name) +
theme_minimal(base_size = 12) +
theme(
legend.position = "none",
# axis.text.y = element_blank(),
axis.text.x = element_text(size = 11)
)
# Combine the two plots side by side
gbc_combined <- gbc_path + gbc_cv
gbc_combined
# ggsave("images/ml_glmnet_gbc.png", gbc_combined,
# width = 8, height = 4, dpi = 300)
# ggsave("images/ml_glmnet_gbc.eps", gbc_combined, device = cairo_ps,
# width = 8, height = 4)
# Identify the optimal lambda minimizing the CV error
lambda.opt <- obj.cv$lambda.min
lambda.opt
# [1] 0.02472102
# Extract the coefficients at lambda.min
beta.opt <- coef(obj.cv, s = "lambda.min")
# Identify which coefficients are non-zero
beta.selected <- beta.opt[abs(beta.opt[, 1]) > 0, ]
beta.selected # show the non-zero variables
length(beta.selected) # how many are non-zero?
#------------------------------------------------------------------------------
# 5. Refit Cox Model Using Selected Variables
#------------------------------------------------------------------------------
# We'll pull the variable names and re-fit a coxph with these in the training set
selected_vars <- names(beta.selected)
# Refit the Cox model using only selected variables
obj_lasso <- coxph(Surv(time, status) ~ as.matrix(train[, selected_vars]),
data = train)
# Get risk scores in test set
beta_lasso <- coef(obj_lasso) # fitted coefficients
test_lasso_rs <- as.matrix(test[, selected_vars]) %*% beta_lasso |> as.vector()
# Metrics data frames ----------------------------------------------------------
# 6 models: Lasso, Full Cox, inadequate Cox (without prog and nodes)
# survival tree (rpart), random survival forest (rsf), SSVM
# For each model, we will compute C-index (Harrell's and Uno's) and Brier Score
c_index_results0 <- data.frame(
Model = character(),
Harrell_C = numeric(),
Uno_C_60 = numeric()
)
brier_score_results0 <- data.frame(
Model = character(),
Time = integer(),
Brier_Score = numeric()
)
# C index
# 0.65827
# Harrell's C
lasso_harrell_c <- Cindex_fun(riskscore = test_lasso_rs,
time = test$time,
status = test$status)
# [1] 0.6583023
# Uno's C at tau = 60
lasso_uno_c_60 <- Cindex_fun(timepoints = 60,
riskscore = test_lasso_rs,
time = test$time,
status = test$status)
# tau C
# 1 60 0.6449649
# Store results
c_index_results1 <- rbind(
c_index_results0,
data.frame(
Model = "Cox-Lasso",
Harrell_C = lasso_harrell_c,
Uno_C_60 = lasso_uno_c_60$C
)
)
#==============================================================================
# (B) Brier Score Calculation for Lasso
#==============================================================================
#------------------------------------------------------------------------------
# 1. Lasso-Fitted Model's Predictions on Test Data
#------------------------------------------------------------------------------
pred_lasso <- predsurv_cox(
obj_lasso,
as.matrix(test[, selected_vars])
)
St_lasso <- pred_lasso$St
t_lasso <- pred_lasso$t
#------------------------------------------------------------------------------
# 3. Evaluate Brier Score for Cox-Lasso at times 12..60
#------------------------------------------------------------------------------
timepoints <- 1:60
brier_score_lasso <- BSfun(
timepoints = timepoints,
time = test$time,
status = test$status,
St = St_lasso,
t = t_lasso
)
brier_score_results1 <- rbind(
brier_score_results0,
data.frame(
Model = rep("Cox-Lasso", length(timepoints)),
Time = timepoints,
Brier_Score = brier_score_lasso
)
)
#------------------------------------------------------------------------------
# 4. Full Cox Model
#------------------------------------------------------------------------------
obj_full <- coxph(Surv(train$time, train$status) ~ as.matrix(train[, pred_list]))
# Get risk scores in test set
beta_full <- coef(obj_full) # fitted coefficients
test_full_rs <- as.matrix(test[, pred_list]) %*% beta_full |> as.vector()
# C index
# Harrell's C
full_harrell_c <- Cindex_fun(riskscore = test_full_rs,
time = test$time,
status = test$status)
# [1] 0.6568314
# Uno's C at tau = 60
full_uno_c_60 <- Cindex_fun(timepoints = 60,
riskscore = test_full_rs,
time = test$time,
status = test$status)
# tau C
# 1 60 0.648307
c_index_results2 <- rbind(
c_index_results1,
data.frame(
Model = "Full Cox",
Harrell_C = full_harrell_c,
Uno_C_60 = full_uno_c_60$C
)
)
# 5. Brier Score for Full Cox
# Predicted survival in test set
pred_full <- predsurv_cox(obj_full, as.matrix(test[, pred_list]))
St_full <- pred_full$St
t_full <- pred_full$t
# Brier Score for the full Cox
brier_score_full <- BSfun(
timepoints = timepoints,
time = test$time,
status = test$status,
St = St_full,
t = t_full
)
brier_score_results2 <- rbind(
brier_score_results1,
data.frame(
Model = rep("Full Cox", length(timepoints)),
Time = timepoints,
Brier_Score = brier_score_full
)
)
# An inadequate Cox model (without prog and nodes)
obj_inadequate <- coxph(Surv(train$time, train$status) ~ as.matrix(train[, c("hormone", "age40", "meno", "size", "grade", "estrg")]))
# Get risk scores in test set
beta_inadequate <- coef(obj_inadequate) # fitted coefficients
test_inadequate_rs <- as.matrix(test[, c("hormone", "age40", "meno", "size", "grade", "estrg")]) %*% beta_inadequate |> as.vector()
# C index
# Harrell's C
inadequate_harrell_c <- Cindex_fun(riskscore = test_inadequate_rs,
time = test$time,
status = test$status)
# [1] 0.602081
# Uno's C at tau = 60
inadequate_uno_c_60 <- Cindex_fun(timepoints = 60,
riskscore = test_inadequate_rs,
time = test$time,
status = test$status)
# Store results
c_index_results3 <- rbind(
c_index_results2,
data.frame(
Model = "Inadequate Cox",
Harrell_C = inadequate_harrell_c,
Uno_C_60 = inadequate_uno_c_60$C
)
)
# 5. Brier Score for Inadequate Cox
# Predicted survival in test set
pred_inadequate <- predsurv_cox(obj_inadequate, as.matrix(test[, c("hormone", "age40", "meno", "size", "grade", "estrg")]))
St_inadequate <- pred_inadequate$St
t_inadequate <- pred_inadequate$t
brier_score_inadequate <- BSfun(
timepoints = timepoints,
time = test$time,
status = test$status,
St = St_inadequate,
t = t_inadequate
)
brier_score_results3 <- rbind(
brier_score_results2,
data.frame(
Model = rep("Inadequate Cox", length(timepoints)),
Time = timepoints,
Brier_Score = brier_score_inadequate
)
)
#==============================================================================
# (C) Survival Tree Analysis
#==============================================================================
library(rpart)
library(rpart.plot)
#------------------------------------------------------------------------------
# 1. Building a Survival Tree (Train Set)
#------------------------------------------------------------------------------
set.seed(12345) # for reproducibility
obj_tree <- rpart(
Surv(time, status) ~ hormone + meno + size + grade + nodes + prog + estrg + age,
control = rpart.control(xval = 10, minbucket = 2, cp = 0),
data = train
)
# Summarize the tree object
printcp(obj_tree) # shows the cross-validation results
cptable <- obj_tree$cptable
colnames(cptable)
#> [1] "CP" "nsplit" "rel error" "xerror" "xstd"
# Identify the complexity parameter that yields minimal xerror
cptable <- obj_tree$cptable |>
as.data.frame() |>
slice(1:10)
# Identify optimal CP (min CV error)
cp_opt <- cptable[which.min(cptable$xerror), "CP"]
# Plot CV error vs CP
tree_cv <- ggplot(cptable, aes(x = log(CP), y = xerror)) +
geom_line(linewidth = 1) +
geom_point(size = 3, alpha = 0.5) +
geom_vline(xintercept = log(cp_opt), linetype = "dashed") +
scale_x_reverse() +
labs(
x = expression(log(lambda)),
y = "Cross-validated relative error"
) +
theme_minimal(base_size = 12) +
theme(
legend.position = "none",
axis.text.x = element_text(size = 11)
)
# Prune the tree using cp.opt
fit_tree <- prune(obj_tree, cp = cp_opt)
class(fit_tree)
#> [1] "rpart"
# ?rpart
# fit_tree$frame
#------------------------------------------------------------------------------
# 2. Visualize the Pruned Tree and KM Curves
#------------------------------------------------------------------------------
# Plot the tree structure
rpart.plot(fit_tree)
# Combine tree_cv and p_tree_manual by 40% vs 60% width
# fig_gbc_tree <- tree_cv + p_tree_manual +
# plot_layout(widths = c(0.4, 0.6))
#
# fig_gbc_tree
# ggsave("images/ml_rpart_gbc.png", fig_gbc_tree,
# width = 8, height = 4, dpi = 300)
# ggsave("images/ml_rpart_gbc.eps", fig_gbc_tree, device = cairo_ps,
# width = 8, height = 4)
# Variable importance
print(fit_tree$variable.importance)
fit_tree
fit_tree$where # terminal node assignments for training subjects
# Map row numbers in fit_tree$frame to node numbers
where_to_nodes <- rownames(fit_tree$frame) # node numbers
term_nodes_train <- where_to_nodes[fit_tree$where] # terminal node number for each subject
train$.term_node <- term_nodes_train # add to node info to training set
# Fit a KM in each terminal node => helpful to see the survival in each leaf
km_fit <- survfit(Surv(time, status) ~ .term_node, data = train)
plot(
km_fit,
lty = 1:4,
mark.time= FALSE,
xlab = "Time (years)",
ylab = "Relapse-free survival",
lwd = 2,
cex.lab = 1.2,
cex.axis = 1.2
)
legend(
"bottomleft",
paste(sort(unique(term_nodes_train))),
lty = 1:4,
lwd = 2,
cex = 1.2
)
# install.packages("ggsurvfit")
# ggplot version
library(ggsurvfit)
fig_km_tree <- survfit2(Surv(time, status) ~ .term_node,
data = train |>
mutate(.term_node = str_c("Node ", .term_node))) |>
ggsurvfit(linetype_aes = TRUE, linewidth = 1) +
## --- manual labels + arrows on the KM panel -----------------------------
# High risk (assume lowest curve)
annotate(
"text",
x = 18, y = 0.25,
label = "High risk",
hjust = 1,
size = 4
) +
geom_curve(
data = data.frame(
x = rep(19, 2),
y = rep(0.25, 2),
xend = c(26, 26),
yend = c(0.25, 0.4)
),
aes(x = x, y = y, xend = xend, yend = yend),
arrow = arrow(length = unit(0.15, "cm")),
curvature = 0
) +
# Moderate risk
annotate(
"text",
x = 42, y = 0.5,
label = "Moderate risk",
hjust = 0.5,
size = 4
) +
# Low risk (assume top curve)
annotate(
"text",
x = 48, y = 0.75,
label = "Low risk",
hjust = 0.5,
size = 4
) +
## ------------------------------------------------------------------------
scale_x_continuous(limits = c(0, 72), breaks = seq(0, 72, by = 12)) +
scale_linetype_manual(
values = 1:4
) +
# number at risk
add_risktable(
risktable_stats = "n.risk",
size = 4
) +
labs(
x = "Time (months)",
y = "Relapse-free survival",
color = "Leaf",
linetype = "Leaf"
) +
theme_minimal(base_size = 12) +
theme(
legend.text = element_text(size = 11),
legend.position = "top",
legend.key.width = unit(1, "cm"),
)
fig_km_tree
# ggsave("images/ml_km_tree_gbc.png", fig_km_tree,
# width = 8, height = 5, dpi = 300)
# ggsave("images/ml_km_tree_gbc.eps", fig_km_tree,
# width = 8, height = 5)
# combine nodes 5 and 7 into high risk group
train <- train |>
mutate(
risk_group = case_when(
.term_node %in% c("5", "7") ~ "High risk",
.term_node == "6" ~ "Moderate risk",
.term_node == "4" ~ "Low risk"
)
)
# KM plot by risk group
km_fit2 <- survfit(Surv(time, status) ~ risk_group, data = train)
# Tabulate using gtsummary::tbl_survfit
library(gtsummary)
tbl_surv <- km_fit2 |>
tbl_survfit( # Pass `survfit` object
times = c(12, 24, 60), # Time points for estimates
label_header = "Month {time}" # Column label: "Month xx"
)
tbl_surv
## Variable importance
vi_surrog <- fit_tree$variable.importance
vi_surrog <- vi_surrog / sum(vi_surrog) * 100 # percent importance
# without surrogate splits
set.seed(12345) # for reproducibility
obj_tree_no_surrog <- rpart(
Surv(time, status) ~ hormone + meno + size + grade + nodes + prog + estrg + age,
control = rpart.control(xval = 10, minbucket = 2, cp = 0, maxsurrogate = 0),
data = train
)
fit_tree_no_surrog <- prune(obj_tree_no_surrog, cp = cp_opt)
vi_no_surrog <- fit_tree_no_surrog$variable.importance
vi_no_surrog <- vi_no_surrog / sum(vi_no_surrog) * 100 # percent importance
# Compare variable importance with and without surrogate splits
# ggplot
# make tibble
feature_list <- c("hormone", "age", "meno", "size", "grade", "nodes", "prog", "estrg")
fig_tree_vi <- tibble(
variable = names(vi_surrog),
vi = vi_surrog) |>
# assign 0 to variables not in vi_surrog
complete(variable = feature_list, fill = list(vi = 0.1)) |>
ggplot(aes(y = reorder(variable, vi), x = vi)) +
geom_bar(stat = "identity", fill = "steelblue") +
scale_x_continuous(limits = c(0, max(c(vi_surrog, vi_no_surrog)) * 1.0)) +
labs(
y = NULL,
x = "Feature importance (%)",
title = "With surrogate splits"
) +
theme_minimal(base_size = 12) +
theme(
plot.title = element_text(hjust = 0.5, size = 12),
panel.grid.major.y = element_blank()
)
fig_tree_vi_no_surrog <- tibble(
variable = names(vi_no_surrog),
vi = vi_no_surrog) |>
# assign 0 to variables not in vi_no_surrog
complete(variable = feature_list, fill = list(vi = .1)) |>
ggplot(aes(y = reorder(variable, vi), x = vi)) +
geom_bar(stat = "identity", fill = "steelblue") +
scale_x_continuous(limits = c(0, max(c(vi_surrog, vi_no_surrog)) * 1.0)) +
labs(
y = NULL,
x = "Feature importance (%)",
title = "Without surrogate splits"
) +
theme_minimal(base_size = 12) +
theme(
plot.title = element_text(hjust = 0.5, size = 12),
panel.grid.major.y = element_blank()
)
fig_tree_vi_comp <- fig_tree_vi + fig_tree_vi_no_surrog +
plot_layout(ncol = 2)
# ggsave("images/ml_tree_vi_gbc.png",
# fig_tree_vi_comp,
# width = 8, height = 4.5, dpi = 300)
# ggsave("images/ml_tree_vi_gbc.eps",
# fig_tree_vi_comp,
# width = 8, height = 4.5)
# Check surrogate split info in fit_tree
fit_tree$splits
#> count ncat improve index adj
#> prog 400 1 41.929765 24.5 0.00000000
#> nodes 400 -1 27.853517 3.5 0.00000000
#> estrg 400 1 22.339746 8.5 0.00000000
#> grade 400 -1 12.676483 1.5 0.00000000
#> age 400 1 12.347645 31.5 0.00000000
#------------------------------------------------------------------------------
# 3. Extracting Leaf-Specific Survival Functions
#------------------------------------------------------------------------------
tmp <- summary(km_fit2)
tmp.strata<- sub(".*=", "", tmp$strata) # node labels
tmp.t <- tmp$time
tmp.surv <- tmp$surv
# Terminal node IDs
TN <- sort(unique(tmp.strata))
N <- length(TN)
# Sort the unique times from tmp.t
t_unique <- sort(unique(tmp.t))
m <- length(t_unique)
# fitted_surv[j,k] => survival at time t_unique[j] for node k
fitted_surv <- matrix(NA, nrow = m, ncol = N)
for (j in seq_len(m)) {
tj <- t_unique[j]
for (k in seq_len(N)) {
# times within that node
node_times <- c(0, tmp.t[tmp.strata == TN[k]])
node_survs <- c(1, tmp.surv[tmp.strata == TN[k]])
idx <- sum(node_times <= tj)
fitted_surv[j, k] <- node_survs[idx]
}
}
#------------------------------------------------------------------------------
# 4. Apply the Tree to the Test Set
#------------------------------------------------------------------------------
library(treeClust)
# ?rpart.predict.leaves
# rpart.predict.leaves() => which leaf each test subject lands in
# Terminal nodes ("where" labels) for test data
where_test <- rpart.predict.leaves(fit_tree, newdata = test)
term_node_test <- where_to_nodes[where_test] # map to node numbers
pred_test <- test |> # add terminal node info to test set
bind_cols(.term_node = term_node_test) |>
# risk group
mutate(
risk_group = case_when(
.term_node %in% c("5", "7") ~ "High risk",
.term_node == "6" ~ "Moderate risk",
.term_node == "4" ~ "Low risk"
)
)
# broom::tidy(km_fit) |>
# nest_by(strata) |>
# mutate(
# .term_node = parse_number(sub(".*=", "", strata)
# )
# )
n_test <- nrow(test)
# ?rpart.predict.leaves
# Construct an (n_test x m) matrix of survival probabilities
St_tree <- matrix(NA, nrow = n_test, ncol = m)
for (k in seq_len(N)) {
# Index test subjects in node k
ind <- which(pred_test$risk_group == TN[k])
# replicate the node-k survival curve for these subjects
if (length(ind) > 0) {
St_tree[ind, ] <- matrix(fitted_surv[, k], nrow = length(ind), ncol = m, byrow = TRUE)
}
}
# Risk score
tree_rs <- fit_tree$frame$yval[where_test]
# C index
# Harrell's C
tree_harrell_c <- Cindex_fun(riskscore = tree_rs,
time = test$time,
status = test$status)
# [1] 0.641105
# Uno's C at tau = 60
tree_uno_c_60 <- Cindex_fun(timepoints = 60,
riskscore = tree_rs,
time = test$time,
status = test$status)
# Store results
c_index_results4 <- rbind(
c_index_results3,
data.frame(
Model = "Survival Tree",
Harrell_C = tree_harrell_c,
Uno_C_60 = tree_uno_c_60$C
)
)
#------------------------------------------------------------------------------
# 5. Brier Score for the Survival Tree
#------------------------------------------------------------------------------
brier_score_tree <- BSfun(
timepoints = timepoints,
time = test$time,
status= test$status,
St = St_tree,
t = t_unique
)
brier_score_results4 <- rbind(
brier_score_results3,
data.frame(
Model = rep("Survival Tree", length(timepoints)),
Time = timepoints,
Brier_Score = brier_score_tree
)
)
# ------- Oblique random survival forest (ORSF) ----
# install.packages("aorsf")
library(aorsf)
# Fit a basic oblique random survival forest on training data
set.seed(12345) # for reproducibility
orsf_fit <- orsf(
Surv(time, status) ~ hormone + age + meno + size + grade + nodes + prog + estrg,
data = train,
n_tree = 200,
oobag_pred_type = 'risk',
oobag_pred_horizon = 60,
oobag_eval_every = 1
)
orsf_fit
#> ---------- Oblique random survival forest
#>
#> Linear combinations: Accelerated Cox regression
#> N observations: 400
#> N events: 174
#> N trees: 200
#> N predictors total: 8
#> N predictors per node: 3
#> Average leaves per tree: 29.855
#> Min observations in leaf: 5
#> Min events in leaf: 1
#> OOB stat value: 0.68
#> OOB stat type: Harrell's C-index
#> Variable importance: anova
#>
#> -----------------------------------------
# Plot out-of-bag C-statistic over number of trees
fig_rsf_ntree <- tibble(
n_tree = seq(1, 200, by = 1),
c_stat = orsf_fit$eval_oobag$stat_values
) |>
ggplot(aes(x = n_tree, y = c_stat)) +
geom_point(size = 3, alpha = 0.5) +
geom_line(linewidth = 1) +
labs(
x = "Number of trees grown",
y = paste0("OOB ", orsf_fit$eval_oobag$stat_type)
) +
theme_minimal(base_size = 12)
fig_rsf_ntree
# Variance importance by permutation
vi_rsf <- orsf_vi_permute(orsf_fit)
# set negative values to 0 and normalize
vi_rsf<- pmax(vi_rsf, 0)
vi_rsf <- vi_rsf / sum(vi_rsf) * 100 + 0.1 # +0.1 for cosmetic purpose
# plot
fig_rsf_vi <- tibble(
variable = names(vi_rsf),
vi = vi_rsf
) |>
ggplot(aes(y = reorder(variable, vi), x = vi)) +
geom_bar(stat = "identity", fill = "steelblue") +
labs(
y = NULL,
x = "Feature importance (%)"
) +
theme_minimal(base_size = 12) +
theme(
panel.grid.major.y = element_blank()
)
fig_rsf_vi
fig_gbc_rsf <- fig_rsf_ntree + fig_rsf_vi
fig_gbc_rsf
# ggsave("images/ml_rsf_gbc.png", fig_gbc_rsf,
# width = 8, height = 4.3, dpi = 300)
# ggsave("images/ml_rsf_gbc.eps", fig_gbc_rsf, device = cairo_ps,
# width = 8, height = 4.3)
# How to do prediction using orsf_fit
pred_orsf <- predict(orsf_fit, new_data = test, pred_type= 'surv',
pred_horizon = seq(0, 60, by = 1)
)
# Extract risk scores (S(60))
orsf_rs <- pred_orsf[, 60]
# C index
# Harrell's C
orsf_harrell_c <- Cindex_fun(riskscore = - orsf_rs, # negative of survival
time = test$time,
status = test$status)
# [1] 0.6893658
# Uno's C at tau = 60
orsf_uno_c_60 <- Cindex_fun(timepoints = 60,
riskscore = - orsf_rs,
time = test$time,
status = test$status)
# Store results
c_index_results5 <- rbind(
c_index_results4,
data.frame(
Model = "Oblique RSF",
Harrell_C = orsf_harrell_c,
Uno_C_60 = orsf_uno_c_60$C
)
)
# Brier Score for ORSF
St_orsf <- as.matrix(pred_orsf)
t_orsf <- seq(0, 60, by = 1)
brier_score_orsf <- BSfun(
timepoints = timepoints,
time = test$time,
status = test$status,
St = St_orsf,
t = t_orsf
)
brier_score_results5 <- rbind(
brier_score_results4,
data.frame(
Model = rep("Oblique RSF", length(timepoints)),
Time = timepoints,
Brier_Score = brier_score_orsf
)
)
## Survival support vector machine (SSVM)
# install.packages("survivalsvm")
library(survivalsvm)
# Preprocessing: standardize predictors in train and test
feature_list <- c("hormone", "age40", "meno", "size", "grade", "nodes", "prog", "estrg")
# Get mean and sd from train
train_means <- colMeans(train[, feature_list], na.rm = TRUE)
train_sds <- apply(train[, feature_list], 2, sd, na.rm = TRUE)
# Standardize train
train_scaled <- train
train_scaled[, feature_list] <- scale(train[, feature_list],
center = train_means,
scale = train_sds)
# For validation of gamma.mu parameter -----------------------------------
library(SurvMetrics)
# Perform 10-fold cross-validation on training set
cv_folds <- 10
set.seed(1234)
fold_indices <- sample(rep(1:cv_folds, length.out = nrow(train_scaled)))
# gamma values
gamma_values <- exp(seq(-2, 2, by = 0.2))
# CV C-index storage
cv_cindex <- matrix(NA, nrow = cv_folds, ncol = length(gamma_values))
cv_mods <- list()
for (fold in 1:cv_folds) {
train_fold <- train_scaled[fold_indices != fold, ]
val_fold <- train_scaled[fold_indices == fold, ]
for (i in seq_along(gamma_values)) {
gamma_mu <- gamma_values[i]
ssvm_model <- survivalsvm(
Surv(time, status) ~ hormone + age40 + meno + size + grade + nodes + prog + estrg,
data = train_fold,
type = "regression",
gamma.mu = gamma_mu
)
ssvm_pred_val <- predict(ssvm_model, newdata = val_fold)
surv_obj_val <- Surv(val_fold$time, val_fold$status)
cv_cindex[fold, i] <- Cindex(surv_obj_val, ssvm_pred_val$predicted)
}
cat("Fold", fold, "completed.\n")
cat(colMeans(cv_cindex, na.rm = TRUE), "\n")
}
saveRDS(cv_cindex, "Data/ssvmcv_cindex.RDS")
# Average CV C-index across folds
mean_cv_cindex <- colMeans(cv_cindex, na.rm = TRUE)
# mean_cv_cindex
gamma_opt <- gamma_values[which.max(mean_cv_cindex)]
# Plot mean CV C-index vs log(gamma.mu)
fig_ssvm_cv <- tibble(
log_gamma = log(gamma_values),
cindex = mean_cv_cindex
) |>
ggplot(aes(x = log_gamma, y = cindex)) +
geom_line(linewidth = 1) +
geom_vline(xintercept = log(gamma_opt), linetype = "dashed") +
geom_point(size = 3, alpha = 0.5) +
scale_x_reverse() +
scale_y_continuous(breaks = seq(0.686, 0.7, by = 0.002)) +
labs(
x = expression(log(gamma)),
y = "Cross-validated Harrell's C"
) +
theme_minimal(base_size = 12)
fig_ssvm_cv
# ----------------------------------------------------
# # Permutation method to assess variable importance
#
# # For the jth fold, fit model under gamma_opt on train_fold,
# # with variable k permuted,
# # predict on val_fold,
# # and compute C-index. Average across folds for variable k importance.
# var_names <- c("hormone", "age40", "meno", "size", "grade", "nodes", "prog", "estrg")
# var_importance <- numeric(length(var_names))
# names(var_importance) <- var_names
# set.seed(1234)
# for (k in seq_along(var_names)) {
# var_k <- var_names[k]
# cindex_perm <- numeric(cv_folds)
#
# for (fold in 1:cv_folds) {
# train_fold <- train_scaled[fold_indices != fold, ]
# val_fold <- train_scaled[fold_indices == fold, ]
#
# # Permute variable k in train_fold
# train_fold_perm <- train_fold
# train_fold_perm[[var_k]] <- sample(train_fold[[var_k]])
#
# # Fit SSVM under gamma_opt
# ssvm_model <- survivalsvm(
# Surv(time, status) ~ hormone + age40 + meno + size + grade + nodes + prog + estrg,
# data = train_fold_perm,
# type = "regression",
# gamma.mu = gamma_opt
# )
#
# ssvm_pred_val <- predict(ssvm_model, newdata = val_fold)
#
# surv_obj_val <- Surv(val_fold$time, val_fold$status)
#
# cindex_perm[fold] <- Cindex(surv_obj_val, ssvm_pred_val$predicted)
# }
#
# # Variable importance as decrease in C-index due to permutation
# var_importance[k] <- mean_cv_cindex[which(gamma_values == gamma_opt)] - mean(cindex_perm)
# cat("Variable", var_k, "importance:", var_importance[k], "\n")
# }
# saveRDS(var_importance, "Data/ssvm_var_importance.rds")
var_importance <- readRDS("Data/ssvm_var_importance.rds")
# Normalize variable importance
var_importance <- pmax(var_importance, 0)
var_importance <- var_importance / sum(var_importance) * 100 + 0.1 # +0.1 for cosmetic purpose
# Plot variable importance
fig_ssvm_vi <- tibble(
variable = names(var_importance),
vi = var_importance
) |>
ggplot(aes(y = reorder(variable, vi), x = vi)) +
geom_bar(stat = "identity", fill = "steelblue") +
scale_x_continuous(breaks = seq(0, 24, by = 8)) +
labs(
y = NULL,
x = "Feature importance (%)"
) +
theme_minimal(base_size = 12) +
theme(
panel.grid.major.y = element_blank()
)
fig_gbc_ssvm <- fig_ssvm_cv + fig_ssvm_vi
fig_gbc_ssvm
# ggsave("images/ml_ssvm_gbc.png", fig_gbc_ssvm,
# width = 8, height = 4.3, dpi = 300)
# ggsave("images/ml_ssvm_gbc.eps", fig_gbc_ssvm, device = cairo_ps,
# width = 8, height = 4.3)
# Fit entire training data under gamma_opt
ssvm_fit <- survivalsvm(
Surv(time, status) ~ hormone + age40 + meno + size + grade + nodes + prog + estrg,
data = train_scaled,
type = "regression",
gamma.mu = gamma_opt
)
test_scaled <- test
test_scaled[, feature_list] <- scale(test[, feature_list],
center = train_means,
scale = train_sds)
# Risk score from SSVM on test data
ssvm_test_rs <- - predict(ssvm_fit, newdata = test_scaled)$predicted |> as.vector()
# C index
# Harrell's C
ssvm_harrell_c <- Cindex_fun(riskscore = ssvm_test_rs,
time = test$time,
status = test$status)
# [1] 0.6740071
# Uno's C at tau = 60
ssvm_uno_c_60 <- Cindex_fun(timepoints = 60,
riskscore = ssvm_test_rs,
time = test$time,
status = test$status)
# Store results
c_index_results6 <- rbind(
c_index_results5,
data.frame(
Model = "Survival SVM",
Harrell_C = ssvm_harrell_c,
Uno_C_60 = ssvm_uno_c_60$C
)
)
# Brier Score for SSVM
# Predicted survival in test set
# Fit cox model against risk score on training set
ssvm_cox <- coxph(Surv(time, status) ~ ssvm_rs,
data = train_scaled |>
mutate(ssvm_rs = - predict(ssvm_fit, newdata = train_scaled)$predicted |> as.vector())
)
pred_ssvm <- predsurv_cox(ssvm_cox, as.matrix(test_scaled |>
mutate(ssvm_rs = ssvm_test_rs) |>
select(ssvm_rs)
))
St_ssvm <- pred_ssvm$St
t_ssvm <- pred_ssvm$t
brier_score_ssvm <- BSfun(
timepoints = timepoints,
time = test$time,
status = test$status,
St = St_ssvm,
t = t_ssvm
)
brier_score_results6 <- rbind(
brier_score_results5,
data.frame(
Model = rep("Survival SVM", length(timepoints)),
Time = timepoints,
Brier_Score = brier_score_ssvm
)
)
# Add IBS
c_index_results6 |>
left_join(
brier_score_results6 |>
mutate(
Model = fct(Model)
) |>
group_by(Model) |>
summarise(
IBS = mean(Brier_Score)
),
by = "Model"
) |>
mutate(
# round stats by 3 decimal places
Harrell_C = round(Harrell_C, 3),
Uno_C_60 = round(Uno_C_60, 3),
IBS = round(IBS, 3)
) |>
# Output to latex
knitr::kable(
format = "latex",
booktabs = TRUE,
caption = "C-index and Integrated Brier Score (IBS) for various models on the test set."
)
# Plot Brier Scores
brier_score_results6 |>
mutate(
Model = fct(Model)
) |>
ggplot(aes(x = Time, y = Brier_Score, color = Model)) +
geom_line(linewidth = 1.1, alpha = 0.7) +
geom_point(size = 2, alpha = 0.6) +
labs(
x = "Time (months)",
y = "Brier Score",
color = NULL
) +
scale_x_continuous(breaks = seq(0, 60, by = 12)) +
theme_minimal(base_size = 12) +
guides(
color = guide_legend(nrow = 1, byrow = TRUE)
) +
theme(
legend.position = "top",
legend.text = element_text(size = 11)
)
ggsave("images/ml_brier_scores_all_models.png",
width = 8, height = 4, dpi = 300)
ggsave("images/ml_brier_scores_all_models.eps", device = cairo_ps,
width = 8, height = 4)
# surv_obj <- Surv(test$time, test$status)
# Cindex(surv_obj, ssvm_pred$predicted)Tidymodels Code
# install.packages("tidymodels")
library(tidymodels)── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
✔ broom 1.0.6 ✔ recipes 1.0.10
✔ dials 1.2.1 ✔ rsample 1.2.1
✔ dplyr 1.1.4 ✔ tibble 3.2.1
✔ ggplot2 4.0.1 ✔ tidyr 1.3.1
✔ infer 1.0.7 ✔ tune 1.2.1
✔ modeldata 1.3.0 ✔ workflows 1.1.4
✔ parsnip 1.2.1 ✔ workflowsets 1.1.0
✔ purrr 1.0.2 ✔ yardstick 1.3.1
Warning: package 'scales' was built under R version 4.4.3
Warning: package 'ggplot2' was built under R version 4.4.3
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ purrr::discard() masks scales::discard()
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag() masks stats::lag()
✖ recipes::step() masks stats::step()
• Dig deeper into tidy modeling with R at https://www.tmwr.org