pyro-ppl / pyro

Compare d216f0d ... +3 ... b7e21ff

Missing base report.

Unable to compare commits because the base of the compare did not upload a coverage report.

Learn more here.

Showing 4 of 5 files from the diff.
Newly tracked file
pyro/infer/svi.py changed.
Other files ignored by Codecov

@@ -19,6 +19,7 @@
Loading
19 19
        if msg["is_observed"]:
20 20
            msg["is_observed"] = False
21 21
            msg["infer"]["was_observed"] = True
22 +
            msg["infer"]["obs"] = msg["value"]
22 23
            msg["value"] = None
23 24
            msg["done"] = False
24 25
        return None

@@ -0,0 +1,138 @@
Loading
1 +
import operator
2 +
from collections import OrderedDict
3 +
from functools import reduce
4 +
5 +
import torch
6 +
7 +
import pyro
8 +
import pyro.poutine as poutine
9 +
from pyro.distributions.util import scale_and_mask
10 +
from pyro.infer.elbo import ELBO
11 +
from pyro.infer.util import is_validation_enabled
12 +
from pyro.poutine.util import prune_subsample_sites
13 +
from pyro.util import check_model_guide_match, check_site_shape, warn_if_nan
14 +
15 +
16 +
def _squared_error(x, y, scale, mask):
17 +
    diff = x - y
18 +
    error = torch.einsum("np,np->n", diff, diff)
19 +
    return scale_and_mask(error, scale, mask)
20 +
21 +
22 +
class Trace_CRPS:
23 +
    """
24 +
    Posterior predictive CRPS loss.
25 +
26 +
    This is a likelihood-free method; no densities are evaluated.
27 +
28 +
    :param num_particles: The number of particles/samples used to form the
29 +
        gradient estimators. Must be at least 2.
30 +
    :param int max_plate_nesting: Optional bound on max number of nested
31 +
        :func:`pyro.plate` contexts. This is only required when enumerating
32 +
        over sample sites in parallel, e.g. if a site sets
33 +
        ``infer={"enumerate": "parallel"}``. If omitted, ELBO may guess a valid
34 +
        value by running the (model,guide) pair once, however this guess may
35 +
        be incorrect if model or guide structure is dynamic.
36 +
    """
37 +
    def __init__(self,
38 +
                 num_particles=2,
39 +
                 max_plate_nesting=float('inf')):
40 +
        assert isinstance(num_particles, int) and num_particles >= 2
41 +
        self.num_particles = num_particles
42 +
        self.vectorize_particles = True
43 +
        self.max_plate_nesting = max_plate_nesting
44 +
45 +
    def _get_traces(self, model, guide, *args, **kwargs):
46 +
        if self.max_plate_nesting == float("inf"):
47 +
            ELBO._guess_max_plate_nesting(self, model, guide, *args, **kwargs)
48 +
        vectorize = pyro.plate("num_particles_vectorized", self.num_particles,
49 +
                               dim=-1 - self.max_plate_nesting)
50 +
51 +
        # Trace the guide as in ELBO.
52 +
        with poutine.trace() as tr, vectorize:
53 +
            guide(*args, **kwargs)
54 +
        guide_trace = tr.trace
55 +
56 +
        # Trace the model, saving obs in tr2 and posterior predictives in tr1.
57 +
        with poutine.trace() as tr, vectorize, poutine.uncondition():
58 +
            with poutine.replay(trace=guide_trace):
59 +
                model(*args, **kwargs)
60 +
        model_trace = tr.trace
61 +
        for site in model_trace.nodes.values():
62 +
            if site["type"] == "sample" and site["infer"].get("was_observed", False):
63 +
                site["is_observed"] = True
64 +
65 +
        if is_validation_enabled():
66 +
            check_model_guide_match(model_trace, guide_trace, self.max_plate_nesting)
67 +
68 +
        guide_trace = prune_subsample_sites(guide_trace)
69 +
        model_trace = prune_subsample_sites(model_trace)
70 +
        model_trace.compute_log_prob(site_filter=lambda name, site:
71 +
                                     not site["is_observed"] and site["mask"] is not False)
72 +
73 +
        if is_validation_enabled():
74 +
            for site in guide_trace.nodes.values():
75 +
                if site["type"] == "sample":
76 +
                    if not getattr(site["fn"], "has_rsample", False):
77 +
                        raise ValueError("Trace_CRPS only supports fully reparametrized guides")
78 +
            for trace in model_trace.nodes.values():
79 +
                if site["type"] == "sample" and "log_prob" in site:
80 +
                    check_site_shape(site, self.max_plate_nesting)
81 +
82 +
        return guide_trace, model_trace
83 +
84 +
    def __call__(self, model, guide, *args, **kwargs):
85 +
        guide_trace, model_trace = self._get_traces(model, guide, *args, **kwargs)
86 +
87 +
        # Extract observations and posterior predictive samples.
88 +
        data = OrderedDict()
89 +
        samples = OrderedDict()
90 +
        for name, site in model_trace.nodes.items():
91 +
            if site["type"] == "sample" and site["is_observed"]:
92 +
                data[name] = site["infer"]["obs"]
93 +
                samples[name] = site["value"]
94 +
        assert list(data.keys()) == list(samples.keys())
95 +
        if not data:
96 +
            raise ValueError("Found no observations")
97 +
98 +
        # Compute crps from mean average error and generalized entropy.
99 +
        squared_error = []  # E[ (X - x)^2 ]
100 +
        squared_entropy = []  # E[ (X - X')^2 ]
101 +
        prototype = next(iter(data.values()))
102 +
        pairs = prototype.new_ones(self.num_particles, self.num_particles).tril(-1).nonzero()
103 +
        for name, obs in data.items():
104 +
            sample = samples[name]
105 +
            scale = model_trace.nodes[name]["scale"]
106 +
            mask = model_trace.nodes[name]["mask"]
107 +
108 +
            # Flatten.
109 +
            batch_shape = obs.shape[:obs.dim() - model_trace.nodes[name]["fn"].event_dim]
110 +
            if isinstance(scale, torch.Tensor):
111 +
                scale = scale.expand(batch_shape).reshape(-1)
112 +
            if isinstance(mask, torch.Tensor):
113 +
                mask = mask.expand(batch_shape).reshape(-1)
114 +
            obs = obs.reshape(-1)
115 +
            sample = sample.reshape(self.num_particles, -1)
116 +
117 +
            squared_error.append(_squared_error(sample, obs, scale, mask))
118 +
            squared_entropy.append(_squared_error(*sample[pairs].unbind(1), scale, mask))
119 +
120 +
        squared_error = reduce(operator.add, squared_error)
121 +
        squared_entropy = reduce(operator.add, squared_entropy)
122 +
        error = squared_error.sqrt().mean()  # E[ |X-x| ]
123 +
        entropy = squared_entropy.sqrt().mean()  # E[ |X-X'| ]
124 +
        crps = error - 0.5 * entropy
125 +
126 +
        # Compute log p(z).
127 +
        logp = 0
128 +
        for site in model_trace.nodes.values():
129 +
            if site["type"] == "sample" and "log_prob_sum" in site:
130 +
                logp = logp + site["log_prob_sum"]
131 +
132 +
        # Compute final loss.
133 +
        loss = crps - logp
134 +
        warn_if_nan(loss, "loss")
135 +
        return loss
136 +
137 +
    def loss(self, *args, **kwargs):
138 +
        raise NotImplementedError("Trace_CRPS implements only surrogate loss")

@@ -66,7 +66,8 @@
Loading
66 66
            if loss_and_grads is None:
67 67
                def _loss_and_grads(*args, **kwargs):
68 68
                    loss_val = loss(*args, **kwargs)
69 -
                    loss_val.backward(retain_graph=True)
69 +
                    if getattr(loss_val, 'requires_grad', False):
70 +
                        loss_val.backward(retain_graph=True)
70 71
                    return loss_val
71 72
                loss_and_grads = _loss_and_grads
72 73
            self.loss = loss

@@ -10,15 +10,16 @@
Loading
10 10
from pyro.infer.predictive import Predictive
11 11
from pyro.infer.renyi_elbo import RenyiELBO
12 12
from pyro.infer.smcfilter import SMCFilter
13 +
from pyro.infer.svgd import SVGD, IMQSteinKernel, RBFSteinKernel
13 14
from pyro.infer.svi import SVI
15 +
from pyro.infer.trace_crps import Trace_CRPS
14 16
from pyro.infer.trace_elbo import JitTrace_ELBO, Trace_ELBO
15 17
from pyro.infer.trace_mean_field_elbo import JitTraceMeanField_ELBO, TraceMeanField_ELBO
18 +
from pyro.infer.trace_mmd import Trace_MMD
16 19
from pyro.infer.trace_tail_adaptive_elbo import TraceTailAdaptive_ELBO
17 20
from pyro.infer.traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO
18 21
from pyro.infer.tracegraph_elbo import JitTraceGraph_ELBO, TraceGraph_ELBO
19 -
from pyro.infer.trace_mmd import Trace_MMD
20 22
from pyro.infer.util import enable_validation, is_validation_enabled
21 -
from pyro.infer.svgd import SVGD, RBFSteinKernel, IMQSteinKernel
22 23
23 24
__all__ = [
24 25
    "config_enumerate",
@@ -49,6 +50,7 @@
Loading
49 50
    "TracePosterior",
50 51
    "TracePredictive",
51 52
    "TraceTailAdaptive_ELBO",
53 +
    "Trace_CRPS",
52 54
    "Trace_ELBO",
53 55
    "Trace_MMD",
54 56
]

Unable to process changes.

No base report to compare against.

Files Coverage
pyro 94.29%
Project Totals (203 files) 94.29%
Loading