pyro-ppl / pyro

Compare 3085029 ... +2 ... 10a54ce

Coverage Reach
infer/mcmc/util.py infer/mcmc/api.py infer/mcmc/hmc.py infer/mcmc/nuts.py infer/mcmc/adaptation.py infer/mcmc/logger.py infer/mcmc/mcmc_kernel.py infer/mcmc/__init__.py infer/autoguide/guides.py infer/autoguide/initialization.py infer/autoguide/utils.py infer/autoguide/__init__.py infer/traceenum_elbo.py infer/abstract_infer.py infer/tracegraph_elbo.py infer/util.py infer/svgd.py infer/trace_elbo.py infer/discrete.py infer/predictive.py infer/trace_crps.py infer/renyi_elbo.py infer/trace_mmd.py infer/importance.py infer/smcfilter.py infer/trace_mean_field_elbo.py infer/enum.py infer/csis.py infer/elbo.py infer/svi.py infer/trace_tail_adaptive_elbo.py infer/__init__.py contrib/gp/models/sgpr.py contrib/gp/models/gpr.py contrib/gp/models/vsgp.py contrib/gp/models/vgp.py contrib/gp/models/gplvm.py contrib/gp/models/model.py contrib/gp/models/__init__.py contrib/gp/kernels/kernel.py contrib/gp/kernels/isotropic.py contrib/gp/kernels/periodic.py contrib/gp/kernels/dot_product.py contrib/gp/kernels/coregionalize.py contrib/gp/kernels/static.py contrib/gp/kernels/brownian.py contrib/gp/kernels/__init__.py contrib/gp/parameterized.py contrib/gp/likelihoods/multi_class.py contrib/gp/likelihoods/gaussian.py contrib/gp/likelihoods/binary.py contrib/gp/likelihoods/poisson.py contrib/gp/likelihoods/__init__.py contrib/gp/likelihoods/likelihood.py contrib/gp/util.py contrib/gp/__init__.py contrib/oed/eig.py contrib/oed/glmm/glmm.py contrib/oed/glmm/guides.py contrib/oed/glmm/__init__.py contrib/oed/util.py contrib/oed/search.py contrib/oed/__init__.py contrib/tracking/assignment.py contrib/tracking/dynamic_models.py contrib/tracking/hashing.py contrib/tracking/extended_kalman_filter.py contrib/tracking/distributions.py contrib/tracking/measurements.py contrib/tracking/__init__.py contrib/timeseries/gp.py contrib/timeseries/lgssmgp.py contrib/timeseries/lgssm.py contrib/timeseries/__init__.py contrib/timeseries/base.py contrib/minipyro.py contrib/autoname/named.py contrib/autoname/scoping.py contrib/autoname/__init__.py contrib/easyguide/easyguide.py contrib/easyguide/__init__.py contrib/conjugate/infer.py contrib/examples/multi_mnist.py contrib/examples/util.py contrib/bnn/hidden_layer.py contrib/bnn/utils.py contrib/bnn/__init__.py contrib/util.py contrib/autoguide.py contrib/__init__.py distributions/transforms/neural_autoregressive.py distributions/transforms/block_autoregressive.py distributions/transforms/affine_autoregressive.py distributions/transforms/planar.py distributions/transforms/sylvester.py distributions/transforms/polynomial.py distributions/transforms/affine_coupling.py distributions/transforms/householder.py distributions/transforms/cholesky.py distributions/transforms/radial.py distributions/transforms/batchnorm.py distributions/transforms/permute.py distributions/transforms/__init__.py distributions/transforms/utils.py distributions/spanning_tree.py distributions/testing/rejection_gamma.py distributions/testing/rejection_exponential.py distributions/testing/naive_dirichlet.py distributions/testing/fakes.py distributions/hmm.py distributions/conjugate.py distributions/util.py distributions/torch_distribution.py distributions/mixture.py distributions/diag_normal_mixture_shared_cov.py distributions/diag_normal_mixture.py distributions/von_mises.py distributions/lkj.py distributions/empirical.py distributions/gaussian_scale_mixture.py distributions/torch.py distributions/delta.py distributions/relaxed_straight_through.py distributions/rejector.py distributions/torch_patch.py distributions/kl.py distributions/zero_inflated_poisson.py distributions/conditional.py distributions/__init__.py distributions/unit.py distributions/avf_mvn.py distributions/omt_mvn.py distributions/von_mises_3d.py distributions/constraints.py distributions/folded.py distributions/inverse_gamma.py distributions/distribution.py distributions/score_parts.py distributions/torch_transform.py ops/einsum/adjoint.py ops/einsum/torch_sample.py ops/einsum/torch_marginal.py ops/einsum/torch_map.py ops/einsum/util.py ops/einsum/torch_log.py ops/einsum/__init__.py ops/contract.py ops/stats.py ops/gaussian.py ops/rings.py ops/packed.py ops/ssm_gp.py ops/jit.py ops/newton.py ops/linalg.py ops/tensor_utils.py ops/integrator.py ops/indexing.py ops/welford.py ops/dual_averaging.py ops/hessian.py poutine/trace_struct.py poutine/enumerate_messenger.py poutine/handlers.py poutine/runtime.py poutine/trace_messenger.py poutine/subsample_messenger.py poutine/indep_messenger.py poutine/messenger.py poutine/lift_messenger.py poutine/markov_messenger.py poutine/util.py poutine/block_messenger.py poutine/replay_messenger.py poutine/do_messenger.py poutine/broadcast_messenger.py poutine/reentrant_messenger.py poutine/mask_messenger.py poutine/escape_messenger.py poutine/condition_messenger.py poutine/scale_messenger.py poutine/uncondition_messenger.py poutine/infer_config_messenger.py poutine/plate_messenger.py poutine/__init__.py nn/module.py nn/auto_reg_nn.py nn/dense_nn.py nn/__init__.py optim/optim.py optim/multi.py optim/adagrad_rmsprop.py optim/clipped_adam.py optim/pytorch_optimizers.py optim/lr_scheduler.py optim/__init__.py util.py params/param_store.py params/__init__.py primitives.py logger.py __init__.py generic.py

No flags found

Use flags to group coverage reports by test type, project and/or folders.
Then setup custom commit statuses and notifications for each flag.

e.g., #unittest #integration

#production #enterprise

#frontend #backend

Learn more about Codecov Flags here.

Showing 1 of 3 files from the diff.

@@ -21,7 +21,7 @@
Loading
21 21
22 22
class Trace_CRPS:
23 23
    """
24 -
    Posterior predictive CRPS loss with ``KL(q,p)`` regularization.
24 +
    Posterior predictive CRPS loss with optional ``KL(q,p)`` regularization.
25 25
26 26
    This is a likelihood-free method, and can be used for likelihoods without
27 27
    tractible density functions. CRPS is a robust loss function, and is well
@@ -31,10 +31,6 @@
Loading
31 31
    reparametrized likelihood distributions in the model. Model latent
32 32
    distributions may be non-reparametrized.
33 33
34 -
    Note that in the loss ``CRPS + KL(q,p)``, the ``CRPS`` term has data units
35 -
    whereas the ``KL(q,p)`` term has units of nats.  To calibrate these units,
36 -
    you can wrap likelihood sites in ``poutine.scale``.
37 -
38 34
    References
39 35
    [1] `Strictly Proper Scoring Rules, Prediction, and Estimation`
40 36
    Tilmann Gneiting, Adrian E. Raftery (2007)
@@ -43,20 +39,24 @@
Loading
43 39
    :param int num_particles: The number of particles/samples used to form the
44 40
        gradient estimators. Must be at least 2.
45 41
    :param int max_plate_nesting: Optional bound on max number of nested
46 -
        :func:`pyro.plate` contexts. This is only required when enumerating
47 -
        over sample sites in parallel, e.g. if a site sets
48 -
        ``infer={"enumerate": "parallel"}``. If omitted, ELBO may guess a valid
49 -
        value by running the (model,guide) pair once, however this guess may
50 -
        be incorrect if model or guide structure is dynamic.
42 +
        :func:`pyro.plate` contexts. If omitted, ELBO may guess a valid value
43 +
        by running the (model,guide) pair once, however this guess may be
44 +
        incorrect if model or guide structure is dynamic.
45 +
    :param float kl_scale: Nonnegative scale for ``KL(q,p)`` regularization.
46 +
        If zero (default), then log densities will not be computed.
51 47
    """
52 48
    def __init__(self,
53 49
                 num_particles=2,
54 -
                 max_plate_nesting=float('inf')):
55 -
        if not isinstance(num_particles, int) and num_particles >= 2:
50 +
                 max_plate_nesting=float('inf'),
51 +
                 kl_scale=0.):
52 +
        if not (isinstance(num_particles, int) and num_particles >= 2):
56 53
            raise ValueError("Expected num_particles >= 2, actual {}".format(num_particles))
54 +
        if not (isinstance(kl_scale, (float, int)) and kl_scale >= 0):
55 +
            raise ValueError("Expected kl_scale >= 0, actual {}".format(kl_scale))
57 56
        self.num_particles = num_particles
58 57
        self.vectorize_particles = True
59 58
        self.max_plate_nesting = max_plate_nesting
59 +
        self.kl_scale = kl_scale
60 60
61 61
    def _get_traces(self, model, guide, *args, **kwargs):
62 62
        if self.max_plate_nesting == float("inf"):
@@ -83,21 +83,28 @@
Loading
83 83
84 84
        guide_trace = prune_subsample_sites(guide_trace)
85 85
        model_trace = prune_subsample_sites(model_trace)
86 -
        guide_trace.compute_log_prob()
87 -
        model_trace.compute_log_prob(site_filter=lambda name, site: not site["is_observed"])
88 86
        if is_validation_enabled():
89 87
            for site in guide_trace.nodes.values():
90 88
                if site["type"] == "sample":
91 -
                    check_site_shape(site, self.max_plate_nesting)
92 89
                    if not getattr(site["fn"], "has_rsample", False):
93 90
                        raise ValueError("Trace_CRPS requires fully reparametrized guides")
94 91
            for trace in model_trace.nodes.values():
95 92
                if site["type"] == "sample":
96 93
                    if site["is_observed"]:
97 94
                        if not getattr(site["fn"], "has_rsample", False):
98 95
                            raise ValueError("Trace_CRPS requires reparametrized likelihoods")
99 -
                    else:
96 +
97 +
        if self.kl_scale > 0:
98 +
            guide_trace.compute_log_prob()
99 +
            model_trace.compute_log_prob(site_filter=lambda name, site: not site["is_observed"])
100 +
            if is_validation_enabled():
101 +
                for site in guide_trace.nodes.values():
102 +
                    if site["type"] == "sample":
100 103
                        check_site_shape(site, self.max_plate_nesting)
104 +
                for trace in model_trace.nodes.values():
105 +
                    if site["type"] == "sample":
106 +
                        if not site["is_observed"]:
107 +
                            check_site_shape(site, self.max_plate_nesting)
101 108
102 109
        return guide_trace, model_trace
103 110
@@ -147,17 +154,18 @@
Loading
147 154
        entropy = squared_entropy.sqrt().mean()  # E[ |X-X'| ]
148 155
        crps = error - 0.5 * entropy
149 156
150 -
        # Compute KL(guide||model).
157 +
        # Compute KL(guide,model).
151 158
        kl_qp = 0
152 -
        for site in model_trace.nodes.values():
153 -
            if site["type"] == "sample" and not site["is_observed"]:
154 -
                kl_qp = kl_qp + site["log_prob_sum"]
155 -
        for site in guide_trace.nodes.values():
156 -
            if site["type"] == "sample":
157 -
                kl_qp = kl_qp - site["log_prob_sum"]
159 +
        if self.kl_scale > 0:
160 +
            for site in guide_trace.nodes.values():
161 +
                if site["type"] == "sample":
162 +
                    kl_qp = kl_qp + site["log_prob_sum"]
163 +
            for site in model_trace.nodes.values():
164 +
                if site["type"] == "sample" and not site["is_observed"]:
165 +
                    kl_qp = kl_qp - site["log_prob_sum"]
158 166
159 167
        # Compute final loss.
160 -
        loss = crps + kl_qp
168 +
        loss = crps + self.kl_scale * kl_qp
161 169
        warn_if_nan(loss, "loss")
162 170
        return loss
163 171

Learn more Showing 137 files with coverage changes found.

Changes in pyro/contrib/minipyro.py
-2
+2
Loading file...
Changes in pyro/ops/jit.py
-1
+1
Loading file...
Changes in pyro/infer/elbo.py
-1
+1
Loading file...
Changes in pyro/ops/indexing.py
-1
+1
Loading file...
Changes in pyro/poutine/util.py
-1
+1
Loading file...
Changes in pyro/contrib/gp/models/gpr.py
-2
+2
Loading file...
Changes in pyro/infer/util.py
-4
+4
Loading file...
Changes in pyro/poutine/indep_messenger.py
-2
+2
Loading file...
Changes in pyro/distributions/conditional.py
-1
+1
Loading file...
Changes in pyro/ops/einsum/torch_log.py
-1
+1
Loading file...
Changes in pyro/contrib/gp/models/vsgp.py
-2
+2
Loading file...
Changes in pyro/poutine/enumerate_messenger.py
-4
+4
Loading file...
Changes in pyro/infer/mcmc/adaptation.py
-5
+5
Loading file...
Changes in pyro/infer/mcmc/logger.py
-5
+5
Loading file...
Changes in pyro/infer/tracegraph_elbo.py
-7
+7
Loading file...
Changes in pyro/poutine/trace_messenger.py
-4
+4
Loading file...
Changes in pyro/poutine/reentrant_messenger.py
-1
+1
Loading file...
Changes in pyro/poutine/condition_messenger.py
-1
+1
Loading file...
Changes in pyro/distributions/util.py
-9
+9
Loading file...
Changes in pyro/poutine/block_messenger.py
-3
+3
Loading file...
Changes in pyro/infer/abstract_infer.py
-13
+13
Loading file...
Changes in pyro/nn/auto_reg_nn.py
-7
+7
Loading file...
Changes in pyro/infer/enum.py
-6
+6
Loading file...
Changes in pyro/ops/einsum/util.py
-3
+3
Loading file...
Changes in pyro/distributions/constraints.py
-2
+2
Loading file...
Changes in pyro/poutine/subsample_messenger.py
-7
+7
Loading file...
Changes in pyro/poutine/handlers.py
-10
+10
Loading file...
Changes in pyro/contrib/bnn/utils.py
-1
+1
Loading file...
Changes in pyro/distributions/rejector.py
-4
+4
Loading file...
Changes in pyro/infer/mcmc/nuts.py
-17
+17
Loading file...
Changes in pyro/ops/integrator.py
-5
+5
Loading file...
Changes in pyro/contrib/oed/glmm/glmm.py
-16
+16
Loading file...
Changes in pyro/infer/svi.py
-6
+6
Loading file...
Changes in pyro/ops/einsum/torch_marginal.py
-5
+5
Loading file...
Changes in pyro/ops/packed.py
-12
+12
Loading file...
Changes in pyro/infer/mcmc/mcmc_kernel.py
-2
+2
Loading file...
Changes in pyro/util.py
-24
+24
Loading file...
Changes in pyro/ops/rings.py
-22
+22
Loading file...
Changes in pyro/infer/autoguide/initialization.py
-7
+7
Loading file...
Changes in pyro/contrib/gp/util.py
-9
+9
Loading file...
Changes in pyro/ops/welford.py
-4
+4
Loading file...
Changes in pyro/distributions/empirical.py
-8
+8
Loading file...
Changes in pyro/ops/einsum/adjoint.py
-14
+14
Loading file...
Changes in pyro/contrib/autoname/named.py
-15
+15
Loading file...
Changes in pyro/distributions/lkj.py
-11
+11
Loading file...
Changes in pyro/contrib/gp/kernels/isotropic.py
-14
+14
Loading file...
Changes in pyro/distributions/von_mises.py
-14
+14
Loading file...
Changes in pyro/contrib/easyguide/easyguide.py
-30
+30
Loading file...
Changes in pyro/distributions/unit.py
-5
+5
Loading file...
Changes in pyro/params/param_store.py
-22
+22
Loading file...
Changes in pyro/infer/predictive.py
-23
+23
Loading file...
Changes in pyro/contrib/oed/glmm/guides.py
-34
+34
Loading file...
Changes in pyro/distributions/testing/naive_dirichlet.py
-5
+5
Loading file...
Changes in pyro/primitives.py
-24
+24
Loading file...
Changes in pyro/ops/gaussian.py
-44
+44
Loading file...
Changes in pyro/infer/mcmc/util.py
-77
+77
Loading file...
Changes in pyro/distributions/transforms/affine_autoregressive.py
-21
+21
Loading file...
Changes in pyro/infer/mcmc/hmc.py
-44
+44
Loading file...
Changes in pyro/optim/optim.py
-22
+22
Loading file...
Changes in pyro/distributions/torch_patch.py
-10
+10
Loading file...
Changes in pyro/infer/trace_mean_field_elbo.py
-20
+20
Loading file...
Changes in pyro/poutine/trace_struct.py
-66
+66
Loading file...
Changes in pyro/ops/ssm_gp.py
-20
+20
Loading file...
Changes in pyro/distributions/torch_distribution.py
-26
+26
Loading file...
Changes in pyro/infer/traceenum_elbo.py
-77
+77
Loading file...
Changes in pyro/contrib/gp/likelihoods/gaussian.py
-5
+5
Loading file...
Changes in pyro/distributions/mixture.py
-22
+22
Loading file...
Changes in pyro/optim/lr_scheduler.py
-5
+5
Loading file...
Changes in pyro/contrib/gp/kernels/dot_product.py
-10
+10
Loading file...
Changes in pyro/infer/autoguide/guides.py
-103
+103
Loading file...
Changes in pyro/nn/module.py
-74
+74
Loading file...
Changes in pyro/contrib/autoname/scoping.py
-17
+17
Loading file...
Changes in pyro/distributions/testing/rejection_gamma.py
-46
+46
Loading file...
Changes in pyro/ops/contract.py
-79
+79
Loading file...
Changes in pyro/distributions/hmm.py
-58
+58
Loading file...
Changes in pyro/distributions/transforms/permute.py
-10
+10
Loading file...
Changes in pyro/contrib/gp/kernels/coregionalize.py
-10
+10
Loading file...
Changes in pyro/infer/mcmc/api.py
-97
+97
Loading file...
Changes in pyro/distributions/omt_mvn.py
-8
+8
Loading file...
Changes in pyro/contrib/util.py
-16
+16
Loading file...
Changes in pyro/distributions/kl.py
-12
+12
Loading file...
Changes in pyro/ops/stats.py
-70
+70
Loading file...
Changes in pyro/contrib/gp/kernels/static.py
-10
+10
Loading file...
Changes in pyro/contrib/gp/kernels/kernel.py
-32
+32
Loading file...
Changes in pyro/contrib/gp/likelihoods/poisson.py
-6
+6
Loading file...
Changes in pyro/contrib/gp/kernels/periodic.py
-14
+14
Loading file...
Changes in pyro/distributions/avf_mvn.py
-10
+10
Loading file...
Changes in pyro/contrib/tracking/measurements.py
-16
+16
Loading file...
Changes in pyro/contrib/gp/kernels/brownian.py
-9
+9
Loading file...
Changes in pyro/distributions/von_mises_3d.py
-10
+10
Loading file...
Changes in pyro/distributions/transforms/affine_coupling.py
-23
+23
Loading file...
Changes in pyro/contrib/tracking/distributions.py
-20
+20
Loading file...
Changes in pyro/distributions/testing/rejection_exponential.py
-10
+10
Loading file...
Changes in pyro/infer/importance.py
-41
+41
Loading file...
Changes in pyro/contrib/timeseries/gp.py
-92
+92
Loading file...
Changes in pyro/distributions/transforms/planar.py
-34
+34
Loading file...
Changes in pyro/distributions/transforms/neural_autoregressive.py
-53
+53
Loading file...
Changes in pyro/infer/trace_crps.py
-46
+46
Loading file...
Changes in pyro/distributions/transforms/radial.py
-22
+22
Loading file...
Changes in pyro/distributions/transforms/batchnorm.py
-21
+21
Loading file...
Changes in pyro/distributions/folded.py
-9
+9
Loading file...
Changes in pyro/distributions/relaxed_straight_through.py
-21
+21
Loading file...
Changes in pyro/distributions/transforms/householder.py
-26
+26
Loading file...
Changes in pyro/distributions/conjugate.py
-75
+75
Loading file...
Changes in pyro/optim/multi.py
-31
+31
Loading file...
Changes in pyro/distributions/transforms/sylvester.py
-34
+34
Loading file...
Changes in pyro/distributions/transforms/polynomial.py
-32
+32
Loading file...
Changes in pyro/distributions/zero_inflated_poisson.py
-20
+20
Loading file...
Changes in pyro/contrib/gp/parameterized.py
-54
+54
Loading file...
Changes in pyro/distributions/gaussian_scale_mixture.py
-31
+31
Loading file...
Changes in pyro/contrib/timeseries/lgssm.py
-36
+36
Loading file...
Changes in pyro/contrib/tracking/dynamic_models.py
-83
+83
Loading file...
Changes in pyro/contrib/gp/models/gplvm.py
-15
+15
Loading file...
Changes in pyro/poutine/do_messenger.py
-15
+15
Loading file...
Changes in pyro/poutine/uncondition_messenger.py
-8
+8
Loading file...
Changes in pyro/contrib/oed/util.py
-13
+13
Loading file...
Changes in pyro/distributions/transforms/block_autoregressive.py
-63
+63
Loading file...
Changes in pyro/contrib/bnn/hidden_layer.py
-30
+30
Loading file...
Changes in pyro/contrib/oed/eig.py
-242
+242
Loading file...
Changes in pyro/contrib/tracking/extended_kalman_filter.py
-57
+57
Loading file...
Changes in pyro/infer/svgd.py
-94
+94
Loading file...
Changes in pyro/contrib/gp/models/vgp.py
-38
+38
Loading file...
Changes in pyro/contrib/conjugate/infer.py
-90
+90
Loading file...
Changes in pyro/infer/csis.py
-46
+46
Loading file...
Changes in pyro/poutine/lift_messenger.py
-44
+44
Loading file...
Changes in pyro/distributions/diag_normal_mixture.py
-53
+53
Loading file...
Changes in pyro/distributions/diag_normal_mixture_shared_cov.py
-54
+54
Loading file...
Changes in pyro/contrib/timeseries/lgssmgp.py
-66
+66
Loading file...
Changes in pyro/nn/dense_nn.py
-23
+23
Loading file...
Changes in pyro/infer/discrete.py
-84
+84
Loading file...
Changes in pyro/ops/tensor_utils.py
-37
+37
Loading file...
Changes in pyro/contrib/tracking/hashing.py
-67
+67
Loading file...
Changes in pyro/distributions/spanning_tree.py
-180
+180
Loading file...
Changes in pyro/contrib/gp/models/sgpr.py
-78
+78
Loading file...
Changes in pyro/ops/newton.py
-50
+50
Loading file...
Changes in pyro/ops/linalg.py
-46
+46
Loading file...
Changes in pyro/contrib/tracking/assignment.py
-146
+146
Loading file...
Files Coverage
pyro +31.12% 93.80%
Project Totals (203 files) 93.80%
Loading