snorkel-team / snorkel
Showing 1 of 2 files from the diff.

@@ -9,6 +9,7 @@
Loading
9 9
import torch.nn as nn
10 10
import torch.optim as optim
11 11
from munkres import Munkres  # type: ignore
12 +
from tqdm import trange
12 13
13 14
from snorkel.labeling.analysis import LFAnalysis
14 15
from snorkel.labeling.model.base_labeler import BaseLabeler
@@ -809,6 +810,7 @@
Loading
809 810
        L_train: np.ndarray,
810 811
        Y_dev: Optional[np.ndarray] = None,
811 812
        class_balance: Optional[List[float]] = None,
813 +
        progress_bar: bool = True,
812 814
        **kwargs: Any,
813 815
    ) -> None:
814 816
        """Train label model.
@@ -823,6 +825,8 @@
Loading
823 825
            Gold labels for dev set for estimating class_balance, by default None
824 826
        class_balance
825 827
            Each class's percentage of the population, by default None
828 +
        progress_bar
829 +
            To display a progress bar, by default True
826 830
        **kwargs
827 831
            Arguments for changing train config defaults.
828 832
@@ -918,7 +922,13 @@
Loading
918 922
919 923
        # Train the model
920 924
        metrics_hist = {}  # The most recently seen value for all metrics
921 -
        for epoch in range(start_iteration, self.train_config.n_epochs):
925 +
926 +
        if progress_bar:
927 +
            epochs = trange(start_iteration, self.train_config.n_epochs, unit="epoch")
928 +
        else:
929 +
            epochs = range(start_iteration, self.train_config.n_epochs)
930 +
931 +
        for epoch in epochs:
922 932
            self.running_loss = 0.0
923 933
            self.running_examples = 0
924 934
@@ -945,6 +955,10 @@
Loading
945 955
            # Update learning rate
946 956
            self._update_lr_scheduler(epoch)
947 957
958 +
        # Cleanup progress bar if enabled
959 +
        if progress_bar:
960 +
            epochs.close()
961 +
948 962
        # Post-processing operations on mu
949 963
        self._clamp_params()
950 964
        self._break_col_permutation_symmetry()
Files Coverage
snorkel 97.32%
Project Totals (68 files) 97.32%
406.1
TRAVIS_PYTHON_VERSION=3.6
TRAVIS_OS_NAME=linux
TOXENV=coverage,complex,spark,doctest,type,check
1
coverage:
2
  status:
3
    project:
4
      default:
5
        target: 95%
6
    patch:
7
      default:
8
        threshold: 2%
9

10
comment:
11
  layout: "header, diff, flags, files"
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading