Chapter 15 - Machine Learning in Survival Analysis

Slides

Lecture slides here. (To convert html to pdf, press E \(\to\) Print \(\to\) Destination: Save to pdf)

Base R Code

Show the code
##################################################################
# This code generates all numerical results in chapter 15.     ##
##################################################################


# library("glmpath")
library("glmnet")
library(survival)
# read in the complete data
gbc <- read.table("Data//German Breast Cancer Study//gbc.txt")

# subset to first event data
# Sort the data by time within each id
o <- order(gbc$id,gbc$time)
gbc <- gbc[o,]
# get the first row for each id
data.CE <- gbc[!duplicated(gbc$id),]

#set status=1 if status==2 or 1
data.CE$status <- (data.CE$status > 0 ) + 0

# create a binary variable for age<=40 years
data.CE$age40 <- (data.CE$age <= 40) + 0



n <- nrow(data.CE)
head(data.CE)

## select training set
# N=400
set.seed(1234)
ind <- sample(1:n)[1:400]
train <- data.CE[ind,]
test <- data.CE[-ind,]

# Predictor list for Cox-lasso
pred_list <- c("hormone", "age40", "meno", "size", "grade",
            "nodes", "prog", "estrg")

# covariate matrix 
Z <- as.matrix(train[,pred_list])

# time and status variables
time <- train$time
status <- train$status

# dimension of covariate matrix
dim(Z)
# 400x 8

# compute the covariate path as a function of lambda
# alpha=1: L_1 penalty (lasso)
obj <- glmnet(Z,Surv(time, status), family="cox", alpha = 1)
summary(obj)

# compute 10-fold (default) cross-validation
obj.cv <- cv.glmnet(Z,Surv(time, status), family="cox", alpha = 1)


# Figure parameters
par(mfrow = c(1,2))
# par(cex = 1.2)

# plot the covariate paths as
# a function of log-lambda
library(plotmo) # for plot_glmnet
plotmo::plot_glmnet(obj, lwd=2)

# plot(obj,xvar="lambda",lwd=2,label=TRUE)

# plot the the validation error (partial-likelihood deviance)
# as a function of log-lambda
plot(obj.cv)

# the optimal lambda
obj.cv$lambda.min

log(obj.cv$lambda.min)

# the beta at optimal lambda
beta <- coef(obj.cv, s = "lambda.min")
# the non-zero coefficients
beta.selected <- beta[abs(beta[,1])>0,]
# print out the non-zero coefficients
beta.selected
# number of non-zero coefficients
length(beta.selected)


# Refit the training data using the variables selected
selected <- names(beta.selected)
obj <- coxph(Surv(train$time, train$status) ~ as.matrix(train[,selected]))

###############################################################################
## A function that takes on a coxph object obj and a (n x p) test covariate
## matrix Z and outputs the predicted survival function
## Output:
### St: (n x m) matrix with the ith row the predicted survival rates for the
###  ith subject;
### t: m-vector of times.
###############################################################################
predsurv_cox <- function(obj, Z){
  beta=obj$coefficients
  bhaz=basehaz(obj,centered=F)
  L=bhaz$hazard
  t=c(0, bhaz$time)
  St=cbind(rep(1,nrow(Z)),exp(-exp(Z%*%beta)%*%t(L)))
  return(list(St=St,t=t))
}

## Get the predicted survival rates for the test set by Cox-lasso 
pred_surv <- predsurv_cox(obj, Z=as.matrix(test[,selected]))

St <- pred_surv$St
t <- pred_surv$t


## Get the KM estimates for censoring distribution
G_obj <- summary(survfit(Surv(time, 1-status)~1, data=test))
tc <- c(0,G_obj$time)
Gtc <- c(1,G_obj$surv)

#####################################################
# A function calculating the Brier score BS(tau)
# Input:
#   tau: time at which the score is evaluated
#   (time, status): observed test outcomes
#   St,t: predicted survival rates
#   Gtc,tc: KM estimates for censoring distributions
######################################################
BSfun=function(tau,time,status,St,t,Gtc,tc){
  
  n=length(time)
  BSvec=rep(NA,n)
  
  pos=sum(t<=tau)
  
  for (i in 1:n){
    X_i=time[i]
    S_i=St[i,pos]
    
    G_i=ifelse(X_i<=tau&&status[i]==1, 
               Gtc[sum(tc<=X_i)], 
               Gtc[sum(tc<=tau)])
    
    BSvec[i]=ifelse(X_i<=tau&&status[i]==1, 
                    S_i^2/G_i, ifelse(
                      X_i>tau,(1-S_i)^2/G_i,0
                     )
                    )
  
  }
  return(mean(BSvec,na.rm=T))
  
}  

# Compute the Brier score at tau=12 to 60 months
# under Cox-lasso
tau_list <- 12:60
BS_tau <- rep(NA,length(tau_list))
for(i in 1:length(tau_list)){
  BS_tau[i] <- BSfun(tau=tau_list[i],test$time,test$status,St,t,Gtc,tc)
}

plot(tau_list,BS_tau,type='l',lwd=2)

# Full Cox model
obj_full <- coxph(Surv(train$time, train$status)~as.matrix(train[,pred_list]))

# Get the predicted survival rates
pred_surv_full <- predsurv_cox(obj_full, Z=as.matrix(test[,pred_list]))
St_full <- pred_surv_full$St
t_full <- pred_surv_full$t

# Compute the Brier score at tau=12 to 60 months
# under Cox-full
BS_tau_full <- rep(NA,length(tau_list))
for(i in 1:length(tau_list)){
  BS_tau_full[i] <- BSfun(tau=tau_list[i],test$time,test$status,St_full,t_full,Gtc,tc)
}

  

##############################
# Survival trees             #
##############################


library(rpart)
library(rpart.plot)


### Build survival tree with cross-validation error ###
set.seed(12345)

# Conduct 10-fold cross-validation (xval = 10),
# with minimum terminal node size 2 (minbucket = 2)
obj <- rpart(Surv(time, status) ~ hormone+meno+size+grade+nodes+ 
              prog+estrg+age,
             control = rpart.control(xval = 10, minbucket = 2, cp=0),
             data = train)
printcp(obj)

#             CP nsplit rel.error  xerror   xstd
# 1  0.07556835      0   1.00000 1.00411 0.046231
# 2  0.03720019      1   0.92443 0.96817 0.047281
# 3  0.02661914      2   0.88723 0.95124 0.046567
# 4  0.01716925      3   0.86061 0.92745 0.046606
# 5  0.01398306      4   0.84344 0.92976 0.047514
# 6  0.01394869      5   0.82946 0.93941 0.048404
# 7  0.01055028      9   0.77120 0.97722 0.052133
# 8  0.01053135     10   0.76065 1.00140 0.054295

# summary(obj)


# cross-validation results
cptable <- obj$cptable
# complexity parameter values
CP <- cptable[, 1]
# obtain the optimal parameter
cp.opt <- CP[which.min(cptable[, 4])]
# Prune the tree 
fit <- prune(obj, cp = cp.opt)



par(mfrow=c(1,2))
# plot the pruned tree structure
rpart.plot(fit)
# plot the KM curves for the terminal nodes
km <- survfit(Surv(time, status) ~ fit$where, train)
plot(km, lty = 1:4, mark.time = FALSE,
       xlab = "Years", ylab = "Progression",lwd=2,cex.lab=1.2,cex.lab=1.2)
legend("bottomleft", paste('Node', sort(unique(fit$where))), lty = 1:4,lwd=2,cex=1.2)


# Get the KM estimates for the outcome in each terminal node
tmp <- summary(km)
tmp.strata <- as.integer(sub(".*=", "", tmp$strata))  
tmp.t <- tmp$time
tmp.surv <- tmp$surv

# Number of terminal nodes
TN <- unique(tmp.strata)
N <- length(TN)

# Combine the predicted survival rates together,
# as functions of t
t <- sort(unique(tmp.t))
m <- length(t)
fitted_surv=matrix(NA,m,N)
for (j in 1:m){
  tj <- t[j]
  for (k in 1:N){
    tk <- c(0,tmp.t[tmp.strata==TN[k]])
    survk <- c(1,tmp.surv[tmp.strata==TN[k]])
    fitted_surv[j,k] <- survk[sum(tk<=tj)]
  }
}



# Get the terminal node prediction
# for the test data
library(treeClust)
test_term <- rpart.predict.leaves(fit, test)

n <- length(test_term)

St_tree <- matrix(NA, n, m)
for (k in 1:N){
  ind <- which(test_term==TN[k])
  St_tree[ind,] <- matrix(fitted_surv[,k], nrow=length(ind), 
                       ncol=m, byrow=TRUE)
}



## Get the KM estimates for censoring distribution
G_obj <- summary(survfit(Surv(time, 1-status)~1, data=test))
tc <- c(0,G_obj$time)
Gtc <- c(1,G_obj$surv)

# Compute the Brier score at tau=12 to 60 months
# under the pruned survival tree
tau_list <- 12:60
BS_tau_tree <- rep(NA,length(tau_list))
for(i in 1:length(tau_list)){
  BS_tau_tree[i] <- BSfun(tau=tau_list[i],test$time,test$status,St_tree,t,Gtc,tc)
  
}

# Plot the Bier score curves for Cox-lasso, Cox-full, and survival tree
par(mfrow=c(1,1))

plot(tau_list/12,BS_tau_tree,type='l',lwd=2,col="red",cex.axis=1.2,
     cex.lab=1.2,xlab="t (years)", ylab="BS(t)",ylim=c(0,0.25))
lines(tau_list/12,BS_tau,lty=1,lwd=2,col="blue")
lines(tau_list/12,BS_tau_full,lty=1,lwd=2,col="black")
legend("bottomright",1,col=c("red","blue","black"),lwd=2,
       c("Survival Tree","Cox-lasso","Cox-full"),cex=1.2)