No flags found
Use flags to group coverage reports by test type, project and/or folders.
Then setup custom commit statuses and notifications for each flag.
e.g., #unittest #integration
#production #enterprise
#frontend #backend
559d018
... +8 ...
7d49199
Use flags to group coverage reports by test type, project and/or folders.
Then setup custom commit statuses and notifications for each flag.
e.g., #unittest #integration
#production #enterprise
#frontend #backend
1 | + | # |
|
2 | + | # This file contains the mcmcRocPrc() S3 generic, which constructs objects |
|
3 | + | # of class "mcmcRocPrc". For methods for this class, see mcmcRocPrc-methods.R |
|
4 | + | # S3 methods for the mcmcRocPrc() generic handle different types of input |
|
5 | + | # e.g. "rjags" input produced by R2jags. |
|
6 | + | # |
|
7 | + | ||
8 | + | ||
9 | + | ||
1 | 10 | #' ROC and Precision-Recall Curves using Bayesian MCMC estimates |
|
2 | 11 | #' |
|
3 | 12 | #' Generate ROC and Precision-Recall curves after fitting a Bayesian logit or |
|
4 | - | #' probit regression using [R2jags::jags()] |
|
13 | + | #' probit regression using [rstan::stan()], [rstanarm::stan_glm()], |
|
14 | + | #' [R2jags::jags()], [R2WinBUGS::bugs()], [MCMCpack::MCMClogit()], or other |
|
15 | + | #' functions that provide samples from a posterior density. |
|
5 | 16 | #' |
|
6 | - | #' @param object A "rjags" object (see [R2jags::jags()]) for a fitted binary |
|
7 | - | #' choice model. |
|
8 | - | #' @param yname (`character(1)`)\cr |
|
9 | - | #' The name of the dependent variable, should match the variable name in the |
|
10 | - | #' JAGS data object. |
|
11 | - | #' @param xnames ([base::character()])\cr |
|
12 | - | #' A character vector of the independent variable names, should match the |
|
13 | - | #' corresponding names in the JAGS data object. |
|
17 | + | #' @param object A fitted binary choice model, e.g. "rjags" object |
|
18 | + | #' (see [R2jags::jags()]), or a `[N, iter]` matrix of predicted probabilites. |
|
14 | 19 | #' @param curves logical indicator of whether or not to return values to plot |
|
15 | 20 | #' the ROC or Precision-Recall curves. If set to `FALSE` (default), |
|
16 | 21 | #' results are returned as a list without the extra values. |
|
17 | 22 | #' @param fullsims logical indicator of whether full object (based on all MCMC |
|
18 | 23 | #' draws rather than their average) will be returned. Default is `FALSE`. |
|
19 | 24 | #' Note: If `TRUE` is chosen, the function takes notably longer to execute. |
|
25 | + | #' @param yvec A `numeric(N)` vector of observed outcomes. |
|
26 | + | #' @param yname (`character(1)`)\cr |
|
27 | + | #' The name of the dependent variable, should match the variable name in the |
|
28 | + | #' JAGS data object. |
|
29 | + | #' @param xnames ([base::character()])\cr |
|
30 | + | #' A character vector of the independent variable names, should match the |
|
31 | + | #' corresponding names in the JAGS data object. |
|
32 | + | #' @param posterior_samples a "mcmc" object with the posterior samples |
|
20 | 33 | #' @param ... Used by methods |
|
21 | 34 | #' @param x a `mcmcRocPrc()` object |
|
22 | 35 | #' |
|
36 | + | #' @details If only the average AUC-ROC and PR are of interest, setting |
|
37 | + | #' `curves = FALSE` and `fullsims = FALSE` can greatly speed up calculation |
|
38 | + | #' time. The curve data (`curves = TRUE`) is needed for plotting. The plot |
|
39 | + | #' method will always plot both the ROC and PR curves, but the underlying |
|
40 | + | #' data can easily be extracted from the output for your own plotting; |
|
41 | + | #' see the documentation of the value returned below. |
|
42 | + | #' |
|
43 | + | #' The default method works with a matrix of predicted probabilities and the |
|
44 | + | #' vector of observed incomes as input. Other methods accommodate some of the |
|
45 | + | #' common Bayesian modeling packages like rstan (which returns class "stanfit"), |
|
46 | + | #' rstanarm ("stanreg"), R2jags ("jags"), R2WinBUGS ("bugs"), and |
|
47 | + | #' MCMCpack ("mcmc"). Even if a package-specific method is not implemented, |
|
48 | + | #' the default method can always be used as a fallback by manually calculating |
|
49 | + | #' the matrix of predicted probabilities for each posterior sample. |
|
50 | + | #' |
|
51 | + | #' Note that MCMCpack returns generic "mcmc" output that is annotated with |
|
52 | + | #' some additional information as attributes, including the original function |
|
53 | + | #' call. There is no inherent way to distinguish any other kind of "mcmc" |
|
54 | + | #' object from one generated by a proper MCMCpack modeling function, but as a |
|
55 | + | #' basic precaution, `mcmcRocPrc()` will check the saved call and return an |
|
56 | + | #' error if the function called was not `MCMClogit()` or `MCMCprobit()`. |
|
57 | + | #' This behavior can be suppressed by setting `force = TRUE`. |
|
58 | + | #' |
|
23 | 59 | #' @references Beger, Andreas. 2016. “Precision-Recall Curves.” Available at |
|
24 | 60 | #' SSRN: [http://dx.doi.org/10.2139/ssrn.2765419](http://dx.doi.org/10.2139/ssrn.2765419) |
|
25 | 61 | #' |
|
26 | 62 | #' @return Returns a list with length 2 or 4, depending on the on the "curves" |
|
27 | 63 | #' and "fullsims" argument values: |
|
28 | 64 | #' |
|
29 | - | #' - "area_under_roc": `numeric(1)` |
|
30 | - | #' - "area_under_prc": `numeric(1)` |
|
31 | - | #' - "prc_dat": only if `curves = TRUE`; a list with length 1 if `fullsims = FALSE`, longer otherwise |
|
32 | - | #' - "roc_dat": only if `curves = TRUE`; a list with length 1 if `fullsims = FALSE`, longer otherwise |
|
65 | + | #' - "area_under_roc": `numeric()`; either length 1 if `fullsims = FALSE`, or |
|
66 | + | #' one value for each posterior sample otherwise |
|
67 | + | #' - "area_under_prc": `numeric()`; either length 1 if `fullsims = FALSE`, or |
|
68 | + | #' one value for each posterior sample otherwise |
|
69 | + | #' - "prc_dat": only if `curves = TRUE`; a list with length 1 if |
|
70 | + | #' `fullsims = FALSE`, longer otherwise |
|
71 | + | #' - "roc_dat": only if `curves = TRUE`; a list with length 1 if |
|
72 | + | #' `fullsims = FALSE`, longer otherwise |
|
33 | 73 | #' |
|
34 | 74 | #' @examples |
|
35 | 75 | #' # load simulated data and fitted model (see ?sim_data and ?jags_logit) |
|
36 | - | #' library(R2jags) |
|
37 | 76 | #' data("jags_logit") |
|
38 | 77 | #' |
|
39 | 78 | #' # using mcmcRocPrc |
44 | 83 | #' fullsims = FALSE) |
|
45 | 84 | #' fit_sum |
|
46 | 85 | #' plot(fit_sum) |
|
86 | + | #' |
|
87 | + | #' # Equivalently, we can calculate the matrix of predicted probabilities |
|
88 | + | #' # ourselves; using the example from ?jags_logit: |
|
89 | + | #' library(R2jags) |
|
90 | + | #' |
|
91 | + | #' data("sim_data") |
|
92 | + | #' yvec <- sim_data$Y |
|
93 | + | #' xmat <- sim_data[, c("X1", "X2")] |
|
94 | + | #' |
|
95 | + | #' # add intercept to the X data |
|
96 | + | #' xmat <- as.matrix(cbind(Intercept = 1L, xmat)) |
|
97 | + | #' |
|
98 | + | #' beta <- as.matrix(as.mcmc(jags_logit))[, c("b[1]", "b[2]", "b[3]")] |
|
99 | + | #' pred_mat <- plogis(xmat %*% t(beta)) |
|
100 | + | #' |
|
101 | + | #' # the matrix of predictions has rows matching the number of rows in the data; |
|
102 | + | #' # the column are the predictions for each of the 2,000 posterior samples |
|
103 | + | #' nrow(sim_data) |
|
104 | + | #' dim(pred_mat) |
|
105 | + | #' |
|
106 | + | #' # now we can call mcmcRocPrc; the default method works with the matrix |
|
107 | + | #' # of predictions and vector of outcomes as input |
|
108 | + | #' mcmcRocPrc(object = pred_mat, curves = TRUE, fullsims = FALSE, yvec = yvec) |
|
109 | + | #' |
|
47 | 110 | #' @export |
|
48 | 111 | #' @md |
|
112 | + | mcmcRocPrc <- function(object, curves = FALSE, fullsims = FALSE, ...) { |
|
113 | + | UseMethod("mcmcRocPrc", object) |
|
114 | + | } |
|
49 | 115 | ||
50 | - | mcmcRocPrc <- function(object, |
|
51 | - | yname, |
|
52 | - | xnames, |
|
53 | - | curves = FALSE, |
|
54 | - | fullsims = FALSE) { |
|
55 | - | ||
56 | - | link_logit <- any(grepl("logit", object$model$model())) |
|
57 | - | link_probit <- any(grepl("probit", object$model$model())) |
|
58 | - | ||
59 | - | if (isFALSE(link_logit | link_probit)) { |
|
60 | - | stop("Could not identify model link function") |
|
61 | - | } |
|
62 | - | ||
63 | - | mdl_data <- object$model$data() |
|
64 | - | stopifnot(all(xnames %in% names(mdl_data))) |
|
65 | - | stopifnot(all(yname %in% names(mdl_data))) |
|
66 | - | ||
67 | - | # add intercept by default, maybe revisit this |
|
68 | - | xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames]))) |
|
69 | - | yvec <- mdl_data[[yname]] |
|
70 | - | ||
71 | - | pardraws <- as.matrix(coda::as.mcmc(object)) |
|
72 | - | # this is not very robust, assumes pars are 'b[x]' |
|
73 | - | # for both this and the intercept addition above, maybe a more robust solution |
|
74 | - | # down the road would be to dig into the object$model$model() string |
|
75 | - | betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))] |
|
116 | + | #' Constructor for mcmcRocPrc objects |
|
117 | + | #' |
|
118 | + | #' This function actually does the heavy lifting once we have a matrix of |
|
119 | + | #' predicted probabilities from a model, plus the vector of observed outcomes. |
|
120 | + | #' The reason to have it here in a single function is that we don't replicate |
|
121 | + | #' it in each function that accomodates a JAGS, BUGS, RStan, etc. object. |
|
122 | + | #' |
|
123 | + | #' @param pred_prob a `\[N, iter\]` matrix of predicted probabilities |
|
124 | + | #' @param yvec a `numeric(N)` vector of observed outcomes |
|
125 | + | #' @param curves include curve data in output? |
|
126 | + | #' @param fullsims collapse posterior samples into single summary? |
|
127 | + | #' |
|
128 | + | #' @md |
|
129 | + | #' @keywords internal |
|
130 | + | new_mcmcRocPrc <- function(pred_prob, yvec, curves, fullsims) { |
|
76 | 131 | ||
77 | - | if(isTRUE(link_logit)) { |
|
78 | - | pred_prob <- plogis(xdata %*% t(betadraws)) |
|
79 | - | } else if (isTRUE(link_probit)) { |
|
80 | - | pred_prob <- pnorm(xdata %*% t(betadraws)) |
|
81 | - | } |
|
132 | + | stopifnot( |
|
133 | + | "number of predictions and observed outcomes do not match" = nrow(pred_prob)==length(yvec), |
|
134 | + | "yvec must be 0 or 1" = all(yvec %in% c(0L, 1L)), |
|
135 | + | "pred_prob must be in the interval [0, 1]" = all(pred_prob >= 0 & pred_prob <= 1) |
|
136 | + | ) |
|
82 | 137 | ||
83 | 138 | # pred_prob is a [N, iter] matrix, i.e. each column are preds from one |
|
84 | 139 | # set of posterior samples |
146 | 201 | ) |
|
147 | 202 | } |
|
148 | 203 | ||
204 | + | #' @rdname mcmcRocPrc |
|
205 | + | #' |
|
206 | + | #' @md |
|
207 | + | #' @export |
|
208 | + | mcmcRocPrc.default <- function(object, curves, fullsims, yvec, ...) { |
|
209 | + | pred_prob <- object |
|
210 | + | ||
211 | + | stopifnot( |
|
212 | + | "mcmcRocPrc.default requires 'matrix' like input" = inherits(pred_prob, "matrix") |
|
213 | + | ) |
|
214 | + | ||
215 | + | new_mcmcRocPrc(pred_prob, yvec, curves, fullsims) |
|
216 | + | } |
|
217 | + | ||
218 | + | # Under the hood ROC/PRC calculations ------------------------------------- |
|
149 | 219 | ||
150 | 220 | #' Compute ROC and PR curve points |
|
151 | 221 | #' |
213 | 283 | } |
|
214 | 284 | ||
215 | 285 | ||
286 | + | # auc_roc and auc_pr are not really used, but keep around just in case |
|
287 | + | auc_roc <- function(obs, pred) { |
|
288 | + | values <- compute_roc(obs, pred) |
|
289 | + | caTools::trapz(values$x, values$y) |
|
290 | + | } |
|
291 | + | ||
292 | + | auc_pr <- function(obs, pred) { |
|
293 | + | values <- compute_pr(obs, pred) |
|
294 | + | caTools::trapz(values$x, values$y) |
|
295 | + | } |
|
296 | + | ||
297 | + | ||
298 | + | ||
299 | + | # JAGS-like input (rjags, R2jags, runjags) -------------------------------- |
|
216 | 300 | ||
217 | 301 | #' @rdname mcmcRocPrc |
|
218 | 302 | #' |
|
219 | 303 | #' @export |
|
220 | - | print.mcmcRocPrc <- function(x, ...) { |
|
221 | - | ||
222 | - | auc_roc <- x$area_under_roc |
|
223 | - | auc_prc <- x$area_under_prc |
|
224 | - | ||
225 | - | has_curves <- !is.null(x$roc_dat) |
|
226 | - | has_sims <- length(auc_roc) > 1 |
|
227 | - | ||
228 | - | if (!has_sims) { |
|
229 | - | roc_msg <- sprintf("%.3f", round(auc_roc, 3)) |
|
230 | - | prc_msg <- sprintf("%.3f", round(auc_prc, 3)) |
|
231 | - | } else { |
|
232 | - | roc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]", |
|
233 | - | round(mean(auc_roc), 3), |
|
234 | - | round(quantile(auc_roc, 0.1), 3), |
|
235 | - | round(quantile(auc_roc, 0.9), 3)) |
|
236 | - | prc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]", |
|
237 | - | round(mean(auc_prc), 3), |
|
238 | - | round(quantile(auc_prc, 0.1), 3), |
|
239 | - | round(quantile(auc_prc, 0.9), 3)) |
|
304 | + | mcmcRocPrc.jags <- function(object, curves = FALSE, fullsims = FALSE, yname, |
|
305 | + | xnames, posterior_samples, ...) { |
|
306 | + | ||
307 | + | stopifnot( |
|
308 | + | inherits(posterior_samples, c("mcmc", "mcmc.list")) |
|
309 | + | ) |
|
310 | + | ||
311 | + | link_logit <- any(grepl("logit", object$model())) |
|
312 | + | link_probit <- any(grepl("probit", object$model())) |
|
313 | + | ||
314 | + | if (isFALSE(link_logit | link_probit)) { |
|
315 | + | stop("Could not identify model link function") |
|
240 | 316 | } |
|
241 | 317 | ||
242 | - | cat("mcmcRocPrc object\n") |
|
243 | - | cat(sprintf("curves: %s; fullsims: %s\n", has_curves, has_sims)) |
|
244 | - | cat(sprintf("AUC-ROC: %s\n", roc_msg)) |
|
245 | - | cat(sprintf("AUC-PR: %s\n", prc_msg)) |
|
318 | + | mdl_data <- object$data() |
|
319 | + | stopifnot(all(xnames %in% names(mdl_data))) |
|
320 | + | stopifnot(all(yname %in% names(mdl_data))) |
|
321 | + | ||
322 | + | # add intercept by default, maybe revisit this |
|
323 | + | xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames]))) |
|
324 | + | yvec <- mdl_data[[yname]] |
|
325 | + | ||
326 | + | pardraws <- as.matrix(posterior_samples) |
|
327 | + | # this is not very robust, assumes pars are 'b[x]' |
|
328 | + | # for both this and the intercept addition above, maybe a more robust solution |
|
329 | + | # down the road would be to dig into the object$model$model() string |
|
330 | + | betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))] |
|
331 | + | ||
332 | + | if(isTRUE(link_logit)) { |
|
333 | + | pred_prob <- plogis(xdata %*% t(betadraws)) |
|
334 | + | } else if (isTRUE(link_probit)) { |
|
335 | + | pred_prob <- pnorm(xdata %*% t(betadraws)) |
|
336 | + | } |
|
246 | 337 | ||
247 | - | invisible(x) |
|
338 | + | new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, |
|
339 | + | fullsims = fullsims) |
|
248 | 340 | } |
|
249 | 341 | ||
250 | 342 | #' @rdname mcmcRocPrc |
|
251 | 343 | #' |
|
252 | - | #' @param n plot method: if `fullsims = TRUE`, how many sample curves to draw? |
|
253 | - | #' @param alpha plot method: alpha value for plotting sampled curves; between 0 and 1 |
|
344 | + | #' @export |
|
345 | + | mcmcRocPrc.rjags <- function(object, curves = FALSE, fullsims = FALSE, yname, |
|
346 | + | xnames, ...) { |
|
347 | + | ||
348 | + | if (!requireNamespace("R2jags", quietly = TRUE)) { |
|
349 | + | stop("Package \"R2jags\" is needed for this function to work. Please install it.", call. = FALSE) # nocov |
|
350 | + | } |
|
351 | + | ||
352 | + | jags_object <- object$model |
|
353 | + | pardraws <- coda::as.mcmc(object) |
|
354 | + | ||
355 | + | # pass it on to the "jags" method |
|
356 | + | mcmcRocPrc(object = jags_object, curves = curves, fullsims = fullsims, |
|
357 | + | yname = yname, xnames = xnames, posterior_samples = pardraws, ...) |
|
358 | + | } |
|
359 | + | ||
360 | + | #' @rdname mcmcRocPrc |
|
254 | 361 | #' |
|
255 | 362 | #' @export |
|
256 | - | plot.mcmcRocPrc <- function(x, n = 40, alpha = .5, ...) { |
|
363 | + | mcmcRocPrc.runjags <- function(object, curves = FALSE, fullsims = FALSE, yname, |
|
364 | + | xnames, ...) { |
|
365 | + | jags_object <- runjags::as.jags(object, quiet = TRUE) |
|
366 | + | # as.mcmc.runjags will issue a warning when converting multiple chains |
|
367 | + | # because it combines them |
|
368 | + | pardraws <- suppressWarnings(coda::as.mcmc(object)) |
|
369 | + | ||
370 | + | # pass it on to the "jags" method |
|
371 | + | mcmcRocPrc(object = jags_object, curves = curves, fullsims = fullsims, |
|
372 | + | yname = yname, xnames = xnames, posterior_samples = pardraws, ...) |
|
373 | + | } |
|
257 | 374 | ||
258 | - | stopifnot( |
|
259 | - | "Use mcmcRocPrc(..., curves = TRUE) to generate data for plots" = (!is.null(x$roc_dat)), |
|
260 | - | "alpha must be between 0 and 1" = (alpha >= 0 & alpha <= 1), |
|
261 | - | "n must be > 0" = (n > 0) |
|
262 | - | ) |
|
375 | + | ||
376 | + | # STAN-like input (rstan, rstanarm, brms) --------------------------------- |
|
377 | + | ||
378 | + | ||
379 | + | ||
380 | + | #' @rdname mcmcRocPrc |
|
381 | + | #' |
|
382 | + | #' @param data the data that was used in the `stan(data = ?, ...)` call |
|
383 | + | #' |
|
384 | + | #' @export |
|
385 | + | mcmcRocPrc.stanfit <- function(object, curves = FALSE, fullsims = FALSE, data, |
|
386 | + | xnames, yname, ...) { |
|
387 | + | if (!requireNamespace("rstan", quietly = TRUE)) { |
|
388 | + | stop("Package \"rstan\" is needed for this function to work. Please install it.", call. = FALSE) # nocov |
|
389 | + | } |
|
263 | 390 | ||
264 | - | obj<- x |
|
265 | - | fullsims <- length(obj$roc_dat) > 1 |
|
391 | + | if (!is_binary_model(object)) { |
|
392 | + | stop("the input model does not seem to be a binary choice model; if this is a mistake please file an issue at https://github.com/ShanaScogin/BayesPostEst/issues/") |
|
393 | + | } |
|
394 | + | link_type <- identify_link_function(object) |
|
395 | + | if (is.na(link_type)) { |
|
396 | + | stop("could not identify model link function; please file an issue at https://github.com/ShanaScogin/BayesPostEst/issues/") |
|
397 | + | } |
|
266 | 398 | ||
267 | - | if (!fullsims) { |
|
268 | - | ||
269 | - | graphics::par(mfrow = c(1, 2)) |
|
270 | - | plot(obj$roc_dat[[1]], type = "s", xlab = "FPR", ylab = "TPR") |
|
271 | - | graphics::abline(a = 0, b = 1, lty = 3, col = "gray50") |
|
272 | - | ||
273 | - | prc_dat <- obj$prc_dat[[1]] |
|
274 | - | # use first non-NaN y-value for y[1] |
|
275 | - | prc_dat$y[1] <- prc_dat$y[2] |
|
276 | - | plot(prc_dat, type = "l", xlab = "TPR", ylab = "Precision", |
|
277 | - | ylim = c(0, 1)) |
|
278 | - | graphics::abline(a = attr(x, "y_pos_rate"), b = 0, lty = 3, col = "gray50") |
|
279 | - | ||
280 | - | } else { |
|
281 | - | ||
282 | - | graphics::par(mfrow = c(1, 2)) |
|
283 | - | ||
284 | - | roc_dat <- obj$roc_dat |
|
285 | - | ||
286 | - | x <- lapply(roc_dat, `[[`, 1) |
|
287 | - | x <- do.call(cbind, x) |
|
288 | - | colnames(x) <- paste0("sim", 1:ncol(x)) |
|
289 | - | ||
290 | - | y <- lapply(roc_dat, `[[`, 2) |
|
291 | - | y <- do.call(cbind, y) |
|
292 | - | colnames(y) <- paste0("sim", 1:ncol(y)) |
|
293 | - | ||
294 | - | xavg <- rowMeans(x) |
|
295 | - | yavg <- rowMeans(y) |
|
296 | - | ||
297 | - | plot(xavg, yavg, type = "n", xlab = "FPR", ylab = "TPR") |
|
298 | - | samples <- sample(1:ncol(x), n) |
|
299 | - | for (i in samples) { |
|
300 | - | graphics::lines( |
|
301 | - | x[, i], y[, i], type = "s", |
|
302 | - | col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255) |
|
303 | - | ) |
|
304 | - | } |
|
305 | - | graphics::lines(xavg, yavg, type = "s") |
|
306 | - | ||
307 | - | # PRC |
|
308 | - | # The elements of prc_dat have different lengths, unlike roc_dat, so we |
|
309 | - | # have to do the central curve differently. |
|
310 | - | prc_dat <- obj$prc_dat |
|
399 | + | mdl_data <- data |
|
400 | + | stopifnot(all(xnames %in% names(mdl_data))) |
|
401 | + | stopifnot(all(yname %in% names(mdl_data))) |
|
402 | + | ||
403 | + | # add intercept by default, maybe revisit this |
|
404 | + | xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames]))) |
|
405 | + | yvec <- mdl_data[[yname]] |
|
406 | + | ||
407 | + | pardraws <- as.matrix(object) |
|
408 | + | # this is not very robust, assumes pars are 'b[x]' |
|
409 | + | betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))] |
|
410 | + | ||
411 | + | if(link_type=="logit") { |
|
412 | + | pred_prob <- plogis(xdata %*% t(betadraws)) |
|
413 | + | } else if (link_type=="probit") { |
|
414 | + | pred_prob <- pnorm(xdata %*% t(betadraws)) |
|
415 | + | } |
|
416 | + | ||
417 | + | new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, |
|
418 | + | fullsims = fullsims) |
|
419 | + | ||
420 | + | ||
421 | + | } |
|
311 | 422 | ||
312 | - | x <- lapply(prc_dat, `[[`, 1) |
|
313 | - | y <- lapply(prc_dat, `[[`, 2) |
|
314 | - | ||
315 | - | # Instead of combining the list of curve coordinates from each sample into |
|
316 | - | # two x and y matrices, we can first make a point cloud with all curve |
|
317 | - | # points from all samples, and then average the y values at all distinct |
|
318 | - | # x coordinates. The x-axis plots recall (TPR), which will only have as |
|
319 | - | # many distinct values as there are positives in the data, so this does |
|
320 | - | # not lose any information about the x coordinates. |
|
321 | - | point_cloud <- data.frame( |
|
322 | - | x = unlist(x), |
|
323 | - | y = unlist(y) |
|
324 | - | ) |
|
325 | - | point_cloud <- stats::aggregate(point_cloud[, "y", drop = FALSE], |
|
326 | - | # factor implicitly encodes distinct values only, |
|
327 | - | # since they will get the same labels |
|
328 | - | by = list(x = as.factor(point_cloud$x)), |
|
329 | - | FUN = mean) |
|
330 | - | point_cloud$x <- as.numeric(as.character(point_cloud$x)) |
|
331 | - | xavg <- point_cloud$x |
|
332 | - | yavg <- point_cloud$y |
|
333 | - | ||
334 | - | plot(xavg, yavg, type = "n", xlab = "TPR", ylab = "Precision", ylim = c(0, 1)) |
|
335 | - | samples <- sample(1:length(prc_dat), n) |
|
336 | - | for (i in samples) { |
|
337 | - | graphics::lines( |
|
338 | - | x[[i]], y[[i]], |
|
339 | - | col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255) |
|
340 | - | ) |
|
341 | - | } |
|
342 | - | graphics::lines(xavg, yavg) |
|
343 | - | ||
423 | + | #' Try to identify if a stanfit model is a binary choice model |
|
424 | + | #' |
|
425 | + | #' @param obj stanfit object |
|
426 | + | #' |
|
427 | + | #' @keywords internal |
|
428 | + | is_binary_model <- function(obj) { |
|
429 | + | stopifnot(inherits(obj, "stanfit")) |
|
430 | + | model_string <- rstan::get_stancode(obj) |
|
431 | + | grepl("bernoulli", model_string) |
|
432 | + | } |
|
433 | + | ||
434 | + | #' Try to identify link function |
|
435 | + | #' |
|
436 | + | #' @param obj stanfit object |
|
437 | + | #' |
|
438 | + | #' @return Either "logit" or "probit"; if neither can be identified the function |
|
439 | + | #' will return `NA_character_`. |
|
440 | + | #' |
|
441 | + | #' @keywords internal |
|
442 | + | identify_link_function <- function(obj) { |
|
443 | + | stopifnot(inherits(obj, "stanfit")) |
|
444 | + | model_string <- rstan::get_stancode(obj) |
|
445 | + | if (grepl("logit", model_string)) return("logit") |
|
446 | + | if (grepl("Phi", model_string)) return("probit") |
|
447 | + | NA_character_ |
|
448 | + | } |
|
449 | + | ||
450 | + | #' @rdname mcmcRocPrc |
|
451 | + | #' |
|
452 | + | #' @export |
|
453 | + | mcmcRocPrc.stanreg <- function(object, curves = FALSE, fullsims = FALSE, ...) { |
|
454 | + | if (!requireNamespace("rstanarm", quietly = TRUE)) { |
|
455 | + | stop("Package \"rstanarm\" is needed for this function to work. Please install it.", call. = FALSE) # nocov |
|
344 | 456 | } |
|
457 | + | if (!stats::family(object)$family=="binomial") { |
|
458 | + | stop("the input model does not seem to be a binary choice model; should be like 'obj <- stan_glm(family = binomial(), ...)'") |
|
459 | + | } |
|
460 | + | pred_prob <- rstanarm::posterior_linpred(object, transform = TRUE) |
|
461 | + | # posterior_linepred returns a matrix in which data cases are columns, and |
|
462 | + | # MCMC samples are row; we need to transpose this so that columns are samples |
|
463 | + | pred_prob <- t(pred_prob) |
|
464 | + | yvec <- unname(object$y) |
|
345 | 465 | ||
346 | - | invisible(x) |
|
466 | + | new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, |
|
467 | + | fullsims = fullsims) |
|
347 | 468 | } |
|
348 | 469 | ||
349 | 470 | #' @rdname mcmcRocPrc |
|
350 | 471 | #' |
|
351 | - | #' @param row.names see [base::as.data.frame()] |
|
352 | - | #' @param optional see [base::as.data.frame()] |
|
353 | - | #' @param what which information to extract and convert to a data frame? |
|
472 | + | #' @export |
|
473 | + | mcmcRocPrc.brmsfit <- function(object, curves = FALSE, fullsims = FALSE, ...) { |
|
474 | + | if (!requireNamespace("brms", quietly = TRUE)) { |
|
475 | + | stop("Package \"brms\" is needed for this function to work. Please install it.", call. = FALSE) # nocov |
|
476 | + | } |
|
477 | + | if (!stats::family(object)$family=="bernoulli") { |
|
478 | + | stop("the input model does not seem to be a binary choice model; should be like 'obj <- brm(family = bernoulli(), ...)'") |
|
479 | + | } |
|
480 | + | ||
481 | + | pred_prob <- brms::posterior_epred(object) |
|
482 | + | # posterior_epred returns a matrix in which data cases are columns, and |
|
483 | + | # MCMC samples are row; we need to transpose this so that columns are samples |
|
484 | + | pred_prob <- t(pred_prob) |
|
485 | + | yvec <- stats::model.response(stats::model.frame(object)) |
|
486 | + | ||
487 | + | new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, |
|
488 | + | fullsims = fullsims) |
|
489 | + | } |
|
490 | + | ||
491 | + | ||
492 | + | # Other input types (MCMCpack, ...) --------------------------------------- |
|
493 | + | ||
494 | + | ||
495 | + | #' #' @rdname mcmcRocPrc |
|
496 | + | #' #' |
|
497 | + | #' #' @export |
|
498 | + | #' mcmcRocPrc.bugs <- function(object, curves = FALSE, fullsims = FALSE, ...) { |
|
499 | + | #' stop("not implemented yet") |
|
500 | + | #' } |
|
501 | + | ||
502 | + | ||
503 | + | #' @rdname mcmcRocPrc |
|
354 | 504 | #' |
|
505 | + | #' @param type "logit" or "probit" |
|
506 | + | #' @param force for MCMCpack models, suppress warning if the model does not |
|
507 | + | #' appear to be a binary choice model? |
|
508 | + | #' |
|
355 | 509 | #' @export |
|
356 | - | as.data.frame.mcmcRocPrc <- function(x, row.names = NULL, optional = FALSE, |
|
357 | - | what = c("auc", "roc", "prc"), ...) { |
|
358 | - | what <- match.arg(what) |
|
359 | - | if (what=="auc") { |
|
360 | - | # all 4 output types have AUC, so this should work across the board |
|
361 | - | return(as.data.frame(x[c("area_under_roc", "area_under_prc")])) |
|
362 | - | ||
363 | - | } else if (what %in% c("roc", "prc")) { |
|
364 | - | if (what=="roc") element <- "roc_dat" else element <- "prc_dat" |
|
365 | - | ||
366 | - | # if curves was FALSE, there will be no curve data... |
|
367 | - | if (is.null(x[[element]])) { |
|
368 | - | stop("No curve data; use mcmcRocPrc(..., curves = TRUE)") |
|
369 | - | } |
|
370 | - | ||
371 | - | # Otherwise, there will be either one set of coordinates if mcmcmRegPrc() |
|
372 | - | # was called with fullsims = FALSE, or else N_sims curve data sets. |
|
373 | - | # If the latter, we can return a long data frame with an identifying |
|
374 | - | # "sim" column to delineate the sim sets. To ensure consistency in output, |
|
375 | - | # also add this column when fullsims = FALSE. |
|
376 | - | ||
377 | - | # averaged, single coordinate set |
|
378 | - | if (length(x[[element]])==1L) { |
|
379 | - | return(data.frame(sim = 1L, x[[element]][[1]])) |
|
510 | + | mcmcRocPrc.mcmc <- function(object, curves = FALSE, fullsims = FALSE, data, |
|
511 | + | xnames, yname, type = c("logit", "probit"), |
|
512 | + | force = FALSE, ...) { |
|
513 | + | ||
514 | + | if (!force) { |
|
515 | + | if (is.null(attr(object, "call"))) { |
|
516 | + | stop("object does not have a 'call' attribute; was it generated with a MCMCpack function?") |
|
517 | + | } else { |
|
518 | + | func <- as.character(attr(object, "call"))[1] |
|
519 | + | if (!func %in% c("MCMClogit", "MCMCprobit")) { |
|
520 | + | stop("object does not appear to have been fitted using MCMCpack::MCMClogit() or MCMCprobit(); mcmcRocPrc only properly works for those function. To be safe, consider manually calculating the matrix of predicted probabilities.") |
|
521 | + | } |
|
380 | 522 | } |
|
381 | - | ||
382 | - | # full sims |
|
383 | - | # add a unique ID to each coordinate set |
|
384 | - | outlist <- x[[element]] |
|
385 | - | outlist <- Map(cbind, sim = (1:length(outlist)), outlist) |
|
386 | - | # combine into long data frame |
|
387 | - | outdf <- do.call(rbind, outlist) |
|
388 | - | return(outdf) |
|
523 | + | } |
|
524 | + | ||
525 | + | link_type <- match.arg(type) |
|
526 | + | mdl_data <- data |
|
527 | + | stopifnot( |
|
528 | + | all(xnames %in% names(mdl_data)), |
|
529 | + | all(yname %in% names(mdl_data)) |
|
530 | + | ) |
|
531 | + | ||
532 | + | # add intercept by default, maybe revisit this |
|
533 | + | xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames]))) |
|
534 | + | yvec <- mdl_data[[yname]] |
|
535 | + | ||
536 | + | betadraws <- as.matrix(object) |
|
537 | + | ||
538 | + | if(link_type=="logit") { |
|
539 | + | pred_prob <- plogis(xdata %*% t(betadraws)) |
|
540 | + | } else if (link_type=="probit") { |
|
541 | + | pred_prob <- pnorm(xdata %*% t(betadraws)) |
|
389 | 542 | } |
|
390 | - | stop("Developer error (I should not be here): please file an issue on GitHub") # nocov |
|
543 | + | ||
544 | + | new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, |
|
545 | + | fullsims = fullsims) |
|
391 | 546 | } |
|
392 | 547 | ||
393 | 548 | ||
394 | 549 | ||
395 | - | # auc_roc and auc_pr are not really used, but keep around just in case |
|
396 | - | auc_roc <- function(obs, pred) { |
|
397 | - | values <- compute_roc(obs, pred) |
|
398 | - | caTools::trapz(values$x, values$y) |
|
399 | - | } |
|
400 | - | ||
401 | - | auc_pr <- function(obs, pred) { |
|
402 | - | values <- compute_pr(obs, pred) |
|
403 | - | caTools::trapz(values$x, values$y) |
|
404 | - | } |
1 | + | # |
|
2 | + | # Methods for class "mcmcRocPrc", generated by mcmcRocPrc() |
|
3 | + | # |
|
4 | + | ||
5 | + | #' @rdname mcmcRocPrc |
|
6 | + | #' |
|
7 | + | #' @export |
|
8 | + | print.mcmcRocPrc <- function(x, ...) { |
|
9 | + | ||
10 | + | auc_roc <- x$area_under_roc |
|
11 | + | auc_prc <- x$area_under_prc |
|
12 | + | ||
13 | + | has_curves <- !is.null(x$roc_dat) |
|
14 | + | has_sims <- length(auc_roc) > 1 |
|
15 | + | ||
16 | + | if (!has_sims) { |
|
17 | + | roc_msg <- sprintf("%.3f", round(auc_roc, 3)) |
|
18 | + | prc_msg <- sprintf("%.3f", round(auc_prc, 3)) |
|
19 | + | } else { |
|
20 | + | roc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]", |
|
21 | + | round(mean(auc_roc), 3), |
|
22 | + | round(quantile(auc_roc, 0.1), 3), |
|
23 | + | round(quantile(auc_roc, 0.9), 3)) |
|
24 | + | prc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]", |
|
25 | + | round(mean(auc_prc), 3), |
|
26 | + | round(quantile(auc_prc, 0.1), 3), |
|
27 | + | round(quantile(auc_prc, 0.9), 3)) |
|
28 | + | } |
|
29 | + | ||
30 | + | cat("mcmcRocPrc object\n") |
|
31 | + | cat(sprintf("curves: %s; fullsims: %s\n", has_curves, has_sims)) |
|
32 | + | cat(sprintf("AUC-ROC: %s\n", roc_msg)) |
|
33 | + | cat(sprintf("AUC-PR: %s\n", prc_msg)) |
|
34 | + | ||
35 | + | invisible(x) |
|
36 | + | } |
|
37 | + | ||
38 | + | #' @rdname mcmcRocPrc |
|
39 | + | #' |
|
40 | + | #' @param n plot method: if `fullsims = TRUE`, how many sample curves to draw? |
|
41 | + | #' @param alpha plot method: alpha value for plotting sampled curves; between 0 and 1 |
|
42 | + | #' |
|
43 | + | #' @export |
|
44 | + | plot.mcmcRocPrc <- function(x, n = 40, alpha = .5, ...) { |
|
45 | + | ||
46 | + | stopifnot( |
|
47 | + | "Use mcmcRocPrc(..., curves = TRUE) to generate data for plots" = (!is.null(x$roc_dat)), |
|
48 | + | "alpha must be between 0 and 1" = (alpha >= 0 & alpha <= 1), |
|
49 | + | "n must be > 0" = (n > 0) |
|
50 | + | ) |
|
51 | + | ||
52 | + | obj<- x |
|
53 | + | fullsims <- length(obj$roc_dat) > 1 |
|
54 | + | ||
55 | + | if (!fullsims) { |
|
56 | + | ||
57 | + | graphics::par(mfrow = c(1, 2)) |
|
58 | + | plot(obj$roc_dat[[1]], type = "s", xlab = "FPR", ylab = "TPR") |
|
59 | + | graphics::abline(a = 0, b = 1, lty = 3, col = "gray50") |
|
60 | + | ||
61 | + | prc_dat <- obj$prc_dat[[1]] |
|
62 | + | # use first non-NaN y-value for y[1] |
|
63 | + | prc_dat$y[1] <- prc_dat$y[2] |
|
64 | + | plot(prc_dat, type = "l", xlab = "TPR", ylab = "Precision", |
|
65 | + | ylim = c(0, 1)) |
|
66 | + | graphics::abline(a = attr(x, "y_pos_rate"), b = 0, lty = 3, col = "gray50") |
|
67 | + | ||
68 | + | } else { |
|
69 | + | ||
70 | + | graphics::par(mfrow = c(1, 2)) |
|
71 | + | ||
72 | + | roc_dat <- obj$roc_dat |
|
73 | + | ||
74 | + | x <- lapply(roc_dat, `[[`, 1) |
|
75 | + | x <- do.call(cbind, x) |
|
76 | + | colnames(x) <- paste0("sim", 1:ncol(x)) |
|
77 | + | ||
78 | + | y <- lapply(roc_dat, `[[`, 2) |
|
79 | + | y <- do.call(cbind, y) |
|
80 | + | colnames(y) <- paste0("sim", 1:ncol(y)) |
|
81 | + | ||
82 | + | xavg <- rowMeans(x) |
|
83 | + | yavg <- rowMeans(y) |
|
84 | + | ||
85 | + | plot(xavg, yavg, type = "n", xlab = "FPR", ylab = "TPR") |
|
86 | + | samples <- sample(1:ncol(x), n) |
|
87 | + | for (i in samples) { |
|
88 | + | graphics::lines( |
|
89 | + | x[, i], y[, i], type = "s", |
|
90 | + | col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255) |
|
91 | + | ) |
|
92 | + | } |
|
93 | + | graphics::lines(xavg, yavg, type = "s") |
|
94 | + | ||
95 | + | # PRC |
|
96 | + | # The elements of prc_dat have different lengths, unlike roc_dat, so we |
|
97 | + | # have to do the central curve differently. |
|
98 | + | prc_dat <- obj$prc_dat |
|
99 | + | ||
100 | + | x <- lapply(prc_dat, `[[`, 1) |
|
101 | + | y <- lapply(prc_dat, `[[`, 2) |
|
102 | + | ||
103 | + | # Instead of combining the list of curve coordinates from each sample into |
|
104 | + | # two x and y matrices, we can first make a point cloud with all curve |
|
105 | + | # points from all samples, and then average the y values at all distinct |
|
106 | + | # x coordinates. The x-axis plots recall (TPR), which will only have as |
|
107 | + | # many distinct values as there are positives in the data, so this does |
|
108 | + | # not lose any information about the x coordinates. |
|
109 | + | point_cloud <- data.frame( |
|
110 | + | x = unlist(x), |
|
111 | + | y = unlist(y) |
|
112 | + | ) |
|
113 | + | point_cloud <- stats::aggregate(point_cloud[, "y", drop = FALSE], |
|
114 | + | # factor implicitly encodes distinct values only, |
|
115 | + | # since they will get the same labels |
|
116 | + | by = list(x = as.factor(point_cloud$x)), |
|
117 | + | FUN = mean) |
|
118 | + | point_cloud$x <- as.numeric(as.character(point_cloud$x)) |
|
119 | + | xavg <- point_cloud$x |
|
120 | + | yavg <- point_cloud$y |
|
121 | + | ||
122 | + | plot(xavg, yavg, type = "n", xlab = "TPR", ylab = "Precision", ylim = c(0, 1)) |
|
123 | + | samples <- sample(1:length(prc_dat), n) |
|
124 | + | for (i in samples) { |
|
125 | + | graphics::lines( |
|
126 | + | x[[i]], y[[i]], |
|
127 | + | col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255) |
|
128 | + | ) |
|
129 | + | } |
|
130 | + | graphics::lines(xavg, yavg) |
|
131 | + | ||
132 | + | } |
|
133 | + | ||
134 | + | invisible(x) |
|
135 | + | } |
|
136 | + | ||
137 | + | #' @rdname mcmcRocPrc |
|
138 | + | #' |
|
139 | + | #' @param row.names see [base::as.data.frame()] |
|
140 | + | #' @param optional see [base::as.data.frame()] |
|
141 | + | #' @param what which information to extract and convert to a data frame? |
|
142 | + | #' |
|
143 | + | #' @export |
|
144 | + | as.data.frame.mcmcRocPrc <- function(x, row.names = NULL, optional = FALSE, |
|
145 | + | what = c("auc", "roc", "prc"), ...) { |
|
146 | + | what <- match.arg(what) |
|
147 | + | if (what=="auc") { |
|
148 | + | # all 4 output types have AUC, so this should work across the board |
|
149 | + | return(as.data.frame(x[c("area_under_roc", "area_under_prc")])) |
|
150 | + | ||
151 | + | } else if (what %in% c("roc", "prc")) { |
|
152 | + | if (what=="roc") element <- "roc_dat" else element <- "prc_dat" |
|
153 | + | ||
154 | + | # if curves was FALSE, there will be no curve data... |
|
155 | + | if (is.null(x[[element]])) { |
|
156 | + | stop("No curve data; use mcmcRocPrc(..., curves = TRUE)") |
|
157 | + | } |
|
158 | + | ||
159 | + | # Otherwise, there will be either one set of coordinates if mcmcmRegPrc() |
|
160 | + | # was called with fullsims = FALSE, or else N_sims curve data sets. |
|
161 | + | # If the latter, we can return a long data frame with an identifying |
|
162 | + | # "sim" column to delineate the sim sets. To ensure consistency in output, |
|
163 | + | # also add this column when fullsims = FALSE. |
|
164 | + | ||
165 | + | # averaged, single coordinate set |
|
166 | + | if (length(x[[element]])==1L) { |
|
167 | + | return(data.frame(sim = 1L, x[[element]][[1]])) |
|
168 | + | } |
|
169 | + | ||
170 | + | # full sims |
|
171 | + | # add a unique ID to each coordinate set |
|
172 | + | outlist <- x[[element]] |
|
173 | + | outlist <- Map(cbind, sim = (1:length(outlist)), outlist) |
|
174 | + | # combine into long data frame |
|
175 | + | outdf <- do.call(rbind, outlist) |
|
176 | + | return(outdf) |
|
177 | + | } |
|
178 | + | stop("Developer error (I should not be here): please file an issue on GitHub") # nocov |
|
179 | + | } |
|
180 | + | ||
181 | + | ||
182 | + | ||
183 | + |
Learn more Showing 1 files with coverage changes found.
R/mcmcRocPrc-methods.R
Files | Coverage |
---|---|
R | 0.98% 83.11% |
Project Totals (11 files) | 83.11% |
7d49199
cf855a3
8e96b05
1f86c24
5aec1bb
d2c0ac0
f1bb732
86da204
9f9d5b6
559d018