Compare 559d018 ... +9 ... 4a6927b

No flags found

Use flags to group coverage reports by test type, project and/or folders.
Then setup custom commit statuses and notifications for each flag.

e.g., #unittest #integration

#production #enterprise

#frontend #backend

Learn more about Codecov Flags here.


@@ -1,39 +1,78 @@
Loading
1 +
#
2 +
#   This file contains the mcmcRocPrc() S3 generic, which constructs objects
3 +
#   of class "mcmcRocPrc". For methods for this class, see mcmcRocPrc-methods.R
4 +
#   S3 methods for the mcmcRocPrc() generic handle different types of input
5 +
#   e.g. "rjags" input produced by R2jags.
6 +
#
7 +
8 +
9 +
1 10
#' ROC and Precision-Recall Curves using Bayesian MCMC estimates
2 11
#' 
3 12
#' Generate ROC and Precision-Recall curves after fitting a Bayesian logit or 
4 -
#' probit regression using [R2jags::jags()]
13 +
#' probit regression using [rstan::stan()], [rstanarm::stan_glm()], 
14 +
#' [R2jags::jags()], [R2WinBUGS::bugs()], [MCMCpack::MCMClogit()], or other 
15 +
#' functions that provide samples from a posterior density. 
5 16
#' 
6 -
#' @param object A "rjags" object (see [R2jags::jags()]) for a fitted binary
7 -
#'   choice model.
8 -
#' @param yname (`character(1)`)\cr
9 -
#'   The name of the dependent variable, should match the variable name in the 
10 -
#'   JAGS data object.
11 -
#' @param xnames ([base::character()])\cr
12 -
#'   A character vector of the independent variable names, should match the 
13 -
#'   corresponding names in the JAGS data object.
17 +
#' @param object A fitted binary choice model, e.g. "rjags" object 
18 +
#'   (see [R2jags::jags()]), or a `[N, iter]` matrix of predicted probabilites.
14 19
#' @param curves logical indicator of whether or not to return values to plot 
15 20
#'   the ROC or Precision-Recall curves. If set to `FALSE` (default), 
16 21
#'   results are returned as a list without the extra values. 
17 22
#' @param fullsims logical indicator of whether full object (based on all MCMC
18 23
#'   draws rather than their average) will be returned. Default is `FALSE`. 
19 24
#'   Note: If `TRUE` is chosen, the function takes notably longer to execute.
25 +
#' @param yvec A `numeric(N)` vector of observed outcomes. 
26 +
#' @param yname (`character(1)`)\cr
27 +
#'   The name of the dependent variable, should match the variable name in the 
28 +
#'   JAGS data object.
29 +
#' @param xnames ([base::character()])\cr
30 +
#'   A character vector of the independent variable names, should match the 
31 +
#'   corresponding names in the JAGS data object.
32 +
#' @param posterior_samples a "mcmc" object with the posterior samples
20 33
#' @param ... Used by methods
21 34
#' @param x a `mcmcRocPrc()` object
22 35
#' 
36 +
#' @details If only the average AUC-ROC and PR are of interest, setting 
37 +
#'   `curves = FALSE` and `fullsims = FALSE` can greatly speed up calculation 
38 +
#'   time. The curve data (`curves = TRUE`) is needed for plotting. The plot
39 +
#'   method will always plot both the ROC and PR curves, but the underlying
40 +
#'   data can easily be extracted from the output for your own plotting; 
41 +
#'   see the documentation of the value returned below. 
42 +
#'   
43 +
#'   The default method works with a matrix of predicted probabilities and the 
44 +
#'   vector of observed incomes as input. Other methods accommodate some of the 
45 +
#'   common Bayesian modeling packages like rstan (which returns class "stanfit"),
46 +
#'   rstanarm ("stanreg"), R2jags ("jags"), R2WinBUGS ("bugs"), and 
47 +
#'   MCMCpack ("mcmc"). Even if a package-specific method is not implemented, 
48 +
#'   the default method can always be used as a fallback by manually calculating
49 +
#'   the matrix of predicted probabilities for each posterior sample. 
50 +
#'   
51 +
#'   Note that MCMCpack returns generic "mcmc" output that is annotated with 
52 +
#'   some additional information as attributes, including the original function
53 +
#'   call. There is no inherent way to distinguish any other kind of "mcmc" 
54 +
#'   object from one generated by a proper MCMCpack modeling function, but as a
55 +
#'   basic precaution, `mcmcRocPrc()` will check the saved call and return an 
56 +
#'   error if the function called was not `MCMClogit()` or `MCMCprobit()`. 
57 +
#'   This behavior can be suppressed by setting `force = TRUE`. 
58 +
#' 
23 59
#' @references Beger, Andreas. 2016. “Precision-Recall Curves.” Available at 
24 60
#'   SSRN: [http://dx.doi.org/10.2139/ssrn.2765419](http://dx.doi.org/10.2139/ssrn.2765419)
25 61
#' 
26 62
#' @return Returns a list with length 2 or 4, depending on the on the "curves" 
27 63
#'   and "fullsims" argument values:
28 64
#'   
29 -
#'   - "area_under_roc": `numeric(1)`
30 -
#'   - "area_under_prc": `numeric(1)`
31 -
#'   - "prc_dat": only if `curves = TRUE`; a list with length 1 if `fullsims = FALSE`, longer otherwise
32 -
#'   - "roc_dat": only if `curves = TRUE`; a list with length 1 if `fullsims = FALSE`, longer otherwise
65 +
#'   - "area_under_roc": `numeric()`; either length 1 if `fullsims = FALSE`, or 
66 +
#'     one value for each posterior sample otherwise
67 +
#'   - "area_under_prc": `numeric()`; either length 1 if `fullsims = FALSE`, or 
68 +
#'     one value for each posterior sample otherwise
69 +
#'   - "prc_dat": only if `curves = TRUE`; a list with length 1 if 
70 +
#'     `fullsims = FALSE`, longer otherwise
71 +
#'   - "roc_dat": only if `curves = TRUE`; a list with length 1 if 
72 +
#'     `fullsims = FALSE`, longer otherwise
33 73
#'
34 74
#' @examples
35 75
#' # load simulated data and fitted model (see ?sim_data and ?jags_logit)
36 -
#' library(R2jags)
37 76
#' data("jags_logit")
38 77
#' 
39 78
#' # using mcmcRocPrc
@@ -44,41 +83,57 @@
Loading
44 83
#'                       fullsims = FALSE)
45 84
#' fit_sum                     
46 85
#' plot(fit_sum)
86 +
#' 
87 +
#' # Equivalently, we can calculate the matrix of predicted probabilities 
88 +
#' # ourselves; using the example from ?jags_logit:
89 +
#' library(R2jags)
90 +
#' 
91 +
#' data("sim_data")
92 +
#' yvec <- sim_data$Y
93 +
#' xmat <- sim_data[, c("X1", "X2")]
94 +
#' 
95 +
#' # add intercept to the X data
96 +
#' xmat <- as.matrix(cbind(Intercept = 1L, xmat))
97 +
#' 
98 +
#' beta <- as.matrix(as.mcmc(jags_logit))[, c("b[1]", "b[2]", "b[3]")]
99 +
#' pred_mat <- plogis(xmat %*% t(beta)) 
100 +
#' 
101 +
#' # the matrix of predictions has rows matching the number of rows in the data;
102 +
#' # the column are the predictions for each of the 2,000 posterior samples
103 +
#' nrow(sim_data)
104 +
#' dim(pred_mat)
105 +
#' 
106 +
#' # now we can call mcmcRocPrc; the default method works with the matrix
107 +
#' # of predictions and vector of outcomes as input
108 +
#' mcmcRocPrc(object = pred_mat, curves = TRUE, fullsims = FALSE, yvec = yvec)
109 +
#' 
47 110
#' @export
48 111
#' @md
112 +
mcmcRocPrc <- function(object, curves = FALSE, fullsims = FALSE, ...) {
113 +
  UseMethod("mcmcRocPrc", object)
114 +
}
49 115
50 -
mcmcRocPrc <- function(object, 
51 -
                       yname, 
52 -
                       xnames, 
53 -
                       curves = FALSE, 
54 -
                       fullsims = FALSE) {
55 -
  
56 -
  link_logit  <- any(grepl("logit", object$model$model()))
57 -
  link_probit <- any(grepl("probit", object$model$model()))
58 -
  
59 -
  if (isFALSE(link_logit | link_probit)) {
60 -
    stop("Could not identify model link function")
61 -
  }
62 -
  
63 -
  mdl_data <- object$model$data()
64 -
  stopifnot(all(xnames %in% names(mdl_data)))
65 -
  stopifnot(all(yname %in% names(mdl_data)))
66 -
  
67 -
  # add intercept by default, maybe revisit this
68 -
  xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames])))
69 -
  yvec <- mdl_data[[yname]]
70 -
  
71 -
  pardraws <- as.matrix(coda::as.mcmc(object))
72 -
  # this is not very robust, assumes pars are 'b[x]'
73 -
  # for both this and the intercept addition above, maybe a more robust solution
74 -
  # down the road would be to dig into the object$model$model() string
75 -
  betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))]
116 +
#' Constructor for mcmcRocPrc objects
117 +
#' 
118 +
#' This function actually does the heavy lifting once we have a matrix of 
119 +
#' predicted probabilities from a model, plus the vector of observed outcomes.
120 +
#' The reason to have it here in a single function is that we don't replicate 
121 +
#' it in each function that accomodates a JAGS, BUGS, RStan, etc. object.
122 +
#' 
123 +
#' @param pred_prob a `\[N, iter\]` matrix of predicted probabilities 
124 +
#' @param yvec a `numeric(N)` vector of observed outcomes
125 +
#' @param curves include curve data in output?
126 +
#' @param fullsims collapse posterior samples into single summary?
127 +
#' 
128 +
#' @md
129 +
#' @keywords internal
130 +
new_mcmcRocPrc <- function(pred_prob, yvec, curves, fullsims) {
76 131
  
77 -
  if(isTRUE(link_logit)) {
78 -
    pred_prob <- plogis(xdata %*% t(betadraws))  
79 -
  } else if (isTRUE(link_probit)) {
80 -
    pred_prob <- pnorm(xdata %*% t(betadraws))
81 -
  } 
132 +
  stopifnot(
133 +
    "number of predictions and observed outcomes do not match" = nrow(pred_prob)==length(yvec),
134 +
    "yvec must be 0 or 1"                                      = all(yvec %in% c(0L, 1L)),
135 +
    "pred_prob must be in the interval [0, 1]"                 = all(pred_prob >= 0 & pred_prob <= 1)
136 +
  )
82 137
  
83 138
  # pred_prob is a [N, iter] matrix, i.e. each column are preds from one 
84 139
  # set of posterior samples
@@ -146,6 +201,21 @@
Loading
146 201
  )
147 202
}
148 203
204 +
#' @rdname mcmcRocPrc
205 +
#' 
206 +
#' @md
207 +
#' @export
208 +
mcmcRocPrc.default <- function(object, curves, fullsims, yvec, ...) {
209 +
  pred_prob <- object
210 +
  
211 +
  stopifnot(
212 +
    "mcmcRocPrc.default requires 'matrix' like input" = inherits(pred_prob, "matrix")
213 +
  )
214 +
  
215 +
  new_mcmcRocPrc(pred_prob, yvec, curves, fullsims)
216 +
}
217 +
218 +
# Under the hood ROC/PRC calculations -------------------------------------
149 219
150 220
#' Compute ROC and PR curve points
151 221
#' 
@@ -213,192 +283,267 @@
Loading
213 283
}
214 284
215 285
286 +
# auc_roc and auc_pr are not really used, but keep around just in case
287 +
auc_roc <- function(obs, pred) {
288 +
  values <- compute_roc(obs, pred)
289 +
  caTools::trapz(values$x, values$y)
290 +
}
291 +
292 +
auc_pr <- function(obs, pred) {
293 +
  values <- compute_pr(obs, pred)
294 +
  caTools::trapz(values$x, values$y)
295 +
}
296 +
297 +
298 +
299 +
# JAGS-like input (rjags, R2jags, runjags) --------------------------------
216 300
217 301
#' @rdname mcmcRocPrc
218 302
#' 
219 303
#' @export
220 -
print.mcmcRocPrc <- function(x, ...) {
221 -
  
222 -
  auc_roc <- x$area_under_roc
223 -
  auc_prc <- x$area_under_prc
224 -
  
225 -
  has_curves <- !is.null(x$roc_dat)
226 -
  has_sims   <- length(auc_roc) > 1
227 -
  
228 -
  if (!has_sims) {
229 -
    roc_msg <- sprintf("%.3f", round(auc_roc, 3))
230 -
    prc_msg <- sprintf("%.3f", round(auc_prc, 3))
231 -
  } else {
232 -
    roc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]",
233 -
                       round(mean(auc_roc), 3), 
234 -
                       round(quantile(auc_roc, 0.1), 3), 
235 -
                       round(quantile(auc_roc, 0.9), 3))
236 -
    prc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]",
237 -
                       round(mean(auc_prc), 3), 
238 -
                       round(quantile(auc_prc, 0.1), 3), 
239 -
                       round(quantile(auc_prc, 0.9), 3))
304 +
mcmcRocPrc.jags <- function(object, curves = FALSE, fullsims = FALSE, yname, 
305 +
                            xnames, posterior_samples, ...) {
306 +
  
307 +
  stopifnot(
308 +
    inherits(posterior_samples, c("mcmc", "mcmc.list"))
309 +
  )
310 +
  
311 +
  link_logit  <- any(grepl("logit", object$model()))
312 +
  link_probit <- any(grepl("probit", object$model()))
313 +
  
314 +
  if (isFALSE(link_logit | link_probit)) {
315 +
    stop("Could not identify model link function")
240 316
  }
241 317
  
242 -
  cat("mcmcRocPrc object\n")
243 -
  cat(sprintf("curves: %s; fullsims: %s\n", has_curves, has_sims))
244 -
  cat(sprintf("AUC-ROC: %s\n", roc_msg))
245 -
  cat(sprintf("AUC-PR:  %s\n", prc_msg))
318 +
  mdl_data <- object$data()
319 +
  stopifnot(all(xnames %in% names(mdl_data)))
320 +
  stopifnot(all(yname %in% names(mdl_data)))
321 +
  
322 +
  # add intercept by default, maybe revisit this
323 +
  xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames])))
324 +
  yvec  <- mdl_data[[yname]]
325 +
  
326 +
  pardraws <- as.matrix(posterior_samples)
327 +
  # this is not very robust, assumes pars are 'b[x]'
328 +
  # for both this and the intercept addition above, maybe a more robust solution
329 +
  # down the road would be to dig into the object$model$model() string
330 +
  betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))]
331 +
  
332 +
  if(isTRUE(link_logit)) {
333 +
    pred_prob <- plogis(xdata %*% t(betadraws))  
334 +
  } else if (isTRUE(link_probit)) {
335 +
    pred_prob <- pnorm(xdata %*% t(betadraws))
336 +
  } 
246 337
  
247 -
  invisible(x)
338 +
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
339 +
                 fullsims = fullsims)
248 340
}
249 341
250 342
#' @rdname mcmcRocPrc
251 343
#' 
252 -
#' @param n plot method: if `fullsims = TRUE`, how many sample curves to draw?
253 -
#' @param alpha plot method: alpha value for plotting sampled curves; between 0 and 1
344 +
#' @export
345 +
mcmcRocPrc.rjags <- function(object, curves = FALSE, fullsims = FALSE, yname, 
346 +
                             xnames, ...) {
347 +
  
348 +
  if (!requireNamespace("R2jags", quietly = TRUE)) {
349 +
    stop("Package \"R2jags\" is needed for this function to work. Please install it.", call. = FALSE)  # nocov
350 +
  }
351 +
  
352 +
  jags_object <- object$model
353 +
  pardraws    <- coda::as.mcmc(object)
354 +
  
355 +
  # pass it on to the "jags" method
356 +
  mcmcRocPrc(object = jags_object, curves = curves, fullsims = fullsims, 
357 +
             yname = yname, xnames = xnames, posterior_samples = pardraws, ...)
358 +
}
359 +
360 +
#' @rdname mcmcRocPrc
254 361
#' 
255 362
#' @export
256 -
plot.mcmcRocPrc <- function(x, n = 40, alpha = .5, ...) {
363 +
mcmcRocPrc.runjags <- function(object, curves = FALSE, fullsims = FALSE, yname, 
364 +
                               xnames, ...) {
365 +
  jags_object <- runjags::as.jags(object, quiet = TRUE)
366 +
  # as.mcmc.runjags will issue a warning when converting multiple chains
367 +
  # because it combines them
368 +
  pardraws    <- suppressWarnings(coda::as.mcmc(object))
369 +
  
370 +
  # pass it on to the "jags" method
371 +
  mcmcRocPrc(object = jags_object, curves = curves, fullsims = fullsims, 
372 +
             yname = yname, xnames = xnames, posterior_samples = pardraws, ...)
373 +
}
257 374
258 -
  stopifnot(
259 -
    "Use mcmcRocPrc(..., curves = TRUE) to generate data for plots" = (!is.null(x$roc_dat)),
260 -
    "alpha must be between 0 and 1" = (alpha >= 0 & alpha <= 1),
261 -
    "n must be > 0" = (n > 0)
262 -
  )
375 +
376 +
# STAN-like input (rstan, rstanarm, brms) ---------------------------------
377 +
378 +
379 +
380 +
#' @rdname mcmcRocPrc
381 +
#' 
382 +
#' @param data the data that was used in the `stan(data = ?, ...)` call
383 +
#' 
384 +
#' @export
385 +
mcmcRocPrc.stanfit <- function(object, curves = FALSE, fullsims = FALSE, data, 
386 +
                               xnames, yname, ...) {
387 +
  if (!requireNamespace("rstan", quietly = TRUE)) {
388 +
    stop("Package \"rstan\" is needed for this function to work. Please install it.", call. = FALSE)  # nocov
389 +
  }
263 390
  
264 -
  obj<- x
265 -
  fullsims <- length(obj$roc_dat) > 1
391 +
  if (!is_binary_model(object)) {
392 +
    stop("the input model does not seem to be a binary choice model; if this is a mistake please file an issue at https://github.com/ShanaScogin/BayesPostEst/issues/")
393 +
  }
394 +
  link_type <- identify_link_function(object)
395 +
  if (is.na(link_type)) {
396 +
    stop("could not identify model link function; please file an issue at https://github.com/ShanaScogin/BayesPostEst/issues/")
397 +
  }
266 398
  
267 -
  if (!fullsims) {
268 -
    
269 -
    graphics::par(mfrow = c(1, 2))
270 -
    plot(obj$roc_dat[[1]], type = "s", xlab = "FPR", ylab = "TPR")
271 -
    graphics::abline(a = 0, b = 1, lty = 3, col = "gray50")
272 -
    
273 -
    prc_dat <- obj$prc_dat[[1]]
274 -
    # use first non-NaN y-value for y[1]
275 -
    prc_dat$y[1] <- prc_dat$y[2]
276 -
    plot(prc_dat, type = "l", xlab = "TPR", ylab = "Precision",
277 -
         ylim = c(0, 1))
278 -
    graphics::abline(a = attr(x, "y_pos_rate"), b = 0, lty = 3, col = "gray50")
279 -
    
280 -
  } else {
281 -
    
282 -
    graphics::par(mfrow = c(1, 2))
283 -
    
284 -
    roc_dat <- obj$roc_dat
285 -
    
286 -
    x <- lapply(roc_dat, `[[`, 1)
287 -
    x <- do.call(cbind, x)
288 -
    colnames(x) <- paste0("sim", 1:ncol(x))
289 -
    
290 -
    y <- lapply(roc_dat, `[[`, 2)
291 -
    y <- do.call(cbind, y)
292 -
    colnames(y) <- paste0("sim", 1:ncol(y))
293 -
    
294 -
    xavg <- rowMeans(x)
295 -
    yavg <- rowMeans(y)
296 -
    
297 -
    plot(xavg, yavg, type = "n", xlab = "FPR", ylab = "TPR")
298 -
    samples <- sample(1:ncol(x), n)
299 -
    for (i in samples) {
300 -
      graphics::lines(
301 -
        x[, i], y[, i], type = "s",
302 -
        col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255)
303 -
      )
304 -
    }
305 -
    graphics::lines(xavg, yavg, type = "s")
306 -
    
307 -
    # PRC
308 -
    # The elements of prc_dat have different lengths, unlike roc_dat, so we
309 -
    # have to do the central curve differently.
310 -
    prc_dat <- obj$prc_dat
399 +
  mdl_data <- data
400 +
  stopifnot(all(xnames %in% names(mdl_data)))
401 +
  stopifnot(all(yname %in% names(mdl_data)))
402 +
   
403 +
  # add intercept by default, maybe revisit this
404 +
  xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames])))
405 +
  yvec  <- mdl_data[[yname]]
406 +
  
407 +
  pardraws <- as.matrix(object)
408 +
  # this is not very robust, assumes pars are 'b[x]'
409 +
  betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))]
410 +
  
411 +
  if(link_type=="logit") {
412 +
    pred_prob <- plogis(xdata %*% t(betadraws))  
413 +
  } else if (link_type=="probit") {
414 +
    pred_prob <- pnorm(xdata %*% t(betadraws))
415 +
  } 
416 +
  
417 +
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
418 +
                 fullsims = fullsims)
419 +
  
420 +
  
421 +
}
311 422
312 -
    x <- lapply(prc_dat, `[[`, 1)
313 -
    y <- lapply(prc_dat, `[[`, 2)
314 -
    
315 -
    # Instead of combining the list of curve coordinates from each sample into
316 -
    # two x and y matrices, we can first make a point cloud with all curve 
317 -
    # points from all samples, and then average the y values at all distinct
318 -
    # x coordinates. The x-axis plots recall (TPR), which will only have as 
319 -
    # many distinct values as there are positives in the data, so this does 
320 -
    # not lose any information about the x coordinates. 
321 -
    point_cloud <- data.frame(
322 -
      x = unlist(x),
323 -
      y = unlist(y)
324 -
    )
325 -
    point_cloud <- stats::aggregate(point_cloud[, "y", drop = FALSE], 
326 -
                                    # factor implicitly encodes distinct values only,
327 -
                                    # since they will get the same labels
328 -
                                    by = list(x = as.factor(point_cloud$x)), 
329 -
                                    FUN = mean)
330 -
    point_cloud$x <- as.numeric(as.character(point_cloud$x))
331 -
    xavg <- point_cloud$x
332 -
    yavg <- point_cloud$y
333 -
334 -
    plot(xavg, yavg, type = "n", xlab = "TPR", ylab = "Precision", ylim = c(0, 1))
335 -
    samples <- sample(1:length(prc_dat), n)
336 -
    for (i in samples) {
337 -
      graphics::lines(
338 -
        x[[i]], y[[i]], 
339 -
        col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255)
340 -
      )
341 -
    }
342 -
    graphics::lines(xavg, yavg)
343 -
    
423 +
#' Try to identify if a stanfit model is a binary choice model
424 +
#' 
425 +
#' @param obj stanfit object
426 +
#' 
427 +
#' @keywords internal
428 +
is_binary_model <- function(obj) {
429 +
  stopifnot(inherits(obj, "stanfit"))
430 +
  model_string <- rstan::get_stancode(obj)
431 +
  grepl("bernoulli", model_string)
432 +
}
433 +
434 +
#' Try to identify link function 
435 +
#' 
436 +
#' @param obj stanfit object
437 +
#' 
438 +
#' @return Either "logit" or "probit"; if neither can be identified the function
439 +
#' will return `NA_character_`. 
440 +
#' 
441 +
#' @keywords internal
442 +
identify_link_function <- function(obj) {
443 +
  stopifnot(inherits(obj, "stanfit"))
444 +
  model_string <- rstan::get_stancode(obj)
445 +
  if (grepl("logit", model_string)) return("logit")
446 +
  if (grepl("Phi", model_string)) return("probit")
447 +
  NA_character_
448 +
}
449 +
450 +
#' @rdname mcmcRocPrc
451 +
#' 
452 +
#' @export
453 +
mcmcRocPrc.stanreg <- function(object, curves = FALSE, fullsims = FALSE, ...) {
454 +
  if (!requireNamespace("rstanarm", quietly = TRUE)) {
455 +
    stop("Package \"rstanarm\" is needed for this function to work. Please install it.", call. = FALSE)  # nocov
344 456
  }
457 +
  if (!stats::family(object)$family=="binomial") {
458 +
    stop("the input model does not seem to be a binary choice model; should be like 'obj <- stan_glm(family = binomial(), ...)'") 
459 +
  }
460 +
  pred_prob <- rstanarm::posterior_linpred(object, transform = TRUE)
461 +
  # posterior_linepred returns a matrix in which data cases are columns, and 
462 +
  # MCMC samples are row; we need to transpose this so that columns are samples
463 +
  pred_prob <- t(pred_prob)
464 +
  yvec <- unname(object$y)
345 465
  
346 -
  invisible(x)
466 +
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
467 +
                 fullsims = fullsims)
347 468
}
348 469
349 470
#' @rdname mcmcRocPrc
350 471
#' 
351 -
#' @param row.names see [base::as.data.frame()] 
352 -
#' @param optional see [base::as.data.frame()]
353 -
#' @param what which information to extract and convert to a data frame?
472 +
#' @export
473 +
mcmcRocPrc.brmsfit <- function(object, curves = FALSE, fullsims = FALSE, ...) {
474 +
  if (!requireNamespace("brms", quietly = TRUE)) {
475 +
    stop("Package \"brms\" is needed for this function to work. Please install it.", call. = FALSE)  # nocov
476 +
  }
477 +
  if (!stats::family(object)$family=="bernoulli") {
478 +
    stop("the input model does not seem to be a binary choice model; should be like 'obj <- brm(family = bernoulli(), ...)'") 
479 +
  }
480 +
  
481 +
  pred_prob <- brms::posterior_epred(object)
482 +
  # posterior_epred returns a matrix in which data cases are columns, and 
483 +
  # MCMC samples are row; we need to transpose this so that columns are samples
484 +
  pred_prob <- t(pred_prob)
485 +
  yvec <- stats::model.response(stats::model.frame(object))
486 +
  
487 +
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
488 +
                 fullsims = fullsims)
489 +
}
490 +
491 +
492 +
# Other input types (MCMCpack, ...) ---------------------------------------
493 +
494 +
495 +
#' #' @rdname mcmcRocPrc
496 +
#' #'
497 +
#' #' @export
498 +
#' mcmcRocPrc.bugs <- function(object, curves = FALSE, fullsims = FALSE, ...) {
499 +
#'   stop("not implemented yet")
500 +
#' }
501 +
502 +
503 +
#' @rdname mcmcRocPrc
354 504
#' 
505 +
#' @param type "logit" or "probit"
506 +
#' @param force for MCMCpack models, suppress warning if the model does not 
507 +
#'   appear to be a binary choice model?
508 +
#'
355 509
#' @export
356 -
as.data.frame.mcmcRocPrc <- function(x, row.names = NULL, optional = FALSE,
357 -
                                     what = c("auc", "roc", "prc"), ...) {
358 -
  what <- match.arg(what)
359 -
  if (what=="auc") {
360 -
    # all 4 output types have AUC, so this should work across the board
361 -
    return(as.data.frame(x[c("area_under_roc", "area_under_prc")]))
362 -
    
363 -
  } else if (what %in% c("roc", "prc")) {
364 -
    if (what=="roc") element <- "roc_dat" else element <- "prc_dat"
365 -
    
366 -
    # if curves was FALSE, there will be no curve data...
367 -
    if (is.null(x[[element]])) {
368 -
      stop("No curve data; use mcmcRocPrc(..., curves = TRUE)")
369 -
    }
370 -
    
371 -
    # Otherwise, there will be either one set of coordinates if mcmcmRegPrc()
372 -
    # was called with fullsims = FALSE, or else N_sims curve data sets.
373 -
    # If the latter, we can return a long data frame with an identifying 
374 -
    # "sim" column to delineate the sim sets. To ensure consistency in output,
375 -
    # also add this column when fullsims = FALSE.
376 -
    
377 -
    # averaged, single coordinate set
378 -
    if (length(x[[element]])==1L) {
379 -
      return(data.frame(sim = 1L, x[[element]][[1]]))
510 +
mcmcRocPrc.mcmc <- function(object, curves = FALSE, fullsims = FALSE, data, 
511 +
                            xnames, yname, type = c("logit", "probit"), 
512 +
                            force = FALSE, ...) {
513 +
  
514 +
  if (!force) {
515 +
    if (is.null(attr(object, "call"))) {
516 +
      stop("object does not have a 'call' attribute; was it generated with a MCMCpack function?")
517 +
    } else {
518 +
      func <- as.character(attr(object, "call"))[1]
519 +
      if (!func %in% c("MCMClogit", "MCMCprobit")) {
520 +
        stop("object does not appear to have been fitted using MCMCpack::MCMClogit() or MCMCprobit(); mcmcRocPrc only properly works for those function. To be safe, consider manually calculating the matrix of predicted probabilities.")
521 +
      }
380 522
    }
381 -
    
382 -
    # full sims
383 -
    # add a unique ID to each coordinate set
384 -
    outlist <- x[[element]]
385 -
    outlist <- Map(cbind, sim = (1:length(outlist)), outlist)
386 -
    # combine into long data frame
387 -
    outdf <- do.call(rbind, outlist)
388 -
    return(outdf)
523 +
  }
524 +
  
525 +
  link_type <- match.arg(type)
526 +
  mdl_data <- data
527 +
  stopifnot(
528 +
    all(xnames %in% names(mdl_data)),
529 +
    all(yname %in% names(mdl_data))
530 +
  )
531 +
  
532 +
  # add intercept by default, maybe revisit this
533 +
  xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames])))
534 +
  yvec  <- mdl_data[[yname]]
535 +
  
536 +
  betadraws <- as.matrix(object)
537 +
538 +
  if(link_type=="logit") {
539 +
    pred_prob <- plogis(xdata %*% t(betadraws))  
540 +
  } else if (link_type=="probit") {
541 +
    pred_prob <- pnorm(xdata %*% t(betadraws))
389 542
  } 
390 -
  stop("Developer error (I should not be here): please file an issue on GitHub")  # nocov
543 +
  
544 +
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
545 +
                 fullsims = fullsims)
391 546
}
392 547
393 548
394 549
395 -
# auc_roc and auc_pr are not really used, but keep around just in case
396 -
auc_roc <- function(obs, pred) {
397 -
  values <- compute_roc(obs, pred)
398 -
  caTools::trapz(values$x, values$y)
399 -
}
400 -
401 -
auc_pr <- function(obs, pred) {
402 -
  values <- compute_pr(obs, pred)
403 -
  caTools::trapz(values$x, values$y)
404 -
}

@@ -0,0 +1,183 @@
Loading
1 +
#
2 +
#   Methods for class "mcmcRocPrc", generated by mcmcRocPrc()
3 +
#
4 +
5 +
#' @rdname mcmcRocPrc
6 +
#' 
7 +
#' @export
8 +
print.mcmcRocPrc <- function(x, ...) {
9 +
  
10 +
  auc_roc <- x$area_under_roc
11 +
  auc_prc <- x$area_under_prc
12 +
  
13 +
  has_curves <- !is.null(x$roc_dat)
14 +
  has_sims   <- length(auc_roc) > 1
15 +
  
16 +
  if (!has_sims) {
17 +
    roc_msg <- sprintf("%.3f", round(auc_roc, 3))
18 +
    prc_msg <- sprintf("%.3f", round(auc_prc, 3))
19 +
  } else {
20 +
    roc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]",
21 +
                       round(mean(auc_roc), 3), 
22 +
                       round(quantile(auc_roc, 0.1), 3), 
23 +
                       round(quantile(auc_roc, 0.9), 3))
24 +
    prc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]",
25 +
                       round(mean(auc_prc), 3), 
26 +
                       round(quantile(auc_prc, 0.1), 3), 
27 +
                       round(quantile(auc_prc, 0.9), 3))
28 +
  }
29 +
  
30 +
  cat("mcmcRocPrc object\n")
31 +
  cat(sprintf("curves: %s; fullsims: %s\n", has_curves, has_sims))
32 +
  cat(sprintf("AUC-ROC: %s\n", roc_msg))
33 +
  cat(sprintf("AUC-PR:  %s\n", prc_msg))
34 +
  
35 +
  invisible(x)
36 +
}
37 +
38 +
#' @rdname mcmcRocPrc
39 +
#' 
40 +
#' @param n plot method: if `fullsims = TRUE`, how many sample curves to draw?
41 +
#' @param alpha plot method: alpha value for plotting sampled curves; between 0 and 1
42 +
#' 
43 +
#' @export
44 +
plot.mcmcRocPrc <- function(x, n = 40, alpha = .5, ...) {
45 +
  
46 +
  stopifnot(
47 +
    "Use mcmcRocPrc(..., curves = TRUE) to generate data for plots" = (!is.null(x$roc_dat)),
48 +
    "alpha must be between 0 and 1" = (alpha >= 0 & alpha <= 1),
49 +
    "n must be > 0" = (n > 0)
50 +
  )
51 +
  
52 +
  obj<- x
53 +
  fullsims <- length(obj$roc_dat) > 1
54 +
  
55 +
  if (!fullsims) {
56 +
    
57 +
    graphics::par(mfrow = c(1, 2))
58 +
    plot(obj$roc_dat[[1]], type = "s", xlab = "FPR", ylab = "TPR")
59 +
    graphics::abline(a = 0, b = 1, lty = 3, col = "gray50")
60 +
    
61 +
    prc_dat <- obj$prc_dat[[1]]
62 +
    # use first non-NaN y-value for y[1]
63 +
    prc_dat$y[1] <- prc_dat$y[2]
64 +
    plot(prc_dat, type = "l", xlab = "TPR", ylab = "Precision",
65 +
         ylim = c(0, 1))
66 +
    graphics::abline(a = attr(x, "y_pos_rate"), b = 0, lty = 3, col = "gray50")
67 +
    
68 +
  } else {
69 +
    
70 +
    graphics::par(mfrow = c(1, 2))
71 +
    
72 +
    roc_dat <- obj$roc_dat
73 +
    
74 +
    x <- lapply(roc_dat, `[[`, 1)
75 +
    x <- do.call(cbind, x)
76 +
    colnames(x) <- paste0("sim", 1:ncol(x))
77 +
    
78 +
    y <- lapply(roc_dat, `[[`, 2)
79 +
    y <- do.call(cbind, y)
80 +
    colnames(y) <- paste0("sim", 1:ncol(y))
81 +
    
82 +
    xavg <- rowMeans(x)
83 +
    yavg <- rowMeans(y)
84 +
    
85 +
    plot(xavg, yavg, type = "n", xlab = "FPR", ylab = "TPR")
86 +
    samples <- sample(1:ncol(x), n)
87 +
    for (i in samples) {
88 +
      graphics::lines(
89 +
        x[, i], y[, i], type = "s",
90 +
        col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255)
91 +
      )
92 +
    }
93 +
    graphics::lines(xavg, yavg, type = "s")
94 +
    
95 +
    # PRC
96 +
    # The elements of prc_dat have different lengths, unlike roc_dat, so we
97 +
    # have to do the central curve differently.
98 +
    prc_dat <- obj$prc_dat
99 +
    
100 +
    x <- lapply(prc_dat, `[[`, 1)
101 +
    y <- lapply(prc_dat, `[[`, 2)
102 +
    
103 +
    # Instead of combining the list of curve coordinates from each sample into
104 +
    # two x and y matrices, we can first make a point cloud with all curve 
105 +
    # points from all samples, and then average the y values at all distinct
106 +
    # x coordinates. The x-axis plots recall (TPR), which will only have as 
107 +
    # many distinct values as there are positives in the data, so this does 
108 +
    # not lose any information about the x coordinates. 
109 +
    point_cloud <- data.frame(
110 +
      x = unlist(x),
111 +
      y = unlist(y)
112 +
    )
113 +
    point_cloud <- stats::aggregate(point_cloud[, "y", drop = FALSE], 
114 +
                                    # factor implicitly encodes distinct values only,
115 +
                                    # since they will get the same labels
116 +
                                    by = list(x = as.factor(point_cloud$x)), 
117 +
                                    FUN = mean)
118 +
    point_cloud$x <- as.numeric(as.character(point_cloud$x))
119 +
    xavg <- point_cloud$x
120 +
    yavg <- point_cloud$y
121 +
    
122 +
    plot(xavg, yavg, type = "n", xlab = "TPR", ylab = "Precision", ylim = c(0, 1))
123 +
    samples <- sample(1:length(prc_dat), n)
124 +
    for (i in samples) {
125 +
      graphics::lines(
126 +
        x[[i]], y[[i]], 
127 +
        col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255)
128 +
      )
129 +
    }
130 +
    graphics::lines(xavg, yavg)
131 +
    
132 +
  }
133 +
  
134 +
  invisible(x)
135 +
}
136 +
137 +
#' @rdname mcmcRocPrc
138 +
#' 
139 +
#' @param row.names see [base::as.data.frame()] 
140 +
#' @param optional see [base::as.data.frame()]
141 +
#' @param what which information to extract and convert to a data frame?
142 +
#' 
143 +
#' @export
144 +
as.data.frame.mcmcRocPrc <- function(x, row.names = NULL, optional = FALSE,
145 +
                                     what = c("auc", "roc", "prc"), ...) {
146 +
  what <- match.arg(what)
147 +
  if (what=="auc") {
148 +
    # all 4 output types have AUC, so this should work across the board
149 +
    return(as.data.frame(x[c("area_under_roc", "area_under_prc")]))
150 +
    
151 +
  } else if (what %in% c("roc", "prc")) {
152 +
    if (what=="roc") element <- "roc_dat" else element <- "prc_dat"
153 +
    
154 +
    # if curves was FALSE, there will be no curve data...
155 +
    if (is.null(x[[element]])) {
156 +
      stop("No curve data; use mcmcRocPrc(..., curves = TRUE)")
157 +
    }
158 +
    
159 +
    # Otherwise, there will be either one set of coordinates if mcmcmRegPrc()
160 +
    # was called with fullsims = FALSE, or else N_sims curve data sets.
161 +
    # If the latter, we can return a long data frame with an identifying 
162 +
    # "sim" column to delineate the sim sets. To ensure consistency in output,
163 +
    # also add this column when fullsims = FALSE.
164 +
    
165 +
    # averaged, single coordinate set
166 +
    if (length(x[[element]])==1L) {
167 +
      return(data.frame(sim = 1L, x[[element]][[1]]))
168 +
    }
169 +
    
170 +
    # full sims
171 +
    # add a unique ID to each coordinate set
172 +
    outlist <- x[[element]]
173 +
    outlist <- Map(cbind, sim = (1:length(outlist)), outlist)
174 +
    # combine into long data frame
175 +
    outdf <- do.call(rbind, outlist)
176 +
    return(outdf)
177 +
  } 
178 +
  stop("Developer error (I should not be here): please file an issue on GitHub")  # nocov
179 +
}
180 +
181 +
182 +
183 +

Learn more Showing 1 files with coverage changes found.

New file R/mcmcRocPrc-methods.R
New
Loading file...
Files Coverage
R 0.98% 83.11%
Project Totals (11 files) 83.11%
Loading