pyro-ppl / pyro
Showing 1 of 2 files from the diff.

@@ -21,11 +21,26 @@
Loading
21 21
22 22
class Trace_CRPS:
23 23
    """
24 -
    Posterior predictive CRPS loss.
24 +
    Posterior predictive CRPS loss with ``KL(q,p)`` regularization.
25 25
26 -
    This is a likelihood-free method; no densities are evaluated.
26 +
    This is a likelihood-free method, and can be used for likelihoods without
27 +
    tractible density functions. CRPS is a robust loss function, and is well
28 +
    defined for any distribution with finite absolute moment ``E[|data|]``.
27 29
28 -
    :param num_particles: The number of particles/samples used to form the
30 +
    This requires static model structure, fully reparametrized guide, and
31 +
    reparametrized likelihood distributions in the model. Model latent
32 +
    distributions may be non-reparametrized.
33 +
34 +
    Note that in the loss ``CRPS + KL(q,p)``, the ``CRPS`` term has data units
35 +
    whereas the ``KL(q,p)`` term has units of nats.  To calibrate these units,
36 +
    you can wrap likelihood sites in ``poutine.scale``.
37 +
38 +
    References
39 +
    [1] `Strictly Proper Scoring Rules, Prediction, and Estimation`
40 +
    Tilmann Gneiting, Adrian E. Raftery (2007)
41 +
    https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf
42 +
43 +
    :param int num_particles: The number of particles/samples used to form the
29 44
        gradient estimators. Must be at least 2.
30 45
    :param int max_plate_nesting: Optional bound on max number of nested
31 46
        :func:`pyro.plate` contexts. This is only required when enumerating
@@ -37,23 +52,25 @@
Loading
37 52
    def __init__(self,
38 53
                 num_particles=2,
39 54
                 max_plate_nesting=float('inf')):
40 -
        assert isinstance(num_particles, int) and num_particles >= 2
55 +
        if not isinstance(num_particles, int) and num_particles >= 2:
56 +
            raise ValueError("Expected num_particles >= 2, actual {}".format(num_particles))
41 57
        self.num_particles = num_particles
42 58
        self.vectorize_particles = True
43 59
        self.max_plate_nesting = max_plate_nesting
44 60
45 61
    def _get_traces(self, model, guide, *args, **kwargs):
46 62
        if self.max_plate_nesting == float("inf"):
63 +
            # TODO factor this out as a stand-alone helper.
47 64
            ELBO._guess_max_plate_nesting(self, model, guide, *args, **kwargs)
48 65
        vectorize = pyro.plate("num_particles_vectorized", self.num_particles,
49 -
                               dim=-1 - self.max_plate_nesting)
66 +
                               dim=-self.max_plate_nesting)
50 67
51 68
        # Trace the guide as in ELBO.
52 69
        with poutine.trace() as tr, vectorize:
53 70
            guide(*args, **kwargs)
54 71
        guide_trace = tr.trace
55 72
56 -
        # Trace the model, saving obs in tr2 and posterior predictives in tr1.
73 +
        # Trace the model, drawing posterior predictive samples.
57 74
        with poutine.trace() as tr, poutine.uncondition():
58 75
            with poutine.replay(trace=guide_trace), vectorize:
59 76
                model(*args, **kwargs)
@@ -61,7 +78,6 @@
Loading
61 78
        for site in model_trace.nodes.values():
62 79
            if site["type"] == "sample" and site["infer"].get("was_observed", False):
63 80
                site["is_observed"] = True
64 -
65 81
        if is_validation_enabled():
66 82
            check_model_guide_match(model_trace, guide_trace, self.max_plate_nesting)
67 83
@@ -69,7 +85,6 @@
Loading
69 85
        model_trace = prune_subsample_sites(model_trace)
70 86
        guide_trace.compute_log_prob()
71 87
        model_trace.compute_log_prob(site_filter=lambda name, site: not site["is_observed"])
72 -
73 88
        if is_validation_enabled():
74 89
            for site in guide_trace.nodes.values():
75 90
                if site["type"] == "sample":
@@ -87,6 +102,10 @@
Loading
87 102
        return guide_trace, model_trace
88 103
89 104
    def __call__(self, model, guide, *args, **kwargs):
105 +
        """
106 +
        Computes the surrogate loss that can be differentiated with autograd
107 +
        to produce gradient estimates for the model and guide parameters.
108 +
        """
90 109
        guide_trace, model_trace = self._get_traces(model, guide, *args, **kwargs)
91 110
92 111
        # Extract observations and posterior predictive samples.
@@ -143,4 +162,7 @@
Loading
143 162
        return loss
144 163
145 164
    def loss(self, *args, **kwargs):
165 +
        """
166 +
        Not implemented. Added for compatibility with unit tests only.
167 +
        """
146 168
        raise NotImplementedError("Trace_CRPS implements only surrogate loss")
Files Coverage
pyro 93.97%
Project Totals (203 files) 93.97%
5478.5
3.5=.5
TRAVIS_OS_NAME=linux
5478.7
3.5=.5
TRAVIS_OS_NAME=linux
5478.6
3.5=.5
TRAVIS_OS_NAME=linux
5478.8
3.5=.5
TRAVIS_OS_NAME=linux
1
ignore:
2
  - "pyro/docutil.py"
3
  - "pyro/logger.py"
4

5
coverage:
6
  range: 60..95
7
  round: nearest
8
  precision: 2
9

10
comment: false
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading