Chapter 15 - Machine Learning in Survival Analysis
Slides
Lecture slides here. (To convert html to pdf, press E \(\to\) Print \(\to\) Destination: Save to pdf)
Base R Code
Show the code
################################################################### This code generates all numerical results in chapter 15. ##################################################################### library("glmpath")library("glmnet")library(survival)# read in the complete datagbc <-read.table("Data//German Breast Cancer Study//gbc.txt")# subset to first event data# Sort the data by time within each ido <-order(gbc$id,gbc$time)gbc <- gbc[o,]# get the first row for each iddata.CE <- gbc[!duplicated(gbc$id),]#set status=1 if status==2 or 1data.CE$status <- (data.CE$status >0 ) +0# create a binary variable for age<=40 yearsdata.CE$age40 <- (data.CE$age <=40) +0n <-nrow(data.CE)head(data.CE)## select training set# N=400set.seed(1234)ind <-sample(1:n)[1:400]train <- data.CE[ind,]test <- data.CE[-ind,]# Predictor list for Cox-lassopred_list <-c("hormone", "age40", "meno", "size", "grade","nodes", "prog", "estrg")# covariate matrix Z <-as.matrix(train[,pred_list])# time and status variablestime <- train$timestatus <- train$status# dimension of covariate matrixdim(Z)# 400x 8# compute the covariate path as a function of lambda# alpha=1: L_1 penalty (lasso)obj <-glmnet(Z,Surv(time, status), family="cox", alpha =1)summary(obj)# compute 10-fold (default) cross-validationobj.cv <-cv.glmnet(Z,Surv(time, status), family="cox", alpha =1)# Figure parameterspar(mfrow =c(1,2))# par(cex = 1.2)# plot the covariate paths as# a function of log-lambdalibrary(plotmo) # for plot_glmnetplotmo::plot_glmnet(obj, lwd=2)# plot(obj,xvar="lambda",lwd=2,label=TRUE)# plot the the validation error (partial-likelihood deviance)# as a function of log-lambdaplot(obj.cv)# the optimal lambdaobj.cv$lambda.minlog(obj.cv$lambda.min)# the beta at optimal lambdabeta <-coef(obj.cv, s ="lambda.min")# the non-zero coefficientsbeta.selected <- beta[abs(beta[,1])>0,]# print out the non-zero coefficientsbeta.selected# number of non-zero coefficientslength(beta.selected)# Refit the training data using the variables selectedselected <-names(beta.selected)obj <-coxph(Surv(train$time, train$status) ~as.matrix(train[,selected]))################################################################################# A function that takes on a coxph object obj and a (n x p) test covariate## matrix Z and outputs the predicted survival function## Output:### St: (n x m) matrix with the ith row the predicted survival rates for the### ith subject;### t: m-vector of times.###############################################################################predsurv_cox <-function(obj, Z){ beta=obj$coefficients bhaz=basehaz(obj,centered=F) L=bhaz$hazard t=c(0, bhaz$time) St=cbind(rep(1,nrow(Z)),exp(-exp(Z%*%beta)%*%t(L)))return(list(St=St,t=t))}## Get the predicted survival rates for the test set by Cox-lasso pred_surv <-predsurv_cox(obj, Z=as.matrix(test[,selected]))St <- pred_surv$Stt <- pred_surv$t## Get the KM estimates for censoring distributionG_obj <-summary(survfit(Surv(time, 1-status)~1, data=test))tc <-c(0,G_obj$time)Gtc <-c(1,G_obj$surv)###################################################### A function calculating the Brier score BS(tau)# Input:# tau: time at which the score is evaluated# (time, status): observed test outcomes# St,t: predicted survival rates# Gtc,tc: KM estimates for censoring distributions######################################################BSfun=function(tau,time,status,St,t,Gtc,tc){ n=length(time) BSvec=rep(NA,n) pos=sum(t<=tau)for (i in1:n){ X_i=time[i] S_i=St[i,pos] G_i=ifelse(X_i<=tau&&status[i]==1, Gtc[sum(tc<=X_i)], Gtc[sum(tc<=tau)]) BSvec[i]=ifelse(X_i<=tau&&status[i]==1, S_i^2/G_i, ifelse( X_i>tau,(1-S_i)^2/G_i,0 ) ) }return(mean(BSvec,na.rm=T))} # Compute the Brier score at tau=12 to 60 months# under Cox-lassotau_list <-12:60BS_tau <-rep(NA,length(tau_list))for(i in1:length(tau_list)){ BS_tau[i] <-BSfun(tau=tau_list[i],test$time,test$status,St,t,Gtc,tc)}plot(tau_list,BS_tau,type='l',lwd=2)# Full Cox modelobj_full <-coxph(Surv(train$time, train$status)~as.matrix(train[,pred_list]))# Get the predicted survival ratespred_surv_full <-predsurv_cox(obj_full, Z=as.matrix(test[,pred_list]))St_full <- pred_surv_full$Stt_full <- pred_surv_full$t# Compute the Brier score at tau=12 to 60 months# under Cox-fullBS_tau_full <-rep(NA,length(tau_list))for(i in1:length(tau_list)){ BS_tau_full[i] <-BSfun(tau=tau_list[i],test$time,test$status,St_full,t_full,Gtc,tc)}############################### Survival trees ###############################library(rpart)library(rpart.plot)### Build survival tree with cross-validation error ###set.seed(12345)# Conduct 10-fold cross-validation (xval = 10),# with minimum terminal node size 2 (minbucket = 2)obj <-rpart(Surv(time, status) ~ hormone+meno+size+grade+nodes+ prog+estrg+age,control =rpart.control(xval =10, minbucket =2, cp=0),data = train)printcp(obj)# CP nsplit rel.error xerror xstd# 1 0.07556835 0 1.00000 1.00411 0.046231# 2 0.03720019 1 0.92443 0.96817 0.047281# 3 0.02661914 2 0.88723 0.95124 0.046567# 4 0.01716925 3 0.86061 0.92745 0.046606# 5 0.01398306 4 0.84344 0.92976 0.047514# 6 0.01394869 5 0.82946 0.93941 0.048404# 7 0.01055028 9 0.77120 0.97722 0.052133# 8 0.01053135 10 0.76065 1.00140 0.054295# summary(obj)# cross-validation resultscptable <- obj$cptable# complexity parameter valuesCP <- cptable[, 1]# obtain the optimal parametercp.opt <- CP[which.min(cptable[, 4])]# Prune the tree fit <-prune(obj, cp = cp.opt)par(mfrow=c(1,2))# plot the pruned tree structurerpart.plot(fit)# plot the KM curves for the terminal nodeskm <-survfit(Surv(time, status) ~ fit$where, train)plot(km, lty =1:4, mark.time =FALSE,xlab ="Years", ylab ="Progression",lwd=2,cex.lab=1.2,cex.lab=1.2)legend("bottomleft", paste('Node', sort(unique(fit$where))), lty =1:4,lwd=2,cex=1.2)# Get the KM estimates for the outcome in each terminal nodetmp <-summary(km)tmp.strata <-as.integer(sub(".*=", "", tmp$strata)) tmp.t <- tmp$timetmp.surv <- tmp$surv# Number of terminal nodesTN <-unique(tmp.strata)N <-length(TN)# Combine the predicted survival rates together,# as functions of tt <-sort(unique(tmp.t))m <-length(t)fitted_surv=matrix(NA,m,N)for (j in1:m){ tj <- t[j]for (k in1:N){ tk <-c(0,tmp.t[tmp.strata==TN[k]]) survk <-c(1,tmp.surv[tmp.strata==TN[k]]) fitted_surv[j,k] <- survk[sum(tk<=tj)] }}# Get the terminal node prediction# for the test datalibrary(treeClust)test_term <-rpart.predict.leaves(fit, test)n <-length(test_term)St_tree <-matrix(NA, n, m)for (k in1:N){ ind <-which(test_term==TN[k]) St_tree[ind,] <-matrix(fitted_surv[,k], nrow=length(ind), ncol=m, byrow=TRUE)}## Get the KM estimates for censoring distributionG_obj <-summary(survfit(Surv(time, 1-status)~1, data=test))tc <-c(0,G_obj$time)Gtc <-c(1,G_obj$surv)# Compute the Brier score at tau=12 to 60 months# under the pruned survival treetau_list <-12:60BS_tau_tree <-rep(NA,length(tau_list))for(i in1:length(tau_list)){ BS_tau_tree[i] <-BSfun(tau=tau_list[i],test$time,test$status,St_tree,t,Gtc,tc)}# Plot the Bier score curves for Cox-lasso, Cox-full, and survival treepar(mfrow=c(1,1))plot(tau_list/12,BS_tau_tree,type='l',lwd=2,col="red",cex.axis=1.2,cex.lab=1.2,xlab="t (years)", ylab="BS(t)",ylim=c(0,0.25))lines(tau_list/12,BS_tau,lty=1,lwd=2,col="blue")lines(tau_list/12,BS_tau_full,lty=1,lwd=2,col="black")legend("bottomright",1,col=c("red","blue","black"),lwd=2,c("Survival Tree","Cox-lasso","Cox-full"),cex=1.2)