sahirbhatnagar / casebase
1
# This is where all utility functions should appear
2
# These functions are not exported
3

4
`%ni%` <- Negate("%in%")
5

6
# Handling warning messages coming from predictvglm when offset = 0
7
handler_offset <- function(msg) {
8 0
  if (any(grepl("offset", msg))) invokeRestart("muffleWarning")
9
}
10
# Handling warning messages coming from predictvglm when using b-splines
11
handler_bsplines <- function(msg) {
12 0
  if (any(grepl("ill-conditioned bases", msg))) invokeRestart("muffleWarning")
13
}
14
# Handling warning messages coming from vglm.fitter
15
handler_fitter <- function(msg) {
16 0
  if (any(grepl("vglm.fitter", msg))) invokeRestart("muffleWarning")
17
}
18

19
# Check if provided time and event variables are in the dataset
20
# and also check for any good substitute
21
#' @rdname popTime
22
#' @export
23
checkArgsTimeEvent <- function(data, time, event) {
24 2
  if (missing(time)) {
25 2
    if (any(grepl("[\\s\\W_]+time|^time\\b", names(data),
26 2
      ignore.case = TRUE, perl = TRUE
27
    ))) {
28 2
      time <- grep("[\\s\\W_]+time|^time\\b", names(data),
29 2
        ignore.case = TRUE, value = TRUE, perl = TRUE
30
      )
31 2
      if (length(time) > 1) {
32 0
        warning(paste0(
33 0
          "The following variables for time were found in the data: ",
34 0
          paste0(time, collapse = ", "), ". '", time[1],
35 0
          "' will be used as the time variable"
36
        ))
37
      } else {
38 2
        message(paste0(
39 2
          "'", time, "'",
40 2
          " will be used as the time variable"
41
        ))
42
      }
43
    } else {
44 0
      stop("data does not contain time variable")
45
    }
46
  }
47

48 2
  if (missing(event)) {
49 2
    if (any(grepl("[\\s\\W_]+event|^event\\b|[\\s\\W_]+status|^status\\b",
50 2
      names(data)[-which(colnames(data) == time[1])],
51 2
      ignore.case = TRUE, perl = TRUE
52
    ))) {
53 2
      event <- grep("[\\s\\W_]+event|^event\\b|[\\s\\W_]+status|^status\\b",
54 2
        names(data)[-which(colnames(data) == time[1])],
55 2
        ignore.case = TRUE, value = TRUE, perl = TRUE
56
      )
57 2
      if (length(event) > 1) {
58 0
        warning(paste0(
59 0
          "The following variables for event were found in the data: ",
60 0
          paste0(event, collapse = ", "), ". '", event[1],
61 0
          "' will be used as the event variable"
62
        ))
63
      } else {
64 2
        message(paste0(
65 2
          "'", event, "'",
66 2
          " will be used as the event variable"
67
        ))
68
      }
69
    } else {
70 0
      stop("data does not contain event or status variable")
71
    }
72
  }
73

74 2
  if (!all(c(time, event) %in% colnames(data))) {
75 0
    stop("data does not contain supplied time and/or event variables")
76
  }
77

78 2
  return(list(time = time[1], event = event[1]))
79
}
80

81

82
#' Check that Event is in Correct Format
83
#'
84
#' Checks for event categories and gives a warning message indicating which
85
#' level is assumed to be the reference level.
86
#'
87
#' @inheritParams popTime
88
#' @return A list of length two. The first element is the factored event, and
89
#'   the second element is the numeric representation of the event
90
#'
91
#' @export
92
#' @examples
93
#' if (requireNamespace("survival", quietly = TRUE)) {
94
#' library(survival) # for veteran data
95
#' checkArgsEventIndicator(data = veteran, event = "celltype",
96
#'                         censored.indicator = "smallcell")
97
#' checkArgsEventIndicator(data = veteran, event = "status")
98
#' }
99
#' data("bmtcrr") # from casebase
100
#' checkArgsEventIndicator(data = bmtcrr, event = "Sex",
101
#'                         censored.indicator = "M")
102
#' checkArgsEventIndicator(data = bmtcrr, event = "D",
103
#'                         censored.indicator = "AML")
104
#' checkArgsEventIndicator(data = bmtcrr, event = "Status")
105
checkArgsEventIndicator <- function(data, event, censored.indicator) {
106 2
  isFactor <- is.factor(data[[event]])
107 2
  isNumeric <- is.numeric(data[[event]])
108 2
  isCharacter <- is.character(data[[event]])
109

110 2
  if (!any(isFactor, isNumeric, isCharacter)) {
111 0
    stop(strwrap("event variable must be either a factor,
112 0
                     numeric or character variable", width = 60))
113
  }
114

115 2
  nLevels <- nlevels(factor(data[[event]]))
116 0
  if (nLevels < 2) stop(paste("event variable must have",
117 0
                              "at least two unique values"))
118

119 2
  if (missing(censored.indicator) || is.null(censored.indicator)) {
120 2
    if (isFactor) {
121 2
      slev <- levels(data[[event]])
122 2
      warning(paste0(
123 2
        "censor.indicator not specified. assuming ",
124 2
        slev[1], " represents a censored observation and ",
125 2
        slev[2], " is the event of interest"
126
      ))
127 2
      event.factored <- data[[event]]
128
    }
129

130 2
    if (isCharacter) {
131 2
      event.factored <- factor(data[[event]])
132 2
      slev <- levels(event.factored)
133 2
      warning(paste0(
134 2
        "censor.indicator not specified. assuming ",
135 2
        slev[1], " represents a censored observation and ",
136 2
        slev[2], " is the event of interest"
137
      ))
138
    }
139

140 2
    if (isNumeric) {
141 2
      slev <- sort(unique(data[[event]]))
142 0
      if (!any(slev %in% 0)) stop(paste("event is a numeric variable that",
143 0
                                        "doesn't contain 0. if event is a",
144 0
                                        "numericit must contain some 0's",
145 0
                                        "to indicate censored observations"))
146 2
      event.factored <- if (nLevels == 2) {
147 2
        factor(data[[event]],
148 2
          labels = c("censored", "event")
149
        )
150
      } else {
151 2
        factor(data[[event]],
152 2
          labels = c(
153 2
            "censored", "event",
154 2
            paste0(
155 2
              "competing event",
156 2
              if (nLevels >= 4) 1:(nLevels - 2)
157
            )
158
          )
159
        )
160
      }
161
    }
162
  } else {
163 2
    if (!(censored.indicator %in% data[[event]]) & any(isCharacter, isFactor)) {
164 0
      stop(strwrap("censored.indicator not found in event variable of data"))
165
    }
166

167 2
    if (isNumeric) {
168 2
      warning(strwrap("censored.indicator specified but ignored because
169 2
                                event is a numeric variable"))
170 2
      slev <- sort(unique(data[[event]]))
171 0
      if (!any(slev %in% 0)) stop(strwrap("event is a numeric variable that
172 0
                                        doesn't contain 0. if event is a numeric
173 0
                                        it must contain some 0's
174 0
                                        to indicate censored observations"))
175 2
      event.factored <- if (nLevels == 2) {
176 2
        factor(data[[event]],
177 2
          labels = c("censored", "event")
178
        )
179
      } else {
180 2
        factor(data[[event]],
181 2
          labels = c(
182 2
            "censored", "event",
183 2
            paste0(
184 2
              "competing event",
185 2
              if (nLevels >= 4) 1:(nLevels - 2)
186
            )
187
          )
188
        )
189
      }
190
    }
191

192 2
    if (isFactor | isCharacter) {
193 2
      event.factored <- relevel(factor(data[[event]]), censored.indicator)
194 2
      slev <- levels(event.factored)
195 2
      message(paste0(
196 2
        "assuming ",
197 2
        slev[1], " represents a censored observation and ",
198 2
        slev[2], " is the event of interest"
199
      ))
200
    }
201
  }
202

203 2
  return(list(
204 2
    event.factored = event.factored,
205 2
    event.numeric = as.numeric((event.factored)) - 1,
206 2
    nLevels = nLevels
207
  ))
208
}
209

210
# Remove offset from formula
211
# https://stackoverflow.com/a/40313732/2836971
212

213

214
#' @importFrom stats model.matrix
215
#' @importFrom stats contrasts
216
#' @details `prepareX` is a slightly modified version of the same function from
217
#'   the `glmnet` package. It can be used to convert a data.frame to a matrix
218
#'   with categorical variables converted to dummy variables using one-hot
219
#'   encoding
220
#' @rdname fitSmoothHazard
221
#' @export
222
prepareX <- function(formula, data) {
223 2
  whichfac <- sapply(data, inherits, "factor")
224 2
  ctr <- if (any(whichfac)) {
225 2
    lapply(subset(data, select = whichfac),
226 2
           contrasts, contrast = FALSE)
227 2
  } else NULL
228 2
  X <- model.matrix(update(formula, ~ . - 1), data = data, contrasts.arg = ctr)
229 2
  if (any(whichfac))
230 2
    attr(X, "contrasts") <- NULL
231 2
  attr(X, "assign") <- NULL
232 2
  X
233
}
234

235
cv.glmnet.formula <- function(formula, data, event,
236
                              competingRisk = FALSE, ...) {
237 2
  X <- prepareX(formula, data)
238 2
  Y <- data[, event]
239 2
  if (competingRisk) {
240 0
    fam <- "multinomial"
241 0
    offset <- NULL
242
  } else {
243 2
    fam <- "binomial"
244 2
    offset <- data[, "offset"]
245
  }
246 2
  cv.glmnet_offset_hack(X, Y, offset = offset, family = fam,
247 2
                        type.multinomial = "grouped", ...)
248
}
249

250
cv.glmnet_offset_hack <- function(x, y, offset, ...) {
251
  # For some values of the offset, cv.glmnet does not converge
252
  # For constant offset, we can use the hack below
253 2
  if (diff(range(offset)) > 1e-06) {
254 0
    stop("Glmnet is only available with constant offset",
255 0
      call. = FALSE
256
    )
257
  }
258

259 2
  offset_value <- unique(offset)[1]
260
  # 1. Fit without offset
261 2
  out <- glmnet::cv.glmnet(x, y, ...)
262
  # 2. Fix the intercept
263 2
  out$glmnet.fit$a0 <- out$glmnet.fit$a0 - offset_value
264

265 2
  return(out)
266
}
267

268
# Montecarlo Integration
269
# Mimic the interface of integrate
270
integrate_mc <- function(f, lower, upper, ..., subdivisions = 100L) {
271 0
  sampledPoints <- runif(subdivisions,
272 0
    min = lower,
273 0
    max = upper
274
  )
275 0
  return((upper - lower) * mean(f(sampledPoints, ...)))
276
}
277

278
# Taken from brms package
279
expand_dot_formula <- function(formula, data = NULL) {
280 2
  if (isTRUE("." %in% all.vars(formula))) {
281 2
    att <- attributes(formula)
282 2
    try_terms <- try(
283 2
      stats::terms(formula, data = data),
284 2
      silent = TRUE
285
    )
286 2
    if (!is(try_terms, "try-error")) {
287 2
      formula <- formula(try_terms)
288
    }
289 2
    attributes(formula) <- att
290
  }
291 2
  formula
292
}
293

294
# Streamlined version of pracma::cumtrapz
295
trap_int <- function(x, y) {
296 2
  x <- as.matrix(c(x))
297 2
  m <- length(x)
298 2
  y <- as.matrix(y)
299 2
  n <- ncol(y)
300 2
  dt <- kronecker(matrix(1, 1, n), 0.5 * diff(x))
301 2
  ct <- apply(dt * (y[1:(m - 1), ] + y[2:m, ]), 2, cumsum)
302 2
  return(rbind(0, ct))
303
}
304

305
# Detect if formula contains a function of time or interaction----
306 2
count_matches <- function(pat, vec) sapply(regmatches(vec, gregexpr(pat, vec)),
307 2
                                           length)
308

309
balance_parentheses <- function(str) {
310 2
  num_left <- count_matches("\\(", str)
311 2
  num_right <- count_matches("\\)", str)
312

313 2
  str[num_left > num_right] <- sub("\\(", "", str[num_left > num_right])
314 2
  str[num_left < num_right] <- sub("\\)", "", str[num_left < num_right])
315

316 2
  return(str)
317
}
318

319
detect_nonlinear_time <- function(formula, timeVar) {
320
  # Two regular expressions
321
  # 1. Find function arguments
322 2
  pattern_args <- "\\(\\s*([^)]+?)\\s*\\)"
323
  # 2. Find exactly time as the clean string
324 2
  time_regex <- paste0("^", timeVar, "$")
325
  # Extract variables in RHS of formula
326 2
  terms <- attr(terms(formula), "term.labels")
327
  # Then extract the arguments of any function
328 2
  matches <- regmatches(terms, regexpr(pattern_args, terms))
329
  # Next, detect time within nested calls
330 2
  matches <- balance_parentheses(matches)
331 2
  while (any(matches != regmatches(matches, regexpr(pattern_args, matches)))) {
332 2
    matches <- regmatches(matches, regexpr(pattern_args, matches))
333 2
    matches <- balance_parentheses(matches)
334
  }
335
  # Check if one of these arguments is timeVar
336 2
  contain_time <- lapply(
337 2
    strsplit(matches, ","),
338 2
    function(str) {
339 2
      clean_str <- gsub(
340 2
        ".*=", "", # Remove equal signs if they exist
341 2
        gsub("(\\(\\s*|\\s*\\))", "", str)
342 2
      ) # Remove parentheses
343 2
      any(grepl(time_regex, trimws(clean_str)))
344
    }
345
  )
346 2
  any(unlist(contain_time))
347
}
348

349
detect_interaction <- function(formula) {
350
  # Extract the order of the terms
351 2
  orders <- attr(terms(formula), "order")
352
  # Check if terms of order > 1
353 2
  any(orders > 1)
354
}
355

356
# Get typical covariate profile from dataset
357
#' @importFrom stats median
358
get_typical <- function(data) {
359 2
  data.frame(lapply(data, function(col) {
360 2
    if (is.numeric(col) || inherits(col, "Date")) {
361
      # For numeric or dates, take median
362 2
      median(col, na.rm = TRUE)
363
    } else {
364
      # If character string or factor, take most common value
365 0
      mode <- names(sort(-table(col)))[1]
366 0
      factor(mode, levels = levels(factor(col)))
367
    }
368
  }))
369
}
370

371
#' @rdname plot.singleEventCB
372
incrVar <- function(var, increment = 1) {
373 2
  n <- length(var)
374 2
  if (n > 1 && length(increment) == 1) {
375 0
    increment <- rep(increment, n)
376
  }
377 2
  function(data) {
378 2
    for (i in 1:n) {
379 2
      if (is.factor(data[[var[i]]])) {
380 2
        data[[var[i]]] <- fct_shift_ord(data[[var[i]]],
381 2
                                        increment = increment[i])
382
      } else {
383 2
        data[[var[i]]] <- data[[var[i]]] + increment[i]
384
      }
385
    }
386 2
    data
387
  }
388
}
389

390

391
fct_shift_ord <- function(x, increment = 1, cap = TRUE, .fun = `+`) {
392 2
  x_nlevel <- nlevels(x)
393 2
  x_lables <- levels(x)
394

395
  # apply function .fun to the numeric of the ordered vector
396 2
  erg <- .fun(as.numeric(x), increment)
397

398
  # cap to 1 and x_nlevel if the increment was larger
399
  # than the original range of the factor levels
400 2
  if (cap) {
401 2
    erg[erg < 1] <- 1
402 2
    erg[erg > x_nlevel] <- x_nlevel
403
  }
404 2
  ordered(erg, levels = 1:x_nlevel, labels = x_lables)
405
}
406

407

408
hrJacobian <- function(object, newdata, newdata2, term) {
409

410
  # Set offset to zero
411 2
  newdata$offset <- 0
412 2
  newdata2$offset <- 0
413

414

415 2
  m1 <- stats::model.frame(term,
416 2
    data = newdata2,
417 2
    na.action = stats::na.pass,
418 2
    xlev = object$xlevels
419
  )
420 2
  m0 <- stats::model.frame(term,
421 2
    data = newdata,
422 2
    na.action = stats::na.pass,
423 2
    xlev = object$xlevels
424
  )
425

426 2
  X1 <- stats::model.matrix(term, m1, contrasts.arg = object$contrasts)
427 2
  X0 <- stats::model.matrix(term, m0, contrasts.arg = object$contrasts)
428

429
  # this is the jacobian!!
430 2
  X1 - X0
431
}

Read our documentation on viewing source code .

Loading