1 2
import logging
2 2
import random
3 2
from collections import Counter, defaultdict
4 2
from itertools import chain
5 2
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Set, Tuple, Union
6

7 2
import numpy as np
8 2
import torch
9 2
import torch.nn as nn
10 2
import torch.optim as optim
11 2
from munkres import Munkres  # type: ignore
12

13 2
from snorkel.labeling.analysis import LFAnalysis
14 2
from snorkel.labeling.model.base_labeler import BaseLabeler
15 2
from snorkel.labeling.model.graph_utils import get_clique_tree
16 2
from snorkel.labeling.model.logger import Logger
17 2
from snorkel.types import Config
18 2
from snorkel.utils.config_utils import merge_config
19 2
from snorkel.utils.lr_schedulers import LRSchedulerConfig
20 2
from snorkel.utils.optimizers import OptimizerConfig
21

22 2
Metrics = Dict[str, float]
23

24

25 2
class TrainConfig(Config):
26
    """Settings for the fit() method of LabelModel.
27

28
    Parameters
29
    ----------
30
    n_epochs
31
        The number of epochs to train (where each epoch is a single optimization step)
32
    lr
33
        Base learning rate (will also be affected by lr_scheduler choice and settings)
34
    l2
35
        Centered L2 regularization strength
36
    optimizer
37
        Which optimizer to use (one of ["sgd", "adam", "adamax"])
38
    optimizer_config
39
        Settings for the optimizer
40
    lr_scheduler
41
        Which lr_scheduler to use (one of ["constant", "linear", "exponential", "step"])
42
    lr_scheduler_config
43
        Settings for the LRScheduler
44
    prec_init
45
        LF precision initializations / priors
46
    seed
47
        A random seed to initialize the random number generator with
48
    log_freq
49
        Report loss every this many epochs (steps)
50
    mu_eps
51
        Restrict the learned conditional probabilities to [mu_eps, 1-mu_eps]
52
    """
53

54 2
    n_epochs: int = 100
55 2
    lr: float = 0.01
56 2
    l2: float = 0.0
57 2
    optimizer: str = "sgd"
58 2
    optimizer_config: OptimizerConfig = OptimizerConfig()  # type: ignore
59 2
    lr_scheduler: str = "constant"
60 2
    lr_scheduler_config: LRSchedulerConfig = LRSchedulerConfig()  # type: ignore
61 2
    prec_init: float = 0.7
62 2
    seed: int = np.random.randint(1e6)
63 2
    log_freq: int = 10
64 2
    mu_eps: Optional[float] = None
65

66

67 2
class LabelModelConfig(Config):
68
    """Settings for the LabelModel initialization.
69

70
    Parameters
71
    ----------
72
    verbose
73
        Whether to include print statements
74
    device
75
        What device to place the model on ('cpu' or 'cuda:0', for example)
76
    """
77

78 2
    verbose: bool = True
79 2
    device: str = "cpu"
80

81

82 2
class _CliqueData(NamedTuple):
83 2
    start_index: int
84 2
    end_index: int
85 2
    max_cliques: Set[int]
86

87

88 2
class LabelModel(nn.Module, BaseLabeler):
89
    r"""A model for learning the LF accuracies and combining their output labels.
90

91
    This class learns a model of the labeling functions' conditional probabilities
92
    of outputting the true (unobserved) label `Y`, `P(\lf | Y)`, and uses this learned
93
    model to re-weight and combine their output labels.
94

95
    This class is based on the approach in [Training Complex Models with Multi-Task
96
    Weak Supervision](https://arxiv.org/abs/1810.02840), published in AAAI'19. In this
97
    approach, we compute the inverse generalized covariance matrix of the junction tree
98
    of a given LF dependency graph, and perform a matrix completion-style approach with
99
    respect to these empirical statistics. The result is an estimate of the conditional
100
    LF probabilities, `P(\lf | Y)`, which are then set as the parameters of the label
101
    model used to re-weight and combine the labels output by the LFs.
102

103
    Currently this class uses a conditionally independent label model, in which the LFs
104
    are assumed to be conditionally independent given `Y`.
105

106
    Examples
107
    --------
108
    >>> label_model = LabelModel()
109
    >>> label_model = LabelModel(cardinality=3)
110
    >>> label_model = LabelModel(cardinality=3, device='cpu')
111
    >>> label_model = LabelModel(cardinality=3)
112

113
    Parameters
114
    ----------
115
    cardinality
116
        Number of classes, by default 2
117
    **kwargs
118
        Arguments for changing config defaults
119

120
    Raises
121
    ------
122
    ValueError
123
        If config device set to cuda but only cpu is available
124

125
    Attributes
126
    ----------
127
    cardinality
128
        Number of classes, by default 2
129
    config
130
        Training configuration
131
    seed
132
        Random seed
133
    """
134

135 2
    def __init__(self, cardinality: int = 2, **kwargs: Any) -> None:
136 2
        super().__init__()
137 2
        self.config: LabelModelConfig = LabelModelConfig(**kwargs)
138 2
        self.cardinality = cardinality
139

140
        # Confirm that cuda is available if config is using CUDA
141 2
        if self.config.device != "cpu" and not torch.cuda.is_available():
142 0
            raise ValueError("device=cuda but CUDA not available.")
143

144
        # By default, put model in eval mode; switch to train mode in training
145 2
        self.eval()
146

147 2
    def _create_L_ind(self, L: np.ndarray) -> np.ndarray:
148
        """Convert a label matrix with labels in 0...k to a one-hot format.
149

150
        Parameters
151
        ----------
152
        L
153
            An [n,m] label matrix with values in {0,1,...,k}
154

155
        Returns
156
        -------
157
        np.ndarray
158
            An [n,m*k] dense np.ndarray with values in {0,1}
159
        """
160 2
        L_ind = np.zeros((self.n, self.m * self.cardinality))
161 2
        for y in range(1, self.cardinality + 1):
162
            # A[x::y] slices A starting at x at intervals of y
163
            # e.g., np.arange(9)[0::3] == np.array([0,3,6])
164 2
            L_ind[:, (y - 1) :: self.cardinality] = np.where(L == y, 1, 0)
165 2
        return L_ind
166

167 2
    def _get_augmented_label_matrix(
168
        self, L: np.ndarray, higher_order: bool = False
169
    ) -> np.ndarray:
170
        """Create augmented version of label matrix.
171

172
        In augmented version, each column is an indicator
173
        for whether a certain source or clique of sources voted in a certain
174
        pattern.
175

176
        Parameters
177
        ----------
178
        L
179
            An [n,m] label matrix with values in {0,1,...,k}
180
        higher_order
181
            Whether to include higher-order correlations (e.g. LF pairs) in matrix
182

183
        Returns
184
        -------
185
        np.ndarray
186
            An [n,m*k] dense matrix with values in {0,1}
187
        """
188
        # Create a helper data structure which maps cliques (as tuples of member
189
        # sources) --> {start_index, end_index, maximal_cliques}, where
190
        # the last value is a set of indices in this data structure
191 2
        self.c_data: Dict[int, _CliqueData] = {}
192 2
        for i in range(self.m):
193 2
            self.c_data[i] = _CliqueData(
194
                start_index=i * self.cardinality,
195
                end_index=(i + 1) * self.cardinality,
196
                max_cliques=set(
197
                    [
198
                        j
199
                        for j in self.c_tree.nodes()
200
                        if i in self.c_tree.node[j]["members"]
201
                    ]
202
                ),
203
            )
204

205 2
        L_ind = self._create_L_ind(L)
206

207
        # Get the higher-order clique statistics based on the clique tree
208
        # First, iterate over the maximal cliques (nodes of c_tree) and
209
        # separator sets (edges of c_tree)
210 2
        if higher_order:
211 2
            L_aug = np.copy(L_ind)
212 2
            for item in chain(self.c_tree.nodes(), self.c_tree.edges()):
213 2
                if isinstance(item, int):
214 2
                    C = self.c_tree.node[item]
215 2
                elif isinstance(item, tuple):
216 0
                    C = self.c_tree[item[0]][item[1]]
217
                else:
218 0
                    raise ValueError(item)
219 2
                members = list(C["members"])
220

221
                # With unary maximal clique, just store its existing index
222 2
                C["start_index"] = members[0] * self.cardinality
223 2
                C["end_index"] = (members[0] + 1) * self.cardinality
224 2
            return L_aug
225
        else:
226 2
            return L_ind
227

228 2
    def _build_mask(self) -> None:
229
        """Build mask applied to O^{-1}, O for the matrix approx constraint."""
230 2
        self.mask = torch.ones(self.d, self.d).bool()
231 2
        for ci in self.c_data.values():
232 2
            si = ci.start_index
233 2
            ei = ci.end_index
234 2
            for cj in self.c_data.values():
235 2
                sj, ej = cj.start_index, cj.end_index
236

237
                # Check if ci and cj are part of the same maximal clique
238
                # If so, mask out their corresponding blocks in O^{-1}
239 2
                if len(ci.max_cliques.intersection(cj.max_cliques)) > 0:
240 2
                    self.mask[si:ei, sj:ej] = 0
241 2
                    self.mask[sj:ej, si:ei] = 0
242

243 2
    def _generate_O(self, L: np.ndarray, higher_order: bool = False) -> None:
244
        """Generate overlaps and conflicts matrix from label matrix.
245

246
        Parameters
247
        ----------
248
        L
249
            An [n,m] label matrix with values in {0,1,...,k}
250
        higher_order
251
            Whether to include higher-order correlations (e.g. LF pairs) in matrix
252
        """
253 2
        L_aug = self._get_augmented_label_matrix(L, higher_order=higher_order)
254 2
        self.d = L_aug.shape[1]
255 2
        self.O = (
256
            torch.from_numpy(L_aug.T @ L_aug / self.n).float().to(self.config.device)
257
        )
258

259 2
    def _init_params(self) -> None:
260
        r"""Initialize the learned params.
261

262
        - \mu is the primary learned parameter, where each row corresponds to
263
        the probability of a clique C emitting a specific combination of labels,
264
        conditioned on different values of Y (for each column); that is:
265

266
            self.mu[i*self.cardinality + j, y] = P(\lambda_i = j | Y = y)
267

268
        and similarly for higher-order cliques.
269

270
        Raises
271
        ------
272
        ValueError
273
            If prec_init shape does not match number of LFs
274
        """
275
        # Initialize mu so as to break basic reflective symmetry
276
        # Note that we are given either a single or per-LF initial precision
277
        # value, prec_i = P(Y=y|\lf=y), and use:
278
        #   mu_init = P(\lf=y|Y=y) = P(\lf=y) * prec_i / P(Y=y)
279

280
        # Handle single values
281 2
        if isinstance(self.train_config.prec_init, (int, float)):
282 2
            self._prec_init = self.train_config.prec_init * torch.ones(self.m)
283 2
        if self._prec_init.shape[0] != self.m:
284 0
            raise ValueError(f"prec_init must have shape {self.m}.")
285

286
        # Get the per-value labeling propensities
287
        # Note that self.O must have been computed already!
288 2
        lps = torch.diag(self.O).cpu().detach().numpy()
289

290
        # TODO: Update for higher-order cliques!
291 2
        self.mu_init = torch.zeros(self.d, self.cardinality)
292 2
        for i in range(self.m):
293 2
            for y in range(self.cardinality):
294 2
                idx = i * self.cardinality + y
295 2
                mu_init = torch.clamp(lps[idx] * self._prec_init[i] / self.p[y], 0, 1)
296 2
                self.mu_init[idx, y] += mu_init
297

298
        # Initialize randomly based on self.mu_init
299 2
        self.mu = nn.Parameter(  # type: ignore
300
            self.mu_init.clone() * np.random.random()
301
        ).float()
302

303
        # Build the mask over O^{-1}
304 2
        self._build_mask()
305

306 2
    def _get_conditional_probs(self, mu: np.ndarray) -> np.ndarray:
307
        r"""Return the estimated conditional probabilities table given parameters mu.
308

309
        Given a parameter vector mu, return the estimated conditional probabilites
310
        table cprobs, where cprobs is an (m, k+1, k)-dim np.ndarray with:
311

312
            cprobs[i, j, k] = P(\lf_i = j-1 | Y = k)
313

314
        where m is the number of LFs, k is the cardinality, and cprobs includes the
315
        conditional abstain probabilities P(\lf_i = -1 | Y = y).
316

317
        Parameters
318
        ----------
319
        mu
320
            An [m * k, k] np.ndarray with entries in [0, 1]
321

322
        Returns
323
        -------
324
        np.ndarray
325
            An [m, k + 1, k] np.ndarray conditional probabilities table.
326
        """
327 2
        cprobs = np.zeros((self.m, self.cardinality + 1, self.cardinality))
328 2
        for i in range(self.m):
329
            # si = self.c_data[(i,)]['start_index']
330
            # ei = self.c_data[(i,)]['end_index']
331
            # mu_i = mu[si:ei, :]
332 2
            mu_i = mu[i * self.cardinality : (i + 1) * self.cardinality, :]
333 2
            cprobs[i, 1:, :] = mu_i
334

335
            # The 0th row (corresponding to abstains) is the difference between
336
            # the sums of the other rows and one, by law of total probability
337 2
            cprobs[i, 0, :] = 1 - mu_i.sum(axis=0)
338 2
        return cprobs
339

340 2
    def get_conditional_probs(self) -> np.ndarray:
341
        r"""Return the estimated conditional probabilities table.
342

343
        Return the estimated conditional probabilites table cprobs, where cprobs is an
344
        (m, k+1, k)-dim np.ndarray with:
345

346
            cprobs[i, j, k] = P(\lf_i = j-1 | Y = k)
347

348
        where m is the number of LFs, k is the cardinality, and cprobs includes the
349
        conditional abstain probabilities P(\lf_i = -1 | Y = y).
350

351
        Returns
352
        -------
353
        np.ndarray
354
            An [m, k + 1, k] np.ndarray conditional probabilities table.
355
        """
356 2
        return self._get_conditional_probs(self.mu.cpu().detach().numpy())
357

358 2
    def get_weights(self) -> np.ndarray:
359
        """Return the vector of learned LF weights for combining LFs.
360

361
        Returns
362
        -------
363
        np.ndarray
364
            [m,1] vector of learned LF weights for combining LFs.
365

366
        Example
367
        -------
368
        >>> L = np.array([[1, 1, 1], [1, 1, -1], [-1, 0, 0], [0, 0, 0]])
369
        >>> label_model = LabelModel(verbose=False)
370
        >>> label_model.fit(L, seed=123)
371
        >>> np.around(label_model.get_weights(), 2)  # doctest: +SKIP
372
        array([0.99, 0.99, 0.99])
373
        """
374 2
        accs = np.zeros(self.m)
375 2
        cprobs = self.get_conditional_probs()
376 2
        for i in range(self.m):
377 2
            accs[i] = np.diag(cprobs[i, 1:, :] @ self.P.cpu().detach().numpy()).sum()
378 2
        return np.clip(accs / self.coverage, 1e-6, 1.0)
379

380 2
    def predict_proba(self, L: np.ndarray) -> np.ndarray:
381
        r"""Return label probabilities P(Y | \lambda).
382

383
        Parameters
384
        ----------
385
        L
386
            An [n,m] matrix with values in {-1,0,1,...,k-1}f
387

388
        Returns
389
        -------
390
        np.ndarray
391
            An [n,k] array of probabilistic labels
392

393
        Example
394
        -------
395
        >>> L = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 1]])
396
        >>> label_model = LabelModel(verbose=False)
397
        >>> label_model.fit(L, seed=123)
398
        >>> np.around(label_model.predict_proba(L), 1)  # doctest: +SKIP
399
        array([[1., 0.],
400
               [0., 1.],
401
               [0., 1.]])
402
        """
403 2
        L_shift = L + 1  # convert to {0, 1, ..., k}
404 2
        self._set_constants(L_shift)
405 2
        L_aug = self._get_augmented_label_matrix(L_shift)
406 2
        mu = self.mu.cpu().detach().numpy()
407 2
        jtm = np.ones(L_aug.shape[1])
408

409
        # Note: We omit abstains, effectively assuming uniform distribution here
410 2
        X = np.exp(L_aug @ np.diag(jtm) @ np.log(mu) + np.log(self.p))
411 2
        Z = np.tile(X.sum(axis=1).reshape(-1, 1), self.cardinality)
412 2
        return X / Z
413

414 2
    def predict(
415
        self,
416
        L: np.ndarray,
417
        return_probs: Optional[bool] = False,
418
        tie_break_policy: str = "abstain",
419
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
420
        """Return predicted labels, with ties broken according to policy.
421

422
        Policies to break ties include:
423

424
        - "abstain": return an abstain vote (-1)
425
        - "true-random": randomly choose among the tied options
426
        - "random": randomly choose among tied option using deterministic hash
427

428
        NOTE: if tie_break_policy="true-random", repeated runs may have slightly different
429
        results due to difference in broken ties
430

431

432
        Parameters
433
        ----------
434
        L
435
            An [n,m] matrix with values in {-1,0,1,...,k-1}
436
        return_probs
437
            Whether to return probs along with preds
438
        tie_break_policy
439
            Policy to break ties when converting probabilistic labels to predictions
440

441
        Returns
442
        -------
443
        np.ndarray
444
            An [n,1] array of integer labels
445

446
        (np.ndarray, np.ndarray)
447
            An [n,1] array of integer labels and an [n,k] array of probabilistic labels
448

449

450
        Example
451
        -------
452
        >>> L = np.array([[0, 0, -1], [1, 1, -1], [0, 0, -1]])
453
        >>> label_model = LabelModel(verbose=False)
454
        >>> label_model.fit(L)
455
        >>> label_model.predict(L)
456
        array([0, 1, 0])
457
        """
458 2
        return super(LabelModel, self).predict(L, return_probs, tie_break_policy)
459

460 2
    def score(
461
        self,
462
        L: np.ndarray,
463
        Y: np.ndarray,
464
        metrics: Optional[List[str]] = ["accuracy"],
465
        tie_break_policy: str = "abstain",
466
    ) -> Dict[str, float]:
467
        """Calculate one or more scores from user-specified and/or user-defined metrics.
468

469
        Parameters
470
        ----------
471
        L
472
            An [n,m] matrix with values in {-1,0,1,...,k-1}
473
        Y
474
            Gold labels associated with data points in L
475
        metrics
476
            A list of metric names. Possbile metrics are - `accuracy`, `coverage`,
477
            `precision`, `recall`, `f1`, `f1_micro`, `f1_macro`, `fbeta`,
478
            `matthews_corrcoef`, `roc_auc`. See `sklearn.metrics
479
            <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
480
            for details on the metrics.
481
        tie_break_policy
482
            Policy to break ties when converting probabilistic labels to predictions.
483
            Same as :func:`.predict` method above.
484

485

486
        Returns
487
        -------
488
        Dict[str, float]
489
            A dictionary mapping metric names to metric scores
490

491
        Example
492
        -------
493
        >>> L = np.array([[1, 1, -1], [0, 0, -1], [1, 1, -1]])
494
        >>> label_model = LabelModel(verbose=False)
495
        >>> label_model.fit(L)
496
        >>> label_model.score(L, Y=np.array([1, 1, 1]))
497
        {'accuracy': 0.6666666666666666}
498
        >>> label_model.score(L, Y=np.array([1, 1, 1]), metrics=["f1"])
499
        {'f1': 0.8}
500
        """
501 2
        return super(LabelModel, self).score(L, Y, metrics, tie_break_policy)
502

503
    # These loss functions get all their data directly from the LabelModel
504
    # (for better or worse). The unused *args make these compatible with the
505
    # Classifer._train() method which expect loss functions to accept an input.
506 2
    def _loss_l2(self, l2: float = 0) -> torch.Tensor:
507
        r"""L2 loss centered around mu_init, scaled optionally per-source.
508

509
        In other words, diagonal Tikhonov regularization,
510
            ||D(\mu-\mu_{init})||_2^2
511
        where D is diagonal.
512

513
        Parameters
514
        ----------
515
        l2
516
            A float or np.array representing the per-source regularization
517
            strengths to use, by default 0
518

519
        Returns
520
        -------
521
        torch.Tensor
522
            L2 loss between learned mu and initial mu
523
        """
524 2
        if isinstance(l2, (int, float)):
525 2
            D = l2 * torch.eye(self.d)
526
        else:
527 2
            D = torch.diag(torch.from_numpy(l2)).type(torch.float32)
528 2
        D = D.to(self.config.device)
529
        # Note that mu is a matrix and this is the *Frobenius norm*
530 2
        return torch.norm(D @ (self.mu - self.mu_init)) ** 2
531

532 2
    def _loss_mu(self, l2: float = 0) -> torch.Tensor:
533
        r"""Overall mu loss.
534

535
        Parameters
536
        ----------
537
        l2
538
            A float or np.array representing the per-source regularization
539
                strengths to use, by default 0
540

541
        Returns
542
        -------
543
        torch.Tensor
544
            Overall mu loss between learned mu and initial mu
545
        """
546 2
        loss_1 = torch.norm((self.O - self.mu @ self.P @ self.mu.t())[self.mask]) ** 2
547 2
        loss_2 = torch.norm(torch.sum(self.mu @ self.P, 1) - torch.diag(self.O)) ** 2
548 2
        return loss_1 + loss_2 + self._loss_l2(l2=l2)
549

550 2
    def _set_class_balance(
551
        self, class_balance: Optional[List[float]], Y_dev: np.ndarray
552
    ) -> None:
553
        """Set a prior for the class balance.
554

555
        In order of preference:
556
        1) Use user-provided class_balance
557
        2) Estimate balance from Y_dev
558
        3) Assume uniform class distribution
559
        """
560 2
        if class_balance is not None:
561 2
            self.p = np.array(class_balance)
562 2
            if len(self.p) != self.cardinality:
563 2
                raise ValueError(
564
                    f"class_balance has {len(self.p)} entries. Does not match LabelModel cardinality {self.cardinality}."
565
                )
566 2
        elif Y_dev is not None:
567 2
            class_counts = Counter(Y_dev)
568 2
            sorted_counts = np.array([v for k, v in sorted(class_counts.items())])
569 2
            self.p = sorted_counts / sum(sorted_counts)
570 2
            if len(self.p) != self.cardinality:
571 2
                raise ValueError(
572
                    f"Y_dev has {len(self.p)} class(es). Does not match LabelModel cardinality {self.cardinality}."
573
                )
574
        else:
575 2
            self.p = (1 / self.cardinality) * np.ones(self.cardinality)
576

577 2
        if np.any(self.p == 0):
578 2
            raise ValueError(
579
                f"Class balance prior is 0 for class(es) {np.where(self.p)[0]}."
580
            )
581 2
        self.P = torch.diag(torch.from_numpy(self.p)).float().to(self.config.device)
582

583 2
    def _set_constants(self, L: np.ndarray) -> None:
584 2
        self.n, self.m = L.shape
585 2
        if self.m < 3:
586 2
            raise ValueError("L_train should have at least 3 labeling functions")
587 2
        self.t = 1
588

589 2
    def _create_tree(self) -> None:
590 2
        nodes = range(self.m)
591 2
        self.c_tree = get_clique_tree(nodes, [])
592

593 2
    def _execute_logging(self, loss: torch.Tensor) -> Metrics:
594 2
        self.eval()
595 2
        self.running_examples: int
596 2
        self.running_loss: float
597 2
        self.running_loss += loss.item()
598 2
        self.running_examples += 1
599

600
        # Always add average loss
601 2
        metrics_dict = {"train/loss": self.running_loss / self.running_examples}
602

603 2
        if self.logger.check():
604 2
            if self.config.verbose:
605 2
                self.logger.log(metrics_dict)
606

607
            # Reset running loss and examples counts
608 2
            self.running_loss = 0.0
609 2
            self.running_examples = 0
610

611 2
        self.train()
612 2
        return metrics_dict
613

614 2
    def _set_logger(self) -> None:
615 2
        self.logger = Logger(self.train_config.log_freq)
616

617 2
    def _set_optimizer(self) -> None:
618 2
        parameters = filter(lambda p: p.requires_grad, self.parameters())
619

620 2
        optimizer_config = self.train_config.optimizer_config
621 2
        optimizer_name = self.train_config.optimizer
622
        optimizer: optim.Optimizer  # type: ignore
623

624 2
        if optimizer_name == "sgd":
625 2
            optimizer = optim.SGD(  # type: ignore
626
                parameters,
627
                lr=self.train_config.lr,
628
                weight_decay=self.train_config.l2,
629
                **optimizer_config.sgd_config._asdict(),
630
            )
631 2
        elif optimizer_name == "adam":
632 2
            optimizer = optim.Adam(
633
                parameters,
634
                lr=self.train_config.lr,
635
                weight_decay=self.train_config.l2,
636
                **optimizer_config.adam_config._asdict(),
637
            )
638 2
        elif optimizer_name == "adamax":
639 2
            optimizer = optim.Adamax(  # type: ignore
640
                parameters,
641
                lr=self.train_config.lr,
642
                weight_decay=self.train_config.l2,
643
                **optimizer_config.adamax_config._asdict(),
644
            )
645
        else:
646 2
            raise ValueError(f"Unrecognized optimizer option '{optimizer_name}'")
647

648 2
        self.optimizer = optimizer
649

650 2
    def _set_lr_scheduler(self) -> None:
651
        # Set warmup scheduler
652 2
        self._set_warmup_scheduler()
653

654
        # Set lr scheduler
655 2
        lr_scheduler_name = self.train_config.lr_scheduler
656 2
        lr_scheduler_config = self.train_config.lr_scheduler_config
657
        lr_scheduler: Optional[optim.lr_scheduler._LRScheduler]
658

659 2
        if lr_scheduler_name == "constant":
660 2
            lr_scheduler = None
661 2
        elif lr_scheduler_name == "linear":
662 2
            total_steps = self.train_config.n_epochs
663 2
            linear_decay_func = lambda x: (total_steps - self.warmup_steps - x) / (
664
                total_steps - self.warmup_steps
665
            )
666 2
            lr_scheduler = optim.lr_scheduler.LambdaLR(  # type: ignore
667
                self.optimizer, linear_decay_func
668
            )
669 2
        elif lr_scheduler_name == "exponential":
670 2
            lr_scheduler = optim.lr_scheduler.ExponentialLR(
671
                self.optimizer, **lr_scheduler_config.exponential_config._asdict()
672
            )
673 2
        elif lr_scheduler_name == "step":
674 2
            lr_scheduler = optim.lr_scheduler.StepLR(
675
                self.optimizer, **lr_scheduler_config.step_config._asdict()
676
            )
677
        else:
678 2
            raise ValueError(f"Unrecognized lr scheduler option '{lr_scheduler_name}'")
679

680 2
        self.lr_scheduler = lr_scheduler
681

682 2
    def _set_warmup_scheduler(self) -> None:
683
        warmup_scheduler: Optional[optim.lr_scheduler.LambdaLR]
684

685 2
        if self.train_config.lr_scheduler_config.warmup_steps:
686 2
            warmup_steps = self.train_config.lr_scheduler_config.warmup_steps
687 2
            if warmup_steps < 0:
688 0
                raise ValueError("warmup_steps much greater or equal than 0.")
689 2
            warmup_unit = self.train_config.lr_scheduler_config.warmup_unit
690 2
            if warmup_unit == "epochs":
691 2
                self.warmup_steps = int(warmup_steps)
692
            else:
693 2
                raise ValueError(
694
                    "LabelModel does not support any warmup_unit other than 'epochs'."
695
                )
696

697 2
            linear_warmup_func = lambda x: x / self.warmup_steps
698 2
            warmup_scheduler = optim.lr_scheduler.LambdaLR(  # type: ignore
699
                self.optimizer, linear_warmup_func
700
            )
701
            if self.config.verbose:  # pragma: no cover
702
                logging.info(f"Warmup {self.warmup_steps} steps.")
703

704 2
        elif self.train_config.lr_scheduler_config.warmup_percentage:
705 2
            warmup_percentage = self.train_config.lr_scheduler_config.warmup_percentage
706 2
            self.warmup_steps = int(warmup_percentage * self.train_config.n_epochs)
707 2
            linear_warmup_func = lambda x: x / self.warmup_steps
708 2
            warmup_scheduler = optim.lr_scheduler.LambdaLR(  # type: ignore
709
                self.optimizer, linear_warmup_func
710
            )
711
            if self.config.verbose:  # pragma: no cover
712
                logging.info(f"Warmup {self.warmup_steps} steps.")
713

714
        else:
715 2
            warmup_scheduler = None
716 2
            self.warmup_steps = 0
717

718 2
        self.warmup_scheduler = warmup_scheduler
719

720 2
    def _update_lr_scheduler(self, step: int) -> None:
721 2
        if self.warmup_scheduler and step < self.warmup_steps:
722 2
            self.warmup_scheduler.step()  # type: ignore
723 2
        elif self.lr_scheduler is not None:
724 2
            self.lr_scheduler.step()  # type: ignore
725 2
            min_lr = self.train_config.lr_scheduler_config.min_lr
726 2
            if min_lr and self.optimizer.param_groups[0]["lr"] < min_lr:
727 0
                self.optimizer.param_groups[0]["lr"] = min_lr
728

729 2
    def _clamp_params(self) -> None:
730
        """Clamp the values of the learned parameter vector.
731

732
        Clamp the entries of self.mu to be in [mu_eps, 1 - mu_eps], where mu_eps is
733
        either set by the user, or defaults to 1 / 10 ** np.ceil(np.log10(self.n)).
734

735
        Note that if mu_eps is set too high, e.g. in sparse settings where LFs
736
        mostly abstain, this will result in learning conditional probabilities all
737
        equal to mu_eps (and/or 1 - mu_eps)!  See issue #1422.
738

739
        Note: Use user-provided value of mu_eps in train_config, else default to
740
            mu_eps = 1 / 10 ** np.ceil(np.log10(self.n))
741
        this rounding is done to make it more obvious when the parameters have been
742
        clamped.
743
        """
744 2
        if self.train_config.mu_eps is not None:
745 2
            mu_eps = self.train_config.mu_eps
746
        else:
747 2
            mu_eps = min(0.01, 1 / 10 ** np.ceil(np.log10(self.n)))
748 2
        self.mu.data = self.mu.clamp(mu_eps, 1 - mu_eps)  # type: ignore
749

750 2
    def _break_col_permutation_symmetry(self) -> None:
751
        r"""Heuristically choose amongst (possibly) several valid mu values.
752

753
        If there are several values of mu that equivalently satisfy the optimization
754
        objective, as there often are due to column permutation symmetries, then pick
755
        the solution that trusts the user-written LFs most.
756

757
        In more detail, suppose that mu satisfies (minimizes) the two loss objectives:
758
            1. O = mu @ P @ mu.T
759
            2. diag(O) = sum(mu @ P, axis=1)
760
        Then any column permutation matrix Z that commutes with P will also equivalently
761
        satisfy these objectives, and thus is an equally valid (symmetric) solution.
762
        Therefore, we select the solution that maximizes the summed probability of the
763
        LFs being accurate when not abstaining.
764

765
            \sum_lf \sum_{y=1}^{cardinality} P(\lf = y, Y = y)
766
        """
767 2
        mu = self.mu.cpu().detach().numpy()
768 2
        P = self.P.cpu().detach().numpy()
769 2
        d, k = mu.shape
770
        # We want to maximize the sum of diagonals of matrices for each LF. So
771
        # we start by computing the sum of conditional probabilities here.
772 2
        probs_sum = sum([mu[i : i + k] for i in range(0, self.m * k, k)]) @ P
773

774 2
        munkres_solver = Munkres()
775 2
        Z = np.zeros([k, k])
776

777
        # Compute groups of indicess with equal prior in P.
778 2
        groups: DefaultDict[float, List[int]] = defaultdict(list)
779 2
        for i, f in enumerate(P.diagonal()):
780 2
            groups[np.around(f, 3)].append(i)
781 2
        for group in groups.values():
782 2
            if len(group) == 1:
783 2
                Z[group[0], group[0]] = 1.0  # Identity permutation
784 2
                continue
785
            # Compute submatrix corresponding to the group.
786 2
            probs_proj = probs_sum[[[g] for g in group], group]
787
            # Use the Munkres algorithm to find the optimal permutation.
788
            # We use minus because we want to maximize diagonal sum, not minimize,
789
            # and transpose because we want to permute columns, not rows.
790 2
            permutation_pairs = munkres_solver.compute(-probs_proj.T)
791 2
            for i, j in permutation_pairs:
792 2
                Z[group[i], group[j]] = 1.0
793

794
        # Set mu according to permutation
795 2
        self.mu = nn.Parameter(  # type: ignore
796
            torch.Tensor(mu @ Z).to(self.config.device)  # type: ignore
797
        )
798

799 2
    def fit(
800
        self,
801
        L_train: np.ndarray,
802
        Y_dev: Optional[np.ndarray] = None,
803
        class_balance: Optional[List[float]] = None,
804
        **kwargs: Any,
805
    ) -> None:
806
        """Train label model.
807

808
        Train label model to estimate mu, the parameters used to combine LFs.
809

810
        Parameters
811
        ----------
812
        L_train
813
            An [n,m] matrix with values in {-1,0,1,...,k-1}
814
        Y_dev
815
            Gold labels for dev set for estimating class_balance, by default None
816
        class_balance
817
            Each class's percentage of the population, by default None
818
        **kwargs
819
            Arguments for changing train config defaults.
820

821
            n_epochs
822
                The number of epochs to train (where each epoch is a single
823
                optimization step), default is 100
824
            lr
825
                Base learning rate (will also be affected by lr_scheduler choice
826
                and settings), default is 0.01
827
            l2
828
                Centered L2 regularization strength, default is 0.0
829
            optimizer
830
                Which optimizer to use (one of ["sgd", "adam", "adamax"]),
831
                default is "sgd"
832
            optimizer_config
833
                Settings for the optimizer
834
            lr_scheduler
835
                Which lr_scheduler to use (one of ["constant", "linear",
836
                "exponential", "step"]), default is "constant"
837
            lr_scheduler_config
838
                Settings for the LRScheduler
839
            prec_init
840
                LF precision initializations / priors, default is 0.7
841
            seed
842
                A random seed to initialize the random number generator with
843
            log_freq
844
                Report loss every this many epochs (steps), default is 10
845
            mu_eps
846
                Restrict the learned conditional probabilities to
847
                [mu_eps, 1-mu_eps], default is None
848

849
        Raises
850
        ------
851
        Exception
852
            If loss in NaN
853

854
        Examples
855
        --------
856
        >>> L = np.array([[0, 0, -1], [-1, 0, 1], [1, -1, 0]])
857
        >>> Y_dev = [0, 1, 0]
858
        >>> label_model = LabelModel(verbose=False)
859
        >>> label_model.fit(L)
860
        >>> label_model.fit(L, Y_dev=Y_dev, seed=2020, lr=0.05)
861
        >>> label_model.fit(L, class_balance=[0.7, 0.3], n_epochs=200, l2=0.4)
862
        """
863
        # Set random seed
864 2
        self.train_config: TrainConfig = merge_config(  # type:ignore
865
            TrainConfig(), kwargs  # type:ignore
866
        )
867
        # Update base config so that it includes all parameters
868 2
        random.seed(self.train_config.seed)
869 2
        np.random.seed(self.train_config.seed)
870 2
        torch.manual_seed(self.train_config.seed)
871

872 2
        L_shift = L_train + 1  # convert to {0, 1, ..., k}
873 2
        if L_shift.max() > self.cardinality:
874 2
            raise ValueError(
875
                f"L_train has cardinality {L_shift.max()}, cardinality={self.cardinality} passed in."
876
            )
877

878 2
        self._set_constants(L_shift)
879 2
        self._set_class_balance(class_balance, Y_dev)
880 2
        self._create_tree()
881 2
        lf_analysis = LFAnalysis(L_train)
882 2
        self.coverage = lf_analysis.lf_coverages()
883

884
        # Compute O and initialize params
885
        if self.config.verbose:  # pragma: no cover
886
            logging.info("Computing O...")
887 2
        self._generate_O(L_shift)
888 2
        self._init_params()
889

890
        # Estimate \mu
891
        if self.config.verbose:  # pragma: no cover
892
            logging.info("Estimating \mu...")
893

894
        # Set model to train mode
895 2
        self.train()
896

897
        # Move model to GPU
898 2
        self.mu_init = self.mu_init.to(self.config.device)
899
        if self.config.verbose and self.config.device != "cpu":  # pragma: no cover
900
            logging.info("Using GPU...")
901 2
        self.to(self.config.device)
902

903
        # Set training components
904 2
        self._set_logger()
905 2
        self._set_optimizer()
906 2
        self._set_lr_scheduler()
907

908
        # Restore model if necessary
909 2
        start_iteration = 0
910

911
        # Train the model
912 2
        metrics_hist = {}  # The most recently seen value for all metrics
913 2
        for epoch in range(start_iteration, self.train_config.n_epochs):
914 2
            self.running_loss = 0.0
915 2
            self.running_examples = 0
916

917
            # Zero the parameter gradients
918 2
            self.optimizer.zero_grad()
919

920
            # Forward pass to calculate the average loss per example
921 2
            loss = self._loss_mu(l2=self.train_config.l2)
922 2
            if torch.isnan(loss):
923 2
                msg = "Loss is NaN. Consider reducing learning rate."
924 2
                raise Exception(msg)
925

926
            # Backward pass to calculate gradients
927
            # Loss is an average loss per example
928 2
            loss.backward()
929

930
            # Perform optimizer step
931 2
            self.optimizer.step()
932

933
            # Calculate metrics, log, and checkpoint as necessary
934 2
            metrics_dict = self._execute_logging(loss)
935 2
            metrics_hist.update(metrics_dict)
936

937
            # Update learning rate
938 2
            self._update_lr_scheduler(epoch)
939

940
        # Post-processing operations on mu
941 2
        self._clamp_params()
942 2
        self._break_col_permutation_symmetry()
943

944
        # Return model to eval mode
945 2
        self.eval()
946

947
        # Print confusion matrix if applicable
948
        if self.config.verbose:  # pragma: no cover
949
            logging.info("Finished Training")

Read our documentation on viewing source code .

Loading