#1386 Add SlicingClassifier

Open Vincent Chen vincentschen
Coverage Reach
classification/training/loggers/checkpointer.py classification/training/loggers/log_manager.py classification/training/loggers/log_writer.py classification/training/loggers/tensorboard_writer.py classification/training/trainer.py classification/training/schedulers/shuffled_scheduler.py classification/training/schedulers/scheduler.py classification/training/schedulers/sequential_scheduler.py classification/training/schedulers/__init__.py classification/multitask_classifier.py classification/data.py classification/utils.py classification/task.py classification/loss.py labeling/model/label_model.py labeling/model/baselines.py labeling/model/logger.py labeling/model/graph_utils.py labeling/lf/nlp.py labeling/lf/core.py labeling/lf/__init__.py labeling/apply/core.py labeling/apply/pandas.py labeling/apply/dask.py labeling/analysis.py labeling/utils.py slicing/utils.py slicing/slicing_classifier.py slicing/monitor.py slicing/modules/slice_combiner.py slicing/sf/core.py slicing/sf/nlp.py slicing/apply/core.py augmentation/apply/core.py augmentation/apply/pandas.py augmentation/policy/core.py augmentation/policy/sampling.py augmentation/tf.py utils/core.py utils/optimizers.py utils/lr_schedulers.py utils/config_utils.py utils/data_operators.py map/core.py analysis/metrics.py analysis/scorer.py analysis/error_analysis.py analysis/__init__.py preprocess/nlp.py preprocess/core.py synthetic/synthetic_data.py types/data.py types/__init__.py types/classifier.py version.py

No flags found

Use flags to group coverage reports by test type, project and/or folders.
Then setup custom commit statuses and notifications for each flag.

e.g., #unittest #integration

#production #enterprise

#frontend #backend

Learn more about Codecov Flags here.


@@ -0,0 +1,180 @@
Loading
1 +
from typing import Any, Dict, List, Optional, Set, Union
2 +
3 +
import numpy as np
4 +
import pandas as pd
5 +
import torch
6 +
import torch.nn as nn
7 +
8 +
from snorkel.analysis import Scorer
9 +
from snorkel.classification import DictDataLoader, DictDataset, Operation, Task
10 +
from snorkel.classification.multitask_classifier import MultitaskClassifier
11 +
12 +
from .utils import add_slice_labels, convert_to_slice_tasks
13 +
14 +
15 +
class BinarySlicingClassifier(MultitaskClassifier):
16 +
    """A slice-aware binary classifier that supports training + scoring on slice labels.
17 +
18 +
    Parameters
19 +
    ----------
20 +
    base_architecture
21 +
        A network architecture that accepts input data and outputs a representation
22 +
    head_dim
23 +
        Output feature dimension of the base_architecture, and input dimension of the
24 +
        internal prediction head: ``nn.Linear(head_dim, 2)``.
25 +
    slice_names
26 +
        A list of slice names that the model will accept initialize as tasks
27 +
        and accept as corresponding labels
28 +
    scorer
29 +
        A Scorer to be used for initialization of the ``MultitaskClassifier`` superclass.
30 +
    **multitask_kwargs
31 +
        Arbitrary keyword arguments to be passed to the ``MultitaskClassifier`` superclass.
32 +
33 +
    Attributes
34 +
    ----------
35 +
    base_task
36 +
        A base ``snorkel.classification.Task`` that the model will learn.
37 +
        This becomes a ``master_head_module`` that combines slice tasks information.
38 +
        For more, see ``snorkel.slicing.convert_to_slice_tasks``.
39 +
    slice_names
40 +
        See above
41 +
    """
42 +
43 +
    def __init__(
44 +
        self,
45 +
        base_architecture: nn.Module,
46 +
        head_dim: int,
47 +
        slice_names: List[str],
48 +
        input_data_key: str,
49 +
        task_name: str,
50 +
        scorer: Scorer = Scorer(metrics=["accuracy", "f1"]),
51 +
        **multitask_kwargs: Any,
52 +
    ) -> None:
53 +
54 +
        # Initialize module_pool with 1) base_architecture and 2) prediction_head
55 +
        # Assuming `head_dim` can be used to map base_architecture to prediction_head
56 +
        module_pool = nn.ModuleDict(
57 +
            {
58 +
                "base_architecture": base_architecture,
59 +
                "prediction_head": nn.Linear(head_dim, 2),
60 +
            }
61 +
        )
62 +
63 +
        # Create op_sequence from base_architecture -> prediction_head
64 +
        op_sequence = [
65 +
            Operation(
66 +
                name="input_op",
67 +
                module_name="base_architecture",
68 +
                inputs=[("_input_", input_data_key)],
69 +
            ),
70 +
            Operation(
71 +
                name="head_op", module_name="prediction_head", inputs=["input_op"]
72 +
            ),
73 +
        ]
74 +
75 +
        # Initialize base_task using specified base_architecture
76 +
        self.base_task = Task(
77 +
            name=task_name,
78 +
            module_pool=module_pool,
79 +
            op_sequence=op_sequence,
80 +
            scorer=scorer,
81 +
        )
82 +
83 +
        # Convert base_task to associated slice_tasks
84 +
        slice_tasks = convert_to_slice_tasks(self.base_task, slice_names)
85 +
86 +
        # Initialize a MultitaskClassifier with all slice_tasks
87 +
        model_name = f"{task_name}_slicing_classifier"
88 +
        super().__init__(tasks=slice_tasks, name=model_name, **multitask_kwargs)
89 +
        self.slice_names = slice_names
90 +
91 +
    def make_slice_dataloader(
92 +
        self, dataset: DictDataset, S: np.ndarray, **dataloader_kwargs: Any
93 +
    ) -> DictDataLoader:
94 +
        """Create DictDataLoader with slice labels, initialized from specified dataset.
95 +
96 +
        Parameters
97 +
        ----------
98 +
        dataset
99 +
            A DictDataset that will be converted into a slice-aware dataloader
100 +
        S
101 +
            A [num_examples, num_slices] slice matrix indicating whether
102 +
            each example is in every slice
103 +
        slice_names
104 +
            A list of slice names corresponding to columns of ``S``
105 +
106 +
        dataloader_kwargs
107 +
            Arbitrary kwargs to be passed to DictDataLoader
108 +
            See ``DictDataLoader.__init__``.
109 +
        """
110 +
111 +
        # Validate inputs
112 +
        if S.shape[1] != len(self.slice_names):
113 +
            raise ValueError("Num. columns in S matrix does not match num. slice_names")
114 +
115 +
        # Base task must have corresponding labels in dataset
116 +
        if self.base_task.name not in dataset.Y_dict:  # type: ignore
117 +
            raise ValueError(
118 +
                f"Base task ({self.base_task.name}) labels missing from {dataset}"
119 +
            )
120 +
121 +
        # Initialize dataloader
122 +
        dataloader = DictDataLoader(dataset, **dataloader_kwargs)
123 +
124 +
        # Make dataloader slice-aware
125 +
        add_slice_labels(dataloader, self.base_task, S, self.slice_names)
126 +
127 +
        return dataloader
128 +
129 +
    @torch.no_grad()
130 +
    def score_slices(
131 +
        self, dataloaders: List[DictDataLoader], as_dataframe: bool = False
132 +
    ) -> Union[Dict[str, float], pd.DataFrame]:
133 +
        """Scores appropriate slice labels using the overall prediction head.
134 +
135 +
        In other words, uses ``base_task`` (NOT ``slice_tasks``) to evaluate slices.
136 +
137 +
        In practice, we'd like to use a final prediction from a _single_ task head.
138 +
        To do so, ``self.base_task`` leverages reweighted slice representation to
139 +
        make a prediction. In this method, we remap all slice-specific ``pred``
140 +
        labels to ``self.base_task`` for evaluation.
141 +
142 +
        Parameters
143 +
        ----------
144 +
        dataloaders
145 +
            A list of DictDataLoaders to calculate scores for
146 +
        as_dataframe
147 +
            A boolean indicating whether to return results as pandas
148 +
            DataFrame (True) or dict (False)
149 +
        eval_slices_on_base_task
150 +
            A boolean indicating whether to remap slice labels to base task.
151 +
            Otherwise, keeps evaluation of slice labels on slice-specific heads.
152 +
153 +
        Returns
154 +
        -------
155 +
        Dict[str, float]
156 +
            A dictionary mapping metric¡ names to corresponding scores
157 +
            Metric names will be of the form "task/dataset/split/metric"
158 +
        """
159 +
160 +
        eval_mapping: Dict[str, Optional[str]] = {}
161 +
        # Collect all labels
162 +
        all_labels: Union[List, Set] = []
163 +
        for dl in dataloaders:
164 +
            all_labels.extend(dl.dataset.Y_dict.keys())  # type: ignore
165 +
        all_labels = set(all_labels)
166 +
167 +
        # By convention, evaluate on "pred" labels, not "ind" labels
168 +
        # See ``snorkel.slicing.utils.add_slice_labels`` for more about label creation
169 +
        for label in all_labels:
170 +
            if "pred" in label:
171 +
                eval_mapping[label] = self.base_task.name
172 +
            elif "ind" in label:
173 +
                eval_mapping[label] = None
174 +
175 +
        # Call score on the original remapped set of labels
176 +
        return super().score(
177 +
            dataloaders=dataloaders,
178 +
            remap_labels=eval_mapping,
179 +
            as_dataframe=as_dataframe,
180 +
        )

@@ -16,6 +16,7 @@
Loading
16 16
)
17 17
18 18
import numpy as np
19 +
import pandas as pd
19 20
import torch
20 21
import torch.nn as nn
21 22
@@ -315,7 +316,7 @@
Loading
315 316
        self,
316 317
        dataloader: DictDataLoader,
317 318
        return_preds: bool = False,
318 -
        remap_labels: Dict[str, str] = {},
319 +
        remap_labels: Dict[str, Optional[str]] = {},
319 320
    ) -> Dict[str, Dict[str, torch.Tensor]]:
320 321
        """Calculate probabilities, (optionally) predictions, and pull out gold labels.
321 322
@@ -379,9 +380,9 @@
Loading
379 380
    def score(
380 381
        self,
381 382
        dataloaders: List[DictDataLoader],
382 -
        remap_labels: Dict[str, str] = {},
383 +
        remap_labels: Dict[str, Optional[str]] = {},
383 384
        as_dataframe: bool = False,
384 -
    ) -> Dict[str, float]:
385 +
    ) -> Union[Dict[str, float], pd.DataFrame]:
385 386
        """Calculate scores for the provided DictDataLoaders.
386 387
387 388
        Parameters
@@ -416,7 +417,7 @@
Loading
416 417
            # What labels in Y_dict are we ignoring?
417 418
            extra_labels = set(Y_dict.keys()).difference(set(labels_to_tasks.keys()))
418 419
            if extra_labels:
419 -
                logging.warning(
420 +
                logging.info(
420 421
                    f"Ignoring extra labels in dataloader ({dataloader.dataset.split}): {extra_labels}"  # type: ignore
421 422
                )
422 423
@@ -453,7 +454,7 @@
Loading
453 454
            return metric_score_dict
454 455
455 456
    def _get_labels_to_tasks(
456 -
        self, label_names: Iterable[str], remap_labels: Dict[str, str] = {}
457 +
        self, label_names: Iterable[str], remap_labels: Dict[str, Optional[str]] = {}
457 458
    ) -> Dict[str, str]:
458 459
        """Map each label to its corresponding task outputs based on whether the task is available.
459 460

Learn more Showing 2 files with coverage changes found.

New file snorkel/slicing/slicing_classifier.py
New
Loading file...
Changes in snorkel/classification/multitask_classifier.py
-1
+1
Loading file...
Files Coverage
snorkel 0.10% 97.54%
Project Totals (55 files) 97.54%
Loading