pyro/infer/trace_crps.py
changed.
Other files ignored by Codecov
docs/source/inference_algos.rst
has changed.
21 | 21 | ||
22 | 22 | class Trace_CRPS: |
|
23 | 23 | """ |
|
24 | - | Posterior predictive CRPS loss. |
|
24 | + | Posterior predictive CRPS loss with ``KL(q,p)`` regularization. |
|
25 | 25 | ||
26 | - | This is a likelihood-free method; no densities are evaluated. |
|
26 | + | This is a likelihood-free method, and can be used for likelihoods without |
|
27 | + | tractible density functions. CRPS is a robust loss function, and is well |
|
28 | + | defined for any distribution with finite absolute moment ``E[|data|]``. |
|
27 | 29 | ||
28 | - | :param num_particles: The number of particles/samples used to form the |
|
30 | + | This requires static model structure, fully reparametrized guide, and |
|
31 | + | reparametrized likelihood distributions in the model. Model latent |
|
32 | + | distributions may be non-reparametrized. |
|
33 | + | ||
34 | + | Note that in the loss ``CRPS + KL(q,p)``, the ``CRPS`` term has data units |
|
35 | + | whereas the ``KL(q,p)`` term has units of nats. To calibrate these units, |
|
36 | + | you can wrap likelihood sites in ``poutine.scale``. |
|
37 | + | ||
38 | + | References |
|
39 | + | [1] `Strictly Proper Scoring Rules, Prediction, and Estimation` |
|
40 | + | Tilmann Gneiting, Adrian E. Raftery (2007) |
|
41 | + | https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf |
|
42 | + | ||
43 | + | :param int num_particles: The number of particles/samples used to form the |
|
29 | 44 | gradient estimators. Must be at least 2. |
|
30 | 45 | :param int max_plate_nesting: Optional bound on max number of nested |
|
31 | 46 | :func:`pyro.plate` contexts. This is only required when enumerating |
37 | 52 | def __init__(self, |
|
38 | 53 | num_particles=2, |
|
39 | 54 | max_plate_nesting=float('inf')): |
|
40 | - | assert isinstance(num_particles, int) and num_particles >= 2 |
|
55 | + | if not isinstance(num_particles, int) and num_particles >= 2: |
|
56 | + | raise ValueError("Expected num_particles >= 2, actual {}".format(num_particles)) |
|
41 | 57 | self.num_particles = num_particles |
|
42 | 58 | self.vectorize_particles = True |
|
43 | 59 | self.max_plate_nesting = max_plate_nesting |
|
44 | 60 | ||
45 | 61 | def _get_traces(self, model, guide, *args, **kwargs): |
|
46 | 62 | if self.max_plate_nesting == float("inf"): |
|
63 | + | # TODO factor this out as a stand-alone helper. |
|
47 | 64 | ELBO._guess_max_plate_nesting(self, model, guide, *args, **kwargs) |
|
48 | 65 | vectorize = pyro.plate("num_particles_vectorized", self.num_particles, |
|
49 | - | dim=-1 - self.max_plate_nesting) |
|
66 | + | dim=-self.max_plate_nesting) |
|
50 | 67 | ||
51 | 68 | # Trace the guide as in ELBO. |
|
52 | 69 | with poutine.trace() as tr, vectorize: |
|
53 | 70 | guide(*args, **kwargs) |
|
54 | 71 | guide_trace = tr.trace |
|
55 | 72 | ||
56 | - | # Trace the model, saving obs in tr2 and posterior predictives in tr1. |
|
73 | + | # Trace the model, drawing posterior predictive samples. |
|
57 | 74 | with poutine.trace() as tr, poutine.uncondition(): |
|
58 | 75 | with poutine.replay(trace=guide_trace), vectorize: |
|
59 | 76 | model(*args, **kwargs) |
61 | 78 | for site in model_trace.nodes.values(): |
|
62 | 79 | if site["type"] == "sample" and site["infer"].get("was_observed", False): |
|
63 | 80 | site["is_observed"] = True |
|
64 | - | ||
65 | 81 | if is_validation_enabled(): |
|
66 | 82 | check_model_guide_match(model_trace, guide_trace, self.max_plate_nesting) |
|
67 | 83 |
69 | 85 | model_trace = prune_subsample_sites(model_trace) |
|
70 | 86 | guide_trace.compute_log_prob() |
|
71 | 87 | model_trace.compute_log_prob(site_filter=lambda name, site: not site["is_observed"]) |
|
72 | - | ||
73 | 88 | if is_validation_enabled(): |
|
74 | 89 | for site in guide_trace.nodes.values(): |
|
75 | 90 | if site["type"] == "sample": |
87 | 102 | return guide_trace, model_trace |
|
88 | 103 | ||
89 | 104 | def __call__(self, model, guide, *args, **kwargs): |
|
105 | + | """ |
|
106 | + | Computes the surrogate loss that can be differentiated with autograd |
|
107 | + | to produce gradient estimates for the model and guide parameters. |
|
108 | + | """ |
|
90 | 109 | guide_trace, model_trace = self._get_traces(model, guide, *args, **kwargs) |
|
91 | 110 | ||
92 | 111 | # Extract observations and posterior predictive samples. |
143 | 162 | return loss |
|
144 | 163 | ||
145 | 164 | def loss(self, *args, **kwargs): |
|
165 | + | """ |
|
166 | + | Not implemented. Added for compatibility with unit tests only. |
|
167 | + | """ |
|
146 | 168 | raise NotImplementedError("Trace_CRPS implements only surrogate loss") |
Files | Coverage |
---|---|
pyro | 93.97% |
Project Totals (203 files) | 93.97% |
3.5=.5 TRAVIS_OS_NAME=linux
3.5=.5 TRAVIS_OS_NAME=linux
3.5=.5 TRAVIS_OS_NAME=linux
3.5=.5 TRAVIS_OS_NAME=linux