Generalize mcmcRocPrc for various kinds of input objects
Showing 2 of 24 files from the diff.
R/mcmcRocPrc.R
changed.
Newly tracked file
R/mcmcRocPrc-methods.R
created.
Other files ignored by Codecov
man/is_binary_model.Rd
is new.
tests/testthat/test_mcmcReg.R
has changed.
tests/testdata/rstan-logit.rds
has changed.
tests/testdata-raw/rstan-logit.R
is new.
man/identify_link_function.Rd
is new.
tests/testdata-raw/README.md
is new.
tests/testthat/test_mcmcRocPrc.R
has changed.
tests/testdata/brms-logit.rds
has changed.
NAMESPACE
has changed.
tests/testdata/mcmcpack-logit.rds
has changed.
DESCRIPTION
has changed.
tests/testthat.R
has changed.
tests/testthat/helper-libs.R
is new.
man/new_mcmcRocPrc.Rd
is new.
tests/testdata-raw/brms-logit.R
is new.
tests/testdata/rstanarm-logit.rds
has changed.
NEWS.md
has changed.
tests/testdata/runjags-logit.rds
has changed.
man/mcmcRocPrc.Rd
has changed.
@@ -1,39 +1,78 @@
Loading
1 | + | # |
|
2 | + | # This file contains the mcmcRocPrc() S3 generic, which constructs objects |
|
3 | + | # of class "mcmcRocPrc". For methods for this class, see mcmcRocPrc-methods.R |
|
4 | + | # S3 methods for the mcmcRocPrc() generic handle different types of input |
|
5 | + | # e.g. "rjags" input produced by R2jags. |
|
6 | + | # |
|
7 | + | ||
8 | + | ||
9 | + | ||
1 | 10 | #' ROC and Precision-Recall Curves using Bayesian MCMC estimates |
|
2 | 11 | #' |
|
3 | 12 | #' Generate ROC and Precision-Recall curves after fitting a Bayesian logit or |
|
4 | - | #' probit regression using [R2jags::jags()] |
|
13 | + | #' probit regression using [rstan::stan()], [rstanarm::stan_glm()], |
|
14 | + | #' [R2jags::jags()], [R2WinBUGS::bugs()], [MCMCpack::MCMClogit()], or other |
|
15 | + | #' functions that provide samples from a posterior density. |
|
5 | 16 | #' |
|
6 | - | #' @param object A "rjags" object (see [R2jags::jags()]) for a fitted binary |
|
7 | - | #' choice model. |
|
8 | - | #' @param yname (`character(1)`)\cr |
|
9 | - | #' The name of the dependent variable, should match the variable name in the |
|
10 | - | #' JAGS data object. |
|
11 | - | #' @param xnames ([base::character()])\cr |
|
12 | - | #' A character vector of the independent variable names, should match the |
|
13 | - | #' corresponding names in the JAGS data object. |
|
17 | + | #' @param object A fitted binary choice model, e.g. "rjags" object |
|
18 | + | #' (see [R2jags::jags()]), or a `[N, iter]` matrix of predicted probabilites. |
|
14 | 19 | #' @param curves logical indicator of whether or not to return values to plot |
|
15 | 20 | #' the ROC or Precision-Recall curves. If set to `FALSE` (default), |
|
16 | 21 | #' results are returned as a list without the extra values. |
|
17 | 22 | #' @param fullsims logical indicator of whether full object (based on all MCMC |
|
18 | 23 | #' draws rather than their average) will be returned. Default is `FALSE`. |
|
19 | 24 | #' Note: If `TRUE` is chosen, the function takes notably longer to execute. |
|
25 | + | #' @param yvec A `numeric(N)` vector of observed outcomes. |
|
26 | + | #' @param yname (`character(1)`)\cr |
|
27 | + | #' The name of the dependent variable, should match the variable name in the |
|
28 | + | #' JAGS data object. |
|
29 | + | #' @param xnames ([base::character()])\cr |
|
30 | + | #' A character vector of the independent variable names, should match the |
|
31 | + | #' corresponding names in the JAGS data object. |
|
32 | + | #' @param posterior_samples a "mcmc" object with the posterior samples |
|
20 | 33 | #' @param ... Used by methods |
|
21 | 34 | #' @param x a `mcmcRocPrc()` object |
|
22 | 35 | #' |
|
36 | + | #' @details If only the average AUC-ROC and PR are of interest, setting |
|
37 | + | #' `curves = FALSE` and `fullsims = FALSE` can greatly speed up calculation |
|
38 | + | #' time. The curve data (`curves = TRUE`) is needed for plotting. The plot |
|
39 | + | #' method will always plot both the ROC and PR curves, but the underlying |
|
40 | + | #' data can easily be extracted from the output for your own plotting; |
|
41 | + | #' see the documentation of the value returned below. |
|
42 | + | #' |
|
43 | + | #' The default method works with a matrix of predicted probabilities and the |
|
44 | + | #' vector of observed incomes as input. Other methods accommodate some of the |
|
45 | + | #' common Bayesian modeling packages like rstan (which returns class "stanfit"), |
|
46 | + | #' rstanarm ("stanreg"), R2jags ("jags"), R2WinBUGS ("bugs"), and |
|
47 | + | #' MCMCpack ("mcmc"). Even if a package-specific method is not implemented, |
|
48 | + | #' the default method can always be used as a fallback by manually calculating |
|
49 | + | #' the matrix of predicted probabilities for each posterior sample. |
|
50 | + | #' |
|
51 | + | #' Note that MCMCpack returns generic "mcmc" output that is annotated with |
|
52 | + | #' some additional information as attributes, including the original function |
|
53 | + | #' call. There is no inherent way to distinguish any other kind of "mcmc" |
|
54 | + | #' object from one generated by a proper MCMCpack modeling function, but as a |
|
55 | + | #' basic precaution, `mcmcRocPrc()` will check the saved call and return an |
|
56 | + | #' error if the function called was not `MCMClogit()` or `MCMCprobit()`. |
|
57 | + | #' This behavior can be suppressed by setting `force = TRUE`. |
|
58 | + | #' |
|
23 | 59 | #' @references Beger, Andreas. 2016. “Precision-Recall Curves.” Available at |
|
24 | 60 | #' SSRN: [http://dx.doi.org/10.2139/ssrn.2765419](http://dx.doi.org/10.2139/ssrn.2765419) |
|
25 | 61 | #' |
|
26 | 62 | #' @return Returns a list with length 2 or 4, depending on the on the "curves" |
|
27 | 63 | #' and "fullsims" argument values: |
|
28 | 64 | #' |
|
29 | - | #' - "area_under_roc": `numeric(1)` |
|
30 | - | #' - "area_under_prc": `numeric(1)` |
|
31 | - | #' - "prc_dat": only if `curves = TRUE`; a list with length 1 if `fullsims = FALSE`, longer otherwise |
|
32 | - | #' - "roc_dat": only if `curves = TRUE`; a list with length 1 if `fullsims = FALSE`, longer otherwise |
|
65 | + | #' - "area_under_roc": `numeric()`; either length 1 if `fullsims = FALSE`, or |
|
66 | + | #' one value for each posterior sample otherwise |
|
67 | + | #' - "area_under_prc": `numeric()`; either length 1 if `fullsims = FALSE`, or |
|
68 | + | #' one value for each posterior sample otherwise |
|
69 | + | #' - "prc_dat": only if `curves = TRUE`; a list with length 1 if |
|
70 | + | #' `fullsims = FALSE`, longer otherwise |
|
71 | + | #' - "roc_dat": only if `curves = TRUE`; a list with length 1 if |
|
72 | + | #' `fullsims = FALSE`, longer otherwise |
|
33 | 73 | #' |
|
34 | 74 | #' @examples |
|
35 | 75 | #' # load simulated data and fitted model (see ?sim_data and ?jags_logit) |
|
36 | - | #' library(R2jags) |
|
37 | 76 | #' data("jags_logit") |
|
38 | 77 | #' |
|
39 | 78 | #' # using mcmcRocPrc |
@@ -44,41 +83,57 @@
Loading
44 | 83 | #' fullsims = FALSE) |
|
45 | 84 | #' fit_sum |
|
46 | 85 | #' plot(fit_sum) |
|
86 | + | #' |
|
87 | + | #' # Equivalently, we can calculate the matrix of predicted probabilities |
|
88 | + | #' # ourselves; using the example from ?jags_logit: |
|
89 | + | #' library(R2jags) |
|
90 | + | #' |
|
91 | + | #' data("sim_data") |
|
92 | + | #' yvec <- sim_data$Y |
|
93 | + | #' xmat <- sim_data[, c("X1", "X2")] |
|
94 | + | #' |
|
95 | + | #' # add intercept to the X data |
|
96 | + | #' xmat <- as.matrix(cbind(Intercept = 1L, xmat)) |
|
97 | + | #' |
|
98 | + | #' beta <- as.matrix(as.mcmc(jags_logit))[, c("b[1]", "b[2]", "b[3]")] |
|
99 | + | #' pred_mat <- plogis(xmat %*% t(beta)) |
|
100 | + | #' |
|
101 | + | #' # the matrix of predictions has rows matching the number of rows in the data; |
|
102 | + | #' # the column are the predictions for each of the 2,000 posterior samples |
|
103 | + | #' nrow(sim_data) |
|
104 | + | #' dim(pred_mat) |
|
105 | + | #' |
|
106 | + | #' # now we can call mcmcRocPrc; the default method works with the matrix |
|
107 | + | #' # of predictions and vector of outcomes as input |
|
108 | + | #' mcmcRocPrc(object = pred_mat, curves = TRUE, fullsims = FALSE, yvec = yvec) |
|
109 | + | #' |
|
47 | 110 | #' @export |
|
48 | 111 | #' @md |
|
112 | + | mcmcRocPrc <- function(object, curves = FALSE, fullsims = FALSE, ...) { |
|
113 | + | UseMethod("mcmcRocPrc", object) |
|
114 | + | } |
|
49 | 115 | ||
50 | - | mcmcRocPrc <- function(object, |
|
51 | - | yname, |
|
52 | - | xnames, |
|
53 | - | curves = FALSE, |
|
54 | - | fullsims = FALSE) { |
|
55 | - | ||
56 | - | link_logit <- any(grepl("logit", object$model$model())) |
|
57 | - | link_probit <- any(grepl("probit", object$model$model())) |
|
58 | - | ||
59 | - | if (isFALSE(link_logit | link_probit)) { |
|
60 | - | stop("Could not identify model link function") |
|
61 | - | } |
|
62 | - | ||
63 | - | mdl_data <- object$model$data() |
|
64 | - | stopifnot(all(xnames %in% names(mdl_data))) |
|
65 | - | stopifnot(all(yname %in% names(mdl_data))) |
|
66 | - | ||
67 | - | # add intercept by default, maybe revisit this |
|
68 | - | xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames]))) |
|
69 | - | yvec <- mdl_data[[yname]] |
|
70 | - | ||
71 | - | pardraws <- as.matrix(coda::as.mcmc(object)) |
|
72 | - | # this is not very robust, assumes pars are 'b[x]' |
|
73 | - | # for both this and the intercept addition above, maybe a more robust solution |
|
74 | - | # down the road would be to dig into the object$model$model() string |
|
75 | - | betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))] |
|
116 | + | #' Constructor for mcmcRocPrc objects |
|
117 | + | #' |
|
118 | + | #' This function actually does the heavy lifting once we have a matrix of |
|
119 | + | #' predicted probabilities from a model, plus the vector of observed outcomes. |
|
120 | + | #' The reason to have it here in a single function is that we don't replicate |
|
121 | + | #' it in each function that accomodates a JAGS, BUGS, RStan, etc. object. |
|
122 | + | #' |
|
123 | + | #' @param pred_prob a `\[N, iter\]` matrix of predicted probabilities |
|
124 | + | #' @param yvec a `numeric(N)` vector of observed outcomes |
|
125 | + | #' @param curves include curve data in output? |
|
126 | + | #' @param fullsims collapse posterior samples into single summary? |
|
127 | + | #' |
|
128 | + | #' @md |
|
129 | + | #' @keywords internal |
|
130 | + | new_mcmcRocPrc <- function(pred_prob, yvec, curves, fullsims) { |
|
76 | 131 | ||
77 | - | if(isTRUE(link_logit)) { |
|
78 | - | pred_prob <- plogis(xdata %*% t(betadraws)) |
|
79 | - | } else if (isTRUE(link_probit)) { |
|
80 | - | pred_prob <- pnorm(xdata %*% t(betadraws)) |
|
81 | - | } |
|
132 | + | stopifnot( |
|
133 | + | "number of predictions and observed outcomes do not match" = nrow(pred_prob)==length(yvec), |
|
134 | + | "yvec must be 0 or 1" = all(yvec %in% c(0L, 1L)), |
|
135 | + | "pred_prob must be in the interval [0, 1]" = all(pred_prob >= 0 & pred_prob <= 1) |
|
136 | + | ) |
|
82 | 137 | ||
83 | 138 | # pred_prob is a [N, iter] matrix, i.e. each column are preds from one |
|
84 | 139 | # set of posterior samples |
@@ -146,6 +201,21 @@
Loading
146 | 201 | ) |
|
147 | 202 | } |
|
148 | 203 | ||
204 | + | #' @rdname mcmcRocPrc |
|
205 | + | #' |
|
206 | + | #' @md |
|
207 | + | #' @export |
|
208 | + | mcmcRocPrc.default <- function(object, curves, fullsims, yvec, ...) { |
|
209 | + | pred_prob <- object |
|
210 | + | ||
211 | + | stopifnot( |
|
212 | + | "mcmcRocPrc.default requires 'matrix' like input" = inherits(pred_prob, "matrix") |
|
213 | + | ) |
|
214 | + | ||
215 | + | new_mcmcRocPrc(pred_prob, yvec, curves, fullsims) |
|
216 | + | } |
|
217 | + | ||
218 | + | # Under the hood ROC/PRC calculations ------------------------------------- |
|
149 | 219 | ||
150 | 220 | #' Compute ROC and PR curve points |
|
151 | 221 | #' |
@@ -213,192 +283,267 @@
Loading
213 | 283 | } |
|
214 | 284 | ||
215 | 285 | ||
286 | + | # auc_roc and auc_pr are not really used, but keep around just in case |
|
287 | + | auc_roc <- function(obs, pred) { |
|
288 | + | values <- compute_roc(obs, pred) |
|
289 | + | caTools::trapz(values$x, values$y) |
|
290 | + | } |
|
291 | + | ||
292 | + | auc_pr <- function(obs, pred) { |
|
293 | + | values <- compute_pr(obs, pred) |
|
294 | + | caTools::trapz(values$x, values$y) |
|
295 | + | } |
|
296 | + | ||
297 | + | ||
298 | + | ||
299 | + | # JAGS-like input (rjags, R2jags, runjags) -------------------------------- |
|
216 | 300 | ||
217 | 301 | #' @rdname mcmcRocPrc |
|
218 | 302 | #' |
|
219 | 303 | #' @export |
|
220 | - | print.mcmcRocPrc <- function(x, ...) { |
|
221 | - | ||
222 | - | auc_roc <- x$area_under_roc |
|
223 | - | auc_prc <- x$area_under_prc |
|
224 | - | ||
225 | - | has_curves <- !is.null(x$roc_dat) |
|
226 | - | has_sims <- length(auc_roc) > 1 |
|
227 | - | ||
228 | - | if (!has_sims) { |
|
229 | - | roc_msg <- sprintf("%.3f", round(auc_roc, 3)) |
|
230 | - | prc_msg <- sprintf("%.3f", round(auc_prc, 3)) |
|
231 | - | } else { |
|
232 | - | roc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]", |
|
233 | - | round(mean(auc_roc), 3), |
|
234 | - | round(quantile(auc_roc, 0.1), 3), |
|
235 | - | round(quantile(auc_roc, 0.9), 3)) |
|
236 | - | prc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]", |
|
237 | - | round(mean(auc_prc), 3), |
|
238 | - | round(quantile(auc_prc, 0.1), 3), |
|
239 | - | round(quantile(auc_prc, 0.9), 3)) |
|
304 | + | mcmcRocPrc.jags <- function(object, curves = FALSE, fullsims = FALSE, yname, |
|
305 | + | xnames, posterior_samples, ...) { |
|
306 | + | ||
307 | + | stopifnot( |
|
308 | + | inherits(posterior_samples, c("mcmc", "mcmc.list")) |
|
309 | + | ) |
|
310 | + | ||
311 | + | link_logit <- any(grepl("logit", object$model())) |
|
312 | + | link_probit <- any(grepl("probit", object$model())) |
|
313 | + | ||
314 | + | if (isFALSE(link_logit | link_probit)) { |
|
315 | + | stop("Could not identify model link function") |
|
240 | 316 | } |
|
241 | 317 | ||
242 | - | cat("mcmcRocPrc object\n") |
|
243 | - | cat(sprintf("curves: %s; fullsims: %s\n", has_curves, has_sims)) |
|
244 | - | cat(sprintf("AUC-ROC: %s\n", roc_msg)) |
|
245 | - | cat(sprintf("AUC-PR: %s\n", prc_msg)) |
|
318 | + | mdl_data <- object$data() |
|
319 | + | stopifnot(all(xnames %in% names(mdl_data))) |
|
320 | + | stopifnot(all(yname %in% names(mdl_data))) |
|
321 | + | ||
322 | + | # add intercept by default, maybe revisit this |
|
323 | + | xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames]))) |
|
324 | + | yvec <- mdl_data[[yname]] |
|
325 | + | ||
326 | + | pardraws <- as.matrix(posterior_samples) |
|
327 | + | # this is not very robust, assumes pars are 'b[x]' |
|
328 | + | # for both this and the intercept addition above, maybe a more robust solution |
|
329 | + | # down the road would be to dig into the object$model$model() string |
|
330 | + | betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))] |
|
331 | + | ||
332 | + | if(isTRUE(link_logit)) { |
|
333 | + | pred_prob <- plogis(xdata %*% t(betadraws)) |
|
334 | + | } else if (isTRUE(link_probit)) { |
|
335 | + | pred_prob <- pnorm(xdata %*% t(betadraws)) |
|
336 | + | } |
|
246 | 337 | ||
247 | - | invisible(x) |
|
338 | + | new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, |
|
339 | + | fullsims = fullsims) |
|
248 | 340 | } |
|
249 | 341 | ||
250 | 342 | #' @rdname mcmcRocPrc |
|
251 | 343 | #' |
|
252 | - | #' @param n plot method: if `fullsims = TRUE`, how many sample curves to draw? |
|
253 | - | #' @param alpha plot method: alpha value for plotting sampled curves; between 0 and 1 |
|
344 | + | #' @export |
|
345 | + | mcmcRocPrc.rjags <- function(object, curves = FALSE, fullsims = FALSE, yname, |
|
346 | + | xnames, ...) { |
|
347 | + | ||
348 | + | if (!requireNamespace("R2jags", quietly = TRUE)) { |
|
349 | + | stop("Package \"R2jags\" is needed for this function to work. Please install it.", call. = FALSE) # nocov |
|
350 | + | } |
|
351 | + | ||
352 | + | jags_object <- object$model |
|
353 | + | pardraws <- coda::as.mcmc(object) |
|
354 | + | ||
355 | + | # pass it on to the "jags" method |
|
356 | + | mcmcRocPrc(object = jags_object, curves = curves, fullsims = fullsims, |
|
357 | + | yname = yname, xnames = xnames, posterior_samples = pardraws, ...) |
|
358 | + | } |
|
359 | + | ||
360 | + | #' @rdname mcmcRocPrc |
|
254 | 361 | #' |
|
255 | 362 | #' @export |
|
256 | - | plot.mcmcRocPrc <- function(x, n = 40, alpha = .5, ...) { |
|
363 | + | mcmcRocPrc.runjags <- function(object, curves = FALSE, fullsims = FALSE, yname, |
|
364 | + | xnames, ...) { |
|
365 | + | jags_object <- runjags::as.jags(object, quiet = TRUE) |
|
366 | + | # as.mcmc.runjags will issue a warning when converting multiple chains |
|
367 | + | # because it combines them |
|
368 | + | pardraws <- suppressWarnings(coda::as.mcmc(object)) |
|
369 | + | ||
370 | + | # pass it on to the "jags" method |
|
371 | + | mcmcRocPrc(object = jags_object, curves = curves, fullsims = fullsims, |
|
372 | + | yname = yname, xnames = xnames, posterior_samples = pardraws, ...) |
|
373 | + | } |
|
257 | 374 | ||
258 | - | stopifnot( |
|
259 | - | "Use mcmcRocPrc(..., curves = TRUE) to generate data for plots" = (!is.null(x$roc_dat)), |
|
260 | - | "alpha must be between 0 and 1" = (alpha >= 0 & alpha <= 1), |
|
261 | - | "n must be > 0" = (n > 0) |
|
262 | - | ) |
|
375 | + | ||
376 | + | # STAN-like input (rstan, rstanarm, brms) --------------------------------- |
|
377 | + | ||
378 | + | ||
379 | + | ||
380 | + | #' @rdname mcmcRocPrc |
|
381 | + | #' |
|
382 | + | #' @param data the data that was used in the `stan(data = ?, ...)` call |
|
383 | + | #' |
|
384 | + | #' @export |
|
385 | + | mcmcRocPrc.stanfit <- function(object, curves = FALSE, fullsims = FALSE, data, |
|
386 | + | xnames, yname, ...) { |
|
387 | + | if (!requireNamespace("rstan", quietly = TRUE)) { |
|
388 | + | stop("Package \"rstan\" is needed for this function to work. Please install it.", call. = FALSE) # nocov |
|
389 | + | } |
|
263 | 390 | ||
264 | - | obj<- x |
|
265 | - | fullsims <- length(obj$roc_dat) > 1 |
|
391 | + | if (!is_binary_model(object)) { |
|
392 | + | stop("the input model does not seem to be a binary choice model; if this is a mistake please file an issue at https://github.com/ShanaScogin/BayesPostEst/issues/") |
|
393 | + | } |
|
394 | + | link_type <- identify_link_function(object) |
|
395 | + | if (is.na(link_type)) { |
|
396 | + | stop("could not identify model link function; please file an issue at https://github.com/ShanaScogin/BayesPostEst/issues/") |
|
397 | + | } |
|
266 | 398 | ||
267 | - | if (!fullsims) { |
|
268 | - | ||
269 | - | graphics::par(mfrow = c(1, 2)) |
|
270 | - | plot(obj$roc_dat[[1]], type = "s", xlab = "FPR", ylab = "TPR") |
|
271 | - | graphics::abline(a = 0, b = 1, lty = 3, col = "gray50") |
|
272 | - | ||
273 | - | prc_dat <- obj$prc_dat[[1]] |
|
274 | - | # use first non-NaN y-value for y[1] |
|
275 | - | prc_dat$y[1] <- prc_dat$y[2] |
|
276 | - | plot(prc_dat, type = "l", xlab = "TPR", ylab = "Precision", |
|
277 | - | ylim = c(0, 1)) |
|
278 | - | graphics::abline(a = attr(x, "y_pos_rate"), b = 0, lty = 3, col = "gray50") |
|
279 | - | ||
280 | - | } else { |
|
281 | - | ||
282 | - | graphics::par(mfrow = c(1, 2)) |
|
283 | - | ||
284 | - | roc_dat <- obj$roc_dat |
|
285 | - | ||
286 | - | x <- lapply(roc_dat, `[[`, 1) |
|
287 | - | x <- do.call(cbind, x) |
|
288 | - | colnames(x) <- paste0("sim", 1:ncol(x)) |
|
289 | - | ||
290 | - | y <- lapply(roc_dat, `[[`, 2) |
|
291 | - | y <- do.call(cbind, y) |
|
292 | - | colnames(y) <- paste0("sim", 1:ncol(y)) |
|
293 | - | ||
294 | - | xavg <- rowMeans(x) |
|
295 | - | yavg <- rowMeans(y) |
|
296 | - | ||
297 | - | plot(xavg, yavg, type = "n", xlab = "FPR", ylab = "TPR") |
|
298 | - | samples <- sample(1:ncol(x), n) |
|
299 | - | for (i in samples) { |
|
300 | - | graphics::lines( |
|
301 | - | x[, i], y[, i], type = "s", |
|
302 | - | col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255) |
|
303 | - | ) |
|
304 | - | } |
|
305 | - | graphics::lines(xavg, yavg, type = "s") |
|
306 | - | ||
307 | - | # PRC |
|
308 | - | # The elements of prc_dat have different lengths, unlike roc_dat, so we |
|
309 | - | # have to do the central curve differently. |
|
310 | - | prc_dat <- obj$prc_dat |
|
399 | + | mdl_data <- data |
|
400 | + | stopifnot(all(xnames %in% names(mdl_data))) |
|
401 | + | stopifnot(all(yname %in% names(mdl_data))) |
|
402 | + | ||
403 | + | # add intercept by default, maybe revisit this |
|
404 | + | xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames]))) |
|
405 | + | yvec <- mdl_data[[yname]] |
|
406 | + | ||
407 | + | pardraws <- as.matrix(object) |
|
408 | + | # this is not very robust, assumes pars are 'b[x]' |
|
409 | + | betadraws <- pardraws[, c(sprintf("b[%s]", 1:ncol(xdata - 1)))] |
|
410 | + | ||
411 | + | if(link_type=="logit") { |
|
412 | + | pred_prob <- plogis(xdata %*% t(betadraws)) |
|
413 | + | } else if (link_type=="probit") { |
|
414 | + | pred_prob <- pnorm(xdata %*% t(betadraws)) |
|
415 | + | } |
|
416 | + | ||
417 | + | new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, |
|
418 | + | fullsims = fullsims) |
|
419 | + | ||
420 | + | ||
421 | + | } |
|
311 | 422 | ||
312 | - | x <- lapply(prc_dat, `[[`, 1) |
|
313 | - | y <- lapply(prc_dat, `[[`, 2) |
|
314 | - | ||
315 | - | # Instead of combining the list of curve coordinates from each sample into |
|
316 | - | # two x and y matrices, we can first make a point cloud with all curve |
|
317 | - | # points from all samples, and then average the y values at all distinct |
|
318 | - | # x coordinates. The x-axis plots recall (TPR), which will only have as |
|
319 | - | # many distinct values as there are positives in the data, so this does |
|
320 | - | # not lose any information about the x coordinates. |
|
321 | - | point_cloud <- data.frame( |
|
322 | - | x = unlist(x), |
|
323 | - | y = unlist(y) |
|
324 | - | ) |
|
325 | - | point_cloud <- stats::aggregate(point_cloud[, "y", drop = FALSE], |
|
326 | - | # factor implicitly encodes distinct values only, |
|
327 | - | # since they will get the same labels |
|
328 | - | by = list(x = as.factor(point_cloud$x)), |
|
329 | - | FUN = mean) |
|
330 | - | point_cloud$x <- as.numeric(as.character(point_cloud$x)) |
|
331 | - | xavg <- point_cloud$x |
|
332 | - | yavg <- point_cloud$y |
|
333 | - | ||
334 | - | plot(xavg, yavg, type = "n", xlab = "TPR", ylab = "Precision", ylim = c(0, 1)) |
|
335 | - | samples <- sample(1:length(prc_dat), n) |
|
336 | - | for (i in samples) { |
|
337 | - | graphics::lines( |
|
338 | - | x[[i]], y[[i]], |
|
339 | - | col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255) |
|
340 | - | ) |
|
341 | - | } |
|
342 | - | graphics::lines(xavg, yavg) |
|
343 | - | ||
423 | + | #' Try to identify if a stanfit model is a binary choice model |
|
424 | + | #' |
|
425 | + | #' @param obj stanfit object |
|
426 | + | #' |
|
427 | + | #' @keywords internal |
|
428 | + | is_binary_model <- function(obj) { |
|
429 | + | stopifnot(inherits(obj, "stanfit")) |
|
430 | + | model_string <- rstan::get_stancode(obj) |
|
431 | + | grepl("bernoulli", model_string) |
|
432 | + | } |
|
433 | + | ||
434 | + | #' Try to identify link function |
|
435 | + | #' |
|
436 | + | #' @param obj stanfit object |
|
437 | + | #' |
|
438 | + | #' @return Either "logit" or "probit"; if neither can be identified the function |
|
439 | + | #' will return `NA_character_`. |
|
440 | + | #' |
|
441 | + | #' @keywords internal |
|
442 | + | identify_link_function <- function(obj) { |
|
443 | + | stopifnot(inherits(obj, "stanfit")) |
|
444 | + | model_string <- rstan::get_stancode(obj) |
|
445 | + | if (grepl("logit", model_string)) return("logit") |
|
446 | + | if (grepl("Phi", model_string)) return("probit") |
|
447 | + | NA_character_ |
|
448 | + | } |
|
449 | + | ||
450 | + | #' @rdname mcmcRocPrc |
|
451 | + | #' |
|
452 | + | #' @export |
|
453 | + | mcmcRocPrc.stanreg <- function(object, curves = FALSE, fullsims = FALSE, ...) { |
|
454 | + | if (!requireNamespace("rstanarm", quietly = TRUE)) { |
|
455 | + | stop("Package \"rstanarm\" is needed for this function to work. Please install it.", call. = FALSE) # nocov |
|
344 | 456 | } |
|
457 | + | if (!stats::family(object)$family=="binomial") { |
|
458 | + | stop("the input model does not seem to be a binary choice model; should be like 'obj <- stan_glm(family = binomial(), ...)'") |
|
459 | + | } |
|
460 | + | pred_prob <- rstanarm::posterior_linpred(object, transform = TRUE) |
|
461 | + | # posterior_linepred returns a matrix in which data cases are columns, and |
|
462 | + | # MCMC samples are row; we need to transpose this so that columns are samples |
|
463 | + | pred_prob <- t(pred_prob) |
|
464 | + | yvec <- unname(object$y) |
|
345 | 465 | ||
346 | - | invisible(x) |
|
466 | + | new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, |
|
467 | + | fullsims = fullsims) |
|
347 | 468 | } |
|
348 | 469 | ||
349 | 470 | #' @rdname mcmcRocPrc |
|
350 | 471 | #' |
|
351 | - | #' @param row.names see [base::as.data.frame()] |
|
352 | - | #' @param optional see [base::as.data.frame()] |
|
353 | - | #' @param what which information to extract and convert to a data frame? |
|
472 | + | #' @export |
|
473 | + | mcmcRocPrc.brmsfit <- function(object, curves = FALSE, fullsims = FALSE, ...) { |
|
474 | + | if (!requireNamespace("brms", quietly = TRUE)) { |
|
475 | + | stop("Package \"brms\" is needed for this function to work. Please install it.", call. = FALSE) # nocov |
|
476 | + | } |
|
477 | + | if (!stats::family(object)$family=="bernoulli") { |
|
478 | + | stop("the input model does not seem to be a binary choice model; should be like 'obj <- brm(family = bernoulli(), ...)'") |
|
479 | + | } |
|
480 | + | ||
481 | + | pred_prob <- brms::posterior_epred(object) |
|
482 | + | # posterior_epred returns a matrix in which data cases are columns, and |
|
483 | + | # MCMC samples are row; we need to transpose this so that columns are samples |
|
484 | + | pred_prob <- t(pred_prob) |
|
485 | + | yvec <- stats::model.response(stats::model.frame(object)) |
|
486 | + | ||
487 | + | new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, |
|
488 | + | fullsims = fullsims) |
|
489 | + | } |
|
490 | + | ||
491 | + | ||
492 | + | # Other input types (MCMCpack, ...) --------------------------------------- |
|
493 | + | ||
494 | + | ||
495 | + | #' #' @rdname mcmcRocPrc |
|
496 | + | #' #' |
|
497 | + | #' #' @export |
|
498 | + | #' mcmcRocPrc.bugs <- function(object, curves = FALSE, fullsims = FALSE, ...) { |
|
499 | + | #' stop("not implemented yet") |
|
500 | + | #' } |
|
501 | + | ||
502 | + | ||
503 | + | #' @rdname mcmcRocPrc |
|
354 | 504 | #' |
|
505 | + | #' @param type "logit" or "probit" |
|
506 | + | #' @param force for MCMCpack models, suppress warning if the model does not |
|
507 | + | #' appear to be a binary choice model? |
|
508 | + | #' |
|
355 | 509 | #' @export |
|
356 | - | as.data.frame.mcmcRocPrc <- function(x, row.names = NULL, optional = FALSE, |
|
357 | - | what = c("auc", "roc", "prc"), ...) { |
|
358 | - | what <- match.arg(what) |
|
359 | - | if (what=="auc") { |
|
360 | - | # all 4 output types have AUC, so this should work across the board |
|
361 | - | return(as.data.frame(x[c("area_under_roc", "area_under_prc")])) |
|
362 | - | ||
363 | - | } else if (what %in% c("roc", "prc")) { |
|
364 | - | if (what=="roc") element <- "roc_dat" else element <- "prc_dat" |
|
365 | - | ||
366 | - | # if curves was FALSE, there will be no curve data... |
|
367 | - | if (is.null(x[[element]])) { |
|
368 | - | stop("No curve data; use mcmcRocPrc(..., curves = TRUE)") |
|
369 | - | } |
|
370 | - | ||
371 | - | # Otherwise, there will be either one set of coordinates if mcmcmRegPrc() |
|
372 | - | # was called with fullsims = FALSE, or else N_sims curve data sets. |
|
373 | - | # If the latter, we can return a long data frame with an identifying |
|
374 | - | # "sim" column to delineate the sim sets. To ensure consistency in output, |
|
375 | - | # also add this column when fullsims = FALSE. |
|
376 | - | ||
377 | - | # averaged, single coordinate set |
|
378 | - | if (length(x[[element]])==1L) { |
|
379 | - | return(data.frame(sim = 1L, x[[element]][[1]])) |
|
510 | + | mcmcRocPrc.mcmc <- function(object, curves = FALSE, fullsims = FALSE, data, |
|
511 | + | xnames, yname, type = c("logit", "probit"), |
|
512 | + | force = FALSE, ...) { |
|
513 | + | ||
514 | + | if (!force) { |
|
515 | + | if (is.null(attr(object, "call"))) { |
|
516 | + | stop("object does not have a 'call' attribute; was it generated with a MCMCpack function?") |
|
517 | + | } else { |
|
518 | + | func <- as.character(attr(object, "call"))[1] |
|
519 | + | if (!func %in% c("MCMClogit", "MCMCprobit")) { |
|
520 | + | stop("object does not appear to have been fitted using MCMCpack::MCMClogit() or MCMCprobit(); mcmcRocPrc only properly works for those function. To be safe, consider manually calculating the matrix of predicted probabilities.") |
|
521 | + | } |
|
380 | 522 | } |
|
381 | - | ||
382 | - | # full sims |
|
383 | - | # add a unique ID to each coordinate set |
|
384 | - | outlist <- x[[element]] |
|
385 | - | outlist <- Map(cbind, sim = (1:length(outlist)), outlist) |
|
386 | - | # combine into long data frame |
|
387 | - | outdf <- do.call(rbind, outlist) |
|
388 | - | return(outdf) |
|
523 | + | } |
|
524 | + | ||
525 | + | link_type <- match.arg(type) |
|
526 | + | mdl_data <- data |
|
527 | + | stopifnot( |
|
528 | + | all(xnames %in% names(mdl_data)), |
|
529 | + | all(yname %in% names(mdl_data)) |
|
530 | + | ) |
|
531 | + | ||
532 | + | # add intercept by default, maybe revisit this |
|
533 | + | xdata <- as.matrix(cbind(X0 = 1L, as.data.frame(mdl_data[xnames]))) |
|
534 | + | yvec <- mdl_data[[yname]] |
|
535 | + | ||
536 | + | betadraws <- as.matrix(object) |
|
537 | + | ||
538 | + | if(link_type=="logit") { |
|
539 | + | pred_prob <- plogis(xdata %*% t(betadraws)) |
|
540 | + | } else if (link_type=="probit") { |
|
541 | + | pred_prob <- pnorm(xdata %*% t(betadraws)) |
|
389 | 542 | } |
|
390 | - | stop("Developer error (I should not be here): please file an issue on GitHub") # nocov |
|
543 | + | ||
544 | + | new_mcmcRocPrc(pred_prob = pred_prob, yvec = yvec, curves = curves, |
|
545 | + | fullsims = fullsims) |
|
391 | 546 | } |
|
392 | 547 | ||
393 | 548 | ||
394 | 549 | ||
395 | - | # auc_roc and auc_pr are not really used, but keep around just in case |
|
396 | - | auc_roc <- function(obs, pred) { |
|
397 | - | values <- compute_roc(obs, pred) |
|
398 | - | caTools::trapz(values$x, values$y) |
|
399 | - | } |
|
400 | - | ||
401 | - | auc_pr <- function(obs, pred) { |
|
402 | - | values <- compute_pr(obs, pred) |
|
403 | - | caTools::trapz(values$x, values$y) |
|
404 | - | } |
@@ -0,0 +1,183 @@
Loading
1 | + | # |
|
2 | + | # Methods for class "mcmcRocPrc", generated by mcmcRocPrc() |
|
3 | + | # |
|
4 | + | ||
5 | + | #' @rdname mcmcRocPrc |
|
6 | + | #' |
|
7 | + | #' @export |
|
8 | + | print.mcmcRocPrc <- function(x, ...) { |
|
9 | + | ||
10 | + | auc_roc <- x$area_under_roc |
|
11 | + | auc_prc <- x$area_under_prc |
|
12 | + | ||
13 | + | has_curves <- !is.null(x$roc_dat) |
|
14 | + | has_sims <- length(auc_roc) > 1 |
|
15 | + | ||
16 | + | if (!has_sims) { |
|
17 | + | roc_msg <- sprintf("%.3f", round(auc_roc, 3)) |
|
18 | + | prc_msg <- sprintf("%.3f", round(auc_prc, 3)) |
|
19 | + | } else { |
|
20 | + | roc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]", |
|
21 | + | round(mean(auc_roc), 3), |
|
22 | + | round(quantile(auc_roc, 0.1), 3), |
|
23 | + | round(quantile(auc_roc, 0.9), 3)) |
|
24 | + | prc_msg <- sprintf("%.3f [80%%: %.3f - %.3f]", |
|
25 | + | round(mean(auc_prc), 3), |
|
26 | + | round(quantile(auc_prc, 0.1), 3), |
|
27 | + | round(quantile(auc_prc, 0.9), 3)) |
|
28 | + | } |
|
29 | + | ||
30 | + | cat("mcmcRocPrc object\n") |
|
31 | + | cat(sprintf("curves: %s; fullsims: %s\n", has_curves, has_sims)) |
|
32 | + | cat(sprintf("AUC-ROC: %s\n", roc_msg)) |
|
33 | + | cat(sprintf("AUC-PR: %s\n", prc_msg)) |
|
34 | + | ||
35 | + | invisible(x) |
|
36 | + | } |
|
37 | + | ||
38 | + | #' @rdname mcmcRocPrc |
|
39 | + | #' |
|
40 | + | #' @param n plot method: if `fullsims = TRUE`, how many sample curves to draw? |
|
41 | + | #' @param alpha plot method: alpha value for plotting sampled curves; between 0 and 1 |
|
42 | + | #' |
|
43 | + | #' @export |
|
44 | + | plot.mcmcRocPrc <- function(x, n = 40, alpha = .5, ...) { |
|
45 | + | ||
46 | + | stopifnot( |
|
47 | + | "Use mcmcRocPrc(..., curves = TRUE) to generate data for plots" = (!is.null(x$roc_dat)), |
|
48 | + | "alpha must be between 0 and 1" = (alpha >= 0 & alpha <= 1), |
|
49 | + | "n must be > 0" = (n > 0) |
|
50 | + | ) |
|
51 | + | ||
52 | + | obj<- x |
|
53 | + | fullsims <- length(obj$roc_dat) > 1 |
|
54 | + | ||
55 | + | if (!fullsims) { |
|
56 | + | ||
57 | + | graphics::par(mfrow = c(1, 2)) |
|
58 | + | plot(obj$roc_dat[[1]], type = "s", xlab = "FPR", ylab = "TPR") |
|
59 | + | graphics::abline(a = 0, b = 1, lty = 3, col = "gray50") |
|
60 | + | ||
61 | + | prc_dat <- obj$prc_dat[[1]] |
|
62 | + | # use first non-NaN y-value for y[1] |
|
63 | + | prc_dat$y[1] <- prc_dat$y[2] |
|
64 | + | plot(prc_dat, type = "l", xlab = "TPR", ylab = "Precision", |
|
65 | + | ylim = c(0, 1)) |
|
66 | + | graphics::abline(a = attr(x, "y_pos_rate"), b = 0, lty = 3, col = "gray50") |
|
67 | + | ||
68 | + | } else { |
|
69 | + | ||
70 | + | graphics::par(mfrow = c(1, 2)) |
|
71 | + | ||
72 | + | roc_dat <- obj$roc_dat |
|
73 | + | ||
74 | + | x <- lapply(roc_dat, `[[`, 1) |
|
75 | + | x <- do.call(cbind, x) |
|
76 | + | colnames(x) <- paste0("sim", 1:ncol(x)) |
|
77 | + | ||
78 | + | y <- lapply(roc_dat, `[[`, 2) |
|
79 | + | y <- do.call(cbind, y) |
|
80 | + | colnames(y) <- paste0("sim", 1:ncol(y)) |
|
81 | + | ||
82 | + | xavg <- rowMeans(x) |
|
83 | + | yavg <- rowMeans(y) |
|
84 | + | ||
85 | + | plot(xavg, yavg, type = "n", xlab = "FPR", ylab = "TPR") |
|
86 | + | samples <- sample(1:ncol(x), n) |
|
87 | + | for (i in samples) { |
|
88 | + | graphics::lines( |
|
89 | + | x[, i], y[, i], type = "s", |
|
90 | + | col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255) |
|
91 | + | ) |
|
92 | + | } |
|
93 | + | graphics::lines(xavg, yavg, type = "s") |
|
94 | + | ||
95 | + | # PRC |
|
96 | + | # The elements of prc_dat have different lengths, unlike roc_dat, so we |
|
97 | + | # have to do the central curve differently. |
|
98 | + | prc_dat <- obj$prc_dat |
|
99 | + | ||
100 | + | x <- lapply(prc_dat, `[[`, 1) |
|
101 | + | y <- lapply(prc_dat, `[[`, 2) |
|
102 | + | ||
103 | + | # Instead of combining the list of curve coordinates from each sample into |
|
104 | + | # two x and y matrices, we can first make a point cloud with all curve |
|
105 | + | # points from all samples, and then average the y values at all distinct |
|
106 | + | # x coordinates. The x-axis plots recall (TPR), which will only have as |
|
107 | + | # many distinct values as there are positives in the data, so this does |
|
108 | + | # not lose any information about the x coordinates. |
|
109 | + | point_cloud <- data.frame( |
|
110 | + | x = unlist(x), |
|
111 | + | y = unlist(y) |
|
112 | + | ) |
|
113 | + | point_cloud <- stats::aggregate(point_cloud[, "y", drop = FALSE], |
|
114 | + | # factor implicitly encodes distinct values only, |
|
115 | + | # since they will get the same labels |
|
116 | + | by = list(x = as.factor(point_cloud$x)), |
|
117 | + | FUN = mean) |
|
118 | + | point_cloud$x <- as.numeric(as.character(point_cloud$x)) |
|
119 | + | xavg <- point_cloud$x |
|
120 | + | yavg <- point_cloud$y |
|
121 | + | ||
122 | + | plot(xavg, yavg, type = "n", xlab = "TPR", ylab = "Precision", ylim = c(0, 1)) |
|
123 | + | samples <- sample(1:length(prc_dat), n) |
|
124 | + | for (i in samples) { |
|
125 | + | graphics::lines( |
|
126 | + | x[[i]], y[[i]], |
|
127 | + | col = grDevices::rgb(127, 127, 127, alpha = alpha*255, maxColorValue = 255) |
|
128 | + | ) |
|
129 | + | } |
|
130 | + | graphics::lines(xavg, yavg) |
|
131 | + | ||
132 | + | } |
|
133 | + | ||
134 | + | invisible(x) |
|
135 | + | } |
|
136 | + | ||
137 | + | #' @rdname mcmcRocPrc |
|
138 | + | #' |
|
139 | + | #' @param row.names see [base::as.data.frame()] |
|
140 | + | #' @param optional see [base::as.data.frame()] |
|
141 | + | #' @param what which information to extract and convert to a data frame? |
|
142 | + | #' |
|
143 | + | #' @export |
|
144 | + | as.data.frame.mcmcRocPrc <- function(x, row.names = NULL, optional = FALSE, |
|
145 | + | what = c("auc", "roc", "prc"), ...) { |
|
146 | + | what <- match.arg(what) |
|
147 | + | if (what=="auc") { |
|
148 | + | # all 4 output types have AUC, so this should work across the board |
|
149 | + | return(as.data.frame(x[c("area_under_roc", "area_under_prc")])) |
|
150 | + | ||
151 | + | } else if (what %in% c("roc", "prc")) { |
|
152 | + | if (what=="roc") element <- "roc_dat" else element <- "prc_dat" |
|
153 | + | ||
154 | + | # if curves was FALSE, there will be no curve data... |
|
155 | + | if (is.null(x[[element]])) { |
|
156 | + | stop("No curve data; use mcmcRocPrc(..., curves = TRUE)") |
|
157 | + | } |
|
158 | + | ||
159 | + | # Otherwise, there will be either one set of coordinates if mcmcmRegPrc() |
|
160 | + | # was called with fullsims = FALSE, or else N_sims curve data sets. |
|
161 | + | # If the latter, we can return a long data frame with an identifying |
|
162 | + | # "sim" column to delineate the sim sets. To ensure consistency in output, |
|
163 | + | # also add this column when fullsims = FALSE. |
|
164 | + | ||
165 | + | # averaged, single coordinate set |
|
166 | + | if (length(x[[element]])==1L) { |
|
167 | + | return(data.frame(sim = 1L, x[[element]][[1]])) |
|
168 | + | } |
|
169 | + | ||
170 | + | # full sims |
|
171 | + | # add a unique ID to each coordinate set |
|
172 | + | outlist <- x[[element]] |
|
173 | + | outlist <- Map(cbind, sim = (1:length(outlist)), outlist) |
|
174 | + | # combine into long data frame |
|
175 | + | outdf <- do.call(rbind, outlist) |
|
176 | + | return(outdf) |
|
177 | + | } |
|
178 | + | stop("Developer error (I should not be here): please file an issue on GitHub") # nocov |
|
179 | + | } |
|
180 | + | ||
181 | + | ||
182 | + | ||
183 | + |
Files | Coverage |
---|---|
R | 83.11% |
Project Totals (11 files) | 83.11% |
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file.
The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files.
The size and color of each slice is representing the number of statements and the coverage, respectively.