@@ -1,5 +1,4 @@
Loading
1 1
import logging
2 -
import pickle
3 2
import random
4 3
from collections import Counter, defaultdict
5 4
from itertools import chain
@@ -11,12 +10,11 @@
Loading
11 10
import torch.optim as optim
12 11
from munkres import Munkres  # type: ignore
13 12
14 -
from snorkel.analysis import Scorer
15 13
from snorkel.labeling.analysis import LFAnalysis
14 +
from snorkel.labeling.model.base_labeler import BaseLabeler
16 15
from snorkel.labeling.model.graph_utils import get_clique_tree
17 16
from snorkel.labeling.model.logger import Logger
18 17
from snorkel.types import Config
19 -
from snorkel.utils import probs_to_preds
20 18
from snorkel.utils.config_utils import merge_config
21 19
from snorkel.utils.lr_schedulers import LRSchedulerConfig
22 20
from snorkel.utils.optimizers import OptimizerConfig
@@ -87,7 +85,7 @@
Loading
87 85
    max_cliques: Set[int]
88 86
89 87
90 -
class LabelModel(nn.Module):
88 +
class LabelModel(nn.Module, BaseLabeler):
91 89
    r"""A model for learning the LF accuracies and combining their output labels.
92 90
93 91
    This class learns a model of the labeling functions' conditional probabilities
@@ -454,11 +452,7 @@
Loading
454 452
        >>> label_model.predict(L)
455 453
        array([0, 1, 0])
456 454
        """
457 -
        Y_probs = self.predict_proba(L)
458 -
        Y_p = probs_to_preds(Y_probs, tie_break_policy)
459 -
        if return_probs:
460 -
            return Y_p, Y_probs
461 -
        return Y_p
455 +
        return super(LabelModel, self).predict(L, return_probs, tie_break_policy)
462 456
463 457
    def score(
464 458
        self,
@@ -496,18 +490,7 @@
Loading
496 490
        >>> label_model.score(L, Y=np.array([1, 1, 1]), metrics=["f1"])
497 491
        {'f1': 0.8}
498 492
        """
499 -
        if tie_break_policy == "abstain":  # pragma: no cover
500 -
            logging.warning(
501 -
                "Metrics calculated over data points with non-abstain labels only"
502 -
            )
503 -
504 -
        Y_pred, Y_prob = self.predict(
505 -
            L, return_probs=True, tie_break_policy=tie_break_policy
506 -
        )
507 -
508 -
        scorer = Scorer(metrics=metrics)
509 -
        results = scorer.score(Y, Y_pred, Y_prob)
510 -
        return results
493 +
        return super(LabelModel, self).score(L, Y, metrics, tie_break_policy)
511 494
512 495
    # These loss functions get all their data directly from the LabelModel
513 496
    # (for better or worse). The unused *args make these compatible with the
@@ -928,38 +911,3 @@
Loading
928 911
        # Print confusion matrix if applicable
929 912
        if self.config.verbose:  # pragma: no cover
930 913
            logging.info("Finished Training")
931 -
932 -
    def save(self, destination: str) -> None:
933 -
        """Save label model.
934 -
935 -
        Parameters
936 -
        ----------
937 -
        destination
938 -
            Filename for saving model
939 -
940 -
        Example
941 -
        -------
942 -
        >>> label_model.save('./saved_label_model.pkl')  # doctest: +SKIP
943 -
        """
944 -
        f = open(destination, "wb")
945 -
        pickle.dump(self.__dict__, f)
946 -
        f.close()
947 -
948 -
    def load(self, source: str) -> None:
949 -
        """Load existing label model.
950 -
951 -
        Parameters
952 -
        ----------
953 -
        source
954 -
            Filename to load model from
955 -
956 -
        Example
957 -
        -------
958 -
        Load parameters saved in ``saved_label_model``
959 -
960 -
        >>> label_model.load('./saved_label_model.pkl')  # doctest: +SKIP
961 -
        """
962 -
        f = open(source, "rb")
963 -
        tmp_dict = pickle.load(f)
964 -
        f.close()
965 -
        self.__dict__.update(tmp_dict)

@@ -2,26 +2,10 @@
Loading
2 2
3 3
import numpy as np
4 4
5 -
from snorkel.labeling.model.label_model import LabelModel
5 +
from snorkel.labeling.model.base_labeler import BaseLabeler
6 6
7 7
8 -
class BaselineVoter(LabelModel):
9 -
    """Parent baseline label model class with method fit()."""
10 -
11 -
    def fit(self, *args: Any, **kwargs: Any) -> None:
12 -
        """Train majority class model.
13 -
14 -
        Set class balance for majority class label model.
15 -
16 -
        Parameters
17 -
        ----------
18 -
        balance
19 -
            A [k] array of class probabilities
20 -
        """
21 -
        pass
22 -
23 -
24 -
class RandomVoter(BaselineVoter):
8 +
class RandomVoter(BaseLabeler):
25 9
    """Random vote label model.
26 10
27 11
    Example
@@ -57,7 +41,7 @@
Loading
57 41
        return Y_p
58 42
59 43
60 -
class MajorityClassVoter(LabelModel):
44 +
class MajorityClassVoter(BaseLabeler):
61 45
    """Majority class label model."""
62 46
63 47
    def fit(  # type: ignore
@@ -110,7 +94,7 @@
Loading
110 94
        return Y_p
111 95
112 96
113 -
class MajorityLabelVoter(BaselineVoter):
97 +
class MajorityLabelVoter(BaseLabeler):
114 98
    """Majority vote label model."""
115 99
116 100
    def predict_proba(self, L: np.ndarray) -> np.ndarray:

@@ -0,0 +1,146 @@
Loading
1 +
import logging
2 +
import pickle
3 +
from abc import ABC, abstractmethod
4 +
from typing import Any, Dict, List, Optional, Tuple, Union
5 +
6 +
import numpy as np
7 +
8 +
from snorkel.analysis import Scorer
9 +
from snorkel.utils import probs_to_preds
10 +
11 +
12 +
class BaseLabeler(ABC):
13 +
    """Abstract baseline label voter class."""
14 +
15 +
    def __init__(self, cardinality: int = 2, **kwargs: Any) -> None:
16 +
        self.cardinality = cardinality
17 +
18 +
    @abstractmethod
19 +
    def predict_proba(self, L: np.ndarray) -> np.ndarray:
20 +
        """Abstract method for predicting probabilistic labels given a label matrix.
21 +
22 +
        Parameters
23 +
        ----------
24 +
        L
25 +
            An [n,m] matrix with values in {-1,0,1,...,k-1}f
26 +
27 +
        Returns
28 +
        -------
29 +
        np.ndarray
30 +
            An [n,k] array of probabilistic labels
31 +
        """
32 +
        pass
33 +
34 +
    def predict(
35 +
        self,
36 +
        L: np.ndarray,
37 +
        return_probs: Optional[bool] = False,
38 +
        tie_break_policy: str = "abstain",
39 +
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
40 +
        """Return predicted labels, with ties broken according to policy.
41 +
42 +
        Policies to break ties include:
43 +
        "abstain": return an abstain vote (-1)
44 +
        "true-random": randomly choose among the tied options
45 +
        "random": randomly choose among tied option using deterministic hash
46 +
47 +
        NOTE: if tie_break_policy="true-random", repeated runs may have slightly different
48 +
        results due to difference in broken ties
49 +
50 +
51 +
        Parameters
52 +
        ----------
53 +
        L
54 +
            An [n,m] matrix with values in {-1,0,1,...,k-1}
55 +
        return_probs
56 +
            Whether to return probs along with preds
57 +
        tie_break_policy
58 +
            Policy to break ties when converting probabilistic labels to predictions
59 +
60 +
        Returns
61 +
        -------
62 +
        np.ndarray
63 +
            An [n,1] array of integer labels
64 +
65 +
        (np.ndarray, np.ndarray)
66 +
            An [n,1] array of integer labels and an [n,k] array of probabilistic labels
67 +
        """
68 +
        Y_probs = self.predict_proba(L)
69 +
        Y_p = probs_to_preds(Y_probs, tie_break_policy)
70 +
        if return_probs:
71 +
            return Y_p, Y_probs
72 +
        return Y_p
73 +
74 +
    def score(
75 +
        self,
76 +
        L: np.ndarray,
77 +
        Y: np.ndarray,
78 +
        metrics: Optional[List[str]] = ["accuracy"],
79 +
        tie_break_policy: str = "abstain",
80 +
    ) -> Dict[str, float]:
81 +
        """Calculate one or more scores from user-specified and/or user-defined metrics.
82 +
83 +
        Parameters
84 +
        ----------
85 +
        L
86 +
            An [n,m] matrix with values in {-1,0,1,...,k-1}
87 +
        Y
88 +
            Gold labels associated with data points in L
89 +
        metrics
90 +
            A list of metric names
91 +
        tie_break_policy
92 +
            Policy to break ties when converting probabilistic labels to predictions
93 +
94 +
95 +
        Returns
96 +
        -------
97 +
        Dict[str, float]
98 +
            A dictionary mapping metric names to metric scores
99 +
        """
100 +
        if tie_break_policy == "abstain":  # pragma: no cover
101 +
            logging.warning(
102 +
                "Metrics calculated over data points with non-abstain labels only"
103 +
            )
104 +
105 +
        Y_pred, Y_prob = self.predict(
106 +
            L, return_probs=True, tie_break_policy=tie_break_policy
107 +
        )
108 +
109 +
        scorer = Scorer(metrics=metrics)
110 +
        results = scorer.score(Y, Y_pred, Y_prob)
111 +
        return results
112 +
113 +
    def save(self, destination: str) -> None:
114 +
        """Save label model.
115 +
116 +
        Parameters
117 +
        ----------
118 +
        destination
119 +
            Filename for saving model
120 +
121 +
        Example
122 +
        -------
123 +
        >>> label_model.save('./saved_label_model.pkl')  # doctest: +SKIP
124 +
        """
125 +
        f = open(destination, "wb")
126 +
        pickle.dump(self.__dict__, f)
127 +
        f.close()
128 +
129 +
    def load(self, source: str) -> None:
130 +
        """Load existing label model.
131 +
132 +
        Parameters
133 +
        ----------
134 +
        source
135 +
            Filename to load model from
136 +
137 +
        Example
138 +
        -------
139 +
        Load parameters saved in ``saved_label_model``
140 +
141 +
        >>> label_model.load('./saved_label_model.pkl')  # doctest: +SKIP
142 +
        """
143 +
        f = open(source, "rb")
144 +
        tmp_dict = pickle.load(f)
145 +
        f.close()
146 +
        self.__dict__.update(tmp_dict)
Files Coverage
snorkel 97.13%
Project Totals (56 files) 97.13%
278.1
TRAVIS_PYTHON_VERSION=3.6
TRAVIS_OS_NAME=linux
TOXENV=coverage,doctest,type,check
278.2
TRAVIS_PYTHON_VERSION=3.7
TRAVIS_OS_NAME=linux
TOXENV=coverage,doctest,type,check
1
coverage:
2
  status:
3
    project:
4
      default:
5
        target: 95%
6
    patch:
7
      default:
8
        threshold: 2%
9

10
comment:
11
  layout: "header, diff, flags, files"
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading