# Joint analysis of longitudinal and survival data measured on nested time-scales using shared parameter models: an application to fecundity data

## Oxford_prediction.R

library(MASS)
library(nlme)
library(splines)
library(mvtnorm)

RE_ind <- day_level[,16] #Indicator that day should be used for RE estimation

#Get Data to be used for estimating the random effects and covariates
Y <- day_level[RE_ind==1,1] #Intercourse indicators
ind <- day_level[RE_ind==1,2] #Subject numbers
X <-as.matrix(day_level[RE_ind==1,3:11]) #Day-level covariates for intercourse probability
X_P <- as.matrix(day_level[RE_ind==1,12:13]) #Day-level covariates for peak
day <- day_level[RE_ind==1,14] #Day relative to ovulation (if known)
day_miss <- day_level[RE_ind==1,15] #Indicator that the day is known

X_T <- as.matrix(women_level[,1:2]) #TTP covariates
TTP <- women_level[,3]*0+1 #TTP is 1 since were using 1 cycle of data.
CEN <- 1-1*I(women_level[,3]==1 & women_level[,4]==0) #Indicator they didn't get pregnant in the first cycle.
nij <- matrix(women_level[,5],length(TTP),1) #Length of the follicular phase.
ind_po <- 1*I(nij[,1]!=0) #Indicator the length of the follicular is known.

Dat = list(Y=Y,ind=ind,X=X,X_P=X_P,X_T=X_T,T=TTP,CEN=CEN,nij=nij,day=day,day_miss=day_miss,ind_po=ind_po)

knots<- c(-13,-10,-7,-4,-2,1,4,7,10,13)
BK <- c(-20,50)
day[day>BK[2]] <- BK[2]
day[day<BK[1]] <- BK[1]
basis <- bs(day,knots=knots,Boundary.knots = BK)
X_B <- cbind(basis[(1:dim(basis)[1]),1:dim(basis)[2]])
nodes <- 10

source("R_programs.R")

par <- c(-1.019623017, -0.022231701, -0.222435068, 0.183268558, -2.199957766, 0.074197907, -0.148694702, -0.066871265, -0.213375205, -0.176945759, -0.478504015, 0.911746709, 0.236140303, 0.300086114, -1.891660637, -1.420562322, -1.627560036, -2.053313444, -2.52622992, -2.538376681, -0.261931586, 0.569964381, -0.205466003, 0.011981404, 0.172473525, 0.093544916, 0.608545158, 0.450351538, 1.590301719, -0.395095029, -0.339227847, -0.276937273, -1.407143639, 1.453161218, -1.282294669, 0.091406746, -0.014844835, -0.640734346, 15.64263015,-1.560648012)
par_var <- c(0.035825178, 0.000226101, 0.015432345, 0.002584707, 0.011242709, 0.001502623, 0.010793378, 0.000503257, 0.002061199, 0.009917989, 0.00831891, 0.005718286, 0.010713138, 0.015801399, 0.036025078, 0.031809002, 0.051345602, 0.089138962, 0.205026884, 0.243448575, 0.007686414, 0.031959542, 0.179537367, 0.068544324, 0.038251997, 0.030581415, 0.031546658, 0.029525749, 0.048995353, 0.02899384, 0.034356965, 0.028434758, 0.213223587, 0.915985517, 0.406959903, 0.004495922, 0.000228671, 0.0104776, 0.044865411, 0.0011484)
#Running the empirical Bayes procedure.
EB_proc <- mjm.EB(par,Dat,nodes,knots,BK)

#Outputing the empirical Bayes estimates and their covariance matricies.
EB_est <- EB_proc\$Fin_EB.mat
EB_var <- EB_proc\$Fin_Var.mat

#Estimating covariate values for cigarettes, alcohol, and period length
cov_est <- NULL
for(j in unique(ind)){
t_X <- X[ind==j,]
t_row <- c(max(t_X[,3]),mean(t_X[,4]),length(t_X[t_X[,5]==1,5]))
cov_est <- rbind(cov_est,t_row)
}

#Get the rest of the data to be for predicting.
Y <- day_level[RE_ind==0,1] #Intercourse indicators
ind <- day_level[RE_ind==0,2] #Subject numbers
X <-as.matrix(day_level[RE_ind==0,3:11]) #Day-level covariates for intercourse probability
X_P <- as.matrix(day_level[RE_ind==0,12:13]) #Day-level covariates for peak

X_T <- as.matrix(women_level[,1:2]) #TTP covariates
TTP <- women_level[,3] #TTP
CEN <- women_level[,4] #Censoring indicator

#putting in estimated values for covariates (lag will be done later)
for(j in unique(ind)){
#Estimated Covariates
t_est_covs <- cov_est[j,]
t_X <- X[ind==j,]
t_X[,3] <- t_est_covs[1]
t_X[,4] <- t_est_covs[2]
t_cyc <- X_P[ind==j,1]
for(k in unique(t_cyc)){
tt_X <- t_X[t_cyc==k,]
cyc_len <- length(tt_X)/9
tt_X[,5] <- c(rep(1,min(c(t_est_covs[3],cyc_len))),rep(0,max(c(cyc_len-t_est_covs[3],0))))
t_X[t_cyc==k,] <- tt_X
}
X[ind==j,] <- t_X
}

#Implementing the MC procedure for the predictions
MC <- 2000
set.seed(4)
int_matrix <- NULL
surv_array<- array(0,dim=c(length(unique(ind)),6,MC))
unix.time(for(u in 1:MC){

tpar <- rnorm(length(par),par,sqrt(par_var))
t_re_mat<- matrix(0,length(EB_est[,1]),3)
for(j in unique(ind)){
v.eb <- matrix(EB_var[j,],3,3)
teb <- rmvnorm(1,EB_est[j,],v.eb)
t_re_mat[j,] <- teb
}

Dat = list(ind=ind,X=X,X_P=X_P,X_T=X_T,re_mat= t_re_mat)

Prob_est <- mjm.pred.probs(tpar,Dat,knots,BK)
int_matrix <- rbind(int_matrix,Prob_est\$int.prob.mat[,2])
surv_array[,,u] <- Prob_est\$surv.prob.mat[,-1]
})

#Summarizing results for day/cycle
pred.probs <- apply(int_matrix,2,mean)
Y.a <- Y

#Calculating the Calibration ROC Curves
div.0.1 <- seq(0,1,0.05)
avg.pred<- NULL
avg.obs <- NULL
for(i in 1:20){
if(length(pred.probs[pred.probs>=div.0.1[i] & pred.probs<div.0.1[i+1]])>0 & length(Y.a[pred.probs>=div.0.1[i] & pred.probs<div.0.1[i+1]])>0){
avg.pred<- c(avg.pred,mean(pred.probs[pred.probs>=div.0.1[i] & pred.probs<div.0.1[i+1]]))
avg.obs <- c(avg.obs,mean(Y.a[pred.probs>=div.0.1[i] & pred.probs<div.0.1[i+1]]))
}
}

par(mfrow = c(1,2),mar = c(5,5, 3, 3),mgp=c(3,1,0))
plot(avg.pred,avg.obs,type="l",lwd=3,xlab=c("Predicted Probability"),ylab= c("Observed Probability"),ylim=c(0,0.9),xlim=c(0,0.9),cex.axis=1.3,cex.lab=1.3)
lines(c(0,1),c(0,1),lwd=2,lty=2)

c = sort(unique(pred.probs))

TPR = NULL
FPR = NULL
for(k in 1:length(c)){

TP = sum(1*I(Y.a==1 & pred.probs>c[k]))
FN = sum(1*I(Y.a==1 & pred.probs<=c[k]))
TN = sum(1*I(Y.a==0 & pred.probs<=c[k]))
FP = sum(1*I(Y.a==0 & pred.probs>c[k]))

TPR[k] = TP/(TP+FN)
FPR[k] = FP/(FP+TN)

}

TPR = c(1,TPR,0)
FPR = c(1,FPR,0)
AUC = 0
for(j in 2:length(TPR)){AUC = AUC+(TPR[j-1])*(FPR[j-1] - FPR[j])}
AUC

plot(FPR,TPR,type="l",axes=FALSE, ann=FALSE,lab=c(6,5,3),lwd=3)

axis(2,las=1, at=seq(0,1,.2), cex.axis=1.5)
axis(1,las=1, at=seq(0,1,.2), cex.axis=1.5)
lines(c(0,1),c(0,1))
box()
title(ylab = "Sensitivity",cex.lab=2)
title(xlab = "1 - Specificity",cex.lab=2)

#########################TTP ROC##################

T.pred <- 6
surv.est <- apply(surv_array[,T.pred,],1,mean)
t_TTP <- TTP[unique(ind)]
t_CEN <- CEN[unique(ind)]
Tg6 <- 1*I(t_TTP == T.pred | t_TTP > T.pred)

l = 1
surv.est <- surv.est[!(t_TTP < T.pred & t_CEN ==1)]
Tg6 <- Tg6[!(t_TTP < T.pred & t_CEN ==1)]
T_unq <-unique(ind)[!(t_TTP < T.pred & t_CEN ==1)]

True = Tg6
Prob = surv.est

c = sort(unique(Prob))

TPR = NULL
FPR = NULL
TP = NULL
FN = NULL
TN = NULL
FP = NULL
for(k in 1:length(c)){
TP[k] = sum(1*I(True==1 & Prob>c[k]))
FN[k] = sum(1*I(True==1 & Prob<=c[k]))
TN[k] = sum(1*I(True==0 & Prob<=c[k]))
FP[k] = sum(1*I(True==0 & Prob>c[k]))

TPR[k] = TP[k]/(TP[k]+FN[k])
FPR[k] = FP[k]/(FP[k]+TN[k])

}

TPR = c(1,TPR,0)
FPR = c(1,FPR,0)
AUC = 0
for(j in 2:length(TPR)){AUC = AUC+(TPR[j-1])*(FPR[j-1] - FPR[j])}
AUC

par(mfrow = c(1,1),mar = c(5,5, 2, 2),mgp=c(3,1,0))
plot(FPR,TPR,type="l",axes=FALSE, ann=FALSE,lab=c(6,5,3),lwd=3)

axis(2,las=1, at=seq(0,1,.2), cex.axis=1.5)
axis(1,las=1, at=seq(0,1,.2), cex.axis=1.5)
lines(c(0,1),c(0,1))
box()
title(ylab = "Sensitivity",cex.lab=2)
title(xlab = "1 - Specificity",cex.lab=2)