@@ -1,39 +1,78 @@
Loading
1 +
#
2 +
#   This file contains the mcmcRocPrc() S3 generic, which constructs objects
3 +
#   of class "mcmcRocPrc". For methods for this class, see mcmcRocPrc-methods.R
4 +
#   S3 methods for the mcmcRocPrc() generic handle different types of input
5 +
#   e.g. "rjags" input produced by R2jags.
6 +
#
7 +
8 +
9 +
1 10
#' ROC and Precision-Recall Curves using Bayesian MCMC estimates
2 11
#' 
3 12
#' Generate ROC and Precision-Recall curves after fitting a Bayesian logit or 
4 -
#' probit regression using [R2jags::jags()]
13 +
#' probit regression using [rstan::stan()], [rstanarm::stan_glm()], 
14 +
#' [R2jags::jags()], [R2WinBUGS::bugs()], [MCMCpack::MCMClogit()], or other 
15 +
#' functions that provide samples from a posterior density. 
5 16
#' 
6 -
#' @param object A "rjags" object (see [R2jags::jags()]) for a fitted binary
7 -
#'   choice model.
8 -
#' @param yname (`character(1)`)\cr
9 -
#'   The name of the dependent variable, should match the variable name in the 
10 -
#'   JAGS data object.
11 -
#' @param xnames ([base::character()])\cr
12 -
#'   A character vector of the independent variable names, should match the 
13 -
#'   corresponding names in the JAGS data object.
17 +
#' @param object A fitted binary choice model, e.g. "rjags" object 
18 +
#'   (see [R2jags::jags()]), or a `[N, iter]` matrix of predicted probabilites.
14 19
#' @param curves logical indicator of whether or not to return values to plot 
15 20
#'   the ROC or Precision-Recall curves. If set to `FALSE` (default), 
16 21
#'   results are returned as a list without the extra values. 
17 22
#' @param fullsims logical indicator of whether full object (based on all MCMC
18 23
#'   draws rather than their average) will be returned. Default is `FALSE`. 
19 24
#'   Note: If `TRUE` is chosen, the function takes notably longer to execute.
25 +
#' @param yvec A `numeric(N)` vector of observed outcomes. 
26 +
#' @param yname (`character(1)`)\cr
27 +
#'   The name of the dependent variable, should match the variable name in the 
28 +
#'   JAGS data object.
29 +
#' @param xnames ([base::character()])\cr
30 +
#'   A character vector of the independent variable names, should match the 
31 +
#'   corresponding names in the JAGS data object.
32 +
#' @param posterior_samples a "mcmc" object with the posterior samples
20 33
#' @param ... Used by methods
21 34
#' @param x a `mcmcRocPrc()` object
22 35
#' 
36 +
#' @details If only the average AUC-ROC and PR are of interest, setting 
37 +
#'   `curves = FALSE` and `fullsims = FALSE` can greatly speed up calculation 
38 +
#'   time. The curve data (`curves = TRUE`) is needed for plotting. The plot
39 +
#'   method will always plot both the ROC and PR curves, but the underlying
40 +
#'   data can easily be extracted from the output for your own plotting; 
41 +
#'   see the documentation of the value returned below. 
42 +
#'   
43 +
#'   The default method works with a matrix of predicted probabilities and the 
44 +
#'   vector of observed incomes as input. Other methods accommodate some of the 
45 +
#'   common Bayesian modeling packages like rstan (which returns class "stanfit"),
46 +
#'   rstanarm ("stanreg"), R2jags ("jags"), R2WinBUGS ("bugs"), and 
47 +
#'   MCMCpack ("mcmc"). Even if a package-specific method is not implemented, 
48 +
#'   the default method can always be used as a fallback by manually calculating
49 +
#'   the matrix of predicted probabilities for each posterior sample. 
50 +
#'   
51 +
#'   Note that MCMCpack returns generic "mcmc" output that is annotated with 
52 +
#'   some additional information as attributes, including the original function
53 +
#'   call. There is no inherent way to distinguish any other kind of "mcmc" 
54 +
#'   object from one generated by a proper MCMCpack modeling function, but as a
55 +
#'   basic precaution, `mcmcRocPrc()` will check the saved call and return an 
56 +
#'   error if the function called was not `MCMClogit()` or `MCMCprobit()`. 
57 +
#'   This behavior can be suppressed by setting `force = TRUE`. 
58 +
#' 
23 59
#' @references Beger, Andreas. 2016. “Precision-Recall Curves.” Available at 
24 60
#'   SSRN: [http://dx.doi.org/10.2139/ssrn.2765419](http://dx.doi.org/10.2139/ssrn.2765419)
25 61
#' 
26 62
#' @return Returns a list with length 2 or 4, depending on the on the "curves" 
27 63
#'   and "fullsims" argument values:
28 64
#'   
29 -
#'   - "area_under_roc": `numeric(1)`
30 -
#'   - "area_under_prc": `numeric(1)`
31 -
#'   - "prc_dat": only if `curves = TRUE`; a list with length 1 if `fullsims = FALSE`, longer otherwise
32 -
#'   - "roc_dat": only if `curves = TRUE`; a list with length 1 if `fullsims = FALSE`, longer otherwise
65 +
#'   - "area_under_roc": `numeric()`; either length 1 if `fullsims = FALSE`, or 
66 +
#'     one value for each posterior sample otherwise
67 +
#'   - "area_under_prc": `numeric()`; either length 1 if `fullsims = FALSE`, or 
68 +
#'     one value for each posterior sample otherwise
69 +
#'   - "prc_dat": only if `curves = TRUE`; a list with length 1 if 
70 +
#'     `fullsims = FALSE`, longer otherwise
71 +
#'   - "roc_dat": only if `curves = TRUE`; a list with length 1 if 
72 +
#'     `fullsims = FALSE`, longer otherwise
33 73
#'
34 74
#' @examples
35 75
#' # load simulated data and fitted model (see ?sim_data and ?jags_logit)
36 -
#' library(R2jags)
37 76
#' data("jags_logit")
38 77
#' 
39 78
#' # using mcmcRocPrc
@@ -44,41 +83,57 @@
Loading
44 83
#'                       fullsims = FALSE)
45 84
#' fit_sum                     
46 85
#' plot(fit_sum)
86 +
#' 
87 +
#' # Equivalently, we can calculate the matrix of predicted probabilities 
88 +
#' # ourselves; using the example from ?jags_logit:
89 +
#' library(R2jags)
90 +
#' 
91 +
#' data("sim_data")
92 +
#' yvec <- sim_data$Y
93 +
#' xmat <- sim_data[, c("X1", "X2")]
94 +
#' 
95 +
#' # add intercept to the X data
96 +
#' xmat <- as.matrix(cbind(Intercept = 1L, xmat))
97 +
#' 
98 +
#' beta <- as.matrix(as.mcmc(jags_logit))[, c("b[1]", "b[2]", "b[3]")]
99 +
#' pred_mat <- plogis(xmat %*% t(beta)) 
100 +
#' 
101 +
#' # the matrix of predictions has rows matching the number of rows in the data;
102 +
#' # the column are the predictions for each of the 2,000 posterior samples
103 +
#' nrow(sim_data)
104 +
#' dim(pred_mat)
105 +
#' 
106 +
#' # now we can call mcmcRocPrc; the default method works with the matrix
107 +
#' # of predictions and vector of outcomes as input
108 +
#' mcmcRocPrc(object = pred_mat, curves = TRUE, fullsims = FALSE, yvec = yvec)
109 +
#' 
47 110
#' @export
48 111
#' @md
112 +
mcmcRocPrc <- function(object, curves = FALSE, fullsims = FALSE, ...) {
113 +
  UseMethod("mcmcRocPrc", object)
114 +
}
49 115
50 -
mcmcRocPrc <- function(object, 
51 -
                       yname, 
52 -
                       xnames, 
53 -
                       curves = FALSE, 
54 -
                       fullsims = FALSE) {
55 -
  
56 -
  link_logit  <- any(grepl("logit", object$model$model()))
57 -
  link_probit <- any(grepl("probit", object$model$model()))
58 -
  
59 -
  if (isFALSE(link_logit | link_probit)) {
60 -
    stop("Could not identify model link function")
61 -
  }
62 -
  
63 -
  mdl_data <- object$model$data()
64 -
  stopifnot(all(xnames %in% names(mdl_data)))
65 -
  stopifnot(all(yname %in% names(mdl_data)))
66 -
  
67 -
  # add intercept by default, maybe revisit this
68 -
  xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames])))
69 -
  yvec <- mdl_data[[yname]]
70 -
  
71 -
  pardraws <- as.matrix(coda::as.mcmc(object))
72 -
  # this is not very robust, assumes pars are 'b[x]'
73 -
  # for both this and the intercept addition above, maybe a more robust solution
74 -
  # down the road would be to dig into the object$model$model() string
75 -
  betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))]
116 +
#' Constructor for mcmcRocPrc objects
117 +
#' 
118 +
#' This function actually does the heavy lifting once we have a matrix of 
119 +
#' predicted probabilities from a model, plus the vector of observed outcomes.
120 +
#' The reason to have it here in a single function is that we don't replicate 
121 +
#' it in each function that accomodates a JAGS, BUGS, RStan, etc. object.
122 +
#' 
123 +
#' @param pred_prob a `\[N, iter\]` matrix of predicted probabilities 
124 +
#' @param yvec a `numeric(N)` vector of observed outcomes
125 +
#' @param curves include curve data in output?
126 +
#' @param fullsims collapse posterior samples into single summary?
127 +
#' 
128 +
#' @md
129 +
#' @keywords internal
130 +
new_mcmcRocPrc <- function(pred_prob, yvec, curves, fullsims) {
76 131
  
77 -
  if(isTRUE(link_logit)) {
78 -
    pred_prob <- plogis(xdata %*% t(betadraws))  
79 -
  } else if (isTRUE(link_probit)) {
80 -
    pred_prob <- pnorm(xdata %*% t(betadraws))
81 -
  } 
132 +
  stopifnot(
133 +
    "number of predictions and observed outcomes do not match" = nrow(pred_prob)==length(yvec),
134 +
    "yvec must be 0 or 1"                                      = all(yvec %in% c(0L, 1L)),
135 +
    "pred_prob must be in the interval [0, 1]"                 = all(pred_prob >= 0 & pred_prob <= 1)
136 +
  )
82 137
  
83 138
  # pred_prob is a [N, iter] matrix, i.e. each column are preds from one 
84 139
  # set of posterior samples
@@ -146,6 +201,21 @@
Loading
146 201
  )
147 202
}
148 203
204 +
#' @rdname mcmcRocPrc
205 +
#' 
206 +
#' @md
207 +
#' @export
208 +
mcmcRocPrc.default <- function(object, curves, fullsims, yvec, ...) {
209 +
  pred_prob <- object
210 +
  
211 +
  stopifnot(
212 +
    "mcmcRocPrc.default requires 'matrix' like input" = inherits(pred_prob, "matrix")
213 +
  )
214 +
  
215 +
  new_mcmcRocPrc(pred_prob, yvec, curves, fullsims)
216 +
}
217 +
218 +
# Under the hood ROC/PRC calculations -------------------------------------
149 219
150 220
#' Compute ROC and PR curve points
151 221
#' 
@@ -213,192 +283,267 @@
Loading
213 283
}
214 284
215 285
286 +
# auc_roc and auc_pr are not really used, but keep around just in case
287 +
auc_roc <- function(obs, pred) {
288 +
  values <- compute_roc(obs, pred)
289 +
  caTools::trapz(values$x, values$y)
290 +
}
291 +
292 +
auc_pr <- function(obs, pred) {
293 +
  values <- compute_pr(obs, pred)
294 +
  caTools::trapz(values$x, values$y)
295 +
}
296 +
297 +
298 +
299 +
# JAGS-like input (rjags, R2jags, runjags) --------------------------------
216 300
217 301
#' @rdname mcmcRocPrc
218 302
#' 
219 303
#' @export
220 -
print.mcmcRocPrc <- function(x, ...) {
221 -
  
222 -
  auc_roc <- x$area_under_roc
223 -
  auc_prc <- x$area_under_prc
224 -
  
225 -
  has_curves <- !is.null(x$roc_dat)
226 -
  has_sims   <- length(auc_roc) > 1
227 -
  
228 -
  if (!has_sims) {
229 -
    roc_msg <- sprintf("%.3f", round(auc_roc, 3))
230 -
    prc_msg <- sprintf("%.3f", round(auc_prc, 3))
231 -
  } else {
232 -
    roc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]",
233 -
                       round(mean(auc_roc), 3), 
234 -
                       round(quantile(auc_roc, 0.1), 3), 
235 -
                       round(quantile(auc_roc, 0.9), 3))
236 -
    prc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]",
237 -
                       round(mean(auc_prc), 3), 
238 -
                       round(quantile(auc_prc, 0.1), 3), 
239 -
                       round(quantile(auc_prc, 0.9), 3))
304 +
mcmcRocPrc.jags <- function(object, curves = FALSE, fullsims = FALSE, yname, 
305 +
                            xnames, posterior_samples, ...) {
306 +
  
307 +
  stopifnot(
308 +
    inherits(posterior_samples, c("mcmc", "mcmc.list"))
309 +
  )
310 +
  
311 +
  link_logit  <- any(grepl("logit", object$model()))
312 +
  link_probit <- any(grepl("probit", object$model()))
313 +
  
314 +
  if (isFALSE(link_logit | link_probit)) {
315 +
    stop("Could not identify model link function")
240 316
  }
241 317
  
242 -
  cat("mcmcRocPrc object\n")
243 -
  cat(sprintf("curves: %s; fullsims: %s\n", has_curves, has_sims))
244 -
  cat(sprintf("AUC-ROC: %s\n", roc_msg))
245 -
  cat(sprintf("AUC-PR:  %s\n", prc_msg))
318 +
  mdl_data <- object$data()
319 +
  stopifnot(all(xnames %in% names(mdl_data)))
320 +
  stopifnot(all(yname %in% names(mdl_data)))
321 +
  
322 +
  # add intercept by default, maybe revisit this
323 +
  xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames])))
324 +
  yvec  <- mdl_data[[yname]]
325 +
  
326 +
  pardraws <- as.matrix(posterior_samples)
327 +
  # this is not very robust, assumes pars are 'b[x]'
328 +
  # for both this and the intercept addition above, maybe a more robust solution
329 +
  # down the road would be to dig into the object$model$model() string
330 +
  betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))]
331 +
  
332 +
  if(isTRUE(link_logit)) {
333 +
    pred_prob <- plogis(xdata %*% t(betadraws))  
334 +
  } else if (isTRUE(link_probit)) {
335 +
    pred_prob <- pnorm(xdata %*% t(betadraws))
336 +
  } 
246 337
  
247 -
  invisible(x)
338 +
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
339 +
                 fullsims = fullsims)
248 340
}
249 341
250 342
#' @rdname mcmcRocPrc
251 343
#' 
252 -
#' @param n plot method: if `fullsims = TRUE`, how many sample curves to draw?
253 -
#' @param alpha plot method: alpha value for plotting sampled curves; between 0 and 1
344 +
#' @export
345 +
mcmcRocPrc.rjags <- function(object, curves = FALSE, fullsims = FALSE, yname, 
346 +
                             xnames, ...) {
347 +
  
348 +
  if (!requireNamespace("R2jags", quietly = TRUE)) {
349 +
    stop("Package \"R2jags\" is needed for this function to work. Please install it.", call. = FALSE)  # nocov
350 +
  }
351 +
  
352 +
  jags_object <- object$model
353 +
  pardraws    <- coda::as.mcmc(object)
354 +
  
355 +
  # pass it on to the "jags" method
356 +
  mcmcRocPrc(object = jags_object, curves = curves, fullsims = fullsims, 
357 +
             yname = yname, xnames = xnames, posterior_samples = pardraws, ...)
358 +
}
359 +
360 +
#' @rdname mcmcRocPrc
254 361
#' 
255 362
#' @export
256 -
plot.mcmcRocPrc <- function(x, n = 40, alpha = .5, ...) {
363 +
mcmcRocPrc.runjags <- function(object, curves = FALSE, fullsims = FALSE, yname, 
364 +
                               xnames, ...) {
365 +
  jags_object <- runjags::as.jags(object, quiet = TRUE)
366 +
  # as.mcmc.runjags will issue a warning when converting multiple chains
367 +
  # because it combines them
368 +
  pardraws    <- suppressWarnings(coda::as.mcmc(object))
369 +
  
370 +
  # pass it on to the "jags" method
371 +
  mcmcRocPrc(object = jags_object, curves = curves, fullsims = fullsims, 
372 +
             yname = yname, xnames = xnames, posterior_samples = pardraws, ...)
373 +
}
257 374
258 -
  stopifnot(
259 -
    "Use mcmcRocPrc(..., curves = TRUE) to generate data for plots" = (!is.null(x$roc_dat)),
260 -
    "alpha must be between 0 and 1" = (alpha >= 0 & alpha <= 1),
261 -
    "n must be > 0" = (n > 0)
262 -
  )
375 +
376 +
# STAN-like input (rstan, rstanarm, brms) ---------------------------------
377 +
378 +
379 +
380 +
#' @rdname mcmcRocPrc
381 +
#' 
382 +
#' @param data the data that was used in the `stan(data = ?, ...)` call
383 +
#' 
384 +
#' @export
385 +
mcmcRocPrc.stanfit <- function(object, curves = FALSE, fullsims = FALSE, data, 
386 +
                               xnames, yname, ...) {
387 +
  if (!requireNamespace("rstan", quietly = TRUE)) {
388 +
    stop("Package \"rstan\" is needed for this function to work. Please install it.", call. = FALSE)  # nocov
389 +
  }
263 390
  
264 -
  obj<- x
265 -
  fullsims <- length(obj$roc_dat) > 1
391 +
  if (!is_binary_model(object)) {
392 +
    stop("the input model does not seem to be a binary choice model; if this is a mistake please file an issue at https://github.com/ShanaScogin/BayesPostEst/issues/")
393 +
  }
394 +
  link_type <- identify_link_function(object)
395 +
  if (is.na(link_type)) {
396 +
    stop("could not identify model link function; please file an issue at https://github.com/ShanaScogin/BayesPostEst/issues/")
397 +
  }
266 398
  
267 -
  if (!fullsims) {
268 -
    
269 -
    graphics::par(mfrow = c(1, 2))
270 -
    plot(obj$roc_dat[[1]], type = "s", xlab = "FPR", ylab = "TPR")
271 -
    graphics::abline(a = 0, b = 1, lty = 3, col = "gray50")
272 -
    
273 -
    prc_dat <- obj$prc_dat[[1]]
274 -
    # use first non-NaN y-value for y[1]
275 -
    prc_dat$y[1] <- prc_dat$y[2]
276 -
    plot(prc_dat, type = "l", xlab = "TPR", ylab = "Precision",
277 -
         ylim = c(0, 1))
278 -
    graphics::abline(a = attr(x, "y_pos_rate"), b = 0, lty = 3, col = "gray50")
279 -
    
280 -
  } else {
281 -
    
282 -
    graphics::par(mfrow = c(1, 2))
283 -
    
284 -
    roc_dat <- obj$roc_dat
285 -
    
286 -
    x <- lapply(roc_dat, `[[`, 1)
287 -
    x <- do.call(cbind, x)
288 -
    colnames(x) <- paste0("sim", 1:ncol(x))
289 -
    
290 -
    y <- lapply(roc_dat, `[[`, 2)
291 -
    y <- do.call(cbind, y)
292 -
    colnames(y) <- paste0("sim", 1:ncol(y))
293 -
    
294 -
    xavg <- rowMeans(x)
295 -
    yavg <- rowMeans(y)
296 -
    
297 -
    plot(xavg, yavg, type = "n", xlab = "FPR", ylab = "TPR")
298 -
    samples <- sample(1:ncol(x), n)
299 -
    for (i in samples) {
300 -
      graphics::lines(
301 -
        x[, i], y[, i], type = "s",
302 -
        col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255)
303 -
      )
304 -
    }
305 -
    graphics::lines(xavg, yavg, type = "s")
306 -
    
307 -
    # PRC
308 -
    # The elements of prc_dat have different lengths, unlike roc_dat, so we
309 -
    # have to do the central curve differently.
310 -
    prc_dat <- obj$prc_dat
399 +
  mdl_data <- data
400 +
  stopifnot(all(xnames %in% names(mdl_data)))
401 +
  stopifnot(all(yname %in% names(mdl_data)))
402 +
   
403 +
  # add intercept by default, maybe revisit this
404 +
  xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames])))
405 +
  yvec  <- mdl_data[[yname]]
406 +
  
407 +
  pardraws <- as.matrix(object)
408 +
  # this is not very robust, assumes pars are 'b[x]'
409 +
  betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))]
410 +
  
411 +
  if(link_type=="logit") {
412 +
    pred_prob <- plogis(xdata %*% t(betadraws))  
413 +
  } else if (link_type=="probit") {
414 +
    pred_prob <- pnorm(xdata %*% t(betadraws))
415 +
  } 
416 +
  
417 +
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
418 +
                 fullsims = fullsims)
419 +
  
420 +
  
421 +
}
311 422
312 -
    x <- lapply(prc_dat, `[[`, 1)
313 -
    y <- lapply(prc_dat, `[[`, 2)
314 -
    
315 -
    # Instead of combining the list of curve coordinates from each sample into
316 -
    # two x and y matrices, we can first make a point cloud with all curve 
317 -
    # points from all samples, and then average the y values at all distinct
318 -
    # x coordinates. The x-axis plots recall (TPR), which will only have as 
319 -
    # many distinct values as there are positives in the data, so this does 
320 -
    # not lose any information about the x coordinates. 
321 -
    point_cloud <- data.frame(
322 -
      x = unlist(x),
323 -
      y = unlist(y)
324 -
    )
325 -
    point_cloud <- stats::aggregate(point_cloud[, "y", drop = FALSE], 
326 -
                                    # factor implicitly encodes distinct values only,
327 -
                                    # since they will get the same labels
328 -
                                    by = list(x = as.factor(point_cloud$x)), 
329 -
                                    FUN = mean)
330 -
    point_cloud$x <- as.numeric(as.character(point_cloud$x))
331 -
    xavg <- point_cloud$x
332 -
    yavg <- point_cloud$y
333 -
334 -
    plot(xavg, yavg, type = "n", xlab = "TPR", ylab = "Precision", ylim = c(0, 1))
335 -
    samples <- sample(1:length(prc_dat), n)
336 -
    for (i in samples) {
337 -
      graphics::lines(
338 -
        x[[i]], y[[i]], 
339 -
        col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255)
340 -
      )
341 -
    }
342 -
    graphics::lines(xavg, yavg)
343 -
    
423 +
#' Try to identify if a stanfit model is a binary choice model
424 +
#' 
425 +
#' @param obj stanfit object
426 +
#' 
427 +
#' @keywords internal
428 +
is_binary_model <- function(obj) {
429 +
  stopifnot(inherits(obj, "stanfit"))
430 +
  model_string <- rstan::get_stancode(obj)
431 +
  grepl("bernoulli", model_string)
432 +
}
433 +
434 +
#' Try to identify link function 
435 +
#' 
436 +
#' @param obj stanfit object
437 +
#' 
438 +
#' @return Either "logit" or "probit"; if neither can be identified the function
439 +
#' will return `NA_character_`. 
440 +
#' 
441 +
#' @keywords internal
442 +
identify_link_function <- function(obj) {
443 +
  stopifnot(inherits(obj, "stanfit"))
444 +
  model_string <- rstan::get_stancode(obj)
445 +
  if (grepl("logit", model_string)) return("logit")
446 +
  if (grepl("Phi", model_string)) return("probit")
447 +
  NA_character_
448 +
}
449 +
450 +
#' @rdname mcmcRocPrc
451 +
#' 
452 +
#' @export
453 +
mcmcRocPrc.stanreg <- function(object, curves = FALSE, fullsims = FALSE, ...) {
454 +
  if (!requireNamespace("rstanarm", quietly = TRUE)) {
455 +
    stop("Package \"rstanarm\" is needed for this function to work. Please install it.", call. = FALSE)  # nocov
344 456
  }
457 +
  if (!stats::family(object)$family=="binomial") {
458 +
    stop("the input model does not seem to be a binary choice model; should be like 'obj <- stan_glm(family = binomial(), ...)'") 
459 +
  }
460 +
  pred_prob <- rstanarm::posterior_linpred(object, transform = TRUE)
461 +
  # posterior_linepred returns a matrix in which data cases are columns, and 
462 +
  # MCMC samples are row; we need to transpose this so that columns are samples
463 +
  pred_prob <- t(pred_prob)
464 +
  yvec <- unname(object$y)
345 465
  
346 -
  invisible(x)
466 +
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
467 +
                 fullsims = fullsims)
347 468
}
348 469
349 470
#' @rdname mcmcRocPrc
350 471
#' 
351 -
#' @param row.names see [base::as.data.frame()] 
352 -
#' @param optional see [base::as.data.frame()]
353 -
#' @param what which information to extract and convert to a data frame?
472 +
#' @export
473 +
mcmcRocPrc.brmsfit <- function(object, curves = FALSE, fullsims = FALSE, ...) {
474 +
  if (!requireNamespace("brms", quietly = TRUE)) {
475 +
    stop("Package \"brms\" is needed for this function to work. Please install it.", call. = FALSE)  # nocov
476 +
  }
477 +
  if (!stats::family(object)$family=="bernoulli") {
478 +
    stop("the input model does not seem to be a binary choice model; should be like 'obj <- brm(family = bernoulli(), ...)'") 
479 +
  }
480 +
  
481 +
  pred_prob <- brms::posterior_epred(object)
482 +
  # posterior_epred returns a matrix in which data cases are columns, and 
483 +
  # MCMC samples are row; we need to transpose this so that columns are samples
484 +
  pred_prob <- t(pred_prob)
485 +
  yvec <- stats::model.response(stats::model.frame(object))
486 +
  
487 +
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
488 +
                 fullsims = fullsims)
489 +
}
490 +
491 +
492 +
# Other input types (MCMCpack, ...) ---------------------------------------
493 +
494 +
495 +
#' #' @rdname mcmcRocPrc
496 +
#' #'
497 +
#' #' @export
498 +
#' mcmcRocPrc.bugs <- function(object, curves = FALSE, fullsims = FALSE, ...) {
499 +
#'   stop("not implemented yet")
500 +
#' }
501 +
502 +
503 +
#' @rdname mcmcRocPrc
354 504
#' 
505 +
#' @param type "logit" or "probit"
506 +
#' @param force for MCMCpack models, suppress warning if the model does not 
507 +
#'   appear to be a binary choice model?
508 +
#'
355 509
#' @export
356 -
as.data.frame.mcmcRocPrc <- function(x, row.names = NULL, optional = FALSE,
357 -
                                     what = c("auc", "roc", "prc"), ...) {
358 -
  what <- match.arg(what)
359 -
  if (what=="auc") {
360 -
    # all 4 output types have AUC, so this should work across the board
361 -
    return(as.data.frame(x[c("area_under_roc", "area_under_prc")]))
362 -
    
363 -
  } else if (what %in% c("roc", "prc")) {
364 -
    if (what=="roc") element <- "roc_dat" else element <- "prc_dat"
365 -
    
366 -
    # if curves was FALSE, there will be no curve data...
367 -
    if (is.null(x[[element]])) {
368 -
      stop("No curve data; use mcmcRocPrc(..., curves = TRUE)")
369 -
    }
370 -
    
371 -
    # Otherwise, there will be either one set of coordinates if mcmcmRegPrc()
372 -
    # was called with fullsims = FALSE, or else N_sims curve data sets.
373 -
    # If the latter, we can return a long data frame with an identifying 
374 -
    # "sim" column to delineate the sim sets. To ensure consistency in output,
375 -
    # also add this column when fullsims = FALSE.
376 -
    
377 -
    # averaged, single coordinate set
378 -
    if (length(x[[element]])==1L) {
379 -
      return(data.frame(sim = 1L, x[[element]][[1]]))
510 +
mcmcRocPrc.mcmc <- function(object, curves = FALSE, fullsims = FALSE, data, 
511 +
                            xnames, yname, type = c("logit", "probit"), 
512 +
                            force = FALSE, ...) {
513 +
  
514 +
  if (!force) {
515 +
    if (is.null(attr(object, "call"))) {
516 +
      stop("object does not have a 'call' attribute; was it generated with a MCMCpack function?")
517 +
    } else {
518 +
      func <- as.character(attr(object, "call"))[1]
519 +
      if (!func %in% c("MCMClogit", "MCMCprobit")) {
520 +
        stop("object does not appear to have been fitted using MCMCpack::MCMClogit() or MCMCprobit(); mcmcRocPrc only properly works for those function. To be safe, consider manually calculating the matrix of predicted probabilities.")
521 +
      }
380 522
    }
381 -
    
382 -
    # full sims
383 -
    # add a unique ID to each coordinate set
384 -
    outlist <- x[[element]]
385 -
    outlist <- Map(cbind, sim = (1:length(outlist)), outlist)
386 -
    # combine into long data frame
387 -
    outdf <- do.call(rbind, outlist)
388 -
    return(outdf)
523 +
  }
524 +
  
525 +
  link_type <- match.arg(type)
526 +
  mdl_data <- data
527 +
  stopifnot(
528 +
    all(xnames %in% names(mdl_data)),
529 +
    all(yname %in% names(mdl_data))
530 +
  )
531 +
  
532 +
  # add intercept by default, maybe revisit this
533 +
  xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames])))
534 +
  yvec  <- mdl_data[[yname]]
535 +
  
536 +
  betadraws <- as.matrix(object)
537 +
538 +
  if(link_type=="logit") {
539 +
    pred_prob <- plogis(xdata %*% t(betadraws))  
540 +
  } else if (link_type=="probit") {
541 +
    pred_prob <- pnorm(xdata %*% t(betadraws))
389 542
  } 
390 -
  stop("Developer error (I should not be here): please file an issue on GitHub")  # nocov
543 +
  
544 +
  new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, 
545 +
                 fullsims = fullsims)
391 546
}
392 547
393 548
394 549
395 -
# auc_roc and auc_pr are not really used, but keep around just in case
396 -
auc_roc <- function(obs, pred) {
397 -
  values <- compute_roc(obs, pred)
398 -
  caTools::trapz(values$x, values$y)
399 -
}
400 -
401 -
auc_pr <- function(obs, pred) {
402 -
  values <- compute_pr(obs, pred)
403 -
  caTools::trapz(values$x, values$y)
404 -
}

@@ -0,0 +1,183 @@
Loading
1 +
#
2 +
#   Methods for class "mcmcRocPrc", generated by mcmcRocPrc()
3 +
#
4 +
5 +
#' @rdname mcmcRocPrc
6 +
#' 
7 +
#' @export
8 +
print.mcmcRocPrc <- function(x, ...) {
9 +
  
10 +
  auc_roc <- x$area_under_roc
11 +
  auc_prc <- x$area_under_prc
12 +
  
13 +
  has_curves <- !is.null(x$roc_dat)
14 +
  has_sims   <- length(auc_roc) > 1
15 +
  
16 +
  if (!has_sims) {
17 +
    roc_msg <- sprintf("%.3f", round(auc_roc, 3))
18 +
    prc_msg <- sprintf("%.3f", round(auc_prc, 3))
19 +
  } else {
20 +
    roc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]",
21 +
                       round(mean(auc_roc), 3), 
22 +
                       round(quantile(auc_roc, 0.1), 3), 
23 +
                       round(quantile(auc_roc, 0.9), 3))
24 +
    prc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]",
25 +
                       round(mean(auc_prc), 3), 
26 +
                       round(quantile(auc_prc, 0.1), 3), 
27 +
                       round(quantile(auc_prc, 0.9), 3))
28 +
  }
29 +
  
30 +
  cat("mcmcRocPrc object\n")
31 +
  cat(sprintf("curves: %s; fullsims: %s\n", has_curves, has_sims))
32 +
  cat(sprintf("AUC-ROC: %s\n", roc_msg))
33 +
  cat(sprintf("AUC-PR:  %s\n", prc_msg))
34 +
  
35 +
  invisible(x)
36 +
}
37 +
38 +
#' @rdname mcmcRocPrc
39 +
#' 
40 +
#' @param n plot method: if `fullsims = TRUE`, how many sample curves to draw?
41 +
#' @param alpha plot method: alpha value for plotting sampled curves; between 0 and 1
42 +
#' 
43 +
#' @export
44 +
plot.mcmcRocPrc <- function(x, n = 40, alpha = .5, ...) {
45 +
  
46 +
  stopifnot(
47 +
    "Use mcmcRocPrc(..., curves = TRUE) to generate data for plots" = (!is.null(x$roc_dat)),
48 +
    "alpha must be between 0 and 1" = (alpha >= 0 & alpha <= 1),
49 +
    "n must be > 0" = (n > 0)
50 +
  )
51 +
  
52 +
  obj<- x
53 +
  fullsims <- length(obj$roc_dat) > 1
54 +
  
55 +
  if (!fullsims) {
56 +
    
57 +
    graphics::par(mfrow = c(1, 2))
58 +
    plot(obj$roc_dat[[1]], type = "s", xlab = "FPR", ylab = "TPR")
59 +
    graphics::abline(a = 0, b = 1, lty = 3, col = "gray50")
60 +
    
61 +
    prc_dat <- obj$prc_dat[[1]]
62 +
    # use first non-NaN y-value for y[1]
63 +
    prc_dat$y[1] <- prc_dat$y[2]
64 +
    plot(prc_dat, type = "l", xlab = "TPR", ylab = "Precision",
65 +
         ylim = c(0, 1))
66 +
    graphics::abline(a = attr(x, "y_pos_rate"), b = 0, lty = 3, col = "gray50")
67 +
    
68 +
  } else {
69 +
    
70 +
    graphics::par(mfrow = c(1, 2))
71 +
    
72 +
    roc_dat <- obj$roc_dat
73 +
    
74 +
    x <- lapply(roc_dat, `[[`, 1)
75 +
    x <- do.call(cbind, x)
76 +
    colnames(x) <- paste0("sim", 1:ncol(x))
77 +
    
78 +
    y <- lapply(roc_dat, `[[`, 2)
79 +
    y <- do.call(cbind, y)
80 +
    colnames(y) <- paste0("sim", 1:ncol(y))
81 +
    
82 +
    xavg <- rowMeans(x)
83 +
    yavg <- rowMeans(y)
84 +
    
85 +
    plot(xavg, yavg, type = "n", xlab = "FPR", ylab = "TPR")
86 +
    samples <- sample(1:ncol(x), n)
87 +
    for (i in samples) {
88 +
      graphics::lines(
89 +
        x[, i], y[, i], type = "s",
90 +
        col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255)
91 +
      )
92 +
    }
93 +
    graphics::lines(xavg, yavg, type = "s")
94 +
    
95 +
    # PRC
96 +
    # The elements of prc_dat have different lengths, unlike roc_dat, so we
97 +
    # have to do the central curve differently.
98 +
    prc_dat <- obj$prc_dat
99 +
    
100 +
    x <- lapply(prc_dat, `[[`, 1)
101 +
    y <- lapply(prc_dat, `[[`, 2)
102 +
    
103 +
    # Instead of combining the list of curve coordinates from each sample into
104 +
    # two x and y matrices, we can first make a point cloud with all curve 
105 +
    # points from all samples, and then average the y values at all distinct
106 +
    # x coordinates. The x-axis plots recall (TPR), which will only have as 
107 +
    # many distinct values as there are positives in the data, so this does 
108 +
    # not lose any information about the x coordinates. 
109 +
    point_cloud <- data.frame(
110 +
      x = unlist(x),
111 +
      y = unlist(y)
112 +
    )
113 +
    point_cloud <- stats::aggregate(point_cloud[, "y", drop = FALSE], 
114 +
                                    # factor implicitly encodes distinct values only,
115 +
                                    # since they will get the same labels
116 +
                                    by = list(x = as.factor(point_cloud$x)), 
117 +
                                    FUN = mean)
118 +
    point_cloud$x <- as.numeric(as.character(point_cloud$x))
119 +
    xavg <- point_cloud$x
120 +
    yavg <- point_cloud$y
121 +
    
122 +
    plot(xavg, yavg, type = "n", xlab = "TPR", ylab = "Precision", ylim = c(0, 1))
123 +
    samples <- sample(1:length(prc_dat), n)
124 +
    for (i in samples) {
125 +
      graphics::lines(
126 +
        x[[i]], y[[i]], 
127 +
        col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255)
128 +
      )
129 +
    }
130 +
    graphics::lines(xavg, yavg)
131 +
    
132 +
  }
133 +
  
134 +
  invisible(x)
135 +
}
136 +
137 +
#' @rdname mcmcRocPrc
138 +
#' 
139 +
#' @param row.names see [base::as.data.frame()] 
140 +
#' @param optional see [base::as.data.frame()]
141 +
#' @param what which information to extract and convert to a data frame?
142 +
#' 
143 +
#' @export
144 +
as.data.frame.mcmcRocPrc <- function(x, row.names = NULL, optional = FALSE,
145 +
                                     what = c("auc", "roc", "prc"), ...) {
146 +
  what <- match.arg(what)
147 +
  if (what=="auc") {
148 +
    # all 4 output types have AUC, so this should work across the board
149 +
    return(as.data.frame(x[c("area_under_roc", "area_under_prc")]))
150 +
    
151 +
  } else if (what %in% c("roc", "prc")) {
152 +
    if (what=="roc") element <- "roc_dat" else element <- "prc_dat"
153 +
    
154 +
    # if curves was FALSE, there will be no curve data...
155 +
    if (is.null(x[[element]])) {
156 +
      stop("No curve data; use mcmcRocPrc(..., curves = TRUE)")
157 +
    }
158 +
    
159 +
    # Otherwise, there will be either one set of coordinates if mcmcmRegPrc()
160 +
    # was called with fullsims = FALSE, or else N_sims curve data sets.
161 +
    # If the latter, we can return a long data frame with an identifying 
162 +
    # "sim" column to delineate the sim sets. To ensure consistency in output,
163 +
    # also add this column when fullsims = FALSE.
164 +
    
165 +
    # averaged, single coordinate set
166 +
    if (length(x[[element]])==1L) {
167 +
      return(data.frame(sim = 1L, x[[element]][[1]]))
168 +
    }
169 +
    
170 +
    # full sims
171 +
    # add a unique ID to each coordinate set
172 +
    outlist <- x[[element]]
173 +
    outlist <- Map(cbind, sim = (1:length(outlist)), outlist)
174 +
    # combine into long data frame
175 +
    outdf <- do.call(rbind, outlist)
176 +
    return(outdf)
177 +
  } 
178 +
  stop("Developer error (I should not be here): please file an issue on GitHub")  # nocov
179 +
}
180 +
181 +
182 +
183 +
Files Coverage
R 83.11%
Project Totals (11 files) 83.11%
1
comment: false
2

3
coverage:
4
  status:
5
    project:
6
      default:
7
        target: auto
8
        threshold: 1%
9
    patch:
10
      default:
11
        target: auto
12
        threshold: 1%
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading