- Move predict_proba to BaseLabeler class as abtractmethod
- Move predict, score, save, and load methods to BaseLabeler class as shared methods in parent class
- Update RandomVoter, MajorityClassVoter, MajorityLabelVoter, and LabelModel to be subclasses of BaseLabeler.
Showing 3 of 3 files from the diff.
Newly tracked file
snorkel/labeling/model/base_labeler.py
created.
@@ -1,5 +1,4 @@
Loading
1 | 1 | import logging |
|
2 | - | import pickle |
|
3 | 2 | import random |
|
4 | 3 | from collections import Counter, defaultdict |
|
5 | 4 | from itertools import chain |
@@ -11,12 +10,11 @@
Loading
11 | 10 | import torch.optim as optim |
|
12 | 11 | from munkres import Munkres # type: ignore |
|
13 | 12 | ||
14 | - | from snorkel.analysis import Scorer |
|
15 | 13 | from snorkel.labeling.analysis import LFAnalysis |
|
14 | + | from snorkel.labeling.model.base_labeler import BaseLabeler |
|
16 | 15 | from snorkel.labeling.model.graph_utils import get_clique_tree |
|
17 | 16 | from snorkel.labeling.model.logger import Logger |
|
18 | 17 | from snorkel.types import Config |
|
19 | - | from snorkel.utils import probs_to_preds |
|
20 | 18 | from snorkel.utils.config_utils import merge_config |
|
21 | 19 | from snorkel.utils.lr_schedulers import LRSchedulerConfig |
|
22 | 20 | from snorkel.utils.optimizers import OptimizerConfig |
@@ -87,7 +85,7 @@
Loading
87 | 85 | max_cliques: Set[int] |
|
88 | 86 | ||
89 | 87 | ||
90 | - | class LabelModel(nn.Module): |
|
88 | + | class LabelModel(nn.Module, BaseLabeler): |
|
91 | 89 | r"""A model for learning the LF accuracies and combining their output labels. |
|
92 | 90 | ||
93 | 91 | This class learns a model of the labeling functions' conditional probabilities |
@@ -454,11 +452,7 @@
Loading
454 | 452 | >>> label_model.predict(L) |
|
455 | 453 | array([0, 1, 0]) |
|
456 | 454 | """ |
|
457 | - | Y_probs = self.predict_proba(L) |
|
458 | - | Y_p = probs_to_preds(Y_probs, tie_break_policy) |
|
459 | - | if return_probs: |
|
460 | - | return Y_p, Y_probs |
|
461 | - | return Y_p |
|
455 | + | return super(LabelModel, self).predict(L, return_probs, tie_break_policy) |
|
462 | 456 | ||
463 | 457 | def score( |
|
464 | 458 | self, |
@@ -496,18 +490,7 @@
Loading
496 | 490 | >>> label_model.score(L, Y=np.array([1, 1, 1]), metrics=["f1"]) |
|
497 | 491 | {'f1': 0.8} |
|
498 | 492 | """ |
|
499 | - | if tie_break_policy == "abstain": # pragma: no cover |
|
500 | - | logging.warning( |
|
501 | - | "Metrics calculated over data points with non-abstain labels only" |
|
502 | - | ) |
|
503 | - | ||
504 | - | Y_pred, Y_prob = self.predict( |
|
505 | - | L, return_probs=True, tie_break_policy=tie_break_policy |
|
506 | - | ) |
|
507 | - | ||
508 | - | scorer = Scorer(metrics=metrics) |
|
509 | - | results = scorer.score(Y, Y_pred, Y_prob) |
|
510 | - | return results |
|
493 | + | return super(LabelModel, self).score(L, Y, metrics, tie_break_policy) |
|
511 | 494 | ||
512 | 495 | # These loss functions get all their data directly from the LabelModel |
|
513 | 496 | # (for better or worse). The unused *args make these compatible with the |
@@ -928,38 +911,3 @@
Loading
928 | 911 | # Print confusion matrix if applicable |
|
929 | 912 | if self.config.verbose: # pragma: no cover |
|
930 | 913 | logging.info("Finished Training") |
|
931 | - | ||
932 | - | def save(self, destination: str) -> None: |
|
933 | - | """Save label model. |
|
934 | - | ||
935 | - | Parameters |
|
936 | - | ---------- |
|
937 | - | destination |
|
938 | - | Filename for saving model |
|
939 | - | ||
940 | - | Example |
|
941 | - | ------- |
|
942 | - | >>> label_model.save('./saved_label_model.pkl') # doctest: +SKIP |
|
943 | - | """ |
|
944 | - | f = open(destination, "wb") |
|
945 | - | pickle.dump(self.__dict__, f) |
|
946 | - | f.close() |
|
947 | - | ||
948 | - | def load(self, source: str) -> None: |
|
949 | - | """Load existing label model. |
|
950 | - | ||
951 | - | Parameters |
|
952 | - | ---------- |
|
953 | - | source |
|
954 | - | Filename to load model from |
|
955 | - | ||
956 | - | Example |
|
957 | - | ------- |
|
958 | - | Load parameters saved in ``saved_label_model`` |
|
959 | - | ||
960 | - | >>> label_model.load('./saved_label_model.pkl') # doctest: +SKIP |
|
961 | - | """ |
|
962 | - | f = open(source, "rb") |
|
963 | - | tmp_dict = pickle.load(f) |
|
964 | - | f.close() |
|
965 | - | self.__dict__.update(tmp_dict) |
@@ -2,26 +2,10 @@
Loading
2 | 2 | ||
3 | 3 | import numpy as np |
|
4 | 4 | ||
5 | - | from snorkel.labeling.model.label_model import LabelModel |
|
5 | + | from snorkel.labeling.model.base_labeler import BaseLabeler |
|
6 | 6 | ||
7 | 7 | ||
8 | - | class BaselineVoter(LabelModel): |
|
9 | - | """Parent baseline label model class with method fit().""" |
|
10 | - | ||
11 | - | def fit(self, *args: Any, **kwargs: Any) -> None: |
|
12 | - | """Train majority class model. |
|
13 | - | ||
14 | - | Set class balance for majority class label model. |
|
15 | - | ||
16 | - | Parameters |
|
17 | - | ---------- |
|
18 | - | balance |
|
19 | - | A [k] array of class probabilities |
|
20 | - | """ |
|
21 | - | pass |
|
22 | - | ||
23 | - | ||
24 | - | class RandomVoter(BaselineVoter): |
|
8 | + | class RandomVoter(BaseLabeler): |
|
25 | 9 | """Random vote label model. |
|
26 | 10 | ||
27 | 11 | Example |
@@ -57,7 +41,7 @@
Loading
57 | 41 | return Y_p |
|
58 | 42 | ||
59 | 43 | ||
60 | - | class MajorityClassVoter(LabelModel): |
|
44 | + | class MajorityClassVoter(BaseLabeler): |
|
61 | 45 | """Majority class label model.""" |
|
62 | 46 | ||
63 | 47 | def fit( # type: ignore |
@@ -110,7 +94,7 @@
Loading
110 | 94 | return Y_p |
|
111 | 95 | ||
112 | 96 | ||
113 | - | class MajorityLabelVoter(BaselineVoter): |
|
97 | + | class MajorityLabelVoter(BaseLabeler): |
|
114 | 98 | """Majority vote label model.""" |
|
115 | 99 | ||
116 | 100 | def predict_proba(self, L: np.ndarray) -> np.ndarray: |
@@ -0,0 +1,146 @@
Loading
1 | + | import logging |
|
2 | + | import pickle |
|
3 | + | from abc import ABC, abstractmethod |
|
4 | + | from typing import Any, Dict, List, Optional, Tuple, Union |
|
5 | + | ||
6 | + | import numpy as np |
|
7 | + | ||
8 | + | from snorkel.analysis import Scorer |
|
9 | + | from snorkel.utils import probs_to_preds |
|
10 | + | ||
11 | + | ||
12 | + | class BaseLabeler(ABC): |
|
13 | + | """Abstract baseline label voter class.""" |
|
14 | + | ||
15 | + | def __init__(self, cardinality: int = 2, **kwargs: Any) -> None: |
|
16 | + | self.cardinality = cardinality |
|
17 | + | ||
18 | + | @abstractmethod |
|
19 | + | def predict_proba(self, L: np.ndarray) -> np.ndarray: |
|
20 | + | """Abstract method for predicting probabilistic labels given a label matrix. |
|
21 | + | ||
22 | + | Parameters |
|
23 | + | ---------- |
|
24 | + | L |
|
25 | + | An [n,m] matrix with values in {-1,0,1,...,k-1}f |
|
26 | + | ||
27 | + | Returns |
|
28 | + | ------- |
|
29 | + | np.ndarray |
|
30 | + | An [n,k] array of probabilistic labels |
|
31 | + | """ |
|
32 | + | pass |
|
33 | + | ||
34 | + | def predict( |
|
35 | + | self, |
|
36 | + | L: np.ndarray, |
|
37 | + | return_probs: Optional[bool] = False, |
|
38 | + | tie_break_policy: str = "abstain", |
|
39 | + | ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: |
|
40 | + | """Return predicted labels, with ties broken according to policy. |
|
41 | + | ||
42 | + | Policies to break ties include: |
|
43 | + | "abstain": return an abstain vote (-1) |
|
44 | + | "true-random": randomly choose among the tied options |
|
45 | + | "random": randomly choose among tied option using deterministic hash |
|
46 | + | ||
47 | + | NOTE: if tie_break_policy="true-random", repeated runs may have slightly different |
|
48 | + | results due to difference in broken ties |
|
49 | + | ||
50 | + | ||
51 | + | Parameters |
|
52 | + | ---------- |
|
53 | + | L |
|
54 | + | An [n,m] matrix with values in {-1,0,1,...,k-1} |
|
55 | + | return_probs |
|
56 | + | Whether to return probs along with preds |
|
57 | + | tie_break_policy |
|
58 | + | Policy to break ties when converting probabilistic labels to predictions |
|
59 | + | ||
60 | + | Returns |
|
61 | + | ------- |
|
62 | + | np.ndarray |
|
63 | + | An [n,1] array of integer labels |
|
64 | + | ||
65 | + | (np.ndarray, np.ndarray) |
|
66 | + | An [n,1] array of integer labels and an [n,k] array of probabilistic labels |
|
67 | + | """ |
|
68 | + | Y_probs = self.predict_proba(L) |
|
69 | + | Y_p = probs_to_preds(Y_probs, tie_break_policy) |
|
70 | + | if return_probs: |
|
71 | + | return Y_p, Y_probs |
|
72 | + | return Y_p |
|
73 | + | ||
74 | + | def score( |
|
75 | + | self, |
|
76 | + | L: np.ndarray, |
|
77 | + | Y: np.ndarray, |
|
78 | + | metrics: Optional[List[str]] = ["accuracy"], |
|
79 | + | tie_break_policy: str = "abstain", |
|
80 | + | ) -> Dict[str, float]: |
|
81 | + | """Calculate one or more scores from user-specified and/or user-defined metrics. |
|
82 | + | ||
83 | + | Parameters |
|
84 | + | ---------- |
|
85 | + | L |
|
86 | + | An [n,m] matrix with values in {-1,0,1,...,k-1} |
|
87 | + | Y |
|
88 | + | Gold labels associated with data points in L |
|
89 | + | metrics |
|
90 | + | A list of metric names |
|
91 | + | tie_break_policy |
|
92 | + | Policy to break ties when converting probabilistic labels to predictions |
|
93 | + | ||
94 | + | ||
95 | + | Returns |
|
96 | + | ------- |
|
97 | + | Dict[str, float] |
|
98 | + | A dictionary mapping metric names to metric scores |
|
99 | + | """ |
|
100 | + | if tie_break_policy == "abstain": # pragma: no cover |
|
101 | + | logging.warning( |
|
102 | + | "Metrics calculated over data points with non-abstain labels only" |
|
103 | + | ) |
|
104 | + | ||
105 | + | Y_pred, Y_prob = self.predict( |
|
106 | + | L, return_probs=True, tie_break_policy=tie_break_policy |
|
107 | + | ) |
|
108 | + | ||
109 | + | scorer = Scorer(metrics=metrics) |
|
110 | + | results = scorer.score(Y, Y_pred, Y_prob) |
|
111 | + | return results |
|
112 | + | ||
113 | + | def save(self, destination: str) -> None: |
|
114 | + | """Save label model. |
|
115 | + | ||
116 | + | Parameters |
|
117 | + | ---------- |
|
118 | + | destination |
|
119 | + | Filename for saving model |
|
120 | + | ||
121 | + | Example |
|
122 | + | ------- |
|
123 | + | >>> label_model.save('./saved_label_model.pkl') # doctest: +SKIP |
|
124 | + | """ |
|
125 | + | f = open(destination, "wb") |
|
126 | + | pickle.dump(self.__dict__, f) |
|
127 | + | f.close() |
|
128 | + | ||
129 | + | def load(self, source: str) -> None: |
|
130 | + | """Load existing label model. |
|
131 | + | ||
132 | + | Parameters |
|
133 | + | ---------- |
|
134 | + | source |
|
135 | + | Filename to load model from |
|
136 | + | ||
137 | + | Example |
|
138 | + | ------- |
|
139 | + | Load parameters saved in ``saved_label_model`` |
|
140 | + | ||
141 | + | >>> label_model.load('./saved_label_model.pkl') # doctest: +SKIP |
|
142 | + | """ |
|
143 | + | f = open(source, "rb") |
|
144 | + | tmp_dict = pickle.load(f) |
|
145 | + | f.close() |
|
146 | + | self.__dict__.update(tmp_dict) |
Files | Coverage |
---|---|
snorkel | 97.13% |
Project Totals (56 files) | 97.13% |
278.1
TRAVIS_PYTHON_VERSION=3.6 TRAVIS_OS_NAME=linux TOXENV=coverage,doctest,type,check
278.2
TRAVIS_PYTHON_VERSION=3.7 TRAVIS_OS_NAME=linux TOXENV=coverage,doctest,type,check
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file.
The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files.
The size and color of each slice is representing the number of statements and the coverage, respectively.