#305 Add negative binom to pygam

Open michalk8

No flags found

Use flags to group coverage reports by test type, project and/or folders.
Then setup custom commit statuses and notifications for each flag.

e.g., #unittest #integration

#production #enterprise

#frontend #backend

Learn more about Codecov Flags here.


@@ -89,6 +89,14 @@
Loading
89 89
_copy = """Return a copy of self."""
90 90
_root = "root"
91 91
_final = "final"
92 +
_nb_y = """\
93 +
y
94 +
    Target values, array of of shape `(n,)`.
95 +
"""
96 +
_nb_mu = """\
97 +
mu
98 +
    Expected values, array of of shape `(n,)`.
99 +
"""
92 100
93 101
94 102
def inject_docs(**kwargs):
@@ -119,4 +127,6 @@
Loading
119 127
    time_range=_time_range,
120 128
    velocity_mode=_velocity_mode,
121 129
    velocity_backward_mode=_velocity_backward_mode,
130 +
    nb_y=_nb_y,
131 +
    nb_mu=_nb_mu,
122 132
)

@@ -22,7 +22,10 @@
Loading
22 22
    ExpectileGAM,
23 23
)
24 24
from pygam.terms import s
25 +
from pygam.utils import ylogydu
26 +
from scipy.stats import nbinom
25 27
from sklearn.base import BaseEstimator
28 +
from pygam.distributions import Distribution, divide_weights, multiply_weights
26 29
from scipy.ndimage.filters import convolve
27 30
28 31
import matplotlib as mpl
@@ -53,6 +56,7 @@
Loading
53 56
class GamDistribution(ModeEnum):  # noqa
54 57
    NORMAL = "normal"
55 58
    BINOMIAL = "binomial"
59 +
    NEGATIVE_BINOMIAL = "nb"
56 60
    POISSON = "poisson"
57 61
    GAMMA = "gamma"
58 62
    GAUSS = "gaussian"
@@ -1065,6 +1069,8 @@
Loading
1065 1069
    link
1066 1070
        Name of the link function. Available functions can be found
1067 1071
        `here <https://pygam.readthedocs.io/en/latest/notebooks/tour_of_pygam.html#Link-function:>`_.
1072 +
    r
1073 +
        Number of successes when :paramref:`distribution` is `'nb'`.
1068 1074
    max_iter
1069 1075
        Maximum number of iterations for optimization.
1070 1076
    expectile
@@ -1075,7 +1081,7 @@
Loading
1075 1081
        use the :paramref:`model`'s method.
1076 1082
    grid
1077 1083
        Whether to perform a grid search. Keys correspond to a parameter names and values to range to be searched.
1078 -
        If an empty :class:`dict`, don't perform a grid search. If `None`, uses a default grid.
1084 +
        If `'default'`, use the default grid. If `None`, don't perform grid search.
1079 1085
    spline_kwargs
1080 1086
        Keyword arguments for :class:`pygam.s`.
1081 1087
    **kwargs
@@ -1087,12 +1093,13 @@
Loading
1087 1093
        adata: AnnData,
1088 1094
        n_splines: Optional[int] = 10,
1089 1095
        spline_order: int = 3,
1090 -
        distribution: str = "gamma",
1096 +
        distribution: str = "nb",
1091 1097
        link: str = "log",
1098 +
        r: Optional[float] = None,
1092 1099
        max_iter: int = 2000,
1093 1100
        expectile: Optional[float] = None,
1094 1101
        use_default_conf_int: bool = False,
1095 -
        grid: Optional[Mapping] = MappingProxyType({}),
1102 +
        grid: Optional[Union[str, Mapping]] = None,
1096 1103
        spline_kwargs: Mapping = MappingProxyType({}),
1097 1104
        **kwargs,
1098 1105
    ):
@@ -1105,8 +1112,11 @@
Loading
1105 1112
        )
1106 1113
        link = GamLinkFunction(link)
1107 1114
        distribution = GamDistribution(distribution)
1115 +
1108 1116
        if distribution == GamDistribution.GAUSS:
1109 1117
            distribution = GamDistribution.NORMAL
1118 +
        elif distribution == GamDistribution.NEGATIVE_BINOMIAL:
1119 +
            distribution = NegativeBinomial(r=r)
1110 1120
1111 1121
        if expectile is not None:
1112 1122
            if not (0 < expectile < 1):
@@ -1116,7 +1126,7 @@
Loading
1116 1126
            if distribution != "normal" or link != "identity":
1117 1127
                logg.warning(
1118 1128
                    f"Expectile GAM works only with `normal` distribution and `identity` link function,"
1119 -
                    f"found `{distribution!r}` distribution and {link!r} link functions."
1129 +
                    f"found `{distribution!r}` distribution and `{link!r}` link functions."
1120 1130
                )
1121 1131
            model = ExpectileGAM(
1122 1132
                term, expectile=expectile, max_iter=max_iter, verbose=False, **kwargs
@@ -1126,7 +1136,11 @@
Loading
1126 1136
                distribution, link
1127 1137
            ]  # doing it like this ensure that user can specify scale
1128 1138
            kwargs["link"] = link.s
1129 -
            kwargs["distribution"] = distribution.s
1139 +
            kwargs["distribution"] = (
1140 +
                distribution.s
1141 +
                if isinstance(distribution, GamDistribution)
1142 +
                else distribution
1143 +
            )
1130 1144
            model = gam(
1131 1145
                term,
1132 1146
                max_iter=max_iter,
@@ -1135,16 +1149,16 @@
Loading
1135 1149
            )
1136 1150
        super().__init__(adata, model=model)
1137 1151
        self._use_default_conf_int = use_default_conf_int
1138 -
        self._grid = object()  # sentinel value, `None` performs a grid search
1139 1152
1140 1153
        if grid is None:
1141 1154
            self._grid = None
1142 -
        elif isinstance(grid, (dict, MappingProxyType)):
1143 -
            if len(grid):
1144 -
                self._grid = dict(grid)
1155 +
        elif isinstance(grid, dict):
1156 +
            self._grid = _copy(grid)
1157 +
        elif isinstance(grid, str):
1158 +
            self._grid = object() if grid == "default" else None
1145 1159
        else:
1146 1160
            raise TypeError(
1147 -
                f"Expected `grid` to be `dict` or `None`, found `{type(grid).__name__!r}`."
1161 +
                f"Expected `grid` to be `dict`, `str` or `None`, found `{type(grid).__name__!r}`."
1148 1162
            )
1149 1163
1150 1164
    @d.dedent
@@ -1171,11 +1185,11 @@
Loading
1171 1185
        super().fit(x, y, w, **kwargs)
1172 1186
1173 1187
        if self._grid is not None:
1174 -
1188 +
            # use default grid
1175 1189
            grid = {} if not isinstance(self._grid, dict) else self._grid
1190 +
            print(grid)
1176 1191
            try:
1177 1192
                # workaround for: https://github.com/dswah/pyGAM/issues/273
1178 -
                self.model.fit(self.x, self.y, weights=self.w, **kwargs)
1179 1193
                self.model.gridsearch(
1180 1194
                    self.x,
1181 1195
                    self.y,
@@ -1187,6 +1201,7 @@
Loading
1187 1201
                )
1188 1202
                return self
1189 1203
            except Exception as e:
1204 +
                self.model.fit(self.x, self.y, weights=self.w, **kwargs)
1190 1205
                logg.error(
1191 1206
                    f"Grid search failed, reason: `{e}`. Fitting with default values"
1192 1207
                )
@@ -1276,7 +1291,7 @@
Loading
1276 1291
    smoothing_param
1277 1292
        Smoothing parameter. Increasing this increases the smootheness of splines.
1278 1293
    distribution
1279 -
        Family in `rpy2.robjects.r`, such as `'gaussian'` or `'poisson'`.
1294 +
        Family in `rpy2.robjects.r`, such as `'gaussian'`, 'poisson'` or `'nb'`.
1280 1295
    backend
1281 1296
        R library used to fit GAMs. Valid options are `'mgcv'` and `'gam'`.
1282 1297
        Option `'gam'` ignores the number of splines, as well as family and smoothing parameter.
@@ -1285,6 +1300,7 @@
Loading
1285 1300
    _fallback_backends = {
1286 1301
        "gam": "mgcv",
1287 1302
        "mgcv": "gam",
1303 +
        "brms": "mgcv",
1288 1304
    }
1289 1305
1290 1306
    def __init__(
@@ -1303,6 +1319,11 @@
Loading
1303 1319
        self._lib_name = None
1304 1320
        self._family = distribution
1305 1321
1322 +
        if distribution == "zinb":
1323 +
            backend = "brms"
1324 +
        elif backend == "brms":
1325 +
            distribution = "zinb"
1326 +
1306 1327
        if backend not in self._fallback_backends.keys():
1307 1328
            raise ValueError(
1308 1329
                f"Invalid backend library `{backend!r}`. Valid options are `{list(self._fallback_backends.keys())}`."
@@ -1351,7 +1372,11 @@
Loading
1351 1372
        self._w = self.w[use_ixs]
1352 1373
1353 1374
        family = getattr(robjects.r, self._family, None)
1354 -
        if family is None:
1375 +
        if family is None and self._family != "zinb":
1376 +
            logg.debug(
1377 +
                f"Unable to find distribution `{self._family!r}`. Defaulting to `'gaussian'`"
1378 +
            )
1379 +
            self._family = "gaussian"
1355 1380
            family = robjects.r.gaussian
1356 1381
1357 1382
        pandas2ri.activate()
@@ -1362,9 +1387,17 @@
Loading
1362 1387
                Formula(f'y ~ s(x, k={self._n_splines}, bs="cs")'),
1363 1388
                data=df,
1364 1389
                sp=self._sp,
1365 -
                family=family,
1390 +
                family=self._family,
1366 1391
                weights=pd.Series(self.w),
1367 1392
            )
1393 +
        elif self._lib_name == "brms":
1394 +
            self._model = self._lib.brm(
1395 +
                Formula(f'y ~ s(x, k={self._n_splines}, bs="cs")'),
1396 +
                data=df,
1397 +
                family=self._lib.zero_inflated_negbinomial(),
1398 +
                chains=4,
1399 +
                cores=4,
1400 +
            )
1368 1401
        elif self._lib_name == "gam":
1369 1402
            self._model = self._lib.gam(
1370 1403
                Formula("y ~ s(x)"),
@@ -1467,6 +1500,19 @@
Loading
1467 1500
        self.__dict__ = state
1468 1501
        self._lib, self._lib_name = _maybe_import_r_lib(self._lib_name, raise_exc=True)
1469 1502
1503 +
    def __str__(self) -> str:
1504 +
        return repr(self)
1505 +
1506 +
    def __repr__(self) -> str:
1507 +
        return "{}[{}]".format(
1508 +
            self.__class__.__name__,
1509 +
            None
1510 +
            if self.model is None
1511 +
            else f"gam[family={self._family!r}]"
1512 +
            if self._lib_name == "gam"
1513 +
            else _dup_spaces.sub(" ", str(self.model).replace("\n", " ")).strip(),
1514 +
        )
1515 +
1470 1516
1471 1517
def _maybe_import_r_lib(
1472 1518
    name: str, raise_exc: bool = False
@@ -1501,3 +1547,118 @@
Loading
1501 1547
        raise RuntimeError(
1502 1548
            f"Install R library `{name!r}` first as `install.packages({name!r}).`"
1503 1549
        ) from e
1550 +
1551 +
1552 +
class NegativeBinomial(Distribution):
1553 +
    """
1554 +
    Negative binomial distribution.
1555 +
1556 +
    Parameters
1557 +
    ----------
1558 +
    r
1559 +
        Number of successes.
1560 +
    """
1561 +
1562 +
    def __init__(self, r: float):
1563 +
        super().__init__(name="nb", scale=1.0)
1564 +
        if r is None:
1565 +
            raise ValueError(
1566 +
                "Number of successes `r` must an `int` or `float`, got `None`."
1567 +
            )
1568 +
        if r <= 0:
1569 +
            raise ValueError(f"Expected number of successes to be > 0, got `{r}`.")
1570 +
        self._r = r
1571 +
        self._exclude.append("scale")
1572 +
1573 +
    @d.dedent
1574 +
    def log_pdf(
1575 +
        self, y: np.ndarray, mu: np.ndarray, weights: Optional[np.ndarray] = None
1576 +
    ) -> np.ndarray:  # noqa
1577 +
        """
1578 +
        Computes the log of the pmf of the values under the current distribution.
1579 +
1580 +
        Parameters
1581 +
        ----------
1582 +
        %(nb_y)s
1583 +
        %(nb_mu)s
1584 +
        weights : array-like shape (n,) or None, default: None
1585 +
            Sample weights, array of shape `(n,)`. If `None`, defaults to array of 1s.
1586 +
1587 +
        Returns
1588 +
        -------
1589 +
            The log pmf.
1590 +
        """
1591 +
        p = mu / (mu + self._r)  # success prob
1592 +
1593 +
        return nbinom.logpmf(y, 1, p)
1594 +
1595 +
    @divide_weights
1596 +
    @d.dedent
1597 +
    def V(self, mu: np.ndarray, alpha: Optional[float] = None) -> np.ndarray:
1598 +
        """
1599 +
        Variance function of negative binomial distribution.
1600 +
1601 +
        Parameters
1602 +
        ----------
1603 +
        %(nb_mu)s
1604 +
        alpha
1605 +
            The ancillary parameter for the negative binomial variance function.
1606 +
            :paramref:`alpha` is assumed to be nonstochastic.
1607 +
1608 +
        Returns
1609 +
        -------
1610 +
            The variance, array of shape `(n,)`.
1611 +
        """
1612 +
        if alpha is None:
1613 +
            alpha = 1 / self._r
1614 +
1615 +
        return mu + alpha * (mu ** 2)
1616 +
1617 +
    @multiply_weights
1618 +
    @d.dedent
1619 +
    def deviance(self, y: np.ndarray, mu: np.ndarray, scaled=True) -> np.ndarray:
1620 +
        """
1621 +
        Model deviance.
1622 +
1623 +
        Parameters
1624 +
        ----------
1625 +
        %(nb_y)s
1626 +
        %(nb_mu)s
1627 +
        scaled
1628 +
            Whether to divide the deviance by the distribution scaled.
1629 +
1630 +
        Returns
1631 +
        -------
1632 +
           The deviances, array of shape `(n,)`.
1633 +
        """
1634 +
        dev = 2 * (
1635 +
            ylogydu(y, mu) - (y + self._r) * np.log((self._r + y) / (self._r + mu))
1636 +
        )
1637 +
        if scaled:
1638 +
            dev /= self.scale
1639 +
1640 +
        return dev
1641 +
1642 +
    @d.dedent
1643 +
    def sample(self, mu: np.ndarray) -> np.ndarray:
1644 +
        """
1645 +
        Return random samples from this distribution.
1646 +
1647 +
        Parameters
1648 +
        ----------
1649 +
        %(nb_mu)s
1650 +
1651 +
        Returns
1652 +
        -------
1653 +
            Random samples, array of the same shape as :paramref:`mu`.
1654 +
        """
1655 +
1656 +
        p = mu / (mu + self._r)  # success prob
1657 +
1658 +
        return np.random.negative_binomial(n=1, p=p, size=None)
1659 +
1660 +
    def __repr__(self):
1661 +
        return f"{self.__class__.__name__}(r={self._r})"
1662 +
1663 +
    def __str__(self):
1664 +
        return repr(self)

@@ -361,6 +361,7 @@
Loading
361 361
        color = list(V_.T)
362 362
        if cluster_key is not None:
363 363
            color = [cluster_key] + color
364 +
        kwargs["cmap"] = kwargs.pop("cmap", "viridis")
364 365
365 366
        logg.debug(f"Showing `{use}` {name}vectors")
366 367

Everything is accounted for!

No changes detected that need to be reviewed.
What changes does Codecov check for?
Lines, not adjusted in diff, that have changed coverage data.
Files that introduced coverage data that had none before.
Files that have missing coverage data that once were tracked.
Files Coverage
cellrank -0.34% 78.05%
setup.py 0.00%
Project Totals (46 files) 77.92%
Loading