greta-dev / greta
Showing 29 of 78 files from the diff.
Newly tracked file
R/simulate.R created.
Newly tracked file
R/chol2symm.R created.
Other files ignored by Codecov
man/greta.Rd has changed.
man/inference.Rd has changed.
man/transforms.Rd has changed.
LICENSE has changed.
man/calculate.Rd has changed.
man/functions.Rd has changed.
man/as_data.Rd has changed.
R/overloaded.R has changed.
man/mixture.Rd has changed.
.lintr has changed.
.Rbuildignore has changed.
man/overloaded.Rd has changed.
codemeta.json has changed.
man/variable.Rd has changed.
man/model.Rd has changed.
inst/WORDLIST has changed.
man/joint.Rd has changed.
man/structures.Rd has changed.
NAMESPACE has changed.
LICENSE.md is new.
man/samplers.Rd has changed.
DESCRIPTION has changed.
man/optimisers.Rd has changed.
man/operators.Rd has changed.
R/internals.R has changed.
NEWS.md has changed.
man/internals.Rd has changed.
.travis.yml has changed.

@@ -44,7 +44,8 @@
Loading
44 44
45 45
  if (!is.numeric(pb_update) || length(pb_update) != 1 ||
46 46
      !is.finite(pb_update) || pb_update <= 0) {
47 -
    stop("pb_update must be a finite, positive, scalar integer")
47 +
    stop("pb_update must be a finite, positive, scalar integer",
48 +
         call. = FALSE)
48 49
  }
49 50
50 51
  assign("pb_update", pb_update, envir = pb$.__enclos_env__)

@@ -3,12 +3,12 @@
Loading
3 3
#' @description define an object in an R session as a data greta array for use
4 4
#'   as data in a greta model.
5 5
#' @param x an R object that can be coerced to a greta_array (see details).
6 -
#' @details \code{as_data()} can currently convert R objects to greta_arrays if
6 +
#' @details `as_data()` can currently convert R objects to greta_arrays if
7 7
#'   they are numeric or logical vectors, matrices or arrays; or if they are
8 8
#'   dataframes with only numeric (including integer) or logical elements.
9 9
#'   Logical elements are always converted to numerics. R objects cannot be
10 -
#'   converted if they contain missing (\code{NA}) or infinite (\code{-Inf} or
11 -
#'   \code{Inf}) values.
10 +
#'   converted if they contain missing (`NA`) or infinite (`-Inf` or
11 +
#'   `Inf`) values.
12 12
#' @export
13 13
#' @examples
14 14
#' \dontrun{
@@ -38,10 +38,10 @@
Loading
38 38
39 39
40 40
# if it's already a *data* greta_array fine, else error
41 -
# Begin Exclude Linting
41 +
# nolint start
42 42
#' @export
43 43
as_data.greta_array <- function(x) {
44 -
# End Exclude Linting
44 +
  # nolint end
45 45
  if (!inherits(get_node(x), "data_node")) {
46 46
    stop("cannot coerce a non-data greta_array to data",
47 47
         call. = FALSE)
@@ -50,9 +50,9 @@
Loading
50 50
}
51 51
52 52
# otherwise try to coerce to a greta array
53 -
# Begin Exclude Linting
53 +
# nolint start
54 54
#' @export
55 55
as_data.default <- function(x) {
56 -
# End Exclude Linting
56 +
  # nolint end
57 57
  as.greta_array(x)
58 58
}

@@ -0,0 +1,82 @@
Loading
1 +
#' @title Simulate Responses From `greta_model` Object
2 +
#'
3 +
#' @description Simulate values of all named greta arrays associated with a
4 +
#'   greta model from the model priors, including the response variable.
5 +
#'
6 +
#' @param object a [`greta_model()`][greta::model] object
7 +
#' @param nsim positive integer scalar - the number of responses to simulate
8 +
#' @param seed an optional seed to be used in set.seed immediately before the
9 +
#'   simulation so as to generate a reproducible sample
10 +
#' @param precision the floating point precision to use when calculating values.
11 +
#' @param ... optional additional arguments, none are used at present
12 +
#'
13 +
#' @details This is essentially a wrapper around [calculate()] that
14 +
#'   finds all relevant greta arrays. See that function for more functionality,
15 +
#'   including simulation conditional on fixed values or posterior samples.
16 +
#'
17 +
#'   To simulate values of the response variable, it must be both a named object
18 +
#'   (in the calling environment) and be a greta array. If you don't see it
19 +
#'   showing up in the output, you may need to use `as_data` to convert it
20 +
#'   to a greta array before defining the model.
21 +
#'
22 +
#' @return A named list of vectors, matrices or arrays containing independent
23 +
#'   samples of the greta arrays associated with the model. The number of
24 +
#'   samples will be prepended as the first dimension of the greta array, so
25 +
#'   that a vector of samples is returned for each scalar greta array, and a
26 +
#'   matrix is returned for each vector greta array, etc.
27 +
#'
28 +
#' @importFrom stats simulate
29 +
#' @export
30 +
#'
31 +
#' @examples
32 +
#' \dontrun{
33 +
#' # build a greta model
34 +
#' n <- 10
35 +
#' y <- rnorm(n)
36 +
#' y <- as_data(y)
37 +
#'
38 +
#' library(greta)
39 +
#' sd <- lognormal(1, 2)
40 +
#' mu <- normal(0, 1, dim = n)
41 +
#' distribution(y) <- normal(mu, sd)
42 +
#' m <- model(mu, sd)
43 +
#'
44 +
#' # simulate one random draw of y, mu and sd from the model prior:
45 +
#' sims <- simulate(m)
46 +
#'
47 +
#' # 100 simulations of y, mu and sd
48 +
#' sims <- simulate(m, nsim = 100)
49 +
#'
50 +
#' }
51 +
# nolint start
52 +
simulate.greta_model <- function(
53 +
  object,
54 +
  nsim = 1,
55 +
  seed = NULL,
56 +
  precision = c("double", "single"),
57 +
  ...
58 +
) {
59 +
  # nolint end
60 +
  # find all the greta arrays in the calling environment
61 +
  target_greta_arrays <- all_greta_arrays(parent.frame())
62 +
63 +
  # subset these to only those that are associated with the model
64 +
  target_nodes <- lapply(target_greta_arrays, get_node)
65 +
  target_node_names <- vapply(target_nodes,
66 +
                              member,
67 +
                              "unique_name",
68 +
                              FUN.VALUE = character(1))
69 +
  object_node_names <- vapply(object$dag$node_list,
70 +
                              member,
71 +
                              "unique_name",
72 +
                              FUN.VALUE = character(1))
73 +
  keep <- target_node_names %in% object_node_names
74 +
  target_greta_arrays <- target_greta_arrays[keep]
75 +
76 +
  other_args <- list(precision = precision,
77 +
                     nsim = nsim,
78 +
                     seed = seed)
79 +
80 +
  do.call(calculate, c(target_greta_arrays, other_args))
81 +
82 +
}

@@ -2,17 +2,17 @@
Loading
2 2
#' @aliases distribution
3 3
#' @title define a distribution over data
4 4
#'
5 -
#' @description \code{distribution} defines probability distributions over
5 +
#' @description `distribution` defines probability distributions over
6 6
#'   observed data, e.g. to set a model likelihood.
7 7
#'
8 8
#' @param greta_array a data greta array. For the assignment method it must not
9 9
#'   already have a probability distribution assigned
10 10
#'
11 11
#' @param value a greta array with a distribution (see
12 -
#'   \code{\link{distributions}})
12 +
#'   [distributions()])
13 13
#'
14 14
#' @details The extract method returns the greta array if it has a distribution,
15 -
#'   or \code{NULL} if it doesn't. It has no real use-case, but is included for
15 +
#'   or `NULL` if it doesn't. It has no real use-case, but is included for
16 16
#'   completeness
17 17
#'
18 18
#' @export
@@ -33,7 +33,7 @@
Loading
33 33
#' # get the distribution over y
34 34
#' distribution(y)
35 35
#' }
36 -
`distribution<-` <- function(greta_array, value) {  # Exclude Linting
36 +
`distribution<-` <- function(greta_array, value) {  # nolint
37 37
38 38
  # stash the old greta array to return
39 39
  greta_array_tmp <- greta_array
@@ -44,7 +44,7 @@
Loading
44 44
  node <- get_node(greta_array)
45 45
46 46
  # only for greta arrays without distributions
47 -
  if (!is.null(node$distribution)) {
47 +
  if (has_distribution(node)) {
48 48
    stop("left hand side already has a distribution assigned",
49 49
         call. = FALSE)
50 50
  }
@@ -93,6 +93,10 @@
Loading
93 93
  distribution_node$remove_target()
94 94
  distribution_node$add_target(node)
95 95
96 +
  # if possible, expand the dimensions of the distribution's parameters to match
97 +
  # the target
98 +
  distribution_node$expand_parameters_to(node$dim)
99 +
96 100
  # remove the distribution from the RHS variable greta array
97 101
  value_node$distribution <- NULL
98 102

@@ -3,31 +3,57 @@
Loading
3 3
#' @name calculate
4 4
#' @title calculate greta arrays given fixed values
5 5
#' @description Calculate the values that greta arrays would take, given
6 -
#'   temporary values for the greta arrays on which they depend, and return them
7 -
#'   as numeric R arrays. This can be used to check the behaviour of your model
8 -
#'   or make predictions to new data after model fitting.
6 +
#'   temporary, or simulated values for the greta arrays on which they depend.
7 +
#'   This can be used to check the behaviour of your model, make predictions to
8 +
#'   new data after model fitting, or simulate datasets from either the prior or
9 +
#'   posterior of your model.
9 10
#'
10 -
#' @param target a greta array for which to calculate the value
11 +
#' @param ... one or more greta_arrays for which to calculate the value
11 12
#' @param values a named list giving temporary values of the greta arrays with
12 -
#'   which \code{target} is connected, or a \code{greta_mcmc_list} object
13 -
#'   returned by \code{\link{mcmc}}.
13 +
#'   which `target` is connected, or a `greta_mcmc_list` object
14 +
#'   returned by [mcmc()].
15 +
#' @param nsim an optional positive integer scalar for the number of responses
16 +
#'   to simulate if stochastic greta arrays are present in the model - see
17 +
#'   Details.
18 +
#' @param seed an optional seed to be used in set.seed immediately before the
19 +
#'   simulation so as to generate a reproducible sample
14 20
#' @param precision the floating point precision to use when calculating values.
15 21
#' @param trace_batch_size the number of posterior samples to process at a time
16 -
#'   when \code{target} is a \code{greta_mcmc_list} object; reduce this to
22 +
#'   when `target` is a `greta_mcmc_list` object; reduce this to
17 23
#'   reduce memory demands
18 24
#'
19 -
#' @return A numeric R array with the same dimensions as \code{target}, giving
20 -
#'   the values it would take conditioned on the fixed values given by
21 -
#'   \code{values}.
25 +
#' @return Values of the target greta array(s), given values of the greta arrays
26 +
#'   on which they depend (either specified in `values` or sampled from
27 +
#'   their priors). If `values` is a
28 +
#'   [`greta_mcmc_list()`][greta::mcmc] and `nsim = NULL`, this will
29 +
#'   be a `greta_mcmc_list` object of posterior samples for the target
30 +
#'   greta arrays. Otherwise, the result will be a named list of numeric R
31 +
#'   arrays. If `nsim = NULL` the dimensions of returned numeric R arrays
32 +
#'   will be the same as the corresponding greta arrays, otherwise an additional
33 +
#'   dimension with `nsim` elements will be prepended, to represent
34 +
#'   multiple simulations.
22 35
#'
23 -
#' @details The greta arrays named in \code{values} need not be variables, they
36 +
#' @details The greta arrays named in `values` need not be variables, they
24 37
#'   can also be other operations or even data.
25 38
#'
26 -
#'   At present, if \code{values} is a named list it must contain values for
27 -
#'   \emph{all} of the variable greta arrays with which \code{target} is
39 +
#'   At present, if `values` is a named list it must contain values for
40 +
#'   *all* of the variable greta arrays with which `target` is
28 41
#'   connected, even values are given for intermediate operations, or the target
29 42
#'   doesn't depend on the variable. That may be relaxed in a future release.
30 43
#'
44 +
#'   If the model contains stochastic greta arrays; those with a distribution,
45 +
#'   calculate can be used to sample from these distributions (and all greta
46 +
#'   arrays that depend on them) by setting the `nsim` argument to a
47 +
#'   positive integer for the required number of samples. If `values` is
48 +
#'   specified (either as a list of fixed values or as draws), those values will
49 +
#'   be used, and remaining variables will be sampled conditional on them.
50 +
#'   Observed data with distributions (i.e. response variables defined with
51 +
#'   `distribution()` can also be sampled, provided they are defined as
52 +
#'   greta arrays. This behaviour can be used for a number of tasks, like
53 +
#'   simulating datasets for known parameter sets, simulating parameters and
54 +
#'   data from a set of priors, or simulating datasets from a model posterior.
55 +
#'   See some examples of these below.
56 +
#'
31 57
#' @export
32 58
#'
33 59
#' @examples
@@ -38,7 +64,13 @@
Loading
38 64
#' x <- normal(0, 1, dim = 3)
39 65
#' a <- lognormal(0, 1)
40 66
#' y <- sum(x ^ 2) + a
41 -
#' calculate(y, list(x = c(0.1, 0.2, 0.3), a = 2))
67 +
#' calculate(y, values = list(x = c(0.1, 0.2, 0.3), a = 2))
68 +
#'
69 +
#' # by setting nsim, you can also sample values from their priors
70 +
#' calculate(y, nsim = 3)
71 +
#'
72 +
#' # you can combine sampling and fixed values
73 +
#' calculate(y, values = list(a = 2), nsim = 3)
42 74
#'
43 75
#' # if the greta array only depends on data,
44 76
#' # you can pass an empty list to values (this is the default)
@@ -50,24 +82,38 @@
Loading
50 82
#' alpha <- normal(0, 1)
51 83
#' beta <- normal(0, 1)
52 84
#' sigma <- lognormal(1, 0.1)
85 +
#' y <- as_data(iris$Petal.Width)
53 86
#' mu <- alpha + iris$Petal.Length * beta
54 -
#' distribution(iris$Petal.Width) <- normal(mu, sigma)
87 +
#' distribution(y) <- normal(mu, sigma)
55 88
#' m <- model(alpha, beta, sigma)
56 89
#'
57 -
#' # calculate intermediate greta arrays, given some parameter values
58 -
#' calculate(mu[1:5], list(alpha = 1, beta = 2, sigma = 0.5))
59 -
#' calculate(mu[1:5], list(alpha = -1, beta = 0.2, sigma = 0.5))
90 +
#' # sample values of the parameters, or different observation data (y), from
91 +
#' # the priors (useful for prior # predictive checking) - see also
92 +
#' # ?simulate.greta_model
93 +
#' calculate(alpha, beta, sigma, nsim = 100)
94 +
#' calculate(y, nsim = 100)
60 95
#'
96 +
#' # calculate intermediate greta arrays, given some parameter values (useful
97 +
#' # for debugging models)
98 +
#' calculate(mu[1:5], values = list(alpha = 1, beta = 2, sigma = 0.5))
99 +
#' calculate(mu[1:5], values = list(alpha = -1, beta = 0.2, sigma = 0.5))
61 100
#'
62 -
#' # fit the model then calculate samples at a new greta array
101 +
#' # simulate datasets given fixed parameter values
102 +
#' calculate(y, values = list(alpha = -1, beta = 0.2, sigma = 0.5), nsim = 10)
103 +
#'
104 +
#' # you can use calculate in conjunction with posterior samples from MCMC, e.g.
105 +
#' # sampling different observation datasets, given a random set of these
106 +
#' # posterior samples - useful for posterior predictive model checks
63 107
#' draws <- mcmc(m, n_samples = 500)
108 +
#' calculate(y, values = draws, nsim = 100)
109 +
#'
110 +
#' # you can use calculate on greta arrays created even after the inference on
111 +
#' # the model - e.g. to plot response curves
64 112
#' petal_length_plot <- seq(min(iris$Petal.Length),
65 113
#'                          max(iris$Petal.Length),
66 114
#'                          length.out = 100)
67 115
#' mu_plot <- alpha + petal_length_plot * beta
68 -
#' mu_plot_draws <- calculate(mu_plot, draws)
69 -
#'
70 -
#' # plot the draws
116 +
#' mu_plot_draws <- calculate(mu_plot, values = draws)
71 117
#' mu_est <- colMeans(mu_plot_draws[[1]])
72 118
#' plot(mu_est ~ petal_length_plot, type = "n",
73 119
#'      ylim = range(mu_plot_draws[[1]]))
@@ -78,123 +124,285 @@
Loading
78 124
#' # trace_batch_size can be changed to trade off speed against memory usage
79 125
#' # when calculating. These all produce the same result, but have increasing
80 126
#' # memory requirements:
81 -
#' mu_plot_draws_1 <- calculate(mu_plot, draws, trace_batch_size = 1)
82 -
#' mu_plot_draws_10 <- calculate(mu_plot, draws, trace_batch_size = 10)
83 -
#' mu_plot_draws_inf <- calculate(mu_plot, draws, trace_batch_size = Inf)
84 -
#' }
85 -
#'
127 +
#' mu_plot_draws_1 <- calculate(mu_plot,
128 +
#'                              values = draws,
129 +
#'                              trace_batch_size = 1)
130 +
#' mu_plot_draws_10 <- calculate(mu_plot,
131 +
#'                               values = draws,
132 +
#'                               trace_batch_size = 10)
133 +
#' mu_plot_draws_inf <- calculate(mu_plot,
134 +
#'                                values = draws,
135 +
#'                                trace_batch_size = Inf)
86 136
#'
87 -
calculate <- function(target, values = list(),
137 +
#' }
138 +
calculate <- function(...,
139 +
                      values = list(),
140 +
                      nsim = NULL,
141 +
                      seed = NULL,
88 142
                      precision = c("double", "single"),
89 143
                      trace_batch_size = 100) {
90 144
91 -
  target_name <- deparse(substitute(target))
145 +
  # turn the provided greta arrays into a list and try to find the names
146 +
  target <- list(...)
147 +
  names <- names(target)
148 +
149 +
  # see if any names are missing and try to fill them in
150 +
  if (is.null(names)) {
151 +
    names_missing <- rep(TRUE, length(target))
152 +
  } else {
153 +
    names_missing <- names == ""
154 +
  }
155 +
156 +
  if (any(names_missing)) {
157 +
    scraped_names <- substitute(list(...))[-1]
158 +
    missing_names <- vapply(scraped_names[names_missing], deparse, "")
159 +
    names[names_missing] <- missing_names
160 +
    names(target) <- names
161 +
  }
162 +
163 +
  # catch empty lists here, since check_greta_arrays assumes data greta arrays
164 +
  # have been stripped out
165 +
  if (identical(target, list())) {
166 +
    stop("no greta arrays to calculate were provided",
167 +
         call. = FALSE)
168 +
  }
169 +
170 +
  # check the inputs
171 +
  check_greta_arrays(target,
172 +
                     "calculate",
173 +
                     "\nPerhaps you forgot to explicitly name other arguments?")
174 +
175 +
  # checks and RNG seed setting if we're sampling
176 +
  if (!is.null(nsim)) {
177 +
178 +
    # check nsim is valid
179 +
    nsim <- check_positive_integer(nsim)
180 +
181 +
    # if an RNG seed was provided use it and reset the RNG on exiting
182 +
    if (!is.null(seed)) {
183 +
184 +
      if (!exists(".Random.seed", envir = .GlobalEnv, inherits = FALSE)) {
185 +
        runif(1)
186 +
      }
187 +
188 +
      r_seed <- get(".Random.seed", envir = .GlobalEnv)
189 +
      on.exit(assign(".Random.seed", r_seed, envir = .GlobalEnv))
190 +
      set.seed(seed)
191 +
192 +
    }
193 +
  }
194 +
195 +
  # set precision
92 196
  tf_float <- switch(match.arg(precision),
93 197
                     double = "float64",
94 198
                     single = "float32")
95 199
96 -
  if (!inherits(target, "greta_array"))
97 -
    stop("'target' is not a greta array")
98 -
99 200
  if (inherits(values, "greta_mcmc_list")) {
100 -
    calculate_greta_mcmc_list(
201 +
202 +
    result <- calculate_greta_mcmc_list(
101 203
      target = target,
102 -
      target_name = target_name,
103 204
      values = values,
205 +
      nsim = nsim,
104 206
      tf_float = tf_float,
105 207
      trace_batch_size = trace_batch_size
106 208
    )
209 +
210 +
  } else {
211 +
212 +
    result <- calculate_list(target = target,
213 +
                             values = values,
214 +
                             nsim = nsim,
215 +
                             tf_float = tf_float,
216 +
                             env = parent.frame())
217 +
107 218
  }
108 -
  else {
109 -
    calculate_list(target = target,
110 -
                   values = values,
111 -
                   tf_float = tf_float)
219 +
220 +
  if (!inherits(result, "greta_mcmc_list")) {
221 +
222 +
    # if it's not mcmc samples, make sure the results are in the right order
223 +
    # (tensorflow order seems to be platform specific?!?)
224 +
    order <- match(names(result), names(target))
225 +
    result <- result[order]
226 +
227 +
    # if the result wasn't mcmc samples or simulations, drop the batch dimension
228 +
    if (is.null(nsim)) {
229 +
      result <- lapply(result, drop_first_dim)
230 +
    }
231 +
112 232
  }
233 +
  result
113 234
114 235
}
115 236
116 237
#' @importFrom coda thin
238 +
#' @importFrom stats start end
117 239
calculate_greta_mcmc_list <- function(target,
118 -
                                      target_name,
119 240
                                      values,
241 +
                                      nsim,
120 242
                                      tf_float,
121 243
                                      trace_batch_size) {
122 244
245 +
  stochastic <- !is.null(nsim)
246 +
123 247
  # check trace_batch_size is valid
124 248
  trace_batch_size <- check_trace_batch_size(trace_batch_size)
125 249
250 +
  # get the free state draws and old dag from the samples
126 251
  model_info <- get_model_info(values)
252 +
  mcmc_dag <- model_info$model$dag
253 +
  draws <- model_info$raw_draws
254 +
255 +
  # build a new dag from the targets
256 +
  dag <- dag_class$new(target, tf_float = tf_float)
257 +
  dag$mode <- ifelse(stochastic, "hybrid", "all_forward")
258 +
259 +
  # rearrange the nodes in the dag so that any mcmc dag variables are first and
260 +
  # in the right order (otherwise the free state will be incorrectly defined)
261 +
  in_draws <- names(dag$node_list)  %in% names(mcmc_dag$node_list)
262 +
  order <- order(match(names(dag$node_list[in_draws]),
263 +
                       names(mcmc_dag$node_list)))
264 +
  dag$node_list <- c(dag$node_list[in_draws][order], dag$node_list[!in_draws])
265 +
266 +
  # find variable nodes in the new dag without a free state in the old one.
267 +
  mcmc_dag_variables <- mcmc_dag$node_list[mcmc_dag$node_types == "variable"]
268 +
  dag_variables <- dag$node_list[dag$node_types == "variable"]
269 +
  stateless_names <- setdiff(names(dag_variables), names(mcmc_dag_variables))
270 +
  dag$variables_without_free_state <- dag_variables[stateless_names]
271 +
272 +
  # check there's some commonality between the two dags
273 +
  connected_to_draws <- names(dag$node_list) %in% names(mcmc_dag$node_list)
274 +
  if (!any(connected_to_draws)) {
275 +
    stop("the target greta arrays do not appear to be connected ",
276 +
         "to those in the greta_mcmc_list object",
277 +
         call. = FALSE)
278 +
  }
127 279
128 -
  # copy and refresh the dag
129 -
  dag <- model_info$model$dag$clone()
130 -
  dag$new_tf_environment()
280 +
  # if they didn't specify nsim, check we can deterministically compute the
281 +
  # targets from the draws
282 +
  if (!stochastic) {
283 +
284 +
    # see if the new dag introduces any new variables
285 +
    new_types <- dag$node_types[!connected_to_draws]
286 +
    if (any(new_types == "variable")) {
287 +
      stop("the target greta arrays are related to new variables ",
288 +
           "that are not in the MCMC samples so cannot be calculated ",
289 +
           "from the samples alone. Set 'nsim' if you want to sample them ",
290 +
           "conditionally on the MCMC samples",
291 +
           call. = FALSE)
292 +
    }
131 293
132 -
  # set the precision in the dag
133 -
  dag$tf_float <- tf_float
294 +
    # see if any of the targets are stochastic and not sampled in the mcmc
295 +
    target_nodes <- lapply(target, get_node)
296 +
    target_node_names <- vapply(target_nodes,
297 +
                                member,
298 +
                                "unique_name",
299 +
                                FUN.VALUE = character(1))
300 +
    existing_variables <- target_node_names %in% names(mcmc_dag_variables)
301 +
    have_distributions <- vapply(target_nodes,
302 +
                                 has_distribution,
303 +
                                 FUN.VALUE = logical(1))
304 +
    new_stochastics <- have_distributions & !existing_variables
305 +
    if (any(new_stochastics)) {
306 +
      n_stoch <- sum(new_stochastics)
307 +
      stop("the greta array",
308 +
           ngettext(n_stoch, " ", "s "),
309 +
           paste(names(target)[new_stochastics], collapse = ", "),
310 +
           ngettext(n_stoch,
311 +
                    " has a distribution and is ",
312 +
                    " have distributions and are"),
313 +
           "not in the MCMC samples, so cannot be calculated ",
314 +
           "from the samples alone. Set 'nsim' if you want to sample them ",
315 +
           "conditionally on the MCMC samples",
316 +
           call. = FALSE)
317 +
    }
134 318
135 -
  # extend the dag to include this node, as the target
136 -
  dag$build_dag(list(target))
319 +
  }
137 320
138 -
  self <- dag  # mock for scoping
139 -
  self
140 321
  dag$define_tf()
141 322
142 -
  dag$target_nodes <- list(get_node(target))
143 -
  names(dag$target_nodes) <- target_name
323 +
  dag$target_nodes <- lapply(target, get_node)
324 +
  names(dag$target_nodes) <- names(target)
144 325
145 -
  param <- dag$example_parameters()
146 -
  param[] <- 0
326 +
  # if we're doing stochastic sampling, subsample the draws
327 +
  if (stochastic) {
328 +
329 +
    draws <- as.matrix(draws)
330 +
    n_samples <- nrow(draws)
331 +
332 +
    # if needed, sample with replacement and warn
333 +
    replace <- FALSE
334 +
    if (nsim > n_samples) {
335 +
      replace <- TRUE
336 +
      warning("nsim was greater than the number of posterior samples in ",
337 +
              "values, so posterior samples had to be drawn with replacement",
338 +
              call. = FALSE)
339 +
    }
340 +
341 +
    rows <- sample.int(n_samples, nsim, replace = replace)
342 +
    draws <- draws[rows, , drop = FALSE]
343 +
344 +
    # add the batch size to the data list
345 +
    dag$set_tf_data_list("batch_size", as.integer(nsim))
346 +
347 +
    # pass these values in as the free state
348 +
    trace <- dag$trace_values(draws,
349 +
                              trace_batch_size = trace_batch_size,
350 +
                              flatten = FALSE)
351 +
352 +
    # hopefully values is already a list of the correct dimensions...
147 353
148 -
  # raw draws are either an attribute, or this object
149 -
  model_info <- attr(values, "model_info")
150 -
  draws <- model_info$raw_draws
151 354
152 -
  # trace the target for each chain
153 -
  values <- lapply(draws, dag$trace_values, trace_batch_size = trace_batch_size)
154 -
155 -
  # convert to a greta_mcmc_list object, retaining windowing info
156 -
  trace <- lapply(
157 -
    values,
158 -
    coda::mcmc,
159 -
    start = start(draws),
160 -
    end = end(draws),
161 -
    thin = coda::thin(draws)
162 -
  )
163 -
  trace <- coda::mcmc.list(trace)
164 -
  trace <- as_greta_mcmc_list(trace, model_info)
355 +
  } else {
356 +
357 +
    # for deterministic posterior prediction, just trace the target for each
358 +
    # chain
359 +
    values <- lapply(draws,
360 +
                     dag$trace_values,
361 +
                     trace_batch_size = trace_batch_size)
362 +
363 +
    # convert to a greta_mcmc_list object, retaining windowing info
364 +
    trace <- lapply(
365 +
      values,
366 +
      coda::mcmc,
367 +
      start = stats::start(draws),
368 +
      end = stats::end(draws),
369 +
      thin = coda::thin(draws)
370 +
    )
371 +
    trace <- coda::mcmc.list(trace)
372 +
    trace <- as_greta_mcmc_list(trace, model_info)
373 +
  }
374 +
165 375
  trace
166 376
167 377
}
168 378
169 -
calculate_list <- function(target, values, tf_float) {
379 +
calculate_list <- function(target, values, nsim, tf_float, env) {
170 380
171 -
  # get the values and their names
172 -
  names <- names(values)
173 -
  stopifnot(length(names) == length(values))
381 +
  stochastic <- !is.null(nsim)
174 382
175 -
  # get the corresponding greta arrays
176 -
  fixed_greta_arrays <- lapply(names,
177 -
                               get,
178 -
                               envir = parent.frame(n = 2))
383 +
  fixed_greta_arrays <- list()
179 384
180 -
  # make sure that's what they are
181 -
  are_greta_arrays <- vapply(fixed_greta_arrays,
182 -
                             inherits,
183 -
                             "greta_array",
184 -
                             FUN.VALUE = FALSE)
385 +
  if (!identical(values, list())) {
185 386
186 -
  stopifnot(all(are_greta_arrays))
387 +
    # check the list of values makes sense, and return these and the
388 +
    # corresponding greta arrays (looked up by name in environment env)
389 +
    values_list <- check_values_list(values, env)
390 +
    fixed_greta_arrays <- values_list$fixed_greta_arrays
391 +
    values <- values_list$values
187 392
188 -
  # make sure the values have the correct dimensions
189 -
  values <- mapply(assign_dim,
190 -
                   values,
191 -
                   fixed_greta_arrays,
192 -
                   SIMPLIFY = FALSE)
193 -
194 -
  all_greta_arrays <- c(fixed_greta_arrays, list(target))
393 +
  }
195 394
395 +
  all_greta_arrays <- c(fixed_greta_arrays, target)
196 396
  # define the dag and TF graph
197 397
  dag <- dag_class$new(all_greta_arrays, tf_float = tf_float)
398 +
399 +
  # convert to nodes, and add tensor names to values
400 +
  fixed_nodes <- lapply(fixed_greta_arrays, get_node)
401 +
  names(values) <- vapply(fixed_nodes, dag$tf_name, FUN.VALUE = character(1))
402 +
403 +
  # change dag mode to sampling
404 +
  dag$mode <- "all_sampling"
405 +
198 406
  dag$define_tf()
199 407
  tfe <- dag$tf_environment
200 408
@@ -206,86 +414,67 @@
Loading
206 414
                          dag$tf_name,
207 415
                          FUN.VALUE = "")
208 416
209 -
  # check that there are no unspecified variables on which the target depends
210 -
211 -
  # find all the nodes depended on by this one
212 -
  dependencies <- get_node(target)$parent_names(recursive = TRUE)
213 -
214 -
  # find all the nodes depended on by the new values, and remove them from the
215 -
  # list
216 -
  complete_dependencies <- lapply(fixed_greta_arrays,
217 -
                                  function(x)
218 -
                                    get_node(x)$parent_names(recursive = TRUE))
219 -
  complete_dependencies <- unique(unlist(complete_dependencies))
220 -
221 -
  unmet_dependencies <- dependencies[!dependencies %in% complete_dependencies]
222 -
223 -
  # find all of the remaining nodes that are variables
224 -
  unmet_nodes <- dag$node_list[unmet_dependencies]
225 -
  is_variable <- vapply(unmet_nodes, node_type, FUN.VALUE = "") == "variable"
226 -
227 -
  # if there are any undefined variables
228 -
  if (any(is_variable)) {
229 -
230 -
    # try to find the associated greta arrays to provide a more informative
231 -
    # error message
232 -
    greta_arrays <- all_greta_arrays(parent.frame(2),
233 -
                                     include_data = FALSE)
234 -
235 -
    greta_array_node_names <- vapply(greta_arrays,
236 -
                                function(x) get_node(x)$unique_name,
237 -
                                FUN.VALUE = "")
238 -
239 -
    unmet_variables <- unmet_nodes[is_variable]
240 -
241 -
    matches <- names(unmet_variables) %in% greta_array_node_names
417 +
  # check we can do the calculation
418 +
  if (stochastic) {
419 +
420 +
    # check there are no variables without distributions (or whose children have
421 +
    # distributions - for lkj & wishart) that aren't given fixed values
422 +
    variables <- dag$node_list[dag$node_types == "variable"]
423 +
    have_distributions <- vapply(variables,
424 +
                                 has_distribution,
425 +
                                 FUN.VALUE = logical(1))
426 +
    any_child_has_distribution <- function(variable) {
427 +
      have_distributions <- vapply(variable$children,
428 +
                                   has_distribution,
429 +
                                   FUN.VALUE = logical(1))
430 +
      any(have_distributions)
431 +
    }
432 +
    children_have_distributions <- vapply(variables,
433 +
                                          any_child_has_distribution,
434 +
                                          FUN.VALUE = logical(1))
435 +
436 +
    unsampleable <- !have_distributions & !children_have_distributions
437 +
    fixed_node_names <- vapply(fixed_nodes,
438 +
                               member,
439 +
                               "unique_name",
440 +
                               FUN.VALUE = character(1))
441 +
    unfixed <- !names(variables) %in% fixed_node_names
442 +
443 +
    if (any(unsampleable & unfixed)) {
444 +
      stop("the target greta arrays are related to variables ",
445 +
           "that do not have distributions so cannot be sampled",
446 +
           call. = FALSE)
447 +
    }
242 448
449 +
  } else {
243 450
244 -
    unmet_names_idx <- greta_array_node_names %in% names(unmet_variables)
245 -
    unmet_names <- names(greta_array_node_names)[unmet_names_idx]
451 +
    # check there are no unspecified variables on which the target depends
452 +
    lapply(target, check_dependencies_satisfied, fixed_greta_arrays, dag, env)
246 453
247 -
    # build the message
248 -
    msg <- paste("values have not been provided for all greta arrays on which",
249 -
                 "the target depends.")
454 +
  }
250 455
251 -
    if (any(matches)) {
252 -
      names_text <- paste(unmet_names, collapse = ", ")
253 -
      msg <- paste(msg,
254 -
                   sprintf("Please provide values for the greta array%s: %s",
255 -
                           ifelse(length(matches) > 1, "s", ""),
256 -
                           names_text))
257 -
    } else {
258 -
      msg <- paste(msg,
259 -
                   "\nThe names of the missing greta arrays",
260 -
                   "could not be detected")
261 -
    }
456 +
  # look up the tf names of the target greta arrays (under sampling)
457 +
  # create an object in the environment that's a list of these, and sample that
458 +
  target_nodes <- lapply(target, get_node)
459 +
  target_names_list <- lapply(target_nodes, dag$tf_name)
460 +
  target_tensor_list <- lapply(target_names_list, get, envir = tfe)
461 +
  assign("calculate_target_tensor_list", target_tensor_list, envir = tfe)
262 462
263 -
    stop(msg,
264 -
         call. = FALSE)
265 -
  }
463 +
  # add the batch size to the data list
464 +
  batch_size <- ifelse(stochastic, as.integer(nsim), 1L)
465 +
  dag$set_tf_data_list("batch_size", batch_size)
266 466
267 467
  # add values or data not specified by the user
268 -
  data_list <- tfe$data_list
468 +
  data_list <- dag$get_tf_data_list()
269 469
  missing <- !names(data_list) %in% names(values)
270 470
271 471
  # send list to tf environment and roll into a dict
272 472
  values <- lapply(values, add_first_dim)
273 -
  dag$build_feed_dict(values, data_list = data_list[missing])
473 +
  values <- lapply(values, tile_first_dim, batch_size)
274 474
275 -
  name <- dag$tf_name(get_node(target))
276 -
  result <- dag$tf_sess_run(name, as_text = TRUE)
475 +
  dag$build_feed_dict(values, data_list = data_list[missing])
277 476
278 -
  drop_first_dim(result)
477 +
  # run the sampling
478 +
  dag$tf_sess_run("calculate_target_tensor_list", as_text = TRUE)
279 479
280 480
}
281 -
282 -
# coerce value to have the correct dimensions
283 -
assign_dim <- function(value, greta_array) {
284 -
  array <- strip_unknown_class(get_node(greta_array)$value())
285 -
  if (length(array) != length(value)) {
286 -
    stop("a provided value has different number of elements",
287 -
         " than the greta array", call. = FALSE)
288 -
  }
289 -
  array[] <- value
290 -
  array
291 -
}

@@ -2,39 +2,39 @@
Loading
2 2
3 3
#' @name model
4 4
#' @title greta model objects
5 -
#' @description Create a \code{greta_model} object representing a statistical
6 -
#'   model (using \code{model}), and plot a graphical representation of the
7 -
#'   model. Statistical inference can be performed on \code{greta_model} objects
8 -
#'   with \code{\link{mcmc}}
5 +
#' @description Create a `greta_model` object representing a statistical
6 +
#'   model (using `model`), and plot a graphical representation of the
7 +
#'   model. Statistical inference can be performed on `greta_model` objects
8 +
#'   with [mcmc()]
9 9
NULL
10 10
11 11
#' @rdname model
12 12
#' @export
13 13
#'
14 -
#' @param \dots for \code{model}: \code{greta_array} objects to be tracked by
14 +
#' @param \dots for `model`: `greta_array` objects to be tracked by
15 15
#'   the model (i.e. those for which samples will be retained during mcmc). If
16 -
#'   not provided, all of the non-data \code{greta_array} objects defined in the
17 -
#'   calling environment will be tracked. For \code{print} and
18 -
#'   \code{plot}:further arguments passed to or from other methods (currently
16 +
#'   not provided, all of the non-data `greta_array` objects defined in the
17 +
#'   calling environment will be tracked. For `print` and
18 +
#'   `plot`:further arguments passed to or from other methods (currently
19 19
#'   ignored).
20 20
#'
21 21
#' @param precision the floating point precision to use when evaluating this
22 -
#'   model. Switching from \code{"double"} (the default) to \code{"single"} may
22 +
#'   model. Switching from `"double"` (the default) to `"single"` may
23 23
#'   decrease the computation time but increase the risk of numerical
24 24
#'   instability during sampling.
25 25
#'
26 26
#' @param compile whether to apply
27 -
#'   \href{https://www.tensorflow.org/performance/xla/}{XLA JIT compilation} to
27 +
#'   [XLA JIT compilation](https://www.tensorflow.org/performance/xla/) to
28 28
#'   the TensorFlow graph representing the model. This may slow down model
29 29
#'   definition, and speed up model evaluation.
30 30
#'
31 -
#' @details \code{model()} takes greta arrays as arguments, and defines a
31 +
#' @details `model()` takes greta arrays as arguments, and defines a
32 32
#'   statistical model by finding all of the other greta arrays on which they
33 -
#'   depend, or which depend on them. Further arguments to \code{model} can be
33 +
#'   depend, or which depend on them. Further arguments to `model` can be
34 34
#'   used to configure the TensorFlow graph representing the model, to tweak
35 35
#'   performance.
36 36
#'
37 -
#' @return \code{model} - a \code{greta_model} object.
37 +
#' @return `model` - a `greta_model` object.
38 38
#'
39 39
#' @examples
40 40
#' \dontrun{
@@ -78,31 +78,7 @@
Loading
78 78
79 79
  }
80 80
81 -
  # check they are greta arrays
82 -
  are_greta_arrays <- vapply(target_greta_arrays,
83 -
                             inherits, "greta_array",
84 -
                             FUN.VALUE = FALSE)
85 -
86 -
  if (!all(are_greta_arrays)) {
87 -
88 -
    unexpected_items <- names(target_greta_arrays)[!are_greta_arrays]
89 -
90 -
    msg <- ifelse(length(unexpected_items) > 1,
91 -
                  paste("The following objects passed to model()",
92 -
                        "are not greta arrays: "),
93 -
                  paste("The following object passed to model()",
94 -
                        "is not a greta array: "))
95 -
96 -
    stop(msg,
97 -
         paste(unexpected_items, sep = ", "),
98 -
         call. = FALSE)
99 -
100 -
  }
101 -
102 -
  if (length(target_greta_arrays) == 0) {
103 -
    stop("could not find any non-data greta arrays",
104 -
         call. = FALSE)
105 -
  }
81 +
  target_greta_arrays <- check_greta_arrays(target_greta_arrays, "model")
106 82
107 83
  # get the dag containing the target nodes
108 84
  dag <- dag_class$new(target_greta_arrays,
@@ -183,17 +159,17 @@
Loading
183 159
}
184 160
185 161
# register generic method to coerce objects to a greta model
186 -
as.greta_model <- function(x, ...)  # Exclude Linting
162 +
as.greta_model <- function(x, ...)  # nolint
187 163
  UseMethod("as.greta_model", x)
188 164
189 -
as.greta_model.dag_class <- function(x, ...) {  # Exclude Linting
165 +
as.greta_model.dag_class <- function(x, ...) {  # nolint
190 166
  ans <- list(dag = x)
191 167
  class(ans) <- "greta_model"
192 168
  ans
193 169
}
194 170
195 171
#' @rdname model
196 -
#' @param x a \code{greta_model} object
172 +
#' @param x a `greta_model` object
197 173
#' @export
198 174
print.greta_model <- function(x, ...) {
199 175
  cat("greta model")
@@ -201,19 +177,19 @@
Loading
201 177
202 178
#' @rdname model
203 179
#' @param y unused default argument
204 -
#' @param colour base colour used for plotting. Defaults to \code{greta} colours
180 +
#' @param colour base colour used for plotting. Defaults to `greta` colours
205 181
#'   in violet.
206 182
#'
207 183
#' @details The plot method produces a visual representation of the defined
208 -
#'   model. It uses the \code{DiagrammeR} package, which must be installed
184 +
#'   model. It uses the `DiagrammeR` package, which must be installed
209 185
#'   first. Here's a key to the plots:
210 186
#'   \if{html}{\figure{plotlegend.png}{options: width="100\%"}}
211 187
#'   \if{latex}{\figure{plotlegend.pdf}{options: width=7cm}}
212 188
#'
213 -
#' @return \code{plot} - a \code{\link[DiagrammeR:grViz]{DiagrammeR::grViz}}
189 +
#' @return `plot` - a [DiagrammeR::grViz()]
214 190
#'   object, with the
215 -
#'   \code{\link[DiagrammeR:create_graph]{DiagrammeR::dgr_graph}} object used to
216 -
#'   create it as an attribute \code{"dgr_graph"}.
191 +
#'   [`DiagrammeR::dgr_graph()`][DiagrammeR::create_graph] object used to
192 +
#'   create it as an attribute `"dgr_graph"`.
217 193
#'
218 194
#' @export
219 195
plot.greta_model <- function(x,

@@ -5,19 +5,19 @@
Loading
5 5
# create dag class
6 6
dag_class <- R6Class(
7 7
  "dag_class",
8 +
8 9
  public = list(
9 10
11 +
    mode = "all_forward",
10 12
    node_list = list(),
11 -
    node_types = NA,
12 -
    node_tf_names = NA,
13 +
    target_nodes = list(),
14 +
    variables_without_free_state = list(),
13 15
    tf_environment = NA,
14 16
    tf_graph = NA,
15 -
    target_nodes = NA,
16 -
    parameters_example = NA,
17 +
17 18
    tf_float = NA,
18 19
    n_cores = 0L,
19 20
    compile = NA,
20 -
    adjacency_matrix = NULL,
21 21
    trace_names = NULL,
22 22
23 23
    # create a dag from some target nodes
@@ -31,15 +31,9 @@
Loading
31 31
      # find the nodes we care about
32 32
      self$target_nodes <- lapply(target_greta_arrays, get_node)
33 33
34 -
      # create an adjacency matrix
35 -
      self$build_adjacency_matrix()
36 -
37 34
      # set up the tf environment, with a graph
38 35
      self$new_tf_environment()
39 36
40 -
      # stash an example list to relist parameters
41 -
      self$parameters_example <- self$example_parameters(flat = FALSE)
42 -
43 37
      # store the performance control info
44 38
      self$tf_float <- tf_float
45 39
      self$compile <- compile
@@ -50,7 +44,9 @@
Loading
50 44
51 45
      self$tf_environment <- new.env()
52 46
      self$tf_graph <- tf$Graph()
53 -
      self$tf_environment$data_list <- list()
47 +
      self$tf_environment$all_forward_data_list <- list()
48 +
      self$tf_environment$all_sampling_data_list <- list()
49 +
      self$tf_environment$hybrid_data_list <- list()
54 50
55 51
    },
56 52
@@ -61,8 +57,13 @@
Loading
61 57
      # temporarily pass float type info to options, so it can be accessed by
62 58
      # nodes on definition, without cluncky explicit passing
63 59
      old_float_type <- options()$greta_tf_float
64 -
      on.exit(options(greta_tf_float = old_float_type))
65 -
      options(greta_tf_float = self$tf_float)
60 +
      old_batch_size <- options()$greta_batch_size
61 +
62 +
      on.exit(options(greta_tf_float = old_float_type,
63 +
                      greta_batch_size = old_batch_size))
64 +
65 +
      options(greta_tf_float = self$tf_float,
66 +
              greta_batch_size = self$tf_environment$batch_size)
66 67
67 68
      with(self$tf_graph$as_default(), expr)
68 69
    },
@@ -86,8 +87,9 @@
Loading
86 87
    # sess$run() an expression in the tensorflow environment, with the feed dict
87 88
    tf_sess_run = function(expr, as_text = FALSE) {
88 89
89 -
      if (!as_text)
90 +
      if (!as_text) {
90 91
        expr <- deparse(substitute(expr))
92 +
      }
91 93
92 94
      expr <- paste0("sess$run(", expr, ", feed_dict = feed_dict)")
93 95
@@ -105,52 +107,163 @@
Loading
105 107
        node$register_family(self)
106 108
      }
107 109
108 -
      # stash the node names, types, and tf names
109 -
      self$node_types <- vapply(self$node_list, node_type, FUN.VALUE = "")
110 -
      self$node_tf_names <- self$make_names()
111 -
112 -
    },
113 -
114 -
    # create human-readable names for TF tensors
115 -
    make_names = function() {
116 -
117 -
      types <- self$node_types
118 -
119 -
      for (type in c("variable", "data", "operation", "distribution")) {
120 -
        idx <- which(types == type)
121 -
        types[idx] <- paste(type, seq_along(idx), sep = "_")
122 -
      }
123 -
124 -
      self$node_tf_names <- types
125 -
126 110
    },
127 111
128 112
    # get the TF names for different node types
129 113
    get_tf_names = function(types = NULL) {
114 +
115 +
      # get tf basenames
130 116
      names <- self$node_tf_names
131 117
      if (!is.null(types))
132 118
        names <- names[which(self$node_types %in% types)]
133 -
      names
119 +
120 +
      # prepend mode
121 +
      if (length(names) > 0) {
122 +
        names <- paste(self$mode, names, sep = "_")
123 +
      }
124 +
134 125
    },
135 126
136 127
    # look up the TF name for a single node
137 128
    tf_name = function(node) {
129 +
130 +
      # get tf basename from node name
138 131
      name <- self$node_tf_names[node$unique_name]
139 132
      if (length(name) == 0) {
140 133
        name <- ""
141 134
      }
135 +
136 +
      # prepend mode
137 +
      if (!is.na(name)) {
138 +
        name <- paste(self$mode, name, sep = "_")
139 +
      }
140 +
142 141
      name
142 +
143 +
    },
144 +
145 +
    # how to define a node if the sampling mode is hybrid (this is quite knotty,
146 +
    # so gets its own function)
147 +
    how_to_define_hybrid = function(node) {
148 +
149 +
      node_type <- node_type(node)
150 +
151 +
      # names of variable nodes not connected to the free state in this dag
152 +
      stateless_names <- names(self$variables_without_free_state)
153 +
154 +
      # if the node is data, use sampling mode if it has a distribution and
155 +
      # forward mode if not
156 +
      if (node_type == "data") {
157 +
        node_mode <- ifelse(has_distribution(node), "sampling", "forward")
158 +
      }
159 +
160 +
      # if the node is a variable, use forward mode if it has a free state,
161 +
      # and sampling mode if not
162 +
      if (node_type == "variable") {
163 +
        to_sample <- node$unique_name %in% stateless_names
164 +
        node_mode <- ifelse(to_sample, "sampling", "forward")
165 +
      }
166 +
167 +
      # if it's an operation, see if it has a distribution (for lkj and
168 +
      # wishart) and get mode based on whether the parent has a free state
169 +
      if (node_type == "operation") {
170 +
171 +
        parent_name <- node$parents[[1]]$unique_name
172 +
        parent_stateless <- parent_name %in% stateless_names
173 +
        to_sample <- has_distribution(node) & parent_stateless
174 +
        node_mode <- ifelse(to_sample, "sampling", "forward")
175 +
176 +
      }
177 +
178 +
      # if the node is a distribution, decide based on its target
179 +
      if (node_type == "distribution") {
180 +
181 +
        target <- node$target
182 +
        target_type <- node_type(target)
183 +
184 +
        # if it has no target (e.g. for a mixture distribution), define it in
185 +
        # sampling mode (so it defines before the things that depend on it)
186 +
        if (is.null(target)) {
187 +
          node_mode <- "sampling"
188 +
        }
189 +
190 +
        # if the target is data, use sampling mode
191 +
        if (target_type == "data") {
192 +
          node_mode <- "sampling"
193 +
        }
194 +
195 +
        # if the target is a variable, use forward mode if it has a free
196 +
        # state, and sampling mode if not
197 +
        if (target_type == "variable") {
198 +
          to_sample <- target$unique_name %in% stateless_names
199 +
          node_mode <- ifelse(to_sample, "sampling", "forward")
200 +
201 +
        }
202 +
203 +
        # if the target is an operation, see if that operation has a single
204 +
        # parent that is a variable, and see if that has a free state
205 +
        if (target_type == "operation") {
206 +
207 +
          target_parent_name <- target$parents[[1]]$unique_name
208 +
          target_parent_stateless <- target_parent_name %in% stateless_names
209 +
          node_mode <- ifelse(target_parent_stateless, "sampling", "forward")
210 +
211 +
        }
212 +
213 +
      }
214 +
215 +
      node_mode
216 +
217 +
    },
218 +
219 +
    # how to define the node if we're sampling everything (no free state)
220 +
    how_to_define_all_sampling = function(node) {
221 +
222 +
223 +
      switch(node_type(node),
224 +
             data = ifelse(has_distribution(node), "sampling", "forward"),
225 +
             operation = ifelse(has_distribution(node), "sampling", "forward"),
226 +
             "sampling"
227 +
      )
228 +
    },
229 +
230 +
    # tell a node whether to define itself in forward mode (deterministically
231 +
    # from an existing free state), or in sampling mode (generate a random
232 +
    # version of itself)
233 +
    how_to_define = function(node) {
234 +
235 +
      switch(
236 +
        self$mode,
237 +
238 +
        # if doing inference, everything is push-forward
239 +
        all_forward = "forward",
240 +
241 +
        # sampling from prior most nodes are in sampling mode
242 +
        all_sampling = self$how_to_define_all_sampling(node),
243 +
244 +
        # sampling from posterior some nodes defined forward, others sampled
245 +
        hybrid = self$how_to_define_hybrid(node)
246 +
247 +
      )
248 +
249 +
    },
250 +
251 +
    define_batch_size = function() {
252 +
253 +
      self$tf_run(
254 +
        batch_size <- tf$compat$v1$placeholder(dtype = tf$int32)
255 +
      )
256 +
143 257
    },
144 258
145 259
    define_free_state = function(type = c("variable", "placeholder"),
146 260
                                 name = "free_state") {
147 261
148 262
      type <- match.arg(type)
149 -
150 263
      tfe <- self$tf_environment
151 264
152 -
      vals <- self$example_parameters()
153 -
265 +
      vals <- self$example_parameters(free = TRUE)
266 +
      vals <- unlist_tf(vals)
154 267
155 268
      if (type == "variable") {
156 269
@@ -175,18 +288,16 @@
Loading
175 288
176 289
    },
177 290
178 -
    # define the body of the tensorflow graph in the environment env; without
179 -
    # defining the free_state, or the densities etc.
180 -
    define_tf_body = function(target_nodes = self$node_list) {
291 +
    # split the overall free state vector into free versions of variables
292 +
    split_free_state = function() {
181 293
182 294
      tfe <- self$tf_environment
183 295
184 -
      # split up into separate free state variables and assign
185 296
      free_state <- get("free_state", envir = tfe)
186 297
187 -
      params <- self$parameters_example
298 +
      params <- self$example_parameters(free = TRUE)
188 299
      lengths <- vapply(params,
189 -
                        function(x) as.integer(prod(dim(x))),
300 +
                        function(x) length(x),
190 301
                        FUN.VALUE = 1L)
191 302
192 303
      if (length(lengths) > 1) {
@@ -196,13 +307,26 @@
Loading
196 307
      }
197 308
198 309
      names <- paste0(names(params), "_free")
199 -
      for (i in seq_along(names))
310 +
311 +
      for (i in seq_along(names)) {
200 312
        assign(names[i], args[[i]], envir = tfe)
313 +
      }
314 +
315 +
    },
316 +
317 +
    # define the body of the tensorflow graph in the environment env; without
318 +
    # defining the free_state, or the densities etc.
319 +
    define_tf_body = function(target_nodes = self$node_list) {
201 320
202 -
      # define all nodes, node densities and free states in the environment, and
203 -
      # on the graph
204 -
      self$on_graph(lapply(target_nodes,
205 -
                           function(x) x$define_tf(self)))
321 +
      # if in forward or hybrid mode, split up the free state
322 +
      if (self$mode %in% c("all_forward", "hybrid")) {
323 +
        self$split_free_state()
324 +
      }
325 +
326 +
      # define all nodes in the environment and on the graph
327 +
      self$on_graph(
328 +
        lapply(target_nodes, function(x) x$define_tf(self))
329 +
      )
206 330
207 331
      invisible(NULL)
208 332
@@ -213,17 +337,25 @@
Loading
213 337
214 338
      tfe <- self$tf_environment
215 339
      tfe$n_cores <- self$n_cores
216 -
      # Begin Exclude Linting
340 +
341 +
      # nolint start
217 342
      self$tf_run(
218 -
        config <- tf$compat$v1$ConfigProto(inter_op_parallelism_threads = n_cores,
219 -
                                 intra_op_parallelism_threads = n_cores))
343 +
        config <- tf$compat$v1$ConfigProto(
344 +
          inter_op_parallelism_threads = n_cores,
345 +
          intra_op_parallelism_threads = n_cores
346 +
        )
347 +
      )
220 348
221 349
      if (self$compile) {
222 -
        self$tf_run(py_set_attr(config$graph_options$optimizer_options,
223 -
                                "global_jit_level",
224 -
                                tf$compat$v1$OptimizerOptions$ON_1))
350 +
        self$tf_run(
351 +
          py_set_attr(
352 +
            config$graph_options$optimizer_options,
353 +
            "global_jit_level",
354 +
            tf$compat$v1$OptimizerOptions$ON_1
355 +
          )
356 +
        )
225 357
      }
226 -
      # End Exclude Linting
358 +
      # nolint end
227 359
228 360
      # start a session and initialise all variables
229 361
      self$tf_run(sess <- tf$compat$v1$Session(config = config))
@@ -231,12 +363,18 @@
Loading
231 363
232 364
    },
233 365
234 -
    # define tf graph in environment
235 -
    define_tf = function() {
366 +
    # define tf graph in environment; either for forward-mode computation from a
367 +
    # free state variable, or for sampling
368 +
    define_tf = function(target_nodes = self$node_list) {
236 369
237 -
      # define the free state variable, rest of the graph, and the session
238 -
      self$define_free_state("placeholder", "free_state")
239 -
      self$define_tf_body()
370 +
      # define the free state variable
371 +
      if (self$mode %in% c("all_forward", "hybrid")) {
372 +
        self$define_free_state("placeholder")
373 +
      }
374 +
375 +
      # define the body of the graph (depending on the mode) and the session
376 +
      self$define_batch_size()
377 +
      self$define_tf_body(target_nodes = target_nodes)
240 378
      self$define_tf_session()
241 379
242 380
    },
@@ -246,45 +384,23 @@
Loading
246 384
247 385
      tfe <- self$tf_environment
248 386
249 -
      # get all distribution nodes
250 -
      distributions <- self$node_list[self$node_types == "distribution"]
251 -
252 -
      # keep only those with a target node
253 -
      targets <- lapply(distributions, member, "get_tf_target_node()")
254 -
      has_target <- !vapply(targets, is.null, FUN.VALUE = TRUE)
255 -
256 -
      distributions <- distributions[has_target]
257 -
      targets <- targets[has_target]
258 -
259 -
      # find and get these functions
260 -
      density_names <- vapply(distributions,
261 -
                              self$tf_name,
262 -
                              FUN.VALUE = "")
263 -
      target_names <- vapply(targets,
264 -
                             self$tf_name,
265 -
                             FUN.VALUE = "")
266 -
267 -
      target_tensors <- lapply(target_names,
268 -
                               get,
269 -
                               envir = tfe)
270 -
      density_functions <- lapply(density_names,
271 -
                                  get,
272 -
                                  envir = tfe)
273 -
274 -
      # make the target names lists, for do.call
275 -
      target_lists <- lapply(target_tensors, list)
276 -
277 -
      # execute them
278 -
      densities <- mapply(do.call,
279 -
                          density_functions,
280 -
                          target_lists,
281 -
                          MoreArgs = list(envir = tfe),
387 +
      # get all distribution nodes that have a target
388 +
      distribution_nodes <- self$node_list[self$node_types == "distribution"]
389 +
      target_nodes <- lapply(distribution_nodes, member, "get_tf_target_node()")
390 +
      has_target <- !vapply(target_nodes, is.null, FUN.VALUE = TRUE)
391 +
      distribution_nodes <- distribution_nodes[has_target]
392 +
      target_nodes <- target_nodes[has_target]
393 +
394 +
      # get the densities, evaluated at these targets
395 +
      densities <- mapply(self$evaluate_density,
396 +
                          distribution_nodes,
397 +
                          target_nodes,
282 398
                          SIMPLIFY = FALSE)
283 399
284 -
      # reduce_sum them
400 +
      # reduce_sum each of them (skipping the batch dimension)
285 401
      self$on_graph(summed_densities <- lapply(densities, tf_sum, drop = TRUE))
286 402
287 -
      # remove their names and sum them together
403 +
      # sum them together
288 404
      names(summed_densities) <- NULL
289 405
      self$on_graph(joint_density <- tf$add_n(summed_densities))
290 406
@@ -295,14 +411,16 @@
Loading
295 411
296 412
      # define adjusted joint density
297 413
298 -
      # get names of adjustment tensors for all variable nodes
414 +
      # get names of Jacobian adjustment tensors for all variable nodes
299 415
      adj_names <- paste0(self$get_tf_names(types = "variable"), "_adj")
300 416
301 417
      # get TF density tensors for all distribution
302 418
      adj <- lapply(adj_names, get, envir = self$tf_environment)
303 419
304 -
      # remove their names and sum them together
420 +
      # remove their names and sum them together (accounting for tfp bijectors
421 +
      # sometimes returning a scalar tensor)
305 422
      names(adj) <- NULL
423 +
      adj <- match_batches(adj)
306 424
      self$on_graph(total_adj <- tf$add_n(adj))
307 425
308 426
      # assign overall density to environment
@@ -312,6 +430,81 @@
Loading
312 430
313 431
    },
314 432
433 +
    # evaluate the (truncation-corrected) density of a tfp distribution on its
434 +
    # target tensor
435 +
    evaluate_density = function(distribution_node, target_node) {
436 +
437 +
      tfe <- self$tf_environment
438 +
439 +
      parameter_nodes <- distribution_node$parameters
440 +
441 +
      # get the tensorflow objects for these
442 +
      distrib_constructor <- self$get_tf_object(distribution_node)
443 +
      tf_target <- self$get_tf_object(target_node)
444 +
      tf_parameter_list <- lapply(parameter_nodes, self$get_tf_object)
445 +
446 +
      # execute the distribution constructor functions to return a tfp
447 +
      # distribution object
448 +
      tfp_distribution <- distrib_constructor(tf_parameter_list, dag = self)
449 +
450 +
      self$tf_evaluate_density(tfp_distribution,
451 +
                               tf_target,
452 +
                               truncation = distribution_node$truncation,
453 +
                               bounds = distribution_node$bounds)
454 +
455 +
    },
456 +
457 +
    tf_evaluate_density = function(tfp_distribution,
458 +
                                   tf_target,
459 +
                                   truncation = NULL,
460 +
                                   bounds = NULL) {
461 +
462 +
      # get the uncorrected log density
463 +
      ld <- tfp_distribution$log_prob(tf_target)
464 +
465 +
      # if required, calculate the log-adjustment to the truncation term of
466 +
      # the density function i.e. the density of a distribution, truncated
467 +
      # between a and b, is the non truncated density, divided by the integral
468 +
      # of the density function between the truncation bounds. This can be
469 +
      # calculated from the distribution's CDF
470 +
      if (!is.null(truncation)) {
471 +
472 +
        lower <- truncation[[1]]
473 +
        upper <- truncation[[2]]
474 +
475 +
        if (all(lower == bounds[1])) {
476 +
477 +
          # if only upper is constrained, just need the cdf at the upper
478 +
          offset <- tfp_distribution$log_cdf(fl(upper))
479 +
480 +
        } else if (all(upper == bounds[2])) {
481 +
482 +
          # if only lower is constrained, get the log of the integral above it
483 +
          offset <- tf$math$log(fl(1) - tfp_distribution$cdf(fl(lower)))
484 +
485 +
        } else {
486 +
487 +
          # if both are constrained, get the log of the integral between them
488 +
          offset <- tf$math$log(tfp_distribution$cdf(fl(upper)) -
489 +
                                  tfp_distribution$cdf(fl(lower)))
490 +
491 +
        }
492 +
493 +
        ld <- ld - offset
494 +
495 +
      }
496 +
497 +
498 +
      ld
499 +
500 +
501 +
    },
502 +
503 +
    # get the tf object in envir correpsonding to 'node'
504 +
    get_tf_object = function(node) {
505 +
      get(self$tf_name(node), envir = self$tf_environment)
506 +
    },
507 +
315 508
    # return a function to obtain the model log probability from a tensor for
316 509
    # the free state
317 510
    generate_log_prob_function = function(which = c("adjusted",
@@ -331,6 +524,7 @@
Loading
331 524
        data_names <- self$get_tf_names(types = "data")
332 525
        for (name in data_names)
333 526
          tfe[[name]] <- tfe_old[[name]]
527 +
        tfe$batch_size <- tfe_old$batch_size
334 528
335 529
        # put the free state in the environment, and build out the tf graph
336 530
        tfe$free_state <- free_state
@@ -354,24 +548,48 @@
Loading
354 548
355 549
    },
356 550
357 -
    # return the expected free parameter format either in list or vector form
358 -
    example_parameters = function(flat = TRUE) {
551 +
    # return the expected parameter format either in free state vector form, or
552 +
    # list of transformed parameters
553 +
    example_parameters = function(free = TRUE) {
359 554
360 -
      # find all variable nodes in the graph and get their values
555 +
      # find all variable nodes in the graph
361 556
      nodes <- self$node_list[self$node_types == "variable"]
362 557
      names(nodes) <- self$get_tf_names(types = "variable")
363 -
      current_parameters <- lapply(nodes, member, "value()")
364 558
365 -
      # optionally flatten them
366 -
      if (flat)
367 -
        current_parameters <- unlist_tf(current_parameters)
559 +
      # get their values in either free of non-free form
560 +
      if (free) {
561 +
        parameters <- lapply(nodes, member, "value(free = TRUE)")
562 +
      } else {
563 +
        parameters <- lapply(nodes, member, "value()")
564 +
      }
565 +
566 +
      # remove any of these that don't need a free state here (for calculate())
567 +
      stateless_names <- vapply(self$variables_without_free_state,
568 +
                                self$tf_name,
569 +
                                FUN.VALUE = character(1))
570 +
      keep <- !names(parameters) %in% stateless_names
571 +
      parameters <- parameters[keep]
572 +
573 +
      parameters
574 +
575 +
    },
576 +
577 +
    get_tf_data_list = function() {
578 +
579 +
      data_list_name <- paste0(self$mode, "_data_list")
580 +
      self$tf_environment[[data_list_name]]
368 581
369 -
      current_parameters
582 +
    },
583 +
584 +
    set_tf_data_list = function(element_name, value) {
585 +
586 +
      data_list_name <- paste0(self$mode, "_data_list")
587 +
      self$tf_environment[[data_list_name]][[element_name]] <- value
370 588
371 589
    },
372 590
373 591
    build_feed_dict = function(dict_list = list(),
374 -
                               data_list = self$tf_environment$data_list) {
592 +
                               data_list = self$get_tf_data_list()) {
375 593
376 594
      tfe <- self$tf_environment
377 595
@@ -394,6 +612,9 @@
Loading
394 612
      # create a feed dict in the TF environment
395 613
      parameter_list <- list(free_state = parameters)
396 614
615 +
      # set the batch size to match parameters
616 +
      self$set_tf_data_list("batch_size", nrow(parameters))
617 +
397 618
      self$build_feed_dict(parameter_list)
398 619
399 620
    },
@@ -490,6 +711,7 @@
Loading
490 711
      }
491 712
      elements <- seq_along(trace_list_batches[[1]])
492 713
      trace_list <- lapply(elements, stack_elements, trace_list_batches)
714 +
      names(trace_list) <- names(trace_list_batches[[1]])
493 715
494 716
      # if they are flattened, e.g. for MCMC tracing
495 717
      if (flatten) {
@@ -504,10 +726,6 @@
Loading
504 726
505 727
      } else {
506 728
507 -
        # prepare for return to R
508 -
        trace_list <- lapply(trace_list, drop_first_dim)
509 -
        trace_list <- lapply(trace_list, drop_column_dim)
510 -
511 729
        out <- trace_list
512 730
513 731
      }
@@ -563,7 +781,98 @@
Loading
563 781
564 782
    },
565 783
566 -
    build_adjacency_matrix = function() {
784 +
    # get the tfp distribution object for a distribution node
785 +
    get_tfp_distribution = function(distrib_node) {
786 +
787 +
      # build the tfp distribution object for the distribution, and use it
788 +
      # to get the tensor for the sample
789 +
      distrib_constructor <- self$get_tf_object(distrib_node)
790 +
      parameter_nodes <- distrib_node$parameters
791 +
      tf_parameter_list <- lapply(parameter_nodes, self$get_tf_object)
792 +
793 +
      # execute the distribution constructor functions to return a tfp
794 +
      # distribution object
795 +
      tfp_distribution <- distrib_constructor(tf_parameter_list, dag = self)
796 +
797 +
    },
798 +
799 +
    # try to draw a random sample from a distribution node
800 +
    draw_sample = function(distribution_node) {
801 +
802 +
      tfp_distribution <- self$get_tfp_distribution(distribution_node)
803 +
804 +
      sample <- tfp_distribution$sample
805 +
806 +
      if (is.null(sample)) {
807 +
        stop("sampling is not yet implemented for ",
808 +
             distribution_node$distribution_name,
809 +
             " distributions",
810 +
             call. = FALSE)
811 +
      }
812 +
813 +
      truncation <- distribution_node$truncation
814 +
815 +
      if (is.null(truncation)) {
816 +
817 +
        # if we're not dealing with truncation, sample directly
818 +
        tensor <- sample(seed = get_seed())
819 +
820 +
      } else {
821 +
822 +
        # if we're dealing with truncation (therefore univariate and continuous)
823 +
        # sample a random uniform (tensor), and pass through the truncated
824 +
        # quantile (inverse cdf) function
825 +
826 +
        cdf <- tfp_distribution$cdf
827 +
        quantile <- tfp_distribution$quantile
828 +
829 +
        if (is.null(cdf) | is.null(quantile)) {
830 +
          stop("sampling is not yet implemented for truncated ",
831 +
               distribution_node$distribution_name,
832 +
               " distributions",
833 +
               call. = FALSE)
834 +
        }
835 +
836 +
        # generate a random uniform sample of the correct shape and transform
837 +
        # through truncated inverse CDF to get draws on truncated scale
838 +
        u <- tf_randu(distribution_node$dim, self)
839 +
840 +
        lower <- cdf(fl(truncation[1]))
841 +
        upper <- cdf(fl(truncation[2]))
842 +
        range <- upper - lower
843 +
844 +
        tensor <- quantile(lower + u * range)
845 +
846 +
      }
847 +
848 +
      tensor
849 +
850 +
    }
851 +
852 +
  ),
853 +
854 +
  active = list(
855 +
856 +
    node_types = function(value) {
857 +
      vapply(self$node_list, node_type, FUN.VALUE = "")
858 +
    },
859 +
860 +
    # create human-readable base names for TF tensors. these will actually be
861 +
    # defined prepended with "all_forward_" or "all_sampling" or "hybrid_
862 +
    node_tf_names = function(value) {
863 +
864 +
      types <- self$node_types
865 +
866 +
      for (type in c("variable", "data", "operation", "distribution")) {
867 +
        idx <- which(types == type)
868 +
        types[idx] <- paste(type, seq_along(idx), sep = "_")
869 +
      }
870 +
871 +
      types
872 +
873 +
    },
874 +
875 +
    adjacency_matrix = function(value) {
567 876
568 877
      # make dag matrix
569 878
      n_node <- length(self$node_list)
@@ -573,11 +882,11 @@
Loading
573 882
      rownames(dag_mat) <- colnames(dag_mat) <- node_names
574 883
575 884
      children <- lapply(self$node_list,
576 -
                        member,
577 -
                        "child_names()")
578 -
      parents <- lapply(self$node_list,
579 885
                         member,
580 -
                         "parent_names(recursive = FALSE)")
886 +
                         "child_names()")
887 +
      parents <- lapply(self$node_list,
888 +
                        member,
889 +
                        "parent_names(recursive = FALSE)")
581 890
582 891
      # for distribution nodes, remove target nodes from parents, and put them
583 892
      # in children to send the arrow in the opposite direction when plotting
@@ -608,9 +917,7 @@
Loading
608 917
        dag_mat[parents[[i]], i] <- 1
609 918
      }
610 919
611 -
      self$adjacency_matrix <- dag_mat
612 -
920 +
      dag_mat
613 921
    }
614 -
615 922
  )
616 923
)

@@ -33,22 +33,22 @@
Loading
33 33
#'
34 34
#' @param x a greta array
35 35
#' @param i,j indices specifying elements to extract or replace
36 -
#' @param n a single integer, as in \code{utils::head()} and
37 -
#'   \code{utils::tail()}
36 +
#' @param n a single integer, as in `utils::head()` and
37 +
#'   `utils::tail()`
38 38
#' @param nrow,ncol optional dimensions for the resulting greta array when x is
39 39
#'   not a matrix.
40 -
#' @param value for \code{`[<-`} a greta array to replace elements, for
41 -
#'   \code{`dim<-`} either NULL or a numeric vector of dimensions
40 +
#' @param value for ``[<-`` a greta array to replace elements, for
41 +
#'   ``dim<-`` either NULL or a numeric vector of dimensions
42 42
#' @param ... either further indices specifying elements to extract or replace
43 -
#'   (\code{[}), or multiple greta arrays to combine (\code{cbind()},
44 -
#'   \code{rbind()} & \code{c()}), or additional arguments (\code{rep()},
45 -
#'   \code{head()}, \code{tail()})
43 +
#'   (`[`), or multiple greta arrays to combine (`cbind()`,
44 +
#'   `rbind()` & `c()`), or additional arguments (`rep()`,
45 +
#'   `head()`, `tail()`)
46 46
#' @param drop,recursive generic arguments that are ignored for greta arrays
47 47
#'
48 -
#' @details \code{diag()} can be used to extract or replace the diagonal part of
48 +
#' @details `diag()` can be used to extract or replace the diagonal part of
49 49
#'   a square and two-dimensional greta array, but it cannot be used to create a
50 50
#'   matrix-like greta array from a scalar or vector-like greta array. A static
51 -
#'   diagonal matrix can always be created with e.g. \code{diag(3)}, and then
51 +
#'   diagonal matrix can always be created with e.g. `diag(3)`, and then
52 52
#'   converted into a greta array.
53 53
#'
54 54
#' @examples
@@ -157,7 +157,7 @@
Loading
157 157
158 158
# replace syntax for greta array objects
159 159
#' @export
160 -
`[<-.greta_array` <- function(x, ..., value) {  # Exclude Linting
160 +
`[<-.greta_array` <- function(x, ..., value) {  # nolint
161 161
162 162
  node <- get_node(x)
163 163
@@ -292,7 +292,7 @@
Loading
292 292
293 293
}
294 294
295 -
# Begin Exclude Linting
295 +
# nolint start
296 296
#' @rdname overloaded
297 297
#' @export
298 298
abind <- function(...,
@@ -302,20 +302,20 @@
Loading
302 302
                  hier.names = FALSE, use.dnns = FALSE) {
303 303
  UseMethod("abind")
304 304
}
305 -
# End Exclude Linting
305 +
# nolint end
306 306
307 307
# clear CRAN checks spotting floating global variables
308 308
#' @importFrom utils globalVariables
309 309
utils::globalVariables("N", "greta")
310 310
311 -
# Begin Exclude Linting
311 +
# nolint start
312 312
#' @export
313 313
abind.default <- function(...,
314 314
                          along = N, rev.along = NULL, new.names = NULL,
315 315
                          force.array = TRUE, make.names = use.anon.names,
316 316
                          use.anon.names = FALSE, use.first.dimnames = FALSE,
317 317
                          hier.names = FALSE, use.dnns = FALSE) {
318 -
# End Exclude Linting
318 +
# nolint end
319 319
320 320
  # error nicely if they don't have abind installed
321 321
  abind_installed <- requireNamespace("abind", quietly = TRUE)
@@ -331,7 +331,7 @@
Loading
331 331
332 332
}
333 333
334 -
# Begin Exclude Linting
334 +
# nolint start
335 335
#' @export
336 336
abind.greta_array <- function(...,
337 337
                              along = N, rev.along = NULL, new.names = NULL,
@@ -339,10 +339,10 @@
Loading
339 339
                              use.anon.names = FALSE,
340 340
                              use.first.dimnames = FALSE, hier.names = FALSE,
341 341
                              use.dnns = FALSE) {
342 -
# End Exclude Linting
342 +
# nolint end
343 343
344 344
  # warn if any of the arguments have been changed
345 -
  # Begin Exclude Linting
345 +
  # nolint start
346 346
  user_set_args <- !is.null(rev.along) |
347 347
    !is.null(new.names) |
348 348
    !isTRUE(force.array) |
@@ -351,7 +351,7 @@
Loading
351 351
    !identical(use.first.dimnames, FALSE) |
352 352
    !identical(hier.names, FALSE) |
353 353
    !identical(use.dnns, FALSE)
354 -
  # End Exclude Linting
354 +
  # nolint end
355 355
356 356
  if (user_set_args) {
357 357
    warning("only the argument 'along' is supported when using abind ",
@@ -371,7 +371,7 @@
Loading
371 371
  n <- max(vapply(dims, length, FUN.VALUE = 1L))
372 372
373 373
  # needed to keep the same formals as abind
374 -
  N <- n  # Exclude Linting
374 +
  N <- n  # nolint
375 375
  along <- as.integer(force(along))
376 376
377 377
  # rationalise along, and pad N if we're prepending/appending a dimension
@@ -482,7 +482,7 @@
Loading
482 482
483 483
# reshape greta arrays
484 484
#' @export
485 -
`dim<-.greta_array` <- function(x, value) {  # Exclude Linting
485 +
`dim<-.greta_array` <- function(x, value) {  # nolint
486 486
487 487
  dims <- value
488 488
@@ -552,7 +552,7 @@
Loading
552 552
# arrays
553 553
#' @export
554 554
#' @importFrom utils head
555 -
head.greta_array <- function(x, n = 6L, ...) {  # Exclude Linting
555 +
head.greta_array <- function(x, n = 6L, ...) {  # nolint
556 556
557 557
  stopifnot(length(n) == 1L)
558 558
@@ -585,7 +585,7 @@
Loading
585 585
586 586
#' @export
587 587
#' @importFrom utils tail
588 -
tail.greta_array <- function(x, n = 6L, ...) {  # Exclude Linting
588 +
tail.greta_array <- function(x, n = 6L, ...) {  # nolint
589 589
590 590
  stopifnot(length(n) == 1L)
591 591

@@ -161,36 +161,6 @@
Loading
161 161
162 162
}
163 163
164 -
165 -
# given a flat tensor, convert it into a square symmetric matrix by considering
166 -
# it  as the non-zero elements of the lower-triangular decomposition of the
167 -
# square matrix
168 -
tf_flat_to_chol <- function(x, dims) {
169 -
  # drop trailing dimension, and biject forward to an upper triangular matrix
170 -
171 -
  # indices to the cholesky factor
172 -
  l_dummy <- dummy(dims)
173 -
  indices_diag <- diag(l_dummy)
174 -
  indices_offdiag <- sort(l_dummy[upper.tri(l_dummy, diag = FALSE)])
175 -
176 -
  # indices to the free state
177 -
  x_index_diag <- seq_along(indices_diag) - 1
178 -
  x_index_offdiag <- length(indices_diag) + seq_along(indices_offdiag) - 1
179 -
180 -
  # create an empty vector to fill with the values
181 -
  values_0 <- tf$zeros(shape(1, prod(dims), 1), dtype = tf_float())
182 -
  values_0_diag <- tf_recombine(values_0,
183 -
                                indices_diag,
184 -
                                tf$exp(x[, x_index_diag, , drop = FALSE]))
185 -
  values_z <- tf_recombine(values_0_diag,
186 -
                           indices_offdiag,
187 -
                           x[, x_index_offdiag, , drop = FALSE])
188 -
189 -
  # reshape into lower triangular and return
190 -
  tf$reshape(values_z, shape(-1, dims[1], dims[2]))
191 -
192 -
}
193 -
194 164
# given a (batched, column) vector tensor of elements, corresponding to the
195 165
# correlation-constrained (between -1 and 1) free state of a single row of the
196 166
# cholesky factor of a correlation matrix, return the (upper-triangular
@@ -239,14 +209,14 @@
Loading
239 209
  cond <- function(z, x, sumsq, lp, iter, maxiter)
240 210
    tf$less(iter, maxiter)
241 211
242 -
  # Begin Exclude Linting
212 +
  # nolint start
243 213
  shapes <- list(tf$TensorShape(shape(NULL, n)),
244 214
                 tf$TensorShape(shape(NULL, NULL)),
245 215
                 tf$TensorShape(shape(NULL)),
246 216
                 tf$TensorShape(shape(NULL)),
247 217
                 tf$TensorShape(shape()),
248 218
                 tf$TensorShape(shape()))
249 -
  # End Exclude Linting
219 +
  # nolint end
250 220
251 221
  body <- switch(which,
252 222
                 values = body_values,
@@ -275,58 +245,10 @@
Loading
275 245
276 246
}
277 247
278 -
# convert an unconstrained vector into symmetric correlation matrix
279 -
tf_flat_to_chol_correl <- function(x, dims) {
280 -
281 -
  dims <- dim(x)
282 -
  k <- dims[[2]]
283 -
  n <- (1 + sqrt(8 * k + 1)) / 2
284 -
285 -
  # drop the third dimension
286 -
  if (length(dims) == 3) {
287 -
    x <- tf$squeeze(x, axis = 2L)
288 -
  }
289 -
290 -
  # convert to -1, 1 scale
291 -
  z <- tf$tanh(x)
292 -
293 -
  # split z up into rows
294 -
  z_rows <- tf$split(z, 1:(n - 1), axis = 1L)
295 -
296 -
  # accumulate sum of squares within each row
297 -
  x_rows <- lapply(z_rows, tf_corrmat_row)
298 -
299 -
  # append 0s to all rows for the empty triangle
300 -
  zero_rows <- lapply((n - 2):0,
301 -
                      function(n) {
302 -
                        zeros <- tf$constant(rep(0, n),
303 -
                                             dtype = tf_float(),
304 -
                                             shape = shape(1, n))
305 -
                        expand_to_batch(zeros, x)
306 -
                      })
307 -
308 -
  lists <- mapply(list, x_rows, zero_rows, SIMPLIFY = FALSE)
309 -
  rows <- lapply(lists, tf$concat, axis = 1L)
310 -
311 -
  # add a fixed first row
312 -
  row_one <- tf$constant(c(1, rep(0, n - 1)),
313 -
                         dtype = tf_float(),
314 -
                         shape = shape(1, n))
315 -
  row_one <- expand_to_batch(row_one, x)
316 -
  rows <- c(row_one, rows)
317 -
318 -
  rows <- lapply(rows, tf$expand_dims, 2L)
319 -
320 -
  # combine into upper-triangular matrix
321 -
  mat <- tf$concat(rows, axis = 2L)
322 -
323 -
  mat
324 -
248 +
tf_chol2symm <- function(x) {
249 +
  tf$matmul(x, x, adjoint_a = TRUE)
325 250
}
326 251
327 -
tf_chol_to_symmetric <- function(u)
328 -
  tf$matmul(tf_transpose(u), u)
329 -
330 252
tf_colmeans <- function(x, dims) {
331 253
332 254
  idx <- rowcol_idx(x, dims, "col")
@@ -495,13 +417,6 @@
Loading
495 417
  tf$nn$softmax(latent)
496 418
}
497 419
498 -
# a version of tf$concat that automatically expands out the first dimension if
499 -
# necessary
500 -
tf_concat <- function(values, axis) {
501 -
  values <- match_batches(values)
502 -
  tf$concat(values = values, axis = axis)
503 -
}
504 -
505 420
# map R's extract and replace syntax to tensorflow, for use in operation nodes
506 421
# the following arguments are required:
507 422
#   nelem - number of elements in the original array,
@@ -530,11 +445,6 @@
Loading
530 445
# 0-indexing)
531 446
tf_recombine <- function(ref, index, updates) {
532 447
533 -
  # expand out any data to match the batch dimensions
534 -
  out_list <- match_batches(list(ref, updates))
535 -
  ref <- out_list[[1]]
536 -
  updates <- out_list[[2]]
537 -
538 448
  # vector denoting whether an element is being updated
539 449
  nelem <- dim(ref)[[2]]
540 450
  replaced <- rep(0, nelem)
@@ -655,6 +565,184 @@
Loading
655 565
656 566
}
657 567
568 +
# common construction of a chained bijector for scalars, optionally adding a
569 +
# final reshaping step
570 +
tf_scalar_biject <- function(..., dim) {
571 +
572 +
  steps <- list(...)
573 +
574 +
  if (!is.null(dim)) {
575 +
    steps <- c(tfp$bijectors$Reshape(dim), steps)
576 +
  }
577 +
578 +
  tfp$bijectors$Chain(steps)
579 +
580 +
}
581 +
582 +
tf_scalar_bijector <- function(dim, lower, upper) {
583 +
584 +
  tf_scalar_biject(
585 +
    tfp$bijectors$Identity(),
586 +
    dim = dim
587 +
  )
588 +
589 +
}
590 +
591 +
tf_scalar_pos_bijector <- function(dim, lower, upper) {
592 +
593 +
  tf_scalar_biject(
594 +
    tfp$bijectors$AffineScalar(shift = fl(lower)),
595 +
    tfp$bijectors$Exp(),
596 +
    dim = dim
597 +
  )
598 +
599 +
}
600 +
601 +
tf_scalar_neg_bijector <- function(dim, lower, upper) {
602 +
603 +
  tf_scalar_biject(
604 +
    tfp$bijectors$AffineScalar(shift = fl(upper), scale = fl(-1)),
605 +
    tfp$bijectors$Exp(),
606 +
    dim = dim
607 +
  )
608 +
609 +
}
610 +
611 +
tf_scalar_neg_pos_bijector <- function(dim, lower, upper) {
612 +
613 +
  tf_scalar_biject(
614 +
    tfp$bijectors$AffineScalar(shift = fl(lower), scale = fl(upper - lower)),
615 +
    tfp$bijectors$Sigmoid(),
616 +
    dim = dim
617 +
  )
618 +
619 +
}
620 +
621 +
# a blockwise combination of other transformations, with final reshaping
622 +
tf_scalar_mixed_bijector <- function(dim, lower, upper, constraints) {
623 +
624 +
  constructors <-
625 +
    list(
626 +
      none = tf_scalar_bijector,
627 +
      low = tf_scalar_neg_bijector,
628 +
      high = tf_scalar_pos_bijector,
629 +
      both = tf_scalar_neg_pos_bijector
630 +
    )
631 +
632 +
  # get the constructors, lower and upper bounds for each block
633 +
  rle <- rle(constraints)
634 +
  blocks <- rep(seq_along(rle$lengths), rle$lengths)
635 +
  constructor_idx <- match(rle$values, names(constructors))
636 +
  block_constructors <- constructors[constructor_idx]
637 +
  lowers <- split(lower, blocks)
638 +
  uppers <- split(upper, blocks)
639 +
640 +
  # combine into lists of arguments
641 +
  n_blocks <- length(rle$lengths)
642 +
  dims <- replicate(n_blocks, NULL, simplify = FALSE)
643 +
  block_parameters <- mapply(list, dims, lowers, uppers, SIMPLIFY = FALSE)
644 +
  block_parameters <- lapply(block_parameters,
645 +
                             `names<-`,
646 +
                             c("dim", "lower", "upper"))
647 +
648 +
  # create bijectors for each block
649 +
  names(block_constructors) <- NULL
650 +
  bijectors <- mapply(do.call,
651 +
                      block_constructors,
652 +
                      block_parameters,
653 +
                      SIMPLIFY = FALSE)
654 +
655 +
  # roll into single bijector
656 +
  tf_scalar_biject(
657 +
    tfp$bijectors$Blockwise(bijectors, block_sizes = rle$lengths),
658 +
    dim = dim
659 +
  )
660 +
661 +
}
662 +
663 +
tf_correlation_cholesky_bijector <- function() {
664 +
665 +
  steps <- list(
666 +
    tfp$bijectors$Transpose(perm = 1:0),
667 +
    tfp$bijectors$CorrelationCholesky()
668 +
  )
669 +
  bijector <- tfp$bijectors$Chain(steps)
670 +
671 +
  # forward_log_det_jacobian doesn't seem to work with unknown dimensions yet,
672 +
  # so replace for now with our own
673 +
  ljac_corr_mat <- function(x, event_ndims) {
674 +
675 +
    # find dimension
676 +
    k <- dim(x)[[2]]
677 +
    n <- (1 + sqrt(8 * k + 1)) / 2
678 +
679 +
    # convert to correlation-scale (-1, 1) & get log jacobian
680 +
    z <- tf$tanh(x)
681 +
682 +
    free_to_correl_lp <- tf_sum(log(fl(1) - tf$square(z)))
683 +
    free_to_correl_lp <- tf$squeeze(free_to_correl_lp, 1L)
684 +
685 +
    # split z up into rows
686 +
    z_rows <- tf$split(z, 1:(n - 1), axis = 1L)
687 +
688 +
    # accumulate log prob within each row
689 +
    lps <- lapply(z_rows, tf_corrmat_row, which = "ljac")
690 +
    correl_to_mat_lp <- tf$add_n(lps)
691 +
692 +
    free_to_correl_lp + correl_to_mat_lp
693 +
694 +
  }
695 +
696 +
  list(forward = bijector$forward,
697 +
       inverse = bijector$inverse,
698 +
       forward_log_det_jacobian = ljac_corr_mat)
699 +
700 +
}
701 +
702 +
tf_covariance_cholesky_bijector <- function() {
703 +
  tfp$bijectors$FillTriangular(upper = TRUE)
704 +
}
705 +
706 +
tf_simplex_bijector <- function(dim) {
707 +
708 +
  n_dim <- length(dim)
709 +
  last_dim <- dim[n_dim]
710 +
  raw_dim <- dim
711 +
  raw_dim[n_dim] <- last_dim - 1L
712 +
713 +
  steps <- list(
714 +
    tfp$bijectors$IteratedSigmoidCentered(),
715 +
    tfp$bijectors$Reshape(raw_dim)
716 +
  )
717 +
  tfp$bijectors$Chain(steps)
718 +
719 +
}
720 +
721 +
tf_ordered_bijector <- function(dim) {
722 +
723 +
  steps <- list(
724 +
    tfp$bijectors$Invert(tfp$bijectors$Ordered()),
725 +
    tfp$bijectors$Reshape(dim)
726 +
  )
727 +
  tfp$bijectors$Chain(steps)
728 +
729 +
}
730 +
731 +
# generate a tensor of random standard uniforms with a given shape,
732 +
# including the batch dimension
733 +
tf_randu <- function(dim, dag) {
734 +
  uniform <- tfp$distributions$Uniform(low = fl(0), high = fl(1))
735 +
  shape <- c(dag$tf_environment$batch_size, as.list(dim))
736 +
  uniform$sample(sample_shape = shape, seed = get_seed())
737 +
}
738 +
739 +
# generate an integer tensor of values up to n (indexed from 0) with a given
740 +
# shape, including the batch dimension
741 +
tf_randint <- function(n, dim, dag) {
742 +
  u <- tf_randu(dim, dag)
743 +
  tf$floor(u * as.integer(n))
744 +
}
745 +
658 746
# combine as module for export via internals
659 747
tf_functions_module <- module(tf_as_logical,
660 748
                              tf_as_float,
@@ -663,10 +751,8 @@
Loading
663 751
                              tf_lbeta,
664 752
                              tf_chol,
665 753
                              tf_chol2inv,
666 -
                              tf_flat_to_chol,
667 754
                              tf_corrmat_row,
668 -
                              tf_flat_to_chol_correl,
669 -
                              tf_chol_to_symmetric,
755 +
                              tf_chol2symm,
670 756
                              tf_colmeans,
671 757
                              tf_rowmeans,
672 758
                              tf_colsums,
@@ -696,4 +782,10 @@
Loading
696 782
                              tf_extract_eigenvectors,
697 783
                              tf_extract_eigenvalues,
698 784
                              tf_self_distance,
699 -
                              tf_distance)
785 +
                              tf_distance,
786 +
                              tf_scalar_bijector,
787 +
                              tf_scalar_neg_bijector,
788 +
                              tf_scalar_pos_bijector,
789 +
                              tf_scalar_neg_pos_bijector,
790 +
                              tf_correlation_cholesky_bijector,
791 +
                              tf_covariance_cholesky_bijector)

@@ -5,7 +5,6 @@
Loading
5 5
6 6
    min = NA,
7 7
    max = NA,
8 -
    log_density = NULL,
9 8
10 9
    initialize = function(min, max, dim) {
11 10
@@ -40,19 +39,17 @@
Loading
40 39
      # initialisation)
41 40
      self$min <- min
42 41
      self$max <- max
43 -
44 42
      self$bounds <- c(min, max)
45 43
46 44
      # initialize the rest
47 45
      super$initialize("uniform", dim)
48 46
49 47
      # add them as parents and greta arrays
48 +
      min <- as.greta_array(min)
49 +
      max <- as.greta_array(max)
50 50
      self$add_parameter(min, "min")
51 51
      self$add_parameter(max, "max")
52 52
53 -
      # the density is fixed, so calculate it now
54 -
      self$log_density <- -log(max - min)
55 -
56 53
    },
57 54
58 55
    # default value (ignore any truncation arguments)
@@ -63,14 +60,10 @@
Loading
63 60
64 61
    tf_distrib = function(parameters, dag) {
65 62
66 -
      tf_ld <- fl(self$log_density)
67 -
68 -
      # weird hack to make TF see a gradient here
69 -
      log_prob <- function(x) {
70 -
        tf_ld + tf_flatten(x) * fl(0)
71 -
      }
72 -
73 -
      list(log_prob = log_prob, cdf = NULL, log_cdf = NULL)
63 +
      tfp$distributions$Uniform(
64 +
        low = parameters$min,
65 +
        high = parameters$max
66 +
      )
74 67
75 68
    }
76 69
@@ -120,12 +113,12 @@
Loading
120 113
      self$add_parameter(sdlog, "sdlog")
121 114
    },
122 115
123 -
    # Begin Exclude Linting
116 +
    # nolint start
124 117
    tf_distrib = function(parameters, dag) {
125 118
      tfp$distributions$LogNormal(loc = parameters$meanlog,
126 119
                                  scale = parameters$sdlog)
127 120
    }
128 -
    # End Exclude Linting
121 +
    # nolint end
129 122
130 123
  )
131 124
)
@@ -175,18 +168,14 @@
Loading
175 168
          x * lprob + (fl(1) - x) * lprobnot
176 169
        }
177 170
178 -
        list(log_prob = log_prob, cdf = NULL, log_cdf = NULL)
171 +
        list(log_prob = log_prob)
179 172
180 173
      } else {
181 174
182 175
        tfp$distributions$Bernoulli(probs = parameters$prob)
183 176
184 177
      }
185 -
    },
186 -
187 -
    # no CDF for discrete distributions
188 -
    tf_cdf_function = NULL,
189 -
    tf_log_cdf_function = NULL
178 +
    }
190 179
191 180
  )
192 181
)
@@ -242,17 +231,13 @@
Loading
242 231
          log_choose + x * lprob + (size - x) * lprobnot
243 232
        }
244 233
245 -
        list(log_prob = log_prob, cdf = NULL, log_cdf = NULL)
234 +
        list(log_prob = log_prob)
246 235
247 236
      } else {
248 237
        tfp$distributions$Binomial(total_count = parameters$size,
249 238
                                   probs = parameters$prob)
250 239
      }
251 -
    },
252 -
253 -
    # no CDF for discrete distributions
254 -
    tf_cdf_function = NULL,
255 -
    tf_log_cdf_function = NULL
240 +
    }
256 241
257 242
  )
258 243
)
@@ -289,13 +274,21 @@
Loading
289 274
          tf_lbeta(alpha, beta)
290 275
      }
291 276
292 -
      list(log_prob = log_prob, cdf = NULL, log_cdf = NULL)
277 +
      # generate a beta, then a binomial
278 +
      sample <- function(seed) {
293 279
294 -
    },
280 +
        beta <- tfp$distributions$Beta(concentration1 = alpha,
281 +
                                       concentration0 = beta)
282 +
        probs <- beta$sample(seed = seed)
283 +
        binomial <- tfp$distributions$Binomial(total_count = size,
284 +
                                               probs = probs)
285 +
        binomial$sample(seed = seed)
286 +
287 +
      }
295 288
296 -
    # no CDF for discrete distributions
297 -
    tf_cdf_function = NULL,
298 -
    tf_log_cdf_function = NULL
289 +
      list(log_prob = log_prob, sample = sample)
290 +
291 +
    }
299 292
300 293
  )
301 294
)
@@ -332,11 +325,7 @@
Loading
332 325
333 326
      tfp$distributions$Poisson(log_rate = log_lambda)
334 327
335 -
    },
336 -
337 -
    # no CDF for discrete distributions
338 -
    tf_cdf_function = NULL,
339 -
    tf_log_cdf_function = NULL
328 +
    }
340 329
341 330
  )
342 331
)
@@ -358,16 +347,12 @@
Loading
358 347
      self$add_parameter(prob, "prob")
359 348
    },
360 349
361 -
    # Begin Exclude Linting
350 +
    # nolint start
362 351
    tf_distrib = function(parameters, dag) {
363 352
      tfp$distributions$NegativeBinomial(total_count = parameters$size,
364 353
                                         probs = fl(1) - parameters$prob)
365 -
    },
366 -
    # End Exclude Linting
367 -
368 -
    # no CDF for discrete distributions
369 -
    tf_cdf_function = NULL,
370 -
    tf_log_cdf_function = NULL
354 +
    }
355 +
    # nolint end
371 356
372 357
  )
373 358
)
@@ -403,13 +388,9 @@
Loading
403 388
          tf_lchoose(m + n, k)
404 389
      }
405 390
406 -
      list(log_prob = log_prob, cdf = NULL, log_cdf = NULL)
391 +
      list(log_prob = log_prob)
407 392
408 -
    },
409 -
410 -
    # no CDF for discrete distributions
411 -
    tf_cdf_function = NULL,
412 -
    tf_log_cdf_function = NULL
393 +
    }
413 394
414 395
  )
415 396
)
@@ -460,12 +441,12 @@
Loading
460 441
      self$add_parameter(beta, "beta")
461 442
    },
462 443
463 -
    # Begin Exclude Linting
444 +
    # nolint start
464 445
    tf_distrib = function(parameters, dag) {
465 446
      tfp$distributions$InverseGamma(concentration = parameters$alpha,
466 447
                                     rate = parameters$beta)
467 448
    }
468 -
    # End Exclude Linting
449 +
    # nolint end
469 450
470 451
  )
471 452
)
@@ -494,19 +475,40 @@
Loading
494 475
      a <- parameters$shape
495 476
      b <- parameters$scale
496 477
478 +
      # use the TFP Weibull CDF bijector
479 +
      bijector <- tfp$bijectors$Weibull(scale = b, concentration = a)
480 +
497 481
      log_prob <- function(x) {
498 482
        log(a) - log(b) + (a - fl(1)) * (log(x) - log(b)) - (x / b) ^ a
499 483
      }
500 484
501 485
      cdf <- function(x) {
502 -
        fl(1) - exp(fl(-1) * (x / b) ^ a)
486 +
        bijector$forward(x)
503 487
      }
504 488
505 489
      log_cdf <- function(x) {
506 490
        log(cdf(x))
507 491
      }
508 492
509 -
      list(log_prob = log_prob, cdf = cdf, log_cdf = log_cdf)
493 +
      quantile <- function(x) {
494 +
        bijector$inverse(x)
495 +
      }
496 +
497 +
      sample <- function(seed) {
498 +
499 +
        # sample by pushing standard uniforms through the inverse cdf
500 +
        u <- tf_randu(self$dim, dag)
501 +
        quantile(u)
502 +
503 +
      }
504 +
505 +
      list(
506 +
        log_prob = log_prob,
507 +
        cdf = cdf,
508 +
        log_cdf = log_cdf,
509 +
        quantile = quantile,
510 +
        sample = sample
511 +
      )
510 512
511 513
    }
512 514
@@ -584,13 +586,13 @@
Loading
584 586
      self$add_parameter(sigma, "sigma")
585 587
    },
586 588
587 -
    # Begin Exclude Linting
589 +
    # nolint start
588 590
    tf_distrib = function(parameters, dag) {
589 591
      tfp$distributions$StudentT(df = parameters$df,
590 592
                                 loc = parameters$mu,
591 593
                                 scale = parameters$sigma)
592 594
    }
593 -
    # End Exclude Linting
595 +
    # nolint end
594 596
595 597
  )
596 598
)
@@ -717,11 +719,6 @@
Loading
717 719
    tf_distrib = function(parameters, dag) {
718 720
      tfp$distributions$Logistic(loc = parameters$location,
719 721
                                 scale = parameters$scale)
720 -
    },
721 -
722 -
    # log_cdf in tf$cotrib$distributions has the wrong sign :/
723 -
    tf_log_cdf_function = function(x, parameters) {
724 -
      tf$math$log(self$tf_cdf_function(x, parameters))
725 722
    }
726 723
727 724
  )
@@ -741,7 +738,7 @@
Loading
741 738
      dim <- check_dims(df1, df2, target_dim = dim)
742 739
      check_positive(truncation)
743 740
      self$bounds <- c(0, Inf)
744 -
      super$initialize("d", dim, truncation)
741 +
      super$initialize("f", dim, truncation)
745 742
      self$add_parameter(df1, "df1")
746 743
      self$add_parameter(df2, "df2")
747 744
    },
@@ -771,7 +768,25 @@
Loading
771 768
      log_cdf <- function(x)
772 769
        log(cdf(x))
773 770
774 -
      list(log_prob = log_prob, cdf = cdf, log_cdf = log_cdf)
771 +
      sample <- function(seed) {
772 +
773 +
        # sample as the ratio of two scaled chi squared distributions
774 +
        d1 <- tfp$distributions$Chi2(df = df1)
775 +
        d2 <- tfp$distributions$Chi2(df = df2)
776 +
777 +
        u1 <- d1$sample(seed = seed)
778 +
        u2 <- d2$sample(seed = seed)
779 +
780 +
        (u1 / df1) / (u2 / df2)
781 +
782 +
      }
783 +
784 +
      list(
785 +
        log_prob = log_prob,
786 +
        cdf = cdf,
787 +
        log_cdf = log_cdf,
788 +
        sample = sample
789 +
      )
775 790
776 791
    }
777 792
@@ -798,18 +813,15 @@
Loading
798 813
      # parameters
799 814
      self$bounds <- c(0, Inf)
800 815
      super$initialize("dirichlet", dim,
801 -
                       truncation = c(0, Inf))
816 +
                       truncation = c(0, Inf),
817 +
                       multivariate = TRUE)
802 818
      self$add_parameter(alpha, "alpha")
803 819
804 820
    },
805 821
806 822
    create_target = function(truncation) {
807 823
808 -
      # handle simplex via a greta array
809 -
      free_greta_array <- variable(lower = 0, upper = 1, dim = self$dim)
810 -
811 -
      sums <- rowSums(free_greta_array)
812 -
      simplex_greta_array <- sweep(free_greta_array, 1, sums, "/")
824 +
      simplex_greta_array <- simplex_variable(self$dim)
813 825
814 826
      # return the node for the simplex
815 827
      target_node <- get_node(simplex_greta_array)
@@ -820,16 +832,11 @@
Loading
820 832
    tf_distrib = function(parameters, dag) {
821 833
      alpha <- parameters$alpha
822 834
      tfp$distributions$Dirichlet(concentration = alpha)
823 -
    },
824 -
825 -
    # no CDF for multivariate distributions
826 -
    tf_cdf_function = NULL,
827 -
    tf_log_cdf_function = NULL
835 +
    }
828 836
829 837
  )
830 838
)
831 839
832 -
833 840
dirichlet_multinomial_distribution <- R6Class(
834 841
  "dirichlet_multinomial_distribution",
835 842
  inherit = distribution_node,
@@ -853,25 +860,21 @@
Loading
853 860
      # parameters
854 861
      super$initialize("dirichlet_multinomial",
855 862
                       dim = dim,
856 -
                       discrete = TRUE)
857 -
      self$add_parameter(size, "size", expand_scalar_to = NULL)
863 +
                       discrete = TRUE,
864 +
                       multivariate = TRUE)
865 +
      self$add_parameter(size, "size", shape_matches_output = FALSE)
858 866
      self$add_parameter(alpha, "alpha")
859 867
860 868
    },
861 869
862 -
    # Begin Exclude Linting
870 +
    # nolint start
863 871
    tf_distrib = function(parameters, dag) {
864 -
      parameters <- match_batches(parameters)
865 872
      parameters$size <- tf_flatten(parameters$size)
866 873
      distrib <- tfp$distributions$DirichletMultinomial
867 874
      distrib(total_count = parameters$size,
868 875
              concentration = parameters$alpha)
869 -
    },
870 -
    # End Exclude Linting
871 -
872 -
    # no CDF for multivariate distributions
873 -
    tf_cdf_function = NULL,
874 -
    tf_log_cdf_function = NULL
876 +
    }
877 +
    # nolint end
875 878
876 879
  )
877 880
)
@@ -898,25 +901,21 @@
Loading
898 901
      # parameters
899 902
      super$initialize("multinomial",
900 903
                       dim = dim,
901 -
                       discrete = TRUE)
902 -
      self$add_parameter(size, "size", expand_scalar_to = NULL)
904 +
                       discrete = TRUE,
905 +
                       multivariate = TRUE)
906 +
      self$add_parameter(size, "size", shape_matches_output = FALSE)
903 907
      self$add_parameter(prob, "prob")
904 908
905 909
    },
906 910
907 911
    tf_distrib = function(parameters, dag) {
908 -
      parameters <- match_batches(parameters)
909 912
      parameters$size <- tf_flatten(parameters$size)
910 913
      # scale probs to get absolute density correct
911 914
      parameters$prob <- parameters$prob / tf_sum(parameters$prob)
912 915
913 916
      tfp$distributions$Multinomial(total_count = parameters$size,
914 917
                                    probs = parameters$prob)
915 -
    },
916 -
917 -
    # no CDF for multivariate distributions
918 -
    tf_cdf_function = NULL,
919 -
    tf_log_cdf_function = NULL
918 +
    }
920 919
921 920
  )
922 921
)
@@ -937,7 +936,10 @@
Loading
937 936
938 937
      # coerce the parameter arguments to nodes and add as parents and
939 938
      # parameters
940 -
      super$initialize("categorical", dim = dim, discrete = TRUE)
939 +
      super$initialize("categorical",
940 +
                       dim = dim,
941 +
                       discrete = TRUE,
942 +
                       multivariate = TRUE)
941 943
      self$add_parameter(prob, "prob")
942 944
943 945
    },
@@ -948,11 +950,7 @@
Loading
948 950
      probs <- probs / tf_sum(probs)
949 951
      tfp$distributions$Multinomial(total_count = fl(1),
950 952
                                    probs = probs)
951 -
    },
952 -
953 -
    # no CDF for multivariate distributions
954 -
    tf_cdf_function = NULL,
955 -
    tf_log_cdf_function = NULL
953 +
    }
956 954
957 955
  )
958 956
)
@@ -963,9 +961,9 @@
Loading
963 961
  public = list(
964 962
965 963
    sigma_is_cholesky = FALSE,
966 -
    # Begin Exclude Linting
964 +
    # nolint start
967 965
    initialize = function(mean, Sigma, n_realisations, dimension) {
968 -
    # End Exclude Linting
966 +
    # nolint end
969 967
      # coerce to greta arrays
970 968
      mean <- as.greta_array(mean)
971 969
      sigma <- as.greta_array(Sigma)
@@ -1001,7 +999,7 @@
Loading
1001 999
1002 1000
      # coerce the parameter arguments to nodes and add as parents and
1003 1001
      # parameters
1004 -
      super$initialize("multivariate_normal", dim)
1002 +
      super$initialize("multivariate_normal", dim, multivariate = TRUE)
1005 1003
1006 1004
      if (has_representation(sigma, "cholesky")) {
1007 1005
        sigma <- representation(sigma, "cholesky")
@@ -1028,15 +1026,11 @@
Loading
1028 1026
      l <- tf$expand_dims(l, 1L)
1029 1027
1030 1028
      mu <- parameters$mean
1031 -
      # Begin Exclude Linting
1029 +
      # nolint start
1032 1030
      tfp$distributions$MultivariateNormalTriL(loc = mu,
1033 1031
                                               scale_tril = l)
1034 -
      # End Exclude Linting
1035 -
    },
1036 -
1037 -
    # no CDF for multivariate distributions
1038 -
    tf_cdf_function = NULL,
1039 -
    tf_log_cdf_function = NULL
1032 +
      # nolint end
1033 +
    }
1040 1034
1041 1035
  )
1042 1036
)
@@ -1052,7 +1046,7 @@
Loading
1052 1046
    # set when defining the graph
1053 1047
    target_is_cholesky = FALSE,
1054 1048
1055 -
    initialize = function(df, Sigma) {  # Exclude Linting
1049 +
    initialize = function(df, Sigma) {  # nolint
1056 1050
      # add the nodes as parents and parameters
1057 1051
1058 1052
      df <- as.greta_array(df)
@@ -1071,14 +1065,14 @@
Loading
1071 1065
      dim <- nrow(sigma)
1072 1066
1073 1067
      # initialize with a cholesky factor
1074 -
      super$initialize("wishart", dim(sigma))
1068 +
      super$initialize("wishart", dim(sigma), multivariate = TRUE)
1075 1069
1076 1070
      # set parameters
1077 1071
      if (has_representation(sigma, "cholesky")) {
1078 1072
        sigma <- representation(sigma, "cholesky")
1079 1073
        self$sigma_is_cholesky <- TRUE
1080 1074
      }
1081 -
      self$add_parameter(df, "df", expand_scalar_to = NULL)
1075 +
      self$add_parameter(df, "df", shape_matches_output = FALSE)
1082 1076
      self$add_parameter(sigma, "sigma")
1083 1077
1084 1078
      # make the initial value PD (no idea whether this does anything)
@@ -1090,16 +1084,11 @@
Loading
1090 1084
    # factor representation)
1091 1085
    create_target = function(truncation) {
1092 1086
1093 -
      # create a flat variable greta array
1094 -
      k <- self$dim[1]
1095 -
      free_greta_array <- vble(truncation = c(-Inf, Inf),
1096 -
                               dim = k + k * (k - 1) / 2)
1097 -
      free_greta_array$constraint <- "covariance_matrix"
1087 +
      # create cholesky factor variable greta array
1088 +
      chol_greta_array <- cholesky_variable(self$dim[1])
1098 1089
1099 -
      # reshape to a cholesky factor and then to a symmetric matrix (which
1100 -
      # retains the cholesky representation)
1101 -
      chol_greta_array <- flat_to_chol(free_greta_array, self$dim)
1102 -
      matrix_greta_array <- chol_to_symmetric(chol_greta_array)
1090 +
      # reshape to a symmetric matrix (retaining cholesky representation)
1091 +
      matrix_greta_array <- chol2symm(chol_greta_array)
1103 1092
1104 1093
      # return the node for the symmetric matrix
1105 1094
      target_node <- get_node(matrix_greta_array)
@@ -1159,13 +1148,35 @@
Loading
1159 1148
1160 1149
      }
1161 1150
1162 -
      list(log_prob = log_prob, cdf = NULL, log_cdf = NULL)
1151 +
      sample <- function(seed) {
1163 1152
1164 -
    },
1153 +
        df <- tf$squeeze(parameters$df, 1:2)
1154 +
        sigma <- parameters$sigma
1165 1155
1166 -
    # no CDF for multivariate distributions
1167 -
    tf_cdf_function = NULL,
1168 -
    tf_log_cdf_function = NULL
1156 +
        # get the cholesky factor of Sigma in tf orientation
1157 +
        if (self$sigma_is_cholesky) {
1158 +
          sigma_chol <- tf$linalg$matrix_transpose(sigma)
1159 +
        } else {
1160 +
          sigma_chol <- tf$linalg$cholesky(sigma)
1161 +
        }
1162 +
1163 +
        # use the density for choleskied x, with choleskied Sigma
1164 +
        distrib <- tfp$distributions$Wishart(df = df,
1165 +
                                             scale_tril = sigma_chol)
1166 +
1167 +
        draws <- distrib$sample(seed = seed)
1168 +
1169 +
        if (self$target_is_cholesky) {
1170 +
          draws <- tf_chol(draws)
1171 +
        }
1172 +
1173 +
        draws
1174 +
1175 +
      }
1176 +
1177 +
      list(log_prob = log_prob, sample = sample)
1178 +
1179 +
    }
1169 1180
1170 1181
  )
1171 1182
)
@@ -1203,10 +1214,10 @@
Loading
1203 1214
      }
1204 1215
1205 1216
      dim <- c(dimension, dimension)
1206 -
      super$initialize("lkj_correlation", dim)
1217 +
      super$initialize("lkj_correlation", dim, multivariate = TRUE)
1207 1218
1208 1219
      # don't try to expand scalar eta out to match the target size
1209 -
      self$add_parameter(eta, "eta", expand_scalar_to = NULL)
1220 +
      self$add_parameter(eta, "eta", shape_matches_output = FALSE)
1210 1221
1211 1222
      # make the initial value PD
1212 1223
      self$value(unknowns(dims = dim, data = diag(dimension)))
@@ -1216,18 +1227,11 @@
Loading
1216 1227
    # default (cholesky factor, ignores truncation)
1217 1228
    create_target = function(truncation) {
1218 1229
1219 -
      # handle reshaping via a greta array
1220 -
      k <- self$dim[1]
1221 -
      free_greta_array <- vble(truncation = c(-Inf, Inf),
1222 -
                               dim = k * (k - 1) / 2)
1223 -
      free_greta_array$constraint <- "correlation_matrix"
1230 +
      # create (correlation matrix) cholesky factor variable greta array
1231 +
      chol_greta_array <- cholesky_variable(self$dim[1], correlation = TRUE)
1224 1232
1225 -
      # reshape to a cholesky factor and then to a symmetric correlation matrix
1226 -
      # (which retains the cholesky representation)
1227 -
      chol_greta_array <- flat_to_chol(free_greta_array,
1228 -
                                       self$dim,
1229 -
                                       correl = TRUE)
1230 -
      matrix_greta_array <- chol_to_symmetric(chol_greta_array)
1233 +
      # reshape to a symmetric matrix (retaining cholesky representation)
1234 +
      matrix_greta_array <- chol2symm(chol_greta_array)
1231 1235
1232 1236
      # return the node for the symmetric matrix
1233 1237
      target_node <- get_node(matrix_greta_array)
@@ -1254,40 +1258,40 @@
Loading
1254 1258
1255 1259
    tf_distrib = function(parameters, dag) {
1256 1260
1257 -
      eta <- parameters$eta
1261 +
      eta <- tf$squeeze(parameters$eta, 1:2)
1262 +
      dim <- self$dim[1]
1258 1263
1259 -
      log_prob <- function(x) {
1264 +
      distrib <- tfp$distributions$LKJ(
1265 +
        dimension = dim,
1266 +
        concentration = eta,
1267 +
        input_output_cholesky = self$target_is_cholesky
1268 +
      )
1260 1269
1261 -
        n <- self$dim[1]
1270 +
      # tfp's lkj sampling can't detect the size of the output from eta, for
1271 +
      # some reason. But we can use map_fun to apply their simulation to each
1272 +
      # element of eta.
1273 +
      sample <- function(seed) {
1262 1274
1263 -
        # normalising constant
1264 -
        k <- 1:n
1265 -
        a <- fl(1 - n) * tf$math$lgamma(eta + fl(0.5 * (n - 1)))
1266 -
        b <- tf_sum(fl(0.5 * k * log(pi)) +
1267 -
                      tf$math$lgamma(eta + fl(0.5 * (n - 1 - k))))
1268 -
        norm <- a + b
1275 +
        sample_once <- function(eta) {
1269 1276
1270 -
        # get the cholesky factor of the target in tf_orientation
1271 -
        if (self$target_is_cholesky) {
1272 -
          x_chol <- tf$linalg$matrix_transpose(x)
1273 -
        } else {
1274 -
          x_chol <- tf$linalg$cholesky(x)
1275 -
        }
1277 +
          d <- tfp$distributions$LKJ(
1278 +
            dimension = dim,
1279 +
            concentration = eta,
1280 +
            input_output_cholesky = self$target_is_cholesky
1281 +
          )
1276 1282
1277 -
        diags <- tf$linalg$diag_part(x_chol)
1278 -
        det <- tf$square(tf_prod(diags))
1283 +
          d$sample(seed = seed)
1279 1284
1280 -
        (eta - fl(1)) * tf$math$log(det) + norm
1285 +
        }
1286 +
1287 +
        tf$map_fn(sample_once, eta)
1281 1288
1282 1289
      }
1283 1290
1284 -
      list(log_prob = log_prob, cdf = NULL, log_cdf = NULL)
1291 +
      list(log_prob = distrib$log_prob,
1292 +
           sample = sample)
1285 1293
1286 -
    },
1287 -
1288 -
    # no CDF for multivariate distributions
1289 -
    tf_cdf_function = NULL,
1290 -
    tf_log_cdf_function = NULL
1294 +
    }
1291 1295
1292 1296
  )
1293 1297
)
@@ -1324,37 +1328,37 @@
Loading
1324 1328
1325 1329
# export constructors
1326 1330
1327 -
# Begin Exclude Linting
1331 +
# nolint start
1328 1332
#' @name distributions
1329 1333
#' @title probability distributions
1330 1334
#' @description These functions can be used to define random variables in a
1331 1335
#'   greta model. They return a variable greta array that follows the specified
1332 1336
#'   distribution. This variable greta array can be used to represent a
1333 1337
#'   parameter with prior distribution, combined into a mixture distribution
1334 -
#'   using \code{\link{mixture}}, or used with \code{\link{distribution}} to
1338 +
#'   using [mixture()], or used with [distribution()] to
1335 1339
#'   define a distribution over a data greta array.
1336 1340
#'
1337 1341
#' @param truncation a length-two vector giving values between which to truncate
1338 -
#'   the distribution, similarly to the \code{lower} and \code{upper} arguments
1339 -
#'   to \code{\link{variable}}
1342 +
#'   the distribution, similarly to the `lower` and `upper` arguments
1343 +
#'   to [variable()]
1340 1344
#'
1341 -
#' @param min,max scalar values giving optional limits to \code{uniform}
1342 -
#'   variables. Like \code{lower} and \code{upper}, these must be specified as
1345 +
#' @param min,max scalar values giving optional limits to `uniform`
1346 +
#'   variables. Like `lower` and `upper`, these must be specified as
1343 1347
#'   numerics, they cannot be greta arrays (though see details for a
1344 -
#'   workaround). Unlike \code{lower} and \code{upper}, they must be finite.
1345 -
#'   \code{min} must always be less than \code{max}.
1348 +
#'   workaround). Unlike `lower` and `upper`, they must be finite.
1349 +
#'   `min` must always be less than `max`.
1346 1350
#'
1347 1351
#' @param mean,meanlog,location,mu unconstrained parameters
1348 1352
#'
1349 1353
#' @param
1350 1354
#'   sd,sdlog,sigma,lambda,shape,rate,df,scale,shape1,shape2,alpha,beta,df1,df2,a,b,eta
1351 -
#'    positive parameters, \code{alpha} must be a vector for \code{dirichlet}
1352 -
#'   and \code{dirichlet_multinomial}.
1355 +
#'    positive parameters, `alpha` must be a vector for `dirichlet`
1356 +
#'   and `dirichlet_multinomial`.
1353 1357
#'
1354 1358
#' @param size,m,n,k positive integer parameter
1355 1359
#'
1356 -
#' @param prob probability parameter (\code{0 < prob < 1}), must be a vector for
1357 -
#'   \code{multinomial} and \code{categorical}
1360 +
#' @param prob probability parameter (`0 < prob < 1`), must be a vector for
1361 +
#'   `multinomial` and `categorical`
1358 1362
#'
1359 1363
#' @param Sigma positive definite variance-covariance matrix parameter
1360 1364
#'
@@ -1366,77 +1370,77 @@
Loading
1366 1370
#' @param n_realisations the number of independent realisation of a multivariate
1367 1371
#'   distribution
1368 1372
#'
1369 -
#' @details The discrete probability distributions (\code{bernoulli},
1370 -
#'   \code{binomial}, \code{negative_binomial}, \code{poisson},
1371 -
#'   \code{multinomial}, \code{categorical}, \code{dirichlet_multinomial}) can
1373 +
#' @details The discrete probability distributions (`bernoulli`,
1374 +
#'   `binomial`, `negative_binomial`, `poisson`,
1375 +
#'   `multinomial`, `categorical`, `dirichlet_multinomial`) can
1372 1376
#'   be used when they have fixed values (e.g. defined as a likelihood using
1373 -
#'   \code{\link{distribution}}, but not as unknown variables.
1377 +
#'   [distribution()], but not as unknown variables.
1374 1378
#'
1375 -
#'   For univariate distributions \code{dim} gives the dimensions of the greta
1379 +
#'   For univariate distributions `dim` gives the dimensions of the greta
1376 1380
#'   array to create. Each element of the greta array will be (independently)
1377 -
#'   distributed according to the distribution. \code{dim} can also be left at
1378 -
#'   its default of \code{NULL}, in which case the dimension will be detected
1381 +
#'   distributed according to the distribution. `dim` can also be left at
1382 +
#'   its default of `NULL`, in which case the dimension will be detected
1379 1383
#'   from the dimensions of the parameters (provided they are compatible with
1380 1384
#'   one another).
1381 1385
#'
1382 -
#'   For multivariate distributions (\code{multivariate_normal()},
1383 -
#'   \code{multinomial()}, \code{categorical()}, \code{dirichlet()}, and
1384 -
#'   \code{dirichlet_multinomial()}) each row of the output and parameters
1386 +
#'   For multivariate distributions (`multivariate_normal()`,
1387 +
#'   `multinomial()`, `categorical()`, `dirichlet()`, and
1388 +
#'   `dirichlet_multinomial()`) each row of the output and parameters
1385 1389
#'   corresponds to an independent realisation. If a single realisation or
1386 1390
#'   parameter value is specified, it must therefore be a row vector (see
1387 -
#'   example). \code{n_realisations} gives the number of rows/realisations, and
1388 -
#'   \code{dimension} gives the dimension of the distribution. I.e. a bivariate
1389 -
#'   normal distribution would be produced with \code{multivariate_normal(...,
1390 -
#'   dimension = 2)}. The dimension can usually be detected from the parameters.
1391 +
#'   example). `n_realisations` gives the number of rows/realisations, and
1392 +
#'   `dimension` gives the dimension of the distribution. I.e. a bivariate
1393 +
#'   normal distribution would be produced with `multivariate_normal(...,
1394 +
#'   dimension = 2)`. The dimension can usually be detected from the parameters.
1391 1395
#'
1392 -
#'   \code{multinomial()} does not check that observed values sum to
1393 -
#'   \code{size}, and \code{categorical()} does not check that only one of the
1396 +
#'   `multinomial()` does not check that observed values sum to
1397 +
#'   `size`, and `categorical()` does not check that only one of the
1394 1398
#'   observed entries is 1. It's the user's responsibility to check their data
1395 1399
#'   matches the distribution!
1396 1400
#'
1397 -
#'   The parameters of \code{uniform} must be fixed, not greta arrays. This
1401 +
#'   The parameters of `uniform` must be fixed, not greta arrays. This
1398 1402
#'   ensures these values can always be transformed to a continuous scale to run
1399 -
#'   the samplers efficiently. However, a hierarchical \code{uniform} parameter
1400 -
#'   can always be created by defining a \code{uniform} variable constrained
1403 +
#'   the samplers efficiently. However, a hierarchical `uniform` parameter
1404 +
#'   can always be created by defining a `uniform` variable constrained
1401 1405
#'   between 0 and 1, and then transforming it to the required scale. See below
1402 1406
#'   for an example.
1403 1407
#'
1404 1408
#'   Wherever possible, the parameterisations and argument names of greta
1405 1409
#'   distributions match commonly used R functions for distributions, such as
1406 -
#'   those in the \code{stats} or \code{extraDistr} packages. The following
1410 +
#'   those in the `stats` or `extraDistr` packages. The following
1407 1411
#'   table states the distribution function to which greta's implementation
1408 1412
#'   corresponds:
1409 1413
#'
1410 -
#'   \tabular{ll}{ greta \tab reference\cr \code{uniform} \tab
1411 -
#'   \link[stats:dunif]{stats::dunif}\cr \code{normal} \tab
1412 -
#'   \link[stats:dnorm]{stats::dnorm}\cr \code{lognormal} \tab
1413 -
#'   \link[stats:dlnorm]{stats::dlnorm}\cr \code{bernoulli} \tab
1414 -
#'   \link[extraDistr:dbern]{extraDistr::dbern}\cr \code{binomial} \tab
1415 -
#'   \link[stats:dbinom]{stats::dbinom}\cr \code{beta_binomial} \tab
1416 -
#'   \link[extraDistr:dbbinom]{extraDistr::dbbinom}\cr \code{negative_binomial}
1417 -
#'   \tab \link[stats:dnbinom]{stats::dnbinom}\cr \code{hypergeometric} \tab
1418 -
#'   \link[stats:dhyper]{stats::dhyper}\cr \code{poisson} \tab
1419 -
#'   \link[stats:dpois]{stats::dpois}\cr \code{gamma} \tab
1420 -
#'   \link[stats:dgamma]{stats::dgamma}\cr \code{inverse_gamma} \tab
1421 -
#'   \link[extraDistr:dinvgamma]{extraDistr::dinvgamma}\cr \code{weibull} \tab
1422 -
#'   \link[stats:dweibull]{stats::dweibull}\cr \code{exponential} \tab
1423 -
#'   \link[stats:dexp]{stats::dexp}\cr \code{pareto} \tab
1424 -
#'   \link[extraDistr:dpareto]{extraDistr::dpareto}\cr \code{student} \tab
1425 -
#'   \link[extraDistr:dlst]{extraDistr::dlst}\cr \code{laplace} \tab
1426 -
#'   \link[extraDistr:dlaplace]{extraDistr::dlaplace}\cr \code{beta} \tab
1427 -
#'   \link[stats:dbeta]{stats::dbeta}\cr \code{cauchy} \tab
1428 -
#'   \link[stats:dcauchy]{stats::dcauchy}\cr \code{chi_squared} \tab
1429 -
#'   \link[stats:dchisq]{stats::dchisq}\cr \code{logistic} \tab
1430 -
#'   \link[stats:dlogis]{stats::dlogis}\cr \code{f} \tab
1431 -
#'   \link[stats:df]{stats::df}\cr \code{multivariate_normal} \tab
1432 -
#'   \link[mvtnorm:dmvnorm]{mvtnorm::dmvnorm}\cr \code{multinomial} \tab
1433 -
#'   \link[stats:dmultinom]{stats::dmultinom}\cr \code{categorical} \tab
1434 -
#'   {\link[stats:dmultinom]{stats::dmultinom} (size = 1)}\cr \code{dirichlet}
1435 -
#'   \tab \link[extraDistr:ddirichlet]{extraDistr::ddirichlet}\cr
1436 -
#'   \code{dirichlet_multinomial} \tab
1437 -
#'   \link[extraDistr:ddirmnom]{extraDistr::ddirmnom}\cr \code{wishart} \tab
1438 -
#'   \link[stats:rWishart]{stats::rWishart}\cr \code{lkj_correlation} \tab
1439 -
#'   \href{https://rdrr.io/github/rmcelreath/rethinking/man/dlkjcorr.html}{rethinking::dlkjcorr}
1414 +
#'   \tabular{ll}{ greta \tab reference\cr `uniform` \tab
1415 +
#'   [stats::dunif]\cr `normal` \tab
1416 +
#'   [stats::dnorm]\cr `lognormal` \tab
1417 +
#'   [stats::dlnorm]\cr `bernoulli` \tab
1418 +
#'   [extraDistr::dbern]\cr `binomial` \tab
1419 +
#'   [stats::dbinom]\cr `beta_binomial` \tab
1420 +
#'   [extraDistr::dbbinom]\cr `negative_binomial`
1421 +
#'   \tab [stats::dnbinom]\cr `hypergeometric` \tab
1422 +
#'   [stats::dhyper]\cr `poisson` \tab
1423 +
#'   [stats::dpois]\cr `gamma` \tab
1424 +
#'   [stats::dgamma]\cr `inverse_gamma` \tab
1425 +
#'   [extraDistr::dinvgamma]\cr `weibull` \tab
1426 +
#'   [stats::dweibull]\cr `exponential` \tab
1427 +
#'   [stats::dexp]\cr `pareto` \tab
1428 +
#'   [extraDistr::dpareto]\cr `student` \tab
1429 +
#'   [extraDistr::dlst]\cr `laplace` \tab
1430 +
#'   [extraDistr::dlaplace]\cr `beta` \tab
1431 +
#'   [stats::dbeta]\cr `cauchy` \tab
1432 +
#'   [stats::dcauchy]\cr `chi_squared` \tab
1433 +
#'   [stats::dchisq]\cr `logistic` \tab
1434 +
#'   [stats::dlogis]\cr `f` \tab
1435 +
#'   [stats::df]\cr `multivariate_normal` \tab
1436 +
#'   [mvtnorm::dmvnorm]\cr `multinomial` \tab
1437 +
#'   [stats::dmultinom]\cr `categorical` \tab
1438 +
#'   {[stats::dmultinom] (size = 1)}\cr `dirichlet`
1439 +
#'   \tab [extraDistr::ddirichlet]\cr
1440 +
#'   `dirichlet_multinomial` \tab
1441 +
#'   [extraDistr::ddirmnom]\cr `wishart` \tab
1442 +
#'   [stats::rWishart]\cr `lkj_correlation` \tab
1443 +
#'   [rethinking::dlkjcorr](https://rdrr.io/github/rmcelreath/rethinking/man/dlkjcorr.html)
1440 1444
#'   }
1441 1445
#'
1442 1446
#' @examples
@@ -1486,7 +1490,7 @@
Loading
1486 1490
#'
1487 1491
#' }
1488 1492
NULL
1489 -
# End Exclude Linting
1493 +
# nolint end
1490 1494
1491 1495
#' @rdname distributions
1492 1496
#' @export
@@ -1597,19 +1601,19 @@
Loading
1597 1601
f <- function(df1, df2, dim = NULL, truncation = c(0, Inf))
1598 1602
  distrib("f", df1, df2, dim, truncation)
1599 1603
1600 -
# Begin Exclude Linting
1604 +
# nolint start
1601 1605
#' @rdname distributions
1602 1606
#' @export
1603 1607
multivariate_normal <- function(mean, Sigma,
1604 1608
                                n_realisations = NULL, dimension = NULL) {
1605 -
# End Exclude Linting
1609 +
# nolint end
1606 1610
  distrib("multivariate_normal", mean, Sigma,
1607 1611
          n_realisations, dimension)
1608 1612
}
1609 1613
1610 1614
#' @rdname distributions
1611 1615
#' @export
1612 -
wishart <- function(df, Sigma)  # Exclude Linting
1616 +
wishart <- function(df, Sigma)  # nolint
1613 1617
  distrib("wishart", df, Sigma)
1614 1618
1615 1619
#' @rdname distributions

@@ -1,12 +1,13 @@
Loading
1 1
#' @name joint
2 2
#' @title define joint distributions
3 3
#'
4 -
#' @description \code{joint} combines univariate probability distributions
5 -
#'   together into a multivariate (and \emph{a priori} independent between
4 +
#' @description `joint` combines univariate probability distributions
5 +
#'   together into a multivariate (and *a priori* independent between
6 6
#'   dimensions) joint distribution, either over a variable, or for fixed data.
7 7
#'
8 -
#' @param ... variable greta arrays following probability distributions (see
9 -
#'   \code{\link{distributions}}); the components of the joint distribution.
8 +
#' @param ... scalar variable greta arrays following probability distributions
9 +
#'   (see [distributions()]); the components of the joint
10 +
#'   distribution.
10 11
#'
11 12
#' @param dim the dimensions of the greta array to be returned, either a scalar
12 13
#'   or a vector of positive integers. The final dimension of the greta array
@@ -19,7 +20,7 @@
Loading
19 20
#'   result can usually be achieved by combining variables with separate
20 21
#'   distributions. It is included for situations where it is more convenient to
21 22
#'   consider these as a single distribution, e.g. for use with
22 -
#'   \code{distribution} or \code{mixture}.
23 +
#'   `distribution` or `mixture`.
23 24
#'
24 25
#' @export
25 26
#' @examples
@@ -58,9 +59,10 @@
Loading
58 59
      }
59 60
60 61
      # check the dimensions of the variables in dots
61 -
      dim <- do.call(check_dims, c(dots, target_dim = dim))
62 +
      single_dim <- do.call(check_dims, c(dots, target_dim = dim))
62 63
63 64
      # add the joint dimension as the last dimension
65 +
      dim <- single_dim
64 66
      ndim <- length(dim)
65 67
      if (dim[ndim] == 1) {
66 68
        dim[ndim] <- n_distributions
@@ -70,6 +72,13 @@
Loading
70 72
71 73
      dot_nodes <- lapply(dots, get_node)
72 74
75 +
      # check they are all scalar
76 +
      are_scalar <- vapply(dot_nodes, is_scalar, logical(1))
77 +
      if (!all(are_scalar)) {
78 +
        stop("joint only accepts probability distributions over scalars",
79 +
              call. = FALSE)
80 +
      }
81 +
73 82
      # get the distributions and strip away their variables
74 83
      distribs <- lapply(dot_nodes, member, "distribution")
75 84
      lapply(distribs, function(x) x$remove_target())
@@ -85,47 +94,78 @@
Loading
85 94
             "of discrete and continuous distributions",
86 95
             call. = FALSE)
87 96
      }
97 +
      n_components <- length(dot_nodes)
98 +
99 +
      # work out the support of the resulting distribution, and add as the
100 +
      # bounds of this one, to use when creating target variable
101 +
      lower <- lapply(dot_nodes, member, "lower")
102 +
      lower <- lapply(lower, array, dim = single_dim)
103 +
      upper <- lapply(dot_nodes, member, "upper")
104 +
      upper <- lapply(upper, array, dim = single_dim)
88 105
89 -
      # for any discrete ones, tell them they are fixed
106 +
      self$bounds <- list(
107 +
        lower = do.call(abind::abind, lower),
108 +
        upper = do.call(abind::abind, upper)
109 +
      )
90 110
91 111
      super$initialize("joint", dim, discrete = discrete[1])
92 112
93 113
      for (i in seq_len(n_distributions)) {
94 114
        self$add_parameter(distribs[[i]],
95 115
                           paste("distribution", i),
96 -
                           expand_scalar_to = NULL)
116 +
                           shape_matches_output = FALSE)
97 117
      }
98 118
99 119
    },
100 120
121 +
    create_target = function(truncation) {
122 +
      vble(self$bounds, dim = self$dim)
123 +
    },
124 +
101 125
    tf_distrib = function(parameters, dag) {
102 126
103 -
      densities <- parameters
104 -
      names(densities) <- NULL
127 +
      # get information from the *nodes* for component distributions, not the tf
128 +
      # objects passed in here
129 +
130 +
      # get tfp distributions, truncations, & bounds of component distributions
131 +
      distribution_nodes <- self$parameters
132 +
      truncations <- lapply(distribution_nodes, member, "truncation")
133 +
      bounds <- lapply(distribution_nodes, member, "bounds")
134 +
      tfp_distributions <- lapply(distribution_nodes, dag$get_tfp_distribution)
135 +
      names(tfp_distributions) <- NULL
105 136
106 137
      log_prob <- function(x) {
107 138
108 139
        # split x on the joint dimension, and loop through computing the
109 140
        # densities
110 141
        last_dim <- length(dim(x)) - 1L
111 -
        x_vals <- tf$split(x, length(densities), axis = last_dim)
112 -
        log_probs <- list()
113 -
        for (i in seq_along(densities)) {
114 -
          log_probs[[i]] <- densities[[i]](x_vals[[i]])
115 -
        }
142 +
        x_vals <- tf$split(x, length(tfp_distributions), axis = last_dim)
143 +
144 +
        log_probs <- mapply(
145 +
          dag$tf_evaluate_density,
146 +
          tfp_distributions,
147 +
          x_vals,
148 +
          truncations,
149 +
          bounds,
150 +
          SIMPLIFY = FALSE
151 +
        )
116 152
117 153
        # sum them elementwise
118 154
        tf$add_n(log_probs)
119 155
120 156
      }
121 157
122 -
      list(log_prob = log_prob, cdf = NULL, log_cdf = NULL)
158 +
      sample <- function(seed) {
123 159
124 -
    },
160 +
        samples <- lapply(distribution_nodes, dag$draw_sample)
161 +
        names(samples) <- NULL
162 +
        tf$concat(samples, axis = 2L)
163 +
164 +
      }
125 165
126 -
    tf_cdf_function = NULL,
127 -
    tf_log_cdf_function = NULL
166 +
      list(log_prob = log_prob, sample = sample)
128 167
168 +
    }
129 169
  )
130 170
)
131 171

@@ -1,11 +1,11 @@
Loading
1 1
#' @name mixture
2 2
#' @title mixtures of probability distributions
3 3
#'
4 -
#' @description \code{mixture} combines other probability distributions into a
4 +
#' @description `mixture` combines other probability distributions into a
5 5
#'   single mixture distribution, either over a variable, or for fixed data.
6 6
#'
7 7
#' @param ... variable greta arrays following probability distributions (see
8 -
#'   \code{\link{distributions}}); the component distributions in a mixture
8 +
#'   [distributions()]); the component distributions in a mixture
9 9
#'   distribution.
10 10
#'
11 11
#' @param weights a column vector or array of mixture weights, which must be
@@ -16,10 +16,10 @@
Loading
16 16
#' @param dim the dimensions of the greta array to be returned, either a scalar
17 17
#'   or a vector of positive integers.
18 18
#'
19 -
#' @details The \code{weights} are rescaled to sum to one along the first
19 +
#' @details The `weights` are rescaled to sum to one along the first
20 20
#'   dimension, and are then used as the mixing weights of the distribution.
21 21
#'   I.e. the probability density is calculated as a weighted sum of the
22 -
#'   component probability distributions passed in via \code{\dots}
22 +
#'   component probability distributions passed in via `\dots`
23 23
#'
24 24
#'   The component probability distributions must all be either continuous or
25 25
#'   discrete, and must have the same dimensions.
@@ -134,7 +134,7 @@
Loading
134 134
      lapply(dot_nodes, function(x) x$distribution <- NULL)
135 135
136 136
      # check the distributions are all either discrete or continuous
137 -
      discrete <- vapply(distribs, member, "discrete", FUN.VALUE = FALSE)
137 +
      discrete <- vapply(distribs, member, "discrete", FUN.VALUE = logical(1))
138 138
139 139
      if (!all(discrete) & !all(!discrete)) {
140 140
        stop("cannot construct a mixture from a combination of discrete ",
@@ -142,22 +142,84 @@
Loading
142 142
             call. = FALSE)
143 143
      }
144 144
145 +
      # check the distributions are all either multivariate or univariate
146 +
      multivariate <- vapply(distribs,
147 +
                             member,
148 +
                             "multivariate",
149 +
                             FUN.VALUE = logical(1))
150 +
151 +
      if (!all(multivariate) & !all(!multivariate)) {
152 +
        stop("cannot construct a mixture from a combination of multivariate ",
153 +
             "and univariate distributions",
154 +
             call. = FALSE)
155 +
      }
156 +
157 +
      # ensure the support and bounds of each of the distributions is the same
158 +
      truncations <- lapply(distribs, member, "truncation")
159 +
      bounds <- lapply(distribs, member, "bounds")
160 +
161 +
      truncated <- !vapply(truncations, is.null, logical(1))
162 +
      supports <- bounds
163 +
      supports[truncated] <- truncations[truncated]
164 +
165 +
      n_supports <- length(unique(supports))
166 +
      if (n_supports != 1) {
167 +
        supports_text <- vapply(
168 +
          X = unique(supports),
169 +
          FUN = paste,
170 +
          collapse = " to ",
171 +
          FUN.VALUE = character(1)
172 +
        )
173 +
        stop("component distributions have different support: ",
174 +
              paste(supports_text, collapse = " vs. "),
175 +
             call. = FALSE)
176 +
      }
177 +
178 +
      # get the maximal bounds for all component distributions
179 +
      bounds <- c(do.call(min, bounds),
180 +
                 do.call(max, bounds))
181 +
182 +
      # if the support is smaller than this, treat the distribution as truncated
183 +
      support <- supports[[1]]
184 +
      if (identical(support, bounds)) {
185 +
        truncation <- NULL
186 +
      } else {
187 +
        truncation <- support
188 +
      }
189 +
190 +
      self$bounds <- support
191 +
145 192
      # for any discrete ones, tell them they are fixed
146 -
      super$initialize("mixture", dim, discrete = discrete[1])
193 +
      super$initialize("mixture",
194 +
                       dim,
195 +
                       discrete = discrete[1],
196 +
                       multivariate = multivariate[1])
147 197
148 198
      for (i in seq_len(n_distributions)) {
149 199
        self$add_parameter(distribs[[i]],
150 200
                           paste("distribution", i),
151 -
                           expand_scalar_to = NULL)
201 +
                           shape_matches_output = FALSE)
152 202
      }
153 203
154 204
      self$add_parameter(weights, "weights")
155 205
    },
156 206
207 +
    create_target = function(truncation) {
208 +
      vble(self$bounds, dim = self$dim)
209 +
    },
210 +
157 211
    tf_distrib = function(parameters, dag) {
158 212
159 -
      densities <- parameters[names(parameters) != "weights"]
160 -
      names(densities) <- NULL
213 +
      # get information from the *nodes* for component distributions, not the tf
214 +
      # objects passed in here
215 +
216 +
      # get tfp distributions, truncations, & bounds of component distributions
217 +
      distribution_nodes <- self$parameters[names(self$parameters) != "weights"]
218 +
      truncations <- lapply(distribution_nodes, member, "truncation")
219 +
      bounds <- lapply(distribution_nodes, member, "bounds")
220 +
      tfp_distributions <- lapply(distribution_nodes, dag$get_tfp_distribution)
221 +
      names(tfp_distributions) <- NULL
222 +
161 223
      weights <- parameters$weights
162 224
163 225
      # use log weights if available
@@ -178,7 +240,14 @@
Loading
178 240
      log_prob <- function(x) {
179 241
180 242
        # get component densities in an array
181 -
        log_probs <- lapply(densities, do.call, list(x))
243 +
        log_probs <- mapply(
244 +
          dag$tf_evaluate_density,
245 +
          tfp_distribution = tfp_distributions,
246 +
          truncation = truncations,
247 +
          bounds = bounds,
248 +
          MoreArgs = list(tf_target = x),
249 +
          SIMPLIFY = FALSE
250 +
        )
182 251
        log_probs_arr <- tf$stack(log_probs, 1L)
183 252
184 253
        # massage log_weights into the same shape as log_probs_arr
@@ -199,12 +268,55 @@
Loading
199 268
        tf$reduce_logsumexp(log_probs_weighted_arr, axis = 1L)
200 269
      }
201 270
202 -
      list(log_prob = log_prob, cdf = NULL, log_cdf = NULL)
271 +
      sample <- function(seed) {
203 272
204 -
    },
273 +
        # draw samples from each component
274 +
        samples <- lapply(distribution_nodes, dag$draw_sample)
275 +
        names(samples) <- NULL
276 +
277 +
        ndim <- length(self$dim)
278 +
279 +
        # in univariate case, tile log_weights to match dim, so each element can
280 +
        # be selected independently (otherwise each row is kept together)
281 +
        log_weights <- tf$squeeze(log_weights, 2L)
282 +
283 +
        if (!self$multivariate) {
284 +
285 +
          for (i in seq_len(ndim)) {
286 +
            log_weights <- tf$expand_dims(log_weights, 1L)
287 +
          }
288 +
          log_weights <- tf$tile(log_weights, c(1L, self$dim, 1L))
289 +
290 +
        }
291 +
292 +
        # for each observation, select a random component to sample from
293 +
        cat <- tfp$distributions$Categorical(logits = log_weights)
294 +
        indices <- cat$sample(seed = seed)
295 +
296 +
        # how many dimensions to consider a batch differs beetween multivariate
297 +
        # and univariate
298 +
        collapse_axis <- ndim + 1L
299 +
        n_batches <- ifelse(self$multivariate, 1L, collapse_axis)
300 +
301 +
        # combine the random components on an extra last dimension
302 +
        samples_padded <- lapply(samples, tf$expand_dims, axis = collapse_axis)
303 +
        samples_array <- tf$concat(samples_padded, axis = collapse_axis)
304 +
305 +
        # extract the relevant component
306 +
        indices <- tf$expand_dims(indices, n_batches)
307 +
        draws <- tf$gather(samples_array,
308 +
                           indices,
309 +
                           axis = collapse_axis,
310 +
                           batch_dims = n_batches)
311 +
        draws <- tf$squeeze(draws, collapse_axis)
312 +
313 +
        draws
314 +
315 +
      }
316 +
317 +
      list(log_prob = log_prob, sample = sample)
205 318
206 -
    tf_cdf_function = NULL,
207 -
    tf_log_cdf_function = NULL
319 +
    }
208 320
209 321
  )
210 322
)

@@ -1,6 +1,6 @@
Loading
1 1
# define a greta_array S3 class for the objects users manipulate
2 2
3 -
# Begin Exclude Linting
3 +
# nolint start
4 4
5 5