No flags found
Use flags to group coverage reports by test type, project and/or folders.
Then setup custom commit statuses and notifications for each flag.
e.g., #unittest #integration
#production #enterprise
#frontend #backend
76310c4
... +249 ...
39a1e97
Use flags to group coverage reports by test type, project and/or folders.
Then setup custom commit statuses and notifications for each flag.
e.g., #unittest #integration
#production #enterprise
#frontend #backend
1 | + | #' Extract MCMC samples from a model fit with [tmbstan::tmbstan()]. |
|
2 | + | #' |
|
3 | + | #' @description |
|
4 | + | #' `r lifecycle::badge("experimental")` |
|
5 | + | #' |
|
6 | + | #' Returns a matrix of parameter samples. Rows correspond to the order |
|
7 | + | #' of `your_model$tmb_obj$env$last.par.best`. Columns correspond to |
|
8 | + | #' posterior samples. Is used internally by [predict.sdmTMB()] to make |
|
9 | + | #' fully Bayesian predictions. See the `tmbstan_model` argument |
|
10 | + | #' in [predict.sdmTMB()]. |
|
11 | + | #' |
|
12 | + | #' @param object Output from [tmbstan::tmbstan()] run on the `tmb_obj` |
|
13 | + | #' element of an [sdmTMB()] model. E.g., `tmbstan(your_model$tmb_obj)`. |
|
14 | + | #' @examples |
|
15 | + | #' |
|
16 | + | #' \dontrun{ |
|
17 | + | #' pcod_spde <- make_mesh(pcod, c("X", "Y"), cutoff = 30) |
|
18 | + | #' plot(pcod_spde) |
|
19 | + | #' |
|
20 | + | #' # here we will fix the random field parameters at their approximate |
|
21 | + | #' # MLEs (maximum likelihood estimates) from a previous fit |
|
22 | + | #' # to improve speed of convergence: |
|
23 | + | #' m_tmb <- sdmTMB(density ~ 0 + as.factor(year), |
|
24 | + | #' data = pcod, mesh = pcod_spde, family = tweedie(link = "log"), time = "year", |
|
25 | + | #' control = sdmTMBcontrol(start = list(ln_kappa = rep(-1.58, 2), |
|
26 | + | #' ln_tau_E = -0.15, ln_tau_O = -0.65), |
|
27 | + | #' map = list(ln_kappa = rep(factor(NA), 2), |
|
28 | + | #' ln_tau_E = factor(NA), ln_tau_O = factor(NA)))) |
|
29 | + | #' m_tmb |
|
30 | + | #' |
|
31 | + | #' # will take 3-5 minutes: |
|
32 | + | #' library(tmbstan) |
|
33 | + | #' m_stan <- tmbstan(m_tmb$tmb_obj, iter = 200, chains = 1) |
|
34 | + | #' print(m_stan, pars = c("b_j", "thetaf", "ln_phi", "omega_s[1]", "epsilon_st[1]")) |
|
35 | + | #' |
|
36 | + | #' post <- extract_mcmc(m_stan) |
|
37 | + | #' dim(post) |
|
38 | + | #' |
|
39 | + | #' p <- predict(m_tmb, newdata = qcs_grid, tmbstan_model = m_stan) |
|
40 | + | #' p_last <- p[qcs_grid$year == max(qcs_grid$year), ] # just plot last year |
|
41 | + | #' pred <- qcs_grid[qcs_grid$year == max(qcs_grid$year), ] |
|
42 | + | #' pred$est <- apply(exp(p_last), 1, median) |
|
43 | + | #' pred$lwr <- apply(exp(p_last), 1, quantile, probs = 0.1) |
|
44 | + | #' pred$upr <- apply(exp(p_last), 1, quantile, probs = 0.9) |
|
45 | + | #' pred$cv <- apply(exp(p_last), 1, function(x) sd(x) / mean(x)) |
|
46 | + | #' |
|
47 | + | #' library(ggplot2) |
|
48 | + | #' ggplot(pred, aes(X, Y, fill = est)) + geom_raster() + |
|
49 | + | #' scale_fill_viridis_c(trans = "log") |
|
50 | + | #' ggplot(pred, aes(X, Y, fill = cv)) + geom_raster() + |
|
51 | + | #' scale_fill_viridis_c(trans = "log") |
|
52 | + | #' |
|
53 | + | #' index_quantiles <- get_index_sims(p) |
|
54 | + | #' ggplot(index_quantiles, aes(year, est, ymin = lwr, ymax = upr)) + |
|
55 | + | #' geom_line() + geom_ribbon(alpha = 0.5) |
|
56 | + | #' |
|
57 | + | #' index_samples <- get_index_sims(p, return_sims = TRUE) |
|
58 | + | #' ggplot(index_samples, aes(as.factor(year), .value)) + |
|
59 | + | #' geom_violin() |
|
60 | + | #' } |
|
61 | + | #' @export |
|
62 | + | extract_mcmc <- function(object) { |
|
63 | + | if (!requireNamespace("rstan", quietly = TRUE)) { |
|
64 | + | cli_abort("rstan must be installed to use `extract_mcmc()`.") |
|
65 | + | } |
|
66 | + | post <- rstan::extract(object) |
|
67 | + | p_names <- names(post)[-length(names(post))] # exclude "lp__" |
|
68 | + | p <- lapply(seq_len(length(post[["lp__"]])), function(i) { |
|
69 | + | post_pars <- list() |
|
70 | + | for (j in seq_along(p_names)) { |
|
71 | + | par_j <- p_names[j] |
|
72 | + | if (is.matrix(post[[par_j]])) { |
|
73 | + | post_pars[[j]] <- post[[par_j]][i, , drop = TRUE] |
|
74 | + | } else { |
|
75 | + | post_pars[[j]] <- post[[par_j]][i] |
|
76 | + | } |
|
77 | + | } |
|
78 | + | post_pars |
|
79 | + | }) |
|
80 | + | simplify2array(lapply(p, unlist)) |
|
81 | + | } |
1 | + | #' Plot sdmTMB models with the \pkg{visreg} package |
|
2 | + | #' |
|
3 | + | #' sdmTMB models fit with regular (non-delta) families can be passed to |
|
4 | + | #' [visreg::visreg()] or [visreg::visreg2d()] directly. Examples are shown |
|
5 | + | #' below. Delta models can use the helper functions `visreg_delta()` or |
|
6 | + | #' `visreg2d_delta()` described here. |
|
7 | + | #' |
|
8 | + | #' @param object Fit from [sdmTMB()] |
|
9 | + | #' @param model 1st or 2nd delta model |
|
10 | + | #' @param ... Any arguments passed to [visreg::visreg()] or |
|
11 | + | #' [visreg::visreg2d()] |
|
12 | + | #' |
|
13 | + | #' @details |
|
14 | + | #' Note the residuals are currently randomized quantile residuals, |
|
15 | + | #' *not* deviance residuals as is usual for GLMs with \pkg{visreg}. |
|
16 | + | #' |
|
17 | + | #' @return |
|
18 | + | #' A plot from the visreg package. Optionally, the data plotted invisibly if |
|
19 | + | #' `plot = FALSE`. This is useful if you want to make your own plot after. |
|
20 | + | #' |
|
21 | + | #' @export |
|
22 | + | #' @rdname visreg_delta |
|
23 | + | #' |
|
24 | + | #' @examples |
|
25 | + | #' if (inla_installed() && |
|
26 | + | #' require("ggplot2", quietly = TRUE) && |
|
27 | + | #' require("visreg", quietly = TRUE)) { |
|
28 | + | #' |
|
29 | + | #' pcod_2011$fyear <- as.factor(pcod_2011$year) |
|
30 | + | #' fit <- sdmTMB( |
|
31 | + | #' density ~ s(depth_scaled) + fyear, |
|
32 | + | #' data = pcod_2011, mesh = pcod_mesh_2011, |
|
33 | + | #' spatial = "off", |
|
34 | + | #' family = tweedie() |
|
35 | + | #' ) |
|
36 | + | #' visreg::visreg(fit, xvar = "depth_scaled") |
|
37 | + | #' visreg::visreg(fit, xvar = "fyear") |
|
38 | + | #' visreg::visreg(fit, xvar = "depth_scaled", scale = "response") |
|
39 | + | #' visreg::visreg2d(fit, xvar = "fyear", yvar = "depth_scaled") |
|
40 | + | #' v <- visreg::visreg(fit, xvar = "depth_scaled") |
|
41 | + | #' head(v$fit) |
|
42 | + | #' # now use ggplot2 etc. if desired |
|
43 | + | #' |
|
44 | + | #' # Delta model example: |
|
45 | + | #' fit_dg <- sdmTMB( |
|
46 | + | #' density ~ s(depth_scaled, year, k = 8), |
|
47 | + | #' data = pcod_2011, mesh = pcod_mesh_2011, |
|
48 | + | #' spatial = "off", |
|
49 | + | #' family = delta_gamma() |
|
50 | + | #' ) |
|
51 | + | #' visreg_delta(fit_dg, xvar = "depth_scaled", model = 1, gg = TRUE) |
|
52 | + | #' visreg_delta(fit_dg, xvar = "depth_scaled", model = 2, gg = TRUE) |
|
53 | + | #' visreg_delta(fit_dg, |
|
54 | + | #' xvar = "depth_scaled", model = 1, |
|
55 | + | #' scale = "response", gg = TRUE |
|
56 | + | #' ) |
|
57 | + | #' visreg_delta(fit_dg, |
|
58 | + | #' xvar = "depth_scaled", model = 2, |
|
59 | + | #' scale = "response" |
|
60 | + | #' ) |
|
61 | + | #' visreg_delta(fit_dg, |
|
62 | + | #' xvar = "depth_scaled", model = 2, |
|
63 | + | #' scale = "response", gg = TRUE, rug = FALSE |
|
64 | + | #' ) |
|
65 | + | #' visreg2d_delta(fit_dg, |
|
66 | + | #' xvar = "depth_scaled", yvar = "year", |
|
67 | + | #' model = 2, scale = "response" |
|
68 | + | #' ) |
|
69 | + | #' visreg2d_delta(fit_dg, |
|
70 | + | #' xvar = "depth_scaled", yvar = "year", |
|
71 | + | #' model = 1, scale = "response", plot.type = "persp" |
|
72 | + | #' ) |
|
73 | + | #' visreg2d_delta(fit_dg, |
|
74 | + | #' xvar = "depth_scaled", yvar = "year", |
|
75 | + | #' model = 2, scale = "response", plot.type = "gg" |
|
76 | + | #' ) |
|
77 | + | #' } |
|
78 | + | visreg_delta <- function(object, ..., model = c(1, 2)) { |
|
79 | + | object$visreg_model <- check_model_arg(model) |
|
80 | + | dat <- object$data[!is.na(object$tmb_data$y_i[, model]), , drop = FALSE] |
|
81 | + | visreg::visreg(fit = object, data = dat, ...) |
|
82 | + | } |
|
83 | + | ||
84 | + | #' @export |
|
85 | + | #' @rdname visreg_delta |
|
86 | + | visreg2d_delta <- function(object, ..., model = c(1, 2)) { |
|
87 | + | object$visreg_model <- check_model_arg(model) |
|
88 | + | dat <- object$data[!is.na(object$tmb_data$y_i[, model]), , drop = FALSE] |
|
89 | + | visreg::visreg2d(fit = object, data = dat, ...) |
|
90 | + | } |
|
91 | + | ||
92 | + | check_model_arg <- function(model) { |
|
93 | + | assert_that(is.numeric(model[[1]])) |
|
94 | + | model <- as.integer(model[[1]]) |
|
95 | + | assert_that(model %in% c(1L, 2L)) |
|
96 | + | model |
|
97 | + | } |
1 | 1 | #' Plot anisotropy |
|
2 | 2 | #' |
|
3 | 3 | #' @param object An object from [sdmTMB()]. |
|
4 | + | #' @param model Which model if a delta model. |
|
4 | 5 | #' |
|
5 | 6 | #' @export |
|
6 | 7 | #' @rdname plot_anisotropy |
|
8 | + | #' |
|
9 | + | #' @return A plot of eigenvectors illustrating the estimated anisotropy. A list |
|
10 | + | #' of the plotted data is invisibly returned. |
|
11 | + | #' @references Code adapted from VAST R package |
|
7 | 12 | #' @examples |
|
8 | 13 | #' \donttest{ |
|
9 | - | #' d <- pcod |
|
10 | - | #' m <- sdmTMB(data = d, |
|
11 | - | #' formula = density ~ 0 + as.factor(year), |
|
12 | - | #' time = "year", spde = make_spde(d$X, d$Y, n_knots = 80), |
|
13 | - | #' family = tweedie(link = "log"), anisotropy = TRUE, |
|
14 | - | #' include_spatial = FALSE) |
|
15 | - | #' plot_anisotropy(m) |
|
14 | + | #' if (inla_installed()) { |
|
15 | + | #' d <- pcod |
|
16 | + | #' m <- sdmTMB( |
|
17 | + | #' data = d, |
|
18 | + | #' formula = density ~ 0 + as.factor(year), |
|
19 | + | #' time = "year", mesh = make_mesh(d, c("X", "Y"), n_knots = 80, type = "kmeans"), |
|
20 | + | #' family = tweedie(link = "log"), anisotropy = TRUE, |
|
21 | + | #' spatial = "off" |
|
22 | + | #' ) |
|
23 | + | #' plot_anisotropy(m) |
|
24 | + | #' } |
|
16 | 25 | #' } |
|
17 | - | plot_anisotropy <- function(object) { |
|
18 | - | stopifnot(identical(class(object), "sdmTMB")) |
|
19 | - | report <- object$tmb_obj$report() |
|
20 | - | eig <- eigen(report$H) |
|
26 | + | plot_anisotropy <- function(object, model = 1) { |
|
27 | + | stopifnot(inherits(object, "sdmTMB")) |
|
28 | + | report <- object$tmb_obj$report(object$tmb_obj$env$last.par.best) |
|
29 | + | if (model == 1) eig <- eigen(report$H) |
|
30 | + | if (model == 2) eig <- eigen(report$H2) |
|
21 | 31 | dat <- data.frame( |
|
22 | 32 | x0 = c(0, 0), |
|
23 | 33 | y0 = c(0, 0), |
|
24 | 34 | x1 = eig$vectors[1, , drop = TRUE] * eig$values, |
|
25 | 35 | y1 = eig$vectors[2, , drop = TRUE] * eig$values |
|
26 | 36 | ) |
|
27 | - | plot(0, xlim = range(c(dat$x0, dat$x1)), |
|
28 | - | ylim = range(c(dat$y0, dat$y1)), type = "n", asp = 1, xlab = "", ylab = "") |
|
37 | + | plot(0, |
|
38 | + | xlim = range(c(dat$x0, dat$x1)), |
|
39 | + | ylim = range(c(dat$y0, dat$y1)), |
|
40 | + | type = "n", asp = 1, xlab = "", ylab = "" |
|
41 | + | ) |
|
29 | 42 | graphics::arrows(dat$x0, dat$y0, dat$x1, dat$y1) |
|
30 | 43 | invisible(list(eig = eig, dat = dat, H = report$H)) |
|
31 | 44 | } |
|
32 | 45 | ||
46 | + | #' Plot a smooth term from an sdmTMB model |
|
47 | + | #' |
|
48 | + | #' @param object An [sdmTMB()] model. |
|
49 | + | #' @param select The smoother term to plot. |
|
50 | + | #' @param n The number of equally spaced points to evaluate the smoother along. |
|
51 | + | #' @param level The confidence level. |
|
52 | + | #' @param ggplot Logical: use the \pkg{ggplot2} package? |
|
53 | + | #' @param rug Logical: add rug lines along the lower axis? |
|
54 | + | #' @param return_data Logical: return the predicted data instead of making a plot? |
|
55 | + | #' @export |
|
56 | + | #' |
|
57 | + | #' @details |
|
58 | + | #' Note: |
|
59 | + | #' * Any numeric predictor is set to its mean |
|
60 | + | #' * Any factor predictor is set to its first-level value |
|
61 | + | #' * The time element (if present) is set to its minimum value |
|
62 | + | #' * The x and y coordinates are set to their mean values |
|
63 | + | #' |
|
64 | + | #' @examples |
|
65 | + | #' if (inla_installed()) { |
|
66 | + | #' d <- subset(pcod, year >= 2000 & density > 0) |
|
67 | + | #' pcod_spde <- make_mesh(d, c("X", "Y"), cutoff = 30) |
|
68 | + | #' m <- sdmTMB( |
|
69 | + | #' data = d, |
|
70 | + | #' formula = log(density) ~ s(depth_scaled) + s(year, k = 5), |
|
71 | + | #' mesh = pcod_spde |
|
72 | + | #' ) |
|
73 | + | #' plot_smooth(m) |
|
74 | + | #' } |
|
75 | + | plot_smooth <- function(object, select = 1, n = 100, level = 0.95, |
|
76 | + | ggplot = FALSE, rug = TRUE, return_data = FALSE) { |
|
77 | + | msg <- c( |
|
78 | + | "This function will likely be deprecated.", |
|
79 | + | "Consider using `visreg::visreg()` or `visreg_delta()`.", |
|
80 | + | "See ?visreg_delta() for examples." |
|
81 | + | ) |
|
82 | + | cli_inform(msg) |
|
83 | + | se <- TRUE |
|
84 | + | if (isTRUE(object$delta)) |
|
85 | + | cli_abort("This function doesn't work with delta models yet") |
|
86 | + | ||
87 | + | assert_that(inherits(object, "sdmTMB")) |
|
88 | + | assert_that(is.logical(ggplot)) |
|
89 | + | assert_that(is.logical(return_data)) |
|
90 | + | assert_that(is.logical(se)) |
|
91 | + | assert_that(is.numeric(n)) |
|
92 | + | assert_that(is.numeric(level)) |
|
93 | + | assert_that(length(level) == 1L) |
|
94 | + | assert_that(length(select) == 1L) |
|
95 | + | assert_that(length(n) == 1L) |
|
96 | + | assert_that(is.numeric(select)) |
|
97 | + | assert_that(level > 0 & level < 1) |
|
98 | + | assert_that(n < 500) |
|
99 | + | ||
100 | + | if (ggplot) { |
|
101 | + | if (!requireNamespace("ggplot2", quietly = TRUE)) { |
|
102 | + | cli_abort("ggplot2 not installed") |
|
103 | + | } |
|
104 | + | } |
|
105 | + | ||
106 | + | sm <- parse_smoothers(object$formula[[1]], object$data) |
|
107 | + | sm_names <- unlist(lapply(sm$Zs, function(x) attr(x, "s.label"))) |
|
108 | + | sm_names <- gsub("\\)$", "", gsub("s\\(", "", sm_names)) |
|
109 | + | ||
110 | + | fe_names <- colnames(object$tmb_data$X_ij) |
|
111 | + | fe_names <- fe_names[!fe_names == "offset"] # FIXME |
|
112 | + | fe_names <- fe_names[!fe_names == "(Intercept)"] |
|
113 | + | ||
114 | + | all_names <- c(sm_names, fe_names) |
|
115 | + | if (select > length(sm_names)) { |
|
116 | + | cli_abort("`select` is greater than the number of smooths") |
|
117 | + | } |
|
118 | + | sel_name <- sm_names[select] |
|
119 | + | non_select_names <- all_names[!all_names %in% sel_name] |
|
120 | + | ||
121 | + | x <- object$data[[sel_name]] |
|
122 | + | nd <- data.frame(x = seq(min(x), max(x), length.out = n)) |
|
123 | + | names(nd)[1] <- sel_name |
|
124 | + | ||
125 | + | dat <- object$data |
|
126 | + | .t <- terms(object$formula[[1]]) |
|
127 | + | .t <- labels(.t) |
|
128 | + | checks <- c("^as\\.factor\\(", "^factor\\(") |
|
129 | + | for (ch in checks) { |
|
130 | + | if (any(grepl(ch, .t))) { # any factors from formula? if so, explicitly switch class |
|
131 | + | ft <- grep(ch, .t) |
|
132 | + | for (i in ft) { |
|
133 | + | x <- gsub(ch, "", .t[i]) |
|
134 | + | x <- gsub("\\)$", "", x) |
|
135 | + | dat[[x]] <- as.factor(dat[[x]]) |
|
136 | + | } |
|
137 | + | } |
|
138 | + | } |
|
139 | + | dat[, object$spde$xy_cols] <- NULL |
|
140 | + | dat[[object$time]] <- NULL |
|
141 | + | for (i in seq_len(ncol(dat))) { |
|
142 | + | if (names(dat)[i] != sel_name) { |
|
143 | + | if (is.factor(dat[, i, drop = TRUE])) { |
|
144 | + | nd[[names(dat)[[i]]]] <- sort(dat[, i, drop = TRUE])[[1]] # TODO note! |
|
145 | + | } else { |
|
146 | + | nd[[names(dat)[[i]]]] <- mean(dat[, i, drop = TRUE], na.rm = TRUE) # TODO note! |
|
147 | + | } |
|
148 | + | } |
|
149 | + | } |
|
150 | + | nd[object$time] <- min(object$data[[object$time]], na.rm = TRUE) # TODO note! |
|
151 | + | nd[[object$spde$xy_cols[1]]] <- mean(object$data[[object$spde$xy_cols[1]]], na.rm = TRUE) # TODO note! |
|
152 | + | nd[[object$spde$xy_cols[2]]] <- mean(object$data[[object$spde$xy_cols[2]]], na.rm = TRUE) # TODO note! |
|
153 | + | ||
154 | + | p <- predict(object, newdata = nd, se_fit = se, re_form = NA) |
|
155 | + | if (return_data) { |
|
156 | + | return(p) |
|
157 | + | } |
|
158 | + | inv <- object$family$linkinv |
|
159 | + | qv <- stats::qnorm(1 - (1 - level) / 2) |
|
160 | + | ||
161 | + | if (!ggplot) { |
|
162 | + | if (se) { |
|
163 | + | lwr <- inv(p$est - qv * p$est_se) |
|
164 | + | upr <- inv(p$est + qv * p$est_se) |
|
165 | + | ylim <- range(c(lwr, upr)) |
|
166 | + | } else { |
|
167 | + | ylim <- range(p$est) |
|
168 | + | } |
|
169 | + | plot(nd[[sel_name]], inv(p$est), |
|
170 | + | type = "l", ylim = ylim, |
|
171 | + | xlab = sel_name, ylab = paste0("s(", sel_name, ")") |
|
172 | + | ) |
|
173 | + | if (se) { |
|
174 | + | graphics::lines(nd[[sel_name]], lwr, lty = 2) |
|
175 | + | graphics::lines(nd[[sel_name]], upr, lty = 2) |
|
176 | + | } |
|
177 | + | if (rug) rug(object$data[[sel_name]]) |
|
178 | + | } else { |
|
179 | + | g <- ggplot2::ggplot(p, ggplot2::aes_string(sel_name, "inv(est)", |
|
180 | + | ymin = "inv(est - qv * est_se)", ymax = "inv(est + qv * est_se)" |
|
181 | + | )) + |
|
182 | + | ggplot2::geom_line() + |
|
183 | + | ggplot2::geom_ribbon(alpha = 0.4) + |
|
184 | + | ggplot2::labs(x = sel_name, y = paste0("s(", sel_name, ")")) |
|
185 | + | if (rug) { |
|
186 | + | g <- g + |
|
187 | + | ggplot2::geom_rug( |
|
188 | + | data = object$data, mapping = ggplot2::aes_string(x = sel_name), |
|
189 | + | sides = "b", inherit.aes = FALSE, alpha = 0.3 |
|
190 | + | ) |
|
191 | + | } |
|
192 | + | return(g) |
|
193 | + | } |
|
194 | + | } |
4 | 4 | ||
5 | 5 | #' Predict from an sdmTMB model |
|
6 | 6 | #' |
|
7 | - | #' Can predict on the original data locations or on new data. |
|
7 | + | #' Make predictions from an sdmTMB model; can predict on the original or new |
|
8 | + | #' data. |
|
8 | 9 | #' |
|
9 | 10 | #' @param object An object from [sdmTMB()]. |
|
10 | - | #' @param newdata An optional new data frame. This should be a data frame with |
|
11 | - | #' the same predictor columns as in the fitted data and a time column (if this |
|
12 | - | #' is a spatiotemporal model) with the same name as in the fitted data. There |
|
13 | - | #' should be predictor data for each year in the original data set. |
|
14 | - | #' @param se_fit Should standard errors on predictions at the new locations given by |
|
15 | - | #' `newdata` be calculated? Warning: the current implementation can be slow for |
|
16 | - | #' large data sets or high-resolution projections. |
|
17 | - | #' @param xy_cols A character vector of length 2 that gives the column names of |
|
18 | - | #' the x and y coordinates in `newdata`. |
|
19 | - | #' @param return_tmb_object Logical. If `TRUE`, will include the TMB object in |
|
20 | - | #' a list format output. Necessary for the [get_index()] or [get_cog()] functions. |
|
21 | - | #' @param area A vector of areas for survey grid cells. Only necessary if the |
|
22 | - | #' output will be passed to [get_index()] or [get_cog()]. Should be the same length |
|
23 | - | #' as the number of rows of `newdata`. Defaults to a sequence of 1s. |
|
24 | - | #' @param re_form `NULL` to specify individual-level predictions. `~0` or `NA` |
|
25 | - | #' for population-level predictions. Note that unlike lme4 or glmmTMB, this |
|
26 | - | #' only affects what the standard errors are calculated on if `se_fit = TRUE`. |
|
27 | - | #' Otherwise, predictions at various levels are returned in all cases. |
|
11 | + | #' @param newdata A data frame to make predictions on. This should be a data |
|
12 | + | #' frame with the same predictor columns as in the fitted data and a time |
|
13 | + | #' column (if this is a spatiotemporal model) with the same name as in the |
|
14 | + | #' fitted data. There should be predictor data for each year in the original |
|
15 | + | #' data set. |
|
16 | + | #' @param type Should the `est` column be in link (default) or response space? |
|
17 | + | #' @param se_fit Should standard errors on predictions at the new locations |
|
18 | + | #' given by `newdata` be calculated? Warning: the current implementation can |
|
19 | + | #' be very slow for large data sets or high-resolution projections. A *much* |
|
20 | + | #' faster option is to use the `nsim` argument below and calculate uncertainty |
|
21 | + | #' on the simulations from the joint precision matrix. |
|
22 | + | #' @param return_tmb_object Logical. If `TRUE`, will include the TMB object in a |
|
23 | + | #' list format output. Necessary for the [get_index()] or [get_cog()] |
|
24 | + | #' functions. |
|
25 | + | #' @param re_form `NULL` to specify including all spatial/spatiotemporal random |
|
26 | + | #' effects in predictions. `~0` or `NA` for population-level predictions. Note |
|
27 | + | #' that unlike lme4 or glmmTMB, this only affects what the standard errors are |
|
28 | + | #' calculated on if `se_fit = TRUE`. This does not affect [get_index()] |
|
29 | + | #' calculations. |
|
30 | + | #' @param re_form_iid `NULL` to specify including all random intercepts in the |
|
31 | + | #' predictions. `~0` or `NA` for population-level predictions. No other |
|
32 | + | #' options (e.g., some but not all random intercepts) are implemented yet. |
|
33 | + | #' Only affects predictions with `newdata`. This also affects [get_index()]. |
|
34 | + | #' @param nsim **Experimental.** If > 0, simulate from the joint precision matrix with `sims` |
|
35 | + | #' draws Returns a matrix of `nrow(data)` by `sim` representing the estimates |
|
36 | + | #' of the linear predictor (i.e., in link space). Can be useful for deriving |
|
37 | + | #' uncertainty on predictions (e.g., `apply(x, 1, sd)`) or propagating |
|
38 | + | #' uncertainty. This is currently the fastest way to generate estimates of |
|
39 | + | #' uncertainty on predictions in space with sdmTMB. |
|
40 | + | #' @param area **Deprecated**. Please use `area` in [get_index()]. |
|
41 | + | #' @param sims **Deprecated**. Please use `nsim` instead. |
|
42 | + | #' @param sims_var **Experimental.** Which TMB reported variable from the model |
|
43 | + | #' should be extracted from the joint precision matrix simulation draws? |
|
44 | + | #' Defaults to the link-space predictions. Options include: `"omega_s"`, |
|
45 | + | #' `"zeta_s"`, `"epsilon_st"`, and `"est_rf"` (as described below). |
|
46 | + | #' Other options will be passed verbatim. |
|
47 | + | #' @param tmbstan_model A model fit with [tmbstan::tmbstan()]. See |
|
48 | + | #' [extract_mcmc()] for more details and an example. If specified, the |
|
49 | + | #' predict function will return a matrix of a similar form as if `nsim > 0` |
|
50 | + | #' but representing Bayesian posterior samples from the Stan model. |
|
51 | + | #' @param model Type of prediction if a delta/hurdle model: |
|
52 | + | #' `NA` returns the combined prediction from both components on |
|
53 | + | #' the response scale; `1` or `2` return the first or second model |
|
54 | + | #' component only on the link or response scale depending on the argument |
|
55 | + | #' `type`. |
|
56 | + | #' @param return_tmb_report Logical: return the output from the TMB |
|
57 | + | #' report? For regular prediction this is all the reported variables |
|
58 | + | #' at the MLE parameter values. For `nsim > 0` or when `tmbstan_model` |
|
59 | + | #' is supplied, this is a list where each element is a sample and the |
|
60 | + | #' contents of each element is the output of the report for that sample. |
|
28 | 61 | #' @param ... Not implemented. |
|
29 | 62 | #' |
|
30 | 63 | #' @return |
|
31 | - | #' If `return_tmb_object = FALSE`: |
|
64 | + | #' If `return_tmb_object = FALSE` (and `nsim = 0` and `tmbstan_model = NULL`): |
|
65 | + | #' |
|
32 | 66 | #' A data frame: |
|
33 | 67 | #' * `est`: Estimate in link space (everything is in link space) |
|
34 | 68 | #' * `est_non_rf`: Estimate from everything that isn't a random field |
|
35 | 69 | #' * `est_rf`: Estimate from all random fields combined |
|
36 | 70 | #' * `omega_s`: Spatial (intercept) random field that is constant through time |
|
37 | 71 | #' * `zeta_s`: Spatial slope random field |
|
38 | - | #' * `epsilon_st`: Spatiotemporal (intercept) random fields (could be |
|
39 | - | #' independent draws each year or AR1) |
|
72 | + | #' * `epsilon_st`: Spatiotemporal (intercept) random fields, could be |
|
73 | + | #' off (zero), IID, AR1, or random walk |
|
74 | + | #' |
|
75 | + | #' If `return_tmb_object = TRUE` (and `nsim = 0` and `tmbstan_model = NULL`): |
|
40 | 76 | #' |
|
41 | - | #' If `return_tmb_object = TRUE`: |
|
42 | 77 | #' A list: |
|
43 | 78 | #' * `data`: The data frame described above |
|
44 | 79 | #' * `report`: The TMB report on parameter values |
|
45 | - | #' * `obj`: The TMB object returned from the prediction run. |
|
46 | - | #' * `fit_obj`: The original TMB model object. |
|
80 | + | #' * `obj`: The TMB object returned from the prediction run |
|
81 | + | #' * `fit_obj`: The original TMB model object |
|
82 | + | #' |
|
83 | + | #' In this case, you likely only need the `data` element as an end user. |
|
84 | + | #' The other elements are included for other functions. |
|
85 | + | #' |
|
86 | + | #' If `nsim > 0` or `tmbstan_model` is not `NULL`: |
|
47 | 87 | #' |
|
48 | - | #' You likely only need the `data` element as an end user. The other elements |
|
49 | - | #' are included for other functions. |
|
88 | + | #' A matrix: |
|
89 | + | #' |
|
90 | + | #' * Columns represent samples |
|
91 | + | #' * Rows represent predictions with one row per row of `newdata` |
|
50 | 92 | #' |
|
51 | 93 | #' @export |
|
52 | 94 | #' |
|
53 | 95 | #' @examples |
|
54 | - | #' # We'll only use a small number of knots so this example runs quickly |
|
55 | - | #' # but you will likely want to use many more in applied situations. |
|
96 | + | #' if (require("ggplot2", quietly = TRUE) && inla_installed()) { |
|
56 | 97 | #' |
|
57 | - | #' library(ggplot2) |
|
58 | - | #' d <- pcod |
|
59 | - | #' pcod_spde <- make_spde(d$X, d$Y, n_knots = 50) # just 50 for example speed |
|
98 | + | #' d <- pcod_2011 |
|
99 | + | #' mesh <- make_mesh(d, c("X", "Y"), cutoff = 30) # a coarse mesh for example speed |
|
60 | 100 | #' m <- sdmTMB( |
|
61 | 101 | #' data = d, formula = density ~ 0 + as.factor(year) + depth_scaled + depth_scaled2, |
|
62 | - | #' time = "year", spde = pcod_spde, family = tweedie(link = "log") |
|
102 | + | #' time = "year", mesh = mesh, family = tweedie(link = "log") |
|
63 | 103 | #' ) |
|
64 | 104 | #' |
|
65 | 105 | #' # Predictions at original data locations ------------------------------- |
68 | 108 | #' head(predictions) |
|
69 | 109 | #' |
|
70 | 110 | #' predictions$resids <- residuals(m) # randomized quantile residuals |
|
111 | + | #' |
|
71 | 112 | #' ggplot(predictions, aes(X, Y, col = resids)) + scale_colour_gradient2() + |
|
72 | 113 | #' geom_point() + facet_wrap(~year) |
|
73 | 114 | #' hist(predictions$resids) |
|
74 | 115 | #' qqnorm(predictions$resids);abline(a = 0, b = 1) |
|
75 | 116 | #' |
|
76 | 117 | #' # Predictions onto new data -------------------------------------------- |
|
77 | 118 | #' |
|
78 | - | #' predictions <- predict(m, newdata = qcs_grid) |
|
119 | + | #' qcs_grid_2011 <- subset(qcs_grid, year >= min(pcod_2011$year)) |
|
120 | + | #' predictions <- predict(m, newdata = qcs_grid_2011) |
|
79 | 121 | #' |
|
80 | 122 | #' # A short function for plotting our predictions: |
|
81 | 123 | #' plot_map <- function(dat, column = "est") { |
105 | 147 | #' ggtitle("Spatiotemporal random effects only") + |
|
106 | 148 | #' scale_fill_gradient2() |
|
107 | 149 | #' |
|
108 | - | #' \donttest{ |
|
109 | 150 | #' # Visualizing a marginal effect ---------------------------------------- |
|
110 | - | #' # Also demonstrates getting standard errors on population-level predictions |
|
111 | 151 | #' |
|
112 | 152 | #' nd <- data.frame(depth_scaled = |
|
113 | 153 | #' seq(min(d$depth_scaled), max(d$depth_scaled), length.out = 100)) |
|
114 | 154 | #' nd$depth_scaled2 <- nd$depth_scaled^2 |
|
115 | 155 | #' |
|
116 | - | #' # You'll need at least one time element. If time isn't also a fixed effect |
|
117 | - | #' # then it doesn't matter what you pick: |
|
118 | - | #' nd$year <- 2003L |
|
156 | + | #' # Because this is a spatiotemporal model, you'll need at least one time |
|
157 | + | #' # element. If time isn't also a fixed effect then it doesn't matter what you pick: |
|
158 | + | #' nd$year <- 2011L # L: integer to match original data |
|
119 | 159 | #' p <- predict(m, newdata = nd, se_fit = TRUE, re_form = NA) |
|
120 | 160 | #' ggplot(p, aes(depth_scaled, exp(est), |
|
121 | 161 | #' ymin = exp(est - 1.96 * est_se), ymax = exp(est + 1.96 * est_se))) + |
124 | 164 | #' # Plotting marginal effect of a spline --------------------------------- |
|
125 | 165 | #' |
|
126 | 166 | #' m_gam <- sdmTMB( |
|
127 | - | #' data = d, formula = density ~ 0 + as.factor(year) + s(depth_scaled, k = 3), |
|
128 | - | #' time = "year", spde = pcod_spde, family = tweedie(link = "log") |
|
167 | + | #' data = d, formula = density ~ 0 + as.factor(year) + s(depth_scaled, k = 5), |
|
168 | + | #' time = "year", mesh = mesh, family = tweedie(link = "log") |
|
129 | 169 | #' ) |
|
170 | + | #' plot_smooth(m_gam) |
|
171 | + | #' |
|
172 | + | #' # or manually: |
|
130 | 173 | #' nd <- data.frame(depth_scaled = |
|
131 | 174 | #' seq(min(d$depth_scaled), max(d$depth_scaled), length.out = 100)) |
|
132 | - | #' nd$year <- 2003L |
|
175 | + | #' nd$year <- 2011L |
|
133 | 176 | #' p <- predict(m_gam, newdata = nd, se_fit = TRUE, re_form = NA) |
|
134 | 177 | #' ggplot(p, aes(depth_scaled, exp(est), |
|
135 | 178 | #' ymin = exp(est - 1.96 * est_se), ymax = exp(est + 1.96 * est_se))) + |
|
136 | 179 | #' geom_line() + geom_ribbon(alpha = 0.4) |
|
137 | 180 | #' |
|
138 | 181 | #' # Forecasting ---------------------------------------------------------- |
|
182 | + | #' mesh <- make_mesh(d, c("X", "Y"), cutoff = 15) |
|
139 | 183 | #' |
|
140 | - | #' # Forecasting using Eric Ward's hack with the `weights` argument. |
|
141 | - | #' |
|
142 | - | #' # Add on a fake year of data with the year to forecast: |
|
143 | - | #' dfake <- rbind(d, d[nrow(d), ]) |
|
144 | - | #' dfake[nrow(dfake), "year"] <- max(d$year) + 1 |
|
145 | - | #' tail(dfake) # last row is fake! |
|
146 | - | #' |
|
147 | - | #' weights <- rep(1, nrow(dfake)) |
|
148 | - | #' weights[length(weights)] <- 0 # set last row weight to 0 |
|
149 | - | #' dfake$year_factor <- dfake$year |
|
150 | - | #' dfake$year_factor[nrow(dfake)] <- max(d$year) # share fixed effect for last 2 years |
|
151 | - | #' |
|
152 | - | #' pcod_spde <- make_spde(dfake$X, dfake$Y, n_knots = 50) |
|
153 | - | #' |
|
184 | + | #' unique(d$year) |
|
154 | 185 | #' m <- sdmTMB( |
|
155 | - | #' data = dfake, formula = density ~ 0 + as.factor(year_factor), |
|
156 | - | #' ar1_fields = TRUE, # using an AR1 to have something to forecast with |
|
157 | - | #' weights = weights, |
|
158 | - | #' include_spatial = TRUE, # could also be `FALSE` |
|
159 | - | #' time = "year", spde = pcod_spde, family = tweedie(link = "log") |
|
186 | + | #' data = d, formula = density ~ 1, |
|
187 | + | #' spatiotemporal = "AR1", # using an AR1 to have something to forecast with |
|
188 | + | #' extra_time = 2019L, # `L` for integer to match our data |
|
189 | + | #' spatial = "off", |
|
190 | + | #' time = "year", mesh = mesh, family = tweedie(link = "log") |
|
160 | 191 | #' ) |
|
161 | 192 | #' |
|
162 | 193 | #' # Add a year to our grid: |
|
163 | - | #' qcs_grid$year_factor <- qcs_grid$year |
|
164 | - | #' grid2018 <- qcs_grid[qcs_grid$year == 2017L, ] |
|
165 | - | #' grid2018$year <- 2018L # `L` because `year` is an integer in the data |
|
166 | - | #' qcsgrid_forecast <- rbind(qcs_grid, grid2018) |
|
194 | + | #' grid2019 <- qcs_grid_2011[qcs_grid_2011$year == max(qcs_grid_2011$year), ] |
|
195 | + | #' grid2019$year <- 2019L # `L` because `year` is an integer in the data |
|
196 | + | #' qcsgrid_forecast <- rbind(qcs_grid_2011, grid2019) |
|
167 | 197 | #' |
|
168 | 198 | #' predictions <- predict(m, newdata = qcsgrid_forecast) |
|
169 | 199 | #' plot_map(predictions, "exp(est)") + |
173 | 203 | #' |
|
174 | 204 | #' # Estimating local trends ---------------------------------------------- |
|
175 | 205 | #' |
|
176 | - | #' pcod_spde <- make_spde(d$X, d$Y, n_knots = 100) |
|
177 | - | #' m <- sdmTMB(data = pcod, formula = density ~ depth_scaled + depth_scaled2, |
|
178 | - | #' spde = pcod_spde, family = tweedie(link = "log"), |
|
179 | - | #' spatial_trend = TRUE, time = "year", spatial_only = TRUE) |
|
180 | - | #' p <- predict(m, newdata = qcs_grid) |
|
206 | + | #' d <- pcod |
|
207 | + | #' d$year_scaled <- as.numeric(scale(d$year)) |
|
208 | + | #' mesh <- make_mesh(pcod, c("X", "Y"), cutoff = 25) |
|
209 | + | #' m <- sdmTMB(data = d, formula = density ~ depth_scaled + depth_scaled2, |
|
210 | + | #' mesh = mesh, family = tweedie(link = "log"), |
|
211 | + | #' spatial_varying = ~ 0 + year_scaled, time = "year", spatiotemporal = "off") |
|
212 | + | #' nd <- qcs_grid |
|
213 | + | #' nd$year_scaled <- (nd$year - mean(d$year)) / sd(d$year) |
|
214 | + | #' p <- predict(m, newdata = nd) |
|
181 | 215 | #' |
|
182 | 216 | #' plot_map(p, "zeta_s") + |
|
183 | 217 | #' ggtitle("Spatial slopes") + |
196 | 230 | #' scale_fill_viridis_c(trans = "sqrt") |
|
197 | 231 | #' } |
|
198 | 232 | ||
199 | - | predict.sdmTMB <- function(object, newdata = NULL, se_fit = FALSE, |
|
200 | - | xy_cols = c("X", "Y"), return_tmb_object = FALSE, |
|
201 | - | area = 1, re_form = NULL, ...) { |
|
233 | + | predict.sdmTMB <- function(object, newdata = object$data, |
|
234 | + | type = c("link", "response"), |
|
235 | + | se_fit = FALSE, |
|
236 | + | return_tmb_object = FALSE, |
|
237 | + | area = deprecated(), re_form = NULL, re_form_iid = NULL, nsim = 0, |
|
238 | + | sims = deprecated(), |
|
239 | + | tmbstan_model = NULL, |
|
240 | + | sims_var = "est", |
|
241 | + | model = c(NA, 1, 2), |
|
242 | + | return_tmb_report = FALSE, |
|
243 | + | ...) { |
|
202 | 244 | ||
203 | - | test <- suppressWarnings(tryCatch(object$tmb_obj$report(), error = function(e) NA)) |
|
204 | - | if (all(is.na(test))) object <- update_model(object) |
|
245 | + | if ("version" %in% names(object)) { |
|
246 | + | check_sdmTMB_version(object$version) |
|
247 | + | } else { |
|
248 | + | cli_abort(c("This looks like a very old version of a model fit.", |
|
249 | + | "Re-fit the model before predicting with it.")) |
|
250 | + | } |
|
251 | + | if (!"xy_cols" %in% names(object$spde)) { |
|
252 | + | cli_warn(c("It looks like this model was fit with make_spde().", |
|
253 | + | "Using `xy_cols`, but future versions of sdmTMB may not be compatible with this.", |
|
254 | + | "Please replace make_spde() with make_mesh().")) |
|
255 | + | } else { |
|
256 | + | xy_cols <- object$spde$xy_cols |
|
257 | + | } |
|
258 | + | ||
259 | + | if (is_present(area)) { |
|
260 | + | deprecate_warn("0.0.22", "predict.sdmTMB(area)", "get_index(area)") |
|
261 | + | } else { |
|
262 | + | area <- 1 |
|
263 | + | } |
|
264 | + | ||
265 | + | if (is_present(sims)) { |
|
266 | + | deprecate_warn("0.0.21", "predict.sdmTMB(sims)", "predict.sdmTMB(nsim)") |
|
267 | + | } else { |
|
268 | + | sims <- nsim |
|
269 | + | } |
|
270 | + | ||
271 | + | type <- match.arg(type) |
|
272 | + | ||
273 | + | n_orig <- suppressWarnings(TMB::openmp(NULL)) |
|
274 | + | if (n_orig > 0 && .Platform$OS.type == "unix") { # openMP is supported |
|
275 | + | TMB::openmp(n = object$control$parallel) |
|
276 | + | on.exit({TMB::openmp(n = n_orig)}) |
|
277 | + | } |
|
278 | + | ||
279 | + | sys_calls <- unlist(lapply(sys.calls(), deparse)) # retrieve function that called this |
|
280 | + | vr <- check_visreg(sys_calls) |
|
281 | + | visreg_df <- vr$visreg_df |
|
282 | + | if (visreg_df) { |
|
283 | + | re_form <- vr$re_form |
|
284 | + | se_fit <- vr$se_fit |
|
285 | + | } |
|
205 | 286 | ||
206 | 287 | # from glmmTMB: |
|
207 | 288 | pop_pred <- (!is.null(re_form) && ((re_form == ~0) || identical(re_form, NA))) |
|
289 | + | pop_pred_iid <- (!is.null(re_form_iid) && ((re_form_iid == ~0) || identical(re_form_iid, NA))) |
|
290 | + | if (pop_pred_iid) { |
|
291 | + | exclude_RE <- rep(1L, length(object$tmb_data$exclude_RE)) |
|
292 | + | } else { |
|
293 | + | exclude_RE <- object$tmb_data$exclude_RE |
|
294 | + | } |
|
208 | 295 | ||
209 | 296 | tmb_data <- object$tmb_data |
|
210 | 297 | tmb_data$do_predict <- 1L |
|
211 | 298 | ||
212 | 299 | if (!is.null(newdata)) { |
|
213 | 300 | if (any(!xy_cols %in% names(newdata)) && isFALSE(pop_pred)) |
|
214 | - | stop("`xy_cols` (the column names for the x and y coordinates) ", |
|
215 | - | "are not in `newdata`. Did you miss specifying the argument ", |
|
216 | - | "`xy_cols` to match your data?", call. = FALSE) |
|
301 | + | cli_abort(c("`xy_cols` (the column names for the x and y coordinates) are not in `newdata`.", |
|
302 | + | "Did you miss specifying the argument `xy_cols` to match your data?", |
|
303 | + | "The newer `make_mesh()` (vs. `make_spde()`) takes care of this for you.")) |
|
217 | 304 | ||
218 | 305 | if (object$time == "_sdmTMB_time") newdata[[object$time]] <- 0L |
|
219 | - | if (!identical(class(object$data[[object$time]]), class(newdata[[object$time]]))) |
|
220 | - | stop("Class of fitted time column does not match class of `newdata` time column.", |
|
221 | - | call. = FALSE) |
|
222 | - | original_time <- sort(unique(object$data[[object$time]])) |
|
223 | - | new_data_time <- sort(unique(newdata[[object$time]])) |
|
306 | + | ||
307 | + | check_time_class(object, newdata) |
|
308 | + | original_time <- as.numeric(sort(unique(object$data[[object$time]]))) |
|
309 | + | new_data_time <- as.numeric(sort(unique(newdata[[object$time]]))) |
|
224 | 310 | ||
225 | 311 | if (!all(new_data_time %in% original_time)) |
|
226 | - | stop("Some new time elements were found in `newdata`. ", |
|
227 | - | "For now, make sure only time elements from the original dataset are present. If you would like to predict on new time elements, see the example hack with the `weights` argument in the help for `?predict.sdmTMB`.", |
|
228 | - | call. = FALSE) |
|
229 | - | ||
230 | - | if (!all(original_time %in% new_data_time)) { |
|
231 | - | newdata[["sdmTMB_fake_year"]] <- FALSE |
|
232 | - | missing_time_elements <- original_time[!original_time %in% new_data_time] |
|
233 | - | nd2 <- do.call("rbind", |
|
234 | - | replicate(length(missing_time_elements), newdata[1L,,drop=FALSE], simplify = FALSE)) |
|
235 | - | nd2[[object$time]] <- rep(missing_time_elements, each = 1L) |
|
236 | - | nd2$sdmTMB_fake_year <- TRUE |
|
237 | - | newdata <- rbind(newdata, nd2) |
|
312 | + | cli_abort(c("Some new time elements were found in `newdata`. ", |
|
313 | + | "For now, make sure only time elements from the original dataset are present.", |
|
314 | + | "If you would like to predict on new time elements,", |
|
315 | + | "see the `extra_time` argument in `?predict.sdmTMB`.") |
|
316 | + | ) |
|
317 | + | ||
318 | + | if (!identical(new_data_time, original_time) & isFALSE(pop_pred)) { |
|
319 | + | cli_abort(c("The time elements in `newdata` are not identical to those in the original dataset.", |
|
320 | + | "For now, please predict on all time elements and filter out those you don't need after.", |
|
321 | + | "Please let us know on the GitHub issues tracker if this is important to you.")) |
|
238 | 322 | } |
|
239 | 323 | ||
240 | 324 | # If making population predictions (with standard errors), we don't need |
|
241 | 325 | # to worry about space, so fill in dummy values if the user hasn't made any: |
|
242 | 326 | fake_spatial_added <- FALSE |
|
243 | 327 | if (pop_pred) { |
|
244 | - | for (i in 1:2) { |
|
328 | + | for (i in c(1, 2)) { |
|
245 | 329 | if (!xy_cols[[i]] %in% names(newdata)) { |
|
246 | 330 | newdata[[xy_cols[[i]]]] <- mean(object$data[[xy_cols[[i]]]], na.rm = TRUE) |
|
247 | 331 | fake_spatial_added <- TRUE |
|
248 | 332 | } |
|
249 | 333 | } |
|
250 | 334 | } |
|
251 | 335 | ||
252 | - | if (sum(is.na(new_data_time)) > 1) |
|
253 | - | stop("There is at least one NA value in the time column. ", |
|
254 | - | "Please remove it.", call. = FALSE) |
|
255 | - | ||
256 | - | newdata$sdm_orig_id <- seq(1, nrow(newdata)) |
|
257 | - | fake_newdata <- unique(newdata[,xy_cols]) |
|
258 | - | fake_newdata[["sdm_spatial_id"]] <- seq(1, nrow(fake_newdata)) - 1L |
|
336 | + | if (sum(is.na(new_data_time)) > 0) |
|
337 | + | cli_abort(c("There is at least one NA value in the time column.", |
|
338 | + | "Please remove it.")) |
|
259 | 339 | ||
260 | - | newdata <- base::merge(newdata, fake_newdata, by = xy_cols, |
|
261 | - | all.x = TRUE, all.y = FALSE) |
|
262 | - | newdata <- newdata[order(newdata$sdm_orig_id),, drop=FALSE] |
|
340 | + | # newdata$sdm_orig_id <- seq(1, nrow(newdata)) |
|
341 | + | # fake_newdata <- unique(newdata[,xy_cols]) |
|
342 | + | # fake_newdata[["sdm_spatial_id"]] <- seq(1, nrow(fake_newdata)) - 1L |
|
343 | + | newdata$sdm_spatial_id <- seq(1, nrow(newdata)) # FIXME doing nothing now |
|
344 | + | # |
|
345 | + | # newdata <- base::merge(newdata, fake_newdata, by = xy_cols, |
|
346 | + | # all.x = TRUE, all.y = FALSE) |
|
347 | + | # newdata <- newdata[order(newdata$sdm_orig_id),, drop = FALSE] |
|
263 | 348 | ||
264 | 349 | proj_mesh <- INLA::inla.spde.make.A(object$spde$mesh, |
|
265 | - | loc = as.matrix(fake_newdata[,xy_cols, drop = FALSE])) |
|
350 | + | loc = as.matrix(newdata[,xy_cols, drop = FALSE])) |
|
266 | 351 | ||
267 | - | # this formula has breakpt() etc. in it: |
|
268 | - | thresh <- check_and_parse_thresh_params(object$formula, newdata) |
|
269 | - | formula <- thresh$formula # this one does not |
|
352 | + | if (length(object$formula) == 1L) { |
|
353 | + | # this formula has breakpt() etc. in it: |
|
354 | + | thresh <- list(check_and_parse_thresh_params(object$formula[[1]], newdata)) |
|
355 | + | formula <- list(thresh[[1]]$formula) # this one does not |
|
356 | + | } else { |
|
357 | + | thresh <- list(check_and_parse_thresh_params(object$formula[[1]], newdata), |
|
358 | + | check_and_parse_thresh_params(object$formula[[2]], newdata)) |
|
359 | + | formula <- list(thresh[[1]]$formula, thresh[[2]]$formula) |
|
360 | + | } |
|
270 | 361 | ||
271 | 362 | nd <- newdata |
|
272 | - | response <- get_response(object$formula) |
|
363 | + | response <- get_response(object$formula[[1]]) |
|
273 | 364 | sdmTMB_fake_response <- FALSE |
|
274 | 365 | if (!response %in% names(nd)) { |
|
275 | 366 | nd[[response]] <- 0 # fake for model.matrix |
|
276 | 367 | sdmTMB_fake_response <- TRUE |
|
277 | 368 | } |
|
278 | 369 | ||
279 | 370 | if (!"mgcv" %in% names(object)) object[["mgcv"]] <- FALSE |
|
280 | - | proj_X_ij <- matrix(999) |
|
281 | - | if (!object$mgcv) { |
|
282 | - | proj_X_ij <- tryCatch({model.matrix(object$formula, data = nd)}, |
|
283 | - | error = function(e) NA) |
|
284 | - | } |
|
285 | - | if (object$mgcv || identical(proj_X_ij, NA)) { |
|
286 | - | proj_X_ij <- mgcv::predict.gam(object$mgcv_mod, type = "lpmatrix", newdata = nd) |
|
371 | + | ||
372 | + | # deal with prediction IID random intercepts: |
|
373 | + | RE_names <- barnames(object$split_formula[[1]]$reTrmFormulas) # TODO DELTA HARDCODED TO 1 here; fine for now |
|
374 | + | ## not checking so that not all factors need to be in prediction: |
|
375 | + | # fct_check <- vapply(RE_names, function(x) check_valid_factor_levels(data[[x]], .name = x), TRUE) |
|
376 | + | proj_RE_indexes <- vapply(RE_names, function(x) as.integer(nd[[x]]) - 1L, rep(1L, nrow(nd))) |
|
377 | + | ||
378 | + | proj_X_ij <- list() |
|
379 | + | for (i in seq_along(object$formula)) { |
|
380 | + | f2 <- remove_s_and_t2(object$split_formula[[i]]$fixedFormula) |
|
381 | + | tt <- stats::terms(f2) |
|
382 | + | attr(tt, "predvars") <- attr(object$terms[[i]], "predvars") |
|
383 | + | Terms <- stats::delete.response(tt) |
|
384 | + | mf <- model.frame(Terms, newdata, xlev = object$xlevels[[i]]) |
|
385 | + | proj_X_ij[[i]] <- model.matrix(Terms, mf, contrasts.arg = object$contrasts[[i]]) |
|
287 | 386 | } |
|
387 | + | ||
388 | + | # TODO DELTA hardcoded to 1: |
|
389 | + | sm <- parse_smoothers(object$formula[[1]], data = object$data, newdata = nd) |
|
390 | + | ||
288 | 391 | if (!is.null(object$time_varying)) |
|
289 | 392 | proj_X_rw_ik <- model.matrix(object$time_varying, data = nd) |
|
290 | 393 | else |
|
291 | 394 | proj_X_rw_ik <- matrix(0, ncol = 1, nrow = 1) # dummy |
|
292 | 395 | ||
293 | - | tmb_data$proj_X_threshold <- thresh$X_threshold |
|
294 | - | tmb_data$area_i <- if (length(area) == 1L && area[[1]] == 1) rep(1, nrow(proj_X_ij)) else area |
|
396 | + | ||
397 | + | if (length(area) != nrow(proj_X_ij[[1]]) && length(area) != 1L) { |
|
398 | + | cli_abort("`area` should be of the same length as `nrow(newdata)` or of length 1.") |
|
399 | + | } |
|
400 | + | ||
401 | + | tmb_data$proj_X_threshold <- thresh[[1]]$X_threshold # TODO DELTA HARDCODED TO 1 |
|
402 | + | tmb_data$area_i <- if (length(area) == 1L) rep(area, nrow(proj_X_ij[[1]])) else area |
|
295 | 403 | tmb_data$proj_mesh <- proj_mesh |
|
296 | 404 | tmb_data$proj_X_ij <- proj_X_ij |
|
297 | 405 | tmb_data$proj_X_rw_ik <- proj_X_rw_ik |
|
406 | + | tmb_data$proj_RE_indexes <- proj_RE_indexes |
|
298 | 407 | tmb_data$proj_year <- make_year_i(nd[[object$time]]) |
|
299 | 408 | tmb_data$proj_lon <- newdata[[xy_cols[[1]]]] |
|
300 | 409 | tmb_data$proj_lat <- newdata[[xy_cols[[2]]]] |
|
301 | 410 | tmb_data$calc_se <- as.integer(se_fit) |
|
302 | 411 | tmb_data$pop_pred <- as.integer(pop_pred) |
|
303 | - | tmb_data$calc_time_totals <- as.integer(!se_fit) |
|
304 | - | tmb_data$proj_spatial_index <- newdata$sdm_spatial_id |
|
305 | - | tmb_data$proj_t_i <- as.numeric(newdata[[object$time]]) |
|
306 | - | tmb_data$proj_t_i <- tmb_data$proj_t_i - mean(unique(tmb_data$proj_t_i)) # center on mean |
|
412 | + | tmb_data$exclude_RE <- exclude_RE |
|
413 | + | # tmb_data$calc_index_totals <- as.integer(!se_fit) |
|
414 | + | # tmb_data$calc_cog <- as.integer(!se_fit) |
|
415 | + | tmb_data$proj_spatial_index <- newdata$sdm_spatial_id - 1L |
|
416 | + | tmb_data$proj_Zs <- sm$Zs |
|
417 | + | tmb_data$proj_Xs <- sm$Xs |
|
418 | + | ||
419 | + | # SVC: |
|
420 | + | if (!is.null(object$spatial_varying)) { |
|
421 | + | z_i <- model.matrix(object$spatial_varying_formula, newdata) |
|
422 | + | .int <- grep("(Intercept)", colnames(z_i)) |
|
423 | + | if (sum(.int) > 0) z_i <- z_i[,-.int,drop=FALSE] |
|
424 | + | } else { |
|
425 | + | z_i <- matrix(0, nrow(newdata), 0L) |
|
426 | + | } |
|
427 | + | tmb_data$proj_z_i <- z_i |
|
428 | + | ||
429 | + | epsilon_covariate <- rep(0, length(unique(newdata[[object$time]]))) |
|
430 | + | if (tmb_data$est_epsilon_model) { |
|
431 | + | # covariate vector dimensioned by number of time steps |
|
432 | + | time_steps <- unique(newdata[[object$time]]) |
|
433 | + | for (i in seq_along(time_steps)) { |
|
434 | + | epsilon_covariate[i] <- newdata[newdata[[object$time]] == time_steps[i], |
|
435 | + | object$epsilon_predictor, drop = TRUE][[1]] |
|
436 | + | } |
|
437 | + | } |
|
438 | + | tmb_data$epsilon_predictor <- epsilon_covariate |
|
439 | + | ||
307 | 440 | new_tmb_obj <- TMB::MakeADFun( |
|
308 | 441 | data = tmb_data, |
|
309 | - | parameters = object$tmb_obj$env$parList(), |
|
442 | + | parameters = get_pars(object), |
|
310 | 443 | map = object$tmb_map, |
|
311 | 444 | random = object$tmb_random, |
|
312 | 445 | DLL = "sdmTMB", |
316 | 449 | old_par <- object$model$par |
|
317 | 450 | # need to initialize the new TMB object once: |
|
318 | 451 | new_tmb_obj$fn(old_par) |
|
319 | - | lp <- new_tmb_obj$env$last.par.best |
|
320 | 452 | ||
453 | + | if (sims > 0 && is.null(tmbstan_model)) { |
|
454 | + | if (!"jointPrecision" %in% names(object$sd_report) && !has_no_random_effects(object)) { |
|
455 | + | message("Rerunning TMB::sdreport() with `getJointPrecision = TRUE`.") |
|
456 | + | sd_report <- TMB::sdreport(object$tmb_obj, getJointPrecision = TRUE) |
|
457 | + | } else { |
|
458 | + | sd_report <- object$sd_report |
|
459 | + | } |
|
460 | + | if (has_no_random_effects(object)) { |
|
461 | + | t_draws <- t(mvtnorm::rmvnorm(n = sims, mean = sd_report$par.fixed, |
|
462 | + | sigma = sd_report$cov.fixed)) |
|
463 | + | row.names(t_draws) <- NULL |
|
464 | + | } else { |
|
465 | + | t_draws <- rmvnorm_prec(mu = new_tmb_obj$env$last.par.best, |
|
466 | + | tmb_sd = sd_report, n_sims = sims) |
|
467 | + | } |
|
468 | + | r <- apply(t_draws, 2L, new_tmb_obj$report) |
|
469 | + | } |
|
470 | + | if (!is.null(tmbstan_model)) { |
|
471 | + | if (!"stanfit" %in% class(tmbstan_model)) |
|
472 | + | cli_abort("`tmbstan_model` must be output from `tmbstan::tmbstan()`.") |
|
473 | + | t_draws <- extract_mcmc(tmbstan_model) |
|
474 | + | if (nsim > 0) { |
|
475 | + | if (nsim > ncol(t_draws)) { |
|
476 | + | cli_abort("`nsim` must be <= number of MCMC samples.") |
|
477 | + | } else { |
|
478 | + | t_draws <- t_draws[,seq(ncol(t_draws) - nsim + 1, ncol(t_draws)), drop = FALSE] |
|
479 | + | } |
|
480 | + | } |
|
481 | + | r <- apply(t_draws, 2L, new_tmb_obj$report) |
|
482 | + | } |
|
483 | + | if (!is.null(tmbstan_model) || sims > 0) { |
|
484 | + | if (return_tmb_report) return(r) |
|
485 | + | .var <- switch(sims_var, |
|
486 | + | "est" = "proj_eta", |
|
487 | + | "est_rf" = "proj_rf", |
|
488 | + | "omega_s" = "proj_omega_s_A", |
|
489 | + | "zeta_s" = "proj_zeta_s_A", |
|
490 | + | "epsilon_st" = "proj_epsilon_st_A_vec", |
|
491 | + | sims_var) |
|
492 | + | out <- lapply(r, `[[`, .var) |
|
493 | + | ||
494 | + | if (isTRUE(object$family$delta)) { |
|
495 | + | assert_that(model[[1]] %in% c(NA, 1, 2), |
|
496 | + | msg = "`model` argument not valid; should be one of NA, 1, 2") |
|
497 | + | predtype <- as.integer(model[[1]]) |
|
498 | + | if (predtype %in% c(1L, NA)) { |
|
499 | + | out1 <- lapply(out, function(x) x[, 1L, drop = TRUE]) |
|
500 | + | out1 <- do.call("cbind", out1) |
|
501 | + | } |
|
502 | + | if (predtype %in% c(2L, NA)) { |
|
503 | + | out2 <- lapply(out, function(x) x[, 2L, drop = TRUE]) |
|
504 | + | out2 <- do.call("cbind", out2) |
|
505 | + | } |
|
506 | + | if (is.na(predtype)) { |
|
507 | + | out <- object$family[[1]]$linkinv(out1) * |
|
508 | + | object$family[[2]]$linkinv(out2) |
|
509 | + | } else if (predtype == 1L) { |
|
510 | + | out <- out1 |
|
511 | + | if (type == "response") out <- object$family[[1]]$linkinv(out) |
|
512 | + | } else if (predtype == 2L) { |
|
513 | + | out <- out2 |
|
514 | + | if (type == "response") out <- object$family[[2]]$linkinv(out) |
|
515 | + | } else { |
|
516 | + | cli_abort("`model` type not valid.") |
|
517 | + | } |
|
518 | + | } else { # not a delta model: |
|
519 | + | out <- do.call("cbind", out) |
|
520 | + | if (type == "response") out <- object$family$linkinv(out) |
|
521 | + | } |
|
522 | + | ||
523 | + | rownames(out) <- nd[[object$time]] # for use in index calcs |
|
524 | + | attr(out, "time") <- object$time |
|
525 | + | return(out) |
|
526 | + | } |
|
527 | + | ||
528 | + | lp <- new_tmb_obj$env$last.par.best |
|
321 | 529 | r <- new_tmb_obj$report(lp) |
|
530 | + | if (return_tmb_report) return(r) |
|
322 | 531 | ||
323 | 532 | if (isFALSE(pop_pred)) { |
|
324 | - | nd$est <- r$proj_eta |
|
325 | - | nd$est_non_rf <- r$proj_fe |
|
326 | - | nd$est_rf <- r$proj_rf |
|
327 | - | nd$omega_s <- r$proj_re_sp_st |
|
328 | - | nd$zeta_s <- r$proj_re_sp_slopes |
|
329 | - | nd$epsilon_st <- r$proj_re_st_vector |
|
533 | + | if (isTRUE(object$family$delta)) { |
|
534 | + | nd$est1 <- r$proj_eta[,1] |
|
535 | + | nd$est2 <- r$proj_eta[,2] |
|
536 | + | nd$est_non_rf1 <- r$proj_fe[,1] |
|
537 | + | nd$est_non_rf2 <- r$proj_fe[,2] |
|
538 | + | nd$est_rf1 <- r$proj_rf[,1] |
|
539 | + | nd$est_rf2 <- r$proj_rf[,2] |
|
540 | + | nd$omega_s1 <- r$proj_omega_s_A[,1] |
|
541 | + | nd$omega_s2 <- r$proj_omega_s_A[,2] |
|
542 | + | for (z in seq_len(dim(r$proj_zeta_s_A)[2])) { # SVC: |
|
543 | + | nd[[paste0("zeta_s_", object$spatial_varying[z], "1")]] <- r$proj_zeta_s_A[,z,1] |
|
544 | + | nd[[paste0("zeta_s_", object$spatial_varying[z], "2")]] <- r$proj_zeta_s_A[,z,2] |
|
545 | + | } |
|
546 | + | nd$epsilon_st1 <- r$proj_epsilon_st_A_vec[,1] |
|
547 | + | nd$epsilon_st2 <- r$proj_epsilon_st_A_vec[,2] |
|
548 | + | if (type == "response") { |
|
549 | + | nd$est1 <- object$family[[1]]$linkinv(nd$est1) |
|
550 | + | nd$est2 <- object$family[[2]]$linkinv(nd$est2) |
|
551 | + | if (object$tmb_data$poisson_link_delta) { |
|
552 | + | .n <- nd$est1 # expected group density (already exp()) |
|
553 | + | .p <- 1 - exp(-.n) # expected encounter rate |
|
554 | + | .w <- nd$est2 # expected biomass per group (already exp()) |
|
555 | + | .r <- (.n * .w) / .p # (n * w)/p # positive expectation |
|
556 | + | nd$est1 <- .p # expected encounter rate |
|
557 | + | nd$est2 <- .r # positive expectation |
|
558 | + | nd$est <- .n * .w # expected combined value |
|
559 | + | } else { |
|
560 | + | nd$est <- nd$est1 * nd$est2 |
|
561 | + | } |
|
562 | + | } |
|
563 | + | } else { |
|
564 | + | nd$est <- r$proj_eta[,1] |
|
565 | + | nd$est_non_rf <- r$proj_fe[,1] |
|
566 | + | nd$est_rf <- r$proj_rf[,1] |
|
567 | + | nd$omega_s <- r$proj_omega_s_A[,1] |
|
568 | + | for (z in seq_len(dim(r$proj_zeta_s_A)[2])) { # SVC: |
|
569 | + | nd[[paste0("zeta_s_", object$spatial_varying[z])]] <- r$proj_zeta_s_A[,z,1] |
|
570 | + | } |
|
571 | + | nd$epsilon_st <- r$proj_epsilon_st_A_vec[,1] |
|
572 | + | if (type == "response") { |
|
573 | + | nd$est <- object$family$linkinv(nd$est) |
|
574 | + | } |
|
575 | + | } |
|
330 | 576 | } |
|
331 | 577 | ||
332 | 578 | nd$sdm_spatial_id <- NULL |
|
333 | 579 | nd$sdm_orig_id <- NULL |
|
334 | 580 | ||
335 | 581 | obj <- new_tmb_obj |
|
336 | 582 | ||
583 | + | if ("visreg_model" %in% names(object)) { |
|
584 | + | model <- object$visreg_model |
|
585 | + | } else { |
|
586 | + | model <- 1L |
|
587 | + | } |
|
588 | + | ||
337 | 589 | if (se_fit) { |
|
338 | 590 | sr <- TMB::sdreport(new_tmb_obj, bias.correct = FALSE) |
|
339 | - | ssr <- summary(sr, "report") |
|
591 | + | sr_est_rep <- as.list(sr, "Estimate", report = TRUE) |
|
592 | + | sr_se_rep <- as.list(sr, "Std. Error", report = TRUE) |
|
340 | 593 | if (pop_pred) { |
|
341 | - | proj_eta <- ssr[row.names(ssr) == "proj_fe", , drop = FALSE] |
|
594 | + | proj_eta <- sr_est_rep[["proj_fe"]] |
|
595 | + | se <- sr_se_rep[["proj_fe"]] |
|
342 | 596 | } else { |
|
343 | - | proj_eta <- ssr[row.names(ssr) == "proj_eta", , drop = FALSE] |
|
597 | + | proj_eta <- sr_est_rep[["proj_eta"]] |
|
598 | + | se <- sr_se_rep[["proj_eta"]] |
|
599 | + | } |
|
600 | + | proj_eta <- proj_eta[,model,drop=TRUE] |
|
601 | + | se <- se[,model,drop=TRUE] |
|
602 | + | nd$est <- proj_eta |
|
603 | + | nd$est_se <- se |
|
604 | + | } |
|
605 | + | ||
606 | + | if (pop_pred) { |
|
607 | + | if (!se_fit) { |
|
608 | + | nd$est <- r$proj_fe[,model,drop=TRUE] # FIXME re_form_iid?? |
|
344 | 609 | } |
|
345 | - | row.names(proj_eta) <- NULL |
|
346 | - | d <- as.data.frame(proj_eta) |
|
347 | - | names(d) <- c("est", "se") |
|
348 | - | nd$est <- d$est |
|
349 | - | nd$est_se <- d$se |
|
610 | + | } |
|
611 | + | ||
612 | + | orig_dat <- object$tmb_data$y_i |
|
613 | + | if (model == 2L && nrow(nd) == nrow(orig_dat) && visreg_df) { |
|
614 | + | nd <- nd[!is.na(orig_dat[,2]),,drop=FALSE] # drop NAs from delta positive component |
|
350 | 615 | } |
|
351 | 616 | ||
352 | 617 | if ("sdmTMB_fake_year" %in% names(nd)) { |
362 | 627 | ||
363 | 628 | } else { # We are not dealing with new data: |
|
364 | 629 | if (se_fit) { |
|
365 | - | warning("Standard errors have not been implemented yet unless you ", |
|
630 | + | cli_warn(paste0("Standard errors have not been implemented yet unless you ", |
|
366 | 631 | "supply `newdata`. In the meantime you could supply your original data frame ", |
|
367 | - | "to the `newdata` argument.", call. = FALSE) |
|
632 | + | "to the `newdata` argument.")) |
|
633 | + | } |
|
634 | + | if (isTRUE(object$family$delta)) { |
|
635 | + | cli_abort(c("Delta model prediction not implemented for `newdata = NULL` yet.", |
|
636 | + | "Please provide your data to `newdata`.")) |
|
368 | 637 | } |
|
369 | 638 | nd <- object$data |
|
370 | - | lp <- object$tmb_obj$env$last.par |
|
639 | + | lp <- object$tmb_obj$env$last.par.best |
|
371 | 640 | # object$tmb_obj$fn(lp) # call once to update internal structures? |
|
372 | 641 | r <- object$tmb_obj$report(lp) |
|
373 | 642 | ||
374 | - | nd$est <- r$eta_i |
|
375 | - | # Following is not an error: rw effects baked into fixed effects for new data in above code: |
|
376 | - | nd$est_non_rf <- r$eta_fixed_i + r$eta_rw_i |
|
377 | - | nd$est_rf <- r$omega_s_A + r$epsilon_st_A_vec + r$omega_s_trend_A |
|
378 | - | nd$omega_s <- r$omega_s_A |
|
379 | - | nd$zeta_s <- r$omega_s_trend_A |
|
380 | - | nd$epsilon_st <- r$epsilon_st_A_vec |
|
643 | + | nd$est <- r$eta_i[,1] # DELTA FIXME |
|
644 | + | # The following is not an error, |
|
645 | + | # IID and RW effects are baked into fixed effects for `newdata` in above code: |
|
646 | + | nd$est_non_rf <- r$eta_fixed_i[,1] + r$eta_rw_i[,1] + r$eta_iid_re_i[,1] # DELTA FIXME |
|
647 | + | nd$est_rf <- r$omega_s_A[,1] + r$epsilon_st_A_vec[,1] # DELTA FIXME |
|
648 | + | if (!is.null(object$spatial_varying_formula)) |
|
649 | + | cli_abort(c("Prediction with `newdata = NULL` is not supported with spatially varying coefficients yet.", |
|
650 | + | "Please provide your data to `newdata`.")) |
|
651 | + | # + r$zeta_s_A |
|
652 | + | nd$omega_s <- r$omega_s_A[,1]# DELTA FIXME |
|
653 | + | # for (z in seq_len(dim(r$zeta_s_A)[2])) { # SVC: |
|
654 | + | # nd[[paste0("zeta_s_", object$spatial_varying[z])]] <- r$zeta_s_A[,z,1] |
|
655 | + | # } |
|
656 | + | nd$epsilon_st <- r$epsilon_st_A_vec[,1]# DELTA FIXME |
|
381 | 657 | obj <- object |
|
382 | 658 | } |
|
383 | 659 | ||
384 | - | if (return_tmb_object) |
|
385 | - | return(list(data = nd, report = r, obj = obj, fit_obj = object)) |
|
386 | - | else |
|
387 | - | return(nd) |
|
660 | + | # clean up: |
|
661 | + | if (!object$tmb_data$include_spatial) { |
|
662 | + | nd$omega_s1 <- NULL |
|
663 | + | nd$omega_s2 <- NULL |
|
664 | + | nd$omega_s <- NULL |
|
665 | + | } |
|
666 | + | if (as.logical(object$tmb_data$spatial_only)[1]) { |
|
667 | + | nd$epsilon_st1 <- NULL |
|
668 | + | nd$epsilon_st <- NULL |
|
669 | + | } |
|
670 | + | if (isTRUE(object$family$delta)) { |
|
671 | + | if (as.logical(object$tmb_data$spatial_only)[2]) { |
|
672 | + | nd$epsilon_st2 <- NULL |
|
673 | + | } |
|
674 | + | } |
|
675 | + | if (!object$tmb_data$spatial_covariate) { |
|
676 | + | nd$zeta_s1 <- NULL |
|
677 | + | nd$zeta_s1 <- NULL |
|
678 | + | nd$zeta_s <- NULL |
|
679 | + | } |
|
680 | + | ||
681 | + | if (return_tmb_object) { |
|
682 | + | return(list(data = nd, report = r, obj = obj, fit_obj = object, pred_tmb_data = tmb_data)) |
|
683 | + | } else { |
|
684 | + | if (visreg_df) { |
|
685 | + | # for visreg & related, return consistent objects with lm(), gam() etc. |
|
686 | + | if (isTRUE(se_fit)) { |
|
687 | + | return(list(fit = nd$est, se.fit = nd$est_se)) |
|
688 | + | } else { |
|
689 | + | return(nd$est) |
|
690 | + | } |
|
691 | + | } else { |
|
692 | + | return(nd) # data frame by default |
|
693 | + | } |
|
694 | + | } |
|
388 | 695 | } |
|
389 | 696 | ||
390 | 697 | # https://stackoverflow.com/questions/13217322/how-to-reliably-get-dependent-variable-name-from-formula-object |
394 | 701 | response <- attr(tt, "response") # index of response var |
|
395 | 702 | vars[response] |
|
396 | 703 | } |
|
704 | + | ||
705 | + | remove_9000 <- function(x) { |
|
706 | + | as.package_version(paste0( |
|
707 | + | strsplit(as.character(x), ".", fixed = TRUE)[[1]][1:3], |
|
708 | + | collapse = "." |
|
709 | + | )) |
|
710 | + | } |
|
711 | + | ||
712 | + | check_sdmTMB_version <- function(version) { |
|
713 | + | if (remove_9000(utils::packageVersion("sdmTMB")) > |
|
714 | + | remove_9000(version)) { |
|
715 | + | msg <- paste0( |
|
716 | + | "The installed version of sdmTMB is newer than the version ", |
|
717 | + | "that was used to fit this model. It is possible new parameters ", |
|
718 | + | "have been added to the TMB model since you fit this model and ", |
|
719 | + | "that prediction will fail. We recommend you fit and predict ", |
|
720 | + | "from an sdmTMB model with the same version." |
|
721 | + | ) |
|
722 | + | cli_warn(msg) |
|
723 | + | } |
|
724 | + | } |
|
725 | + | ||
726 | + | check_time_class <- function(object, newdata) { |
|
727 | + | cls1 <- class(object$data[[object$time]]) |
|
728 | + | cls2 <- class(newdata[[object$time]]) |
|
729 | + | if (!identical(cls1, cls2)) { |
|
730 | + | if (!identical(sort(c(cls1, cls2)), c("integer", "numeric"))) { |
|
731 | + | msg <- paste0( |
|
732 | + | "Class of fitted time column (", cls1, ") does not match class of ", |
|
733 | + | "`newdata` time column (", cls2 ,")." |
|
734 | + | ) |
|
735 | + | cli_abort(msg) |
|
736 | + | } |
|
737 | + | } |
|
738 | + | } |
|
739 | + | ||
740 | + | check_visreg <- function(sys_calls) { |
|
741 | + | visreg_df <- FALSE |
|
742 | + | re_form <- NULL |
|
743 | + | se_fit <- FALSE |
|
744 | + | if (any(grepl("setupV", substr(sys_calls, 1, 7)))) { |
|
745 | + | visreg_df <- TRUE |
|
746 | + | re_form <- NA |
|
747 | + | if (any(sys_calls == "residuals(fit)")) visreg_df <- FALSE |
|
748 | + | # turn on standard error if in a function call |
|
749 | + | indx <- which(substr(sys_calls, 1, 10) == "visregPred") |
|
750 | + | if (length(indx) > 0 && any(unlist(strsplit(sys_calls[indx], ",")) == " se.fit = TRUE")) |
|
751 | + | se_fit <- TRUE |
|
752 | + | } |
|
753 | + | named_list(visreg_df, se_fit, re_form) |
|
754 | + | } |
1 | + | # from brms:::rm_wsp() |
|
2 | + | rm_wsp <- function (x) { |
|
3 | + | out <- gsub("[ \t\r\n]+", "", x, perl = TRUE) |
|
4 | + | dim(out) <- dim(x) |
|
5 | + | out |
|
6 | + | } |
|
7 | + | # from brms:::all_terms() |
|
8 | + | all_terms <- function (x) { |
|
9 | + | if (!length(x)) { |
|
10 | + | return(character(0)) |
|
11 | + | } |
|
12 | + | if (!inherits(x, "terms")) { |
|
13 | + | x <- terms(stats::as.formula(x)) |
|
14 | + | } |
|
15 | + | rm_wsp(attr(x, "term.labels")) |
|
16 | + | } |
|
17 | + | ||
18 | + | get_smooth_terms <- function(terms) { |
|
19 | + | x1 <- grep("s\\(", terms) |
|
20 | + | x2 <- grep("t2\\(", terms) |
|
21 | + | if (length(x2) > 0L) |
|
22 | + | cli_abort("t2() smoothers are not yet supported due to issues with prediction on newdata.") |
|
23 | + | x1 |
|
24 | + | } |
|
25 | + | ||
26 | + | parse_smoothers <- function(formula, data, newdata = NULL) { |
|
27 | + | terms <- all_terms(formula) |
|
28 | + | if (!is.null(newdata)) { |
|
29 | + | if (any(grepl("t2\\(", terms))) cli_abort("Prediction on newdata with t2() still has issues.") |
|
30 | + | } |
|
31 | + | smooth_i <- get_smooth_terms(terms) |
|
32 | + | basis <- list() |
|
33 | + | Zs <- list() |
|
34 | + | Xs <- list() |
|
35 | + | if (length(smooth_i) > 0) { |
|
36 | + | has_smooths <- TRUE |
|
37 | + | smterms <- terms[smooth_i] |
|
38 | + | ns <- 0 |
|
39 | + | ns_Xf <- 0 |
|
40 | + | for (i in seq_along(smterms)) { |
|
41 | + | if (grepl('bs\\=\\"re', smterms[i])) stop("Error: bs = 're' is not currently supported for smooths") |
|
42 | + | obj <- eval(str2expression(smterms[i])) |
|
43 | + | basis[[i]] <- mgcv::smoothCon( |
|
44 | + | object = obj, data = data, |
|
45 | + | knots = NULL, absorb.cons = TRUE, |
|
46 | + | diagonal.penalty = TRUE |
|
47 | + | ) |
|
48 | + | for (j in seq_along(basis[[i]])) { # elements > 1 with `by` terms |
|
49 | + | ns_Xf <- ns_Xf + 1 |
|
50 | + | rasm <- mgcv::smooth2random(basis[[i]][[j]], names(data), type = 2) |
|
51 | + | if (!is.null(newdata)) { |
|
52 | + | rasm <- s2rPred(basis[[i]][[j]], rasm, data = newdata) |
|
53 | + | } |
|
54 | + | for (k in seq_along(rasm$rand)) { # elements > 1 with if s(x, y) or t2() |
|
55 | + | ns <- ns + 1 |
|
56 | + | Zs[[ns]] <- rasm$rand[[k]] |
|
57 | + | } |
|
58 | + | Xs[[ns_Xf]] <- rasm$Xf |
|
59 | + | } |
|
60 | + | } |
|
61 | + | sm_dims <- unlist(lapply(Zs, ncol)) |
|
62 | + | Xs <- do.call(cbind, Xs) # combine 'em all into one design matrix |
|
63 | + | b_smooth_start <- c(0, cumsum(sm_dims)[-length(sm_dims)]) |
|
64 | + | } else { |
|
65 | + | has_smooths <- FALSE |
|
66 | + | sm_dims <- 0L |
|
67 | + | b_smooth_start <- 0L |
|
68 | + | Xs <- matrix(nrow = 0L, ncol = 0L) |
|
69 | + | } |
|
70 | + | list(Xs = Xs, Zs = Zs, has_smooths = has_smooths, |
|
71 | + | sm_dims = sm_dims, b_smooth_start = b_smooth_start) |
|
72 | + | } |
|
73 | + | ||
74 | + | # from mgcv docs ?mgcv::smooth2random |
|
75 | + | s2rPred <- function(sm, re, data) { |
|
76 | + | ## Function to aid prediction from smooths represented as type==2 |
|
77 | + | ## random effects. re must be the result of smooth2random(sm,...,type=2). |
|
78 | + | X <- mgcv::PredictMat(sm, data) ## get prediction matrix for new data |
|
79 | + | ## transform to r.e. parameterization |
|
80 | + | if (!is.null(re$trans.U)) { |
|
81 | + | X <- X %*% re$trans.U |
|
82 | + | } |
|
83 | + | X <- t(t(X) * re$trans.D) |
|
84 | + | ## re-order columns according to random effect re-ordering... |
|
85 | + | X[, re$rind] <- X[, re$pen.ind != 0] |
|
86 | + | ## re-order penalization index in same way |
|
87 | + | pen.ind <- re$pen.ind |
|
88 | + | pen.ind[re$rind] <- pen.ind[pen.ind > 0] |
|
89 | + | ## start return object... |
|
90 | + | r <- list(rand = list(), Xf = X[, which(re$pen.ind == 0), drop = FALSE]) |
|
91 | + | for (i in seq_along(re$rand)) { ## loop over random effect matrices |
|
92 | + | r$rand[[i]] <- X[, which(pen.ind == i), drop = FALSE] |
|
93 | + | attr(r$rand[[i]], "s.label") <- attr(re$rand[[i]], "s.label") |
|
94 | + | } |
|
95 | + | names(r$rand) <- names(re$rand) |
|
96 | + | r |
|
97 | + | } |
|
98 | + | ## use function to obtain prediction random and fixed effect matrices |
|
99 | + | ## for first 10 elements of 'dat'. Then confirm that these match the |
|
100 | + | ## first 10 rows of the original model matrices, as they should... |
|
101 | + | # r <- s2rPred(sm,re,dat[1:10,]) |
|
102 | + | # range(r$Xf-re$Xf[1:10,]) |
|
103 | + | # range(r$rand[[1]]-re$rand[[1]][1:10,]) |
Learn more Showing 20 files with coverage changes found.
R/check.R
R/tmb-sim.R
R/print.R
R/smoothers.R
R/stacking.R
R/tidy.R
R/crs.R
R/stan.R
R/zzz.R
R/visreg.R
R/priors.R
R/utils.R
R/plot-pc-matern.R
R/gather-spread.R
src/utils.h
R/get-index-sims.R
src/sdmTMB.cpp
R/predict.R
R/cross-val.R
R/residuals.R
39a1e97
b107e5a
1e8d0a5
3864ceb
4cd59e5
7501410
bac7d86
8c94d23
71232e1
2ea5df3
9201f6c
11d4d45
c5b7bc2
e0d3905
5919246
1ccba11
99d6161
0c24dcd
d428e52
38cde3a
0e434df
b1ec526
25c80e2
b9dcec0
3b135c8
63d93a3
6eed373
7884d77
9849cac
db969bd
6e26a5b
f02cb1b
8c9abc1
e74013c
32293ab
c6d4aaf
72556d3
51fc95f
60d205c
e4d6757
090132b
9d4e96d
0fa6c3f
449cc1c
8b5ad2c
1689a0b
6296d0e
927fdad
330830f
fe232da
57fdc12
ab512b1
8a87414
67a488f
fdd1feb
f7c70cd
7528173
07242f4
35f24f0
2b1a5fb
1ffe02a
6f63855
f922743
50c0e3c
b6a31be
d126dfa
d3330e4
27337a2
eae1c59
bdb8197
3af447f
edc8d2f
8e17122
3afeb05
154fa29
9368467
6e0cb00
dde2138
cdb68ad
82ab8ac
5ea525a
c9a4d87
9ae822d
780ec7c
868b4c7
b20ac83
9205456
bdf7f0b
426076d
12aff21
df3f6c9
24f7dab
48ac91e
eff3512
d51723d
136049e
f5e1794
f319e3e
525d5c4
ff95188
97ff4e6
1f05158
126b9a1
2702c8d
1c232b1
bc87fd8
cec6f5f
48dd089
e558719
f05bf3e
9244a74
397214d
624027f
470861b
72e6766
1b558f9
bd8b524
f01452d
1d4aedd
9854c20
373ceda
e5aafb4
78e68e3
ad76963
a5f167b
8a92992
b48441e
e77d897
3f36dab
80f19de
9e0b91a
6dab916
2c927e5
d3efb86
f2e6a18
ae4cf5f
7d69d03
9c418e9
87acb44
aff40d5
9af905d
cf179d7
99f533c
aacf4ca
509e593
8b669c6
712c4b9
e0a6a95
6f23693
a333b69
828b802
e712f04
6dd62cb
11fcc4b
2f70255
66abd71
beb4894
6d2eb3b
1a1c869
203e020
d3fbd16
3d0edd3
16569bf
f862810
43dbe50
dd95c7d
004148c
1356519
625b4f3
91990ae
82994b5
2fe996d
e2e9bf0
345e06b
6677dce
f57d17d
ad6a750
ae6d0c4
564353c
60c1e6e
0cda622
0781410
e8e07c9
5de9352
ed04e84
6f71d23
bf3a291
af2225b
314d4fe
98a2e85
686e58e
c7ad3ca
c695fbf
1167f6f
b10b3fd
aaceb8b
379681d
0429010
7ac5bca
2c2d43a
4041bea
f376a4a
994a018
c66d129
2dea45c
153e20d
212c89d
8781208
cf4906a
2d667f8
14a545f
f7bddb2
6403fbe
bbb521f
8ae37c6
8e79627
c292c62
be8fd0a
a48137e
9540541
166b915
b521200
93e12da
203a572
9f5d88e
c104df7
b71d331
fbe3293
a93c773
323ba73
637f4f4
5a7027c
714f526
1d01eaa
3ed483a
84e224a
3322093
b157e06
d621edd
ae6ce46
c7ee9ae
04ea682
6e970a1
b42ce5d
2ad8eea
bbaa7ca