1 2
import logging
2 2
import os
3 2
from collections import defaultdict
4 2
from typing import (
5
    Any,
6
    Callable,
7
    Dict,
8
    Iterable,
9
    List,
10
    Mapping,
11
    Optional,
12
    Sequence,
13
    Set,
14
    Tuple,
15
    Union,
16
)
17

18 2
import numpy as np
19 2
import pandas as pd
20 2
import torch
21 2
import torch.nn as nn
22

23 2
from snorkel.analysis import Scorer
24 2
from snorkel.classification.data import DictDataLoader
25 2
from snorkel.classification.utils import metrics_dict_to_dataframe, move_to_device
26 2
from snorkel.types import Config
27 2
from snorkel.utils import probs_to_preds
28

29 2
from .task import Operation, Task
30

31 2
OutputDict = Dict[str, Union[Any, Mapping[str, Any]]]
32

33

34 2
class ClassifierConfig(Config):
35
    """A classifier built from one or more tasks to support advanced workflows.
36

37
    Parameters
38
    ----------
39
    device
40
        The device (GPU) to move the model to (-1 is CPU), but device will also be
41
        moved to CPU if no GPU device is available
42
    dataparallel
43
        Whether or not to use PyTorch DataParallel wrappers to automatically utilize
44
        multiple GPUs if available
45
    """
46

47 2
    device: int = 0
48 2
    dataparallel: bool = True
49

50

51 2
class MultitaskClassifier(nn.Module):
52
    r"""A classifier built from one or more tasks to support advanced workflows.
53

54
    Parameters
55
    ----------
56
    tasks
57
        A list of ``Task``\s to build a model from
58
    name
59
        The name of the classifier
60

61
    Attributes
62
    ----------
63
    config
64
        The config dict containing the settings for this model
65
    name
66
        See above
67
    module_pool
68
        A dictionary of all modules used by any of the tasks (See Task docstring)
69
    task_names
70
        See Task docstring
71
    op_sequences
72
        See Task docstring
73
    loss_funcs
74
        See Task docstring
75
    output_funcs
76
        See Task docstring
77
    scorers
78
        See Task docstring
79
    """
80

81 2
    def __init__(
82
        self, tasks: List[Task], name: Optional[str] = None, **kwargs: Any
83
    ) -> None:
84 2
        super().__init__()
85 2
        self.config = ClassifierConfig(**kwargs)
86 2
        self.name = name or type(self).__name__
87

88
        # Initiate the model attributes
89 2
        self.module_pool = nn.ModuleDict()
90 2
        self.task_names: Set[str] = set()
91 2
        self.op_sequences: Dict[str, Sequence[Operation]] = dict()
92 2
        self.loss_funcs: Dict[str, Callable[..., torch.Tensor]] = dict()
93 2
        self.output_funcs: Dict[str, Callable[..., torch.Tensor]] = dict()
94 2
        self.scorers: Dict[str, Scorer] = dict()
95

96
        # Build network with given tasks
97 2
        self._build_network(tasks)
98

99
        # Report total task count and duplicates
100 2
        all_ops = [op.name for t in tasks for op in t.op_sequence]
101 2
        unique_ops = set(all_ops)
102 2
        all_mods = [mod_name for t in tasks for mod_name in t.module_pool]
103 2
        unique_mods = set(all_mods)
104 2
        logging.info(
105
            f"Created multi-task model {self.name} that contains "
106
            f"task(s) {self.task_names} from "
107
            f"{len(unique_ops)} operations ({len(all_ops) - len(unique_ops)} shared) and "
108
            f"{len(unique_mods)} modules ({len(all_mods) - len(unique_mods)} shared)."
109
        )
110

111
        # Move model to specified device
112 2
        self._move_to_device()
113

114
    def __repr__(self) -> str:
115
        cls_name = type(self).__name__
116
        return f"{cls_name}(name={self.name})"
117

118 2
    def _build_network(self, tasks: List[Task]) -> None:
119
        r"""Construct the network from a list of ``Task``\s by adding them one by one.
120

121
        Parameters
122
        ----------
123
        tasks
124
            A list of ``Task``s
125
        """
126 2
        for task in tasks:
127 2
            if not isinstance(task, Task):
128 2
                raise ValueError(f"Unrecognized task type {task}.")
129 2
            if task.name in self.task_names:
130 2
                raise ValueError(
131
                    f"Found duplicate task {task.name}, different task should use "
132
                    f"different task name."
133
                )
134 2
            self.add_task(task)
135

136 2
    def add_task(self, task: Task) -> None:
137
        """Add a single task to the network.
138

139
        Parameters
140
        ----------
141
        task
142
            A ``Task`` to add
143
        """
144
        # Combine module_pool from all tasks
145 2
        for key in task.module_pool.keys():
146 2
            if key in self.module_pool.keys():
147 2
                if self.config.dataparallel:
148 2
                    task.module_pool[key] = nn.DataParallel(self.module_pool[key])
149
                else:
150 0
                    task.module_pool[key] = self.module_pool[key]
151
            else:
152 2
                if self.config.dataparallel:
153 2
                    self.module_pool[key] = nn.DataParallel(task.module_pool[key])
154
                else:
155 2
                    self.module_pool[key] = task.module_pool[key]
156 2
        self.task_names.add(task.name)
157 2
        self.op_sequences[task.name] = task.op_sequence
158 2
        self.loss_funcs[task.name] = task.loss_func
159 2
        self.output_funcs[task.name] = task.output_func
160 2
        self.scorers[task.name] = task.scorer
161

162
        # Move model to specified device
163 2
        self._move_to_device()
164

165 2
    def forward(  # type: ignore
166
        self, X_dict: Dict[str, Any], task_names: Iterable[str]
167
    ) -> OutputDict:
168
        """Do a forward pass through the network for all specified tasks.
169

170
        Parameters
171
        ----------
172
        X_dict
173
            A dict of data fields
174
        task_names
175
            The names of the tasks to execute the forward pass for
176

177
        Returns
178
        -------
179
        OutputDict
180
            A dict mapping each operation name to its corresponding output
181

182
        Raises
183
        ------
184
        TypeError
185
            If an Operation input has an invalid type
186
        ValueError
187
            If a specified Operation failed to execute
188
        """
189 2
        X_dict_moved = move_to_device(X_dict, self.config.device)
190

191 2
        outputs: OutputDict = {"_input_": X_dict_moved}  # type: ignore
192

193
        # Call forward for each task, using cached result if available
194
        # Each op_sequence consists of one or more operations that are executed in order
195 2
        for task_name in task_names:
196 2
            op_sequence = self.op_sequences[task_name]
197

198 2
            for operation in op_sequence:
199 2
                if operation.name not in outputs:
200 2
                    try:
201 2
                        if operation.inputs:
202
                            # Feed the inputs the module requested in the reqested order
203 2
                            inputs = []
204 2
                            for op_input in operation.inputs:
205 2
                                if isinstance(op_input, tuple):
206
                                    # The output of the indicated operation has a dict
207
                                    # of fields; extract the designated field by name
208 2
                                    op_name, field_key = op_input
209 2
                                    inputs.append(outputs[op_name][field_key])
210
                                else:
211
                                    # The output of the indicated operation has only
212
                                    # one field; use that as the input to the current op
213 2
                                    op_name = op_input
214 2
                                    inputs.append(outputs[op_name])
215

216 2
                            output = self.module_pool[operation.module_name].forward(
217
                                *inputs
218
                            )
219
                        else:
220
                            # Feed the entire outputs dict for the module to pull from
221 2
                            output = self.module_pool[operation.module_name].forward(
222
                                outputs
223
                            )
224 2
                    except Exception as e:
225 2
                        raise ValueError(
226
                            f"Unsuccessful operation {operation}: {repr(e)}."
227
                        )
228 2
                    outputs[operation.name] = output
229

230 2
        return outputs
231

232 2
    def calculate_loss(
233
        self, X_dict: Dict[str, Any], Y_dict: Dict[str, torch.Tensor]
234
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, float]]:
235
        """Calculate the loss for each task and the number of data points contributing.
236

237
        Parameters
238
        ----------
239
        X_dict
240
            A dict of data fields
241
        Y_dict
242
            A dict from task names to label sets
243

244
        Returns
245
        -------
246
        Dict[str, torch.Tensor], Dict[str, float]
247
            A dict of losses by task name and seen examples by task name
248
        """
249

250 2
        loss_dict = dict()
251 2
        count_dict = dict()
252

253 2
        labels_to_tasks = self._get_labels_to_tasks(Y_dict.keys())
254 2
        outputs = self.forward(X_dict, task_names=labels_to_tasks.values())
255

256
        # Calculate loss for each task
257 2
        for label_name, task_name in labels_to_tasks.items():
258 2
            Y = Y_dict[label_name]
259

260
            # Select the active samples
261 2
            if len(Y.size()) == 1:
262 2
                active = Y.detach() != -1
263
            else:
264 0
                active = torch.any(Y.detach() != -1, dim=1)
265

266
            # Only calculate the loss when active example exists
267 2
            if active.any():
268
                # Note: Use label_name as key, but task_name to access model attributes
269 2
                count_dict[label_name] = active.sum().item()
270

271
                # Extract the output of the last operation for this task
272 2
                inputs = outputs[self.op_sequences[task_name][-1].name]
273

274
                # Filter out any inactive examples if inputs is a Tensor
275 2
                if not active.all() and isinstance(inputs, torch.Tensor):
276 2
                    inputs = inputs[active]
277 2
                    Y = Y[active]
278

279 2
                loss_dict[label_name] = self.loss_funcs[task_name](
280
                    inputs, move_to_device(Y, self.config.device)
281
                )
282

283 2
        return loss_dict, count_dict
284

285 2
    @torch.no_grad()
286 2
    def _calculate_probs(
287
        self, X_dict: Dict[str, Any], task_names: Iterable[str]
288
    ) -> Dict[str, Iterable[torch.Tensor]]:
289
        """Calculate the probabilities for each task.
290

291
        Parameters
292
        ----------
293
        X_dict
294
            A dict of data fields
295
        task_names
296
            A list of task names to calculate probabilities for
297

298
        Returns
299
        -------
300
        Dict[str, Iterable[torch.Tensor]]
301
            A dictionary mapping task name to probabilities
302
        """
303

304 2
        self.eval()
305

306 2
        prob_dict = dict()
307

308 2
        outputs = self.forward(X_dict, task_names)
309

310 2
        for task_name in task_names:
311
            # Extract the output of the last operation for this task
312 2
            inputs = outputs[self.op_sequences[task_name][-1].name]
313 2
            prob_dict[task_name] = self.output_funcs[task_name](inputs).cpu().numpy()
314

315 2
        return prob_dict
316

317 2
    @torch.no_grad()
318 2
    def predict(
319
        self,
320
        dataloader: DictDataLoader,
321
        return_preds: bool = False,
322
        remap_labels: Dict[str, Optional[str]] = {},
323
    ) -> Dict[str, Dict[str, torch.Tensor]]:
324
        """Calculate probabilities, (optionally) predictions, and pull out gold labels.
325

326
        Parameters
327
        ----------
328
        dataloader
329
            A DictDataLoader to make predictions for
330
        return_preds
331
            If True, include predictions in the return dict (not just probabilities)
332
        remap_labels
333
            A dict specifying which labels in the dataset's Y_dict (key)
334
            to remap to a new task (value)
335

336
        Returns
337
        -------
338
        Dict[str, Dict[str, torch.Tensor]]
339
            A dictionary mapping label type ('golds', 'probs', 'preds') to values
340
        """
341 2
        self.eval()
342

343 2
        gold_dict_list: Dict[str, List[torch.Tensor]] = defaultdict(list)
344 2
        prob_dict_list: Dict[str, List[torch.Tensor]] = defaultdict(list)
345

346 2
        labels_to_tasks = self._get_labels_to_tasks(
347
            label_names=dataloader.dataset.Y_dict.keys(),  # type: ignore
348
            remap_labels=remap_labels,
349
        )
350 2
        for batch_num, (X_batch_dict, Y_batch_dict) in enumerate(dataloader):
351 2
            prob_batch_dict = self._calculate_probs(
352
                X_batch_dict, labels_to_tasks.values()
353
            )
354 2
            for label_name in labels_to_tasks:
355 2
                task_name = labels_to_tasks[label_name]
356 2
                Y = Y_batch_dict[label_name]
357

358
                # Note: store results under label_name
359
                # but retrieve from pre-computed results using task_name
360 2
                prob_dict_list[label_name].extend(prob_batch_dict[task_name])
361 2
                gold_dict_list[label_name].extend(Y.cpu().numpy())
362

363 2
        gold_dict: Dict[str, np.ndarray] = {}
364 2
        prob_dict: Dict[str, np.ndarray] = {}
365

366 2
        for task_name in gold_dict_list:
367 2
            gold_dict[task_name] = np.array(gold_dict_list[task_name])
368 2
            prob_dict[task_name] = np.array(prob_dict_list[task_name])
369

370 2
        if return_preds:
371 2
            pred_dict: Dict[str, np.ndarray] = defaultdict(list)
372 2
            for task_name, probs in prob_dict.items():
373 2
                pred_dict[task_name] = probs_to_preds(probs)
374

375 2
        results = {"golds": gold_dict, "probs": prob_dict}
376

377 2
        if return_preds:
378 2
            results["preds"] = pred_dict
379

380 2
        return results
381

382 2
    @torch.no_grad()
383 2
    def score(
384
        self,
385
        dataloaders: List[DictDataLoader],
386
        remap_labels: Dict[str, Optional[str]] = {},
387
        as_dataframe: bool = False,
388
    ) -> Union[Dict[str, float], pd.DataFrame]:
389
        """Calculate scores for the provided DictDataLoaders.
390

391
        Parameters
392
        ----------
393
        dataloaders
394
            A list of DictDataLoaders to calculate scores for
395
        remap_labels
396
            A dict specifying which labels in the dataset's Y_dict (key)
397
            to remap to a new task (value)
398
        as_dataframe
399
            A boolean indicating whether to return results as pandas
400
            DataFrame (True) or dict (False)
401

402
        Returns
403
        -------
404
        Dict[str, float]
405
            A dictionary mapping metric names to corresponding scores
406
            Metric names will be of the form "task/dataset/split/metric"
407
        """
408

409 2
        self.eval()
410

411 2
        metric_score_dict = dict()
412

413 2
        for dataloader in dataloaders:
414
            # Construct label to task mapping for evaluation
415 2
            Y_dict = dataloader.dataset.Y_dict  # type: ignore
416 2
            labels_to_tasks = self._get_labels_to_tasks(
417
                label_names=Y_dict.keys(), remap_labels=remap_labels
418
            )
419

420
            # What labels in Y_dict are we ignoring?
421 2
            extra_labels = set(Y_dict.keys()).difference(set(labels_to_tasks.keys()))
422 2
            if extra_labels:
423 2
                logging.info(
424
                    f"Ignoring extra labels in dataloader ({dataloader.dataset.split}): {extra_labels}"  # type: ignore
425
                )
426

427
            # Obtain predictions
428 2
            results = self.predict(
429
                dataloader, return_preds=True, remap_labels=remap_labels
430
            )
431
            # Score and record metrics for each set of predictions
432 2
            for label_name, task_name in labels_to_tasks.items():
433 2
                metric_scores = self.scorers[task_name].score(
434
                    golds=results["golds"][label_name],
435
                    preds=results["preds"][label_name],
436
                    probs=results["probs"][label_name],
437
                )
438

439 2
                for metric_name, metric_value in metric_scores.items():
440
                    # Type ignore statements are necessary because the DataLoader class
441
                    # that DictDataLoader inherits from is what actually sets
442
                    # the class of Dataset, and it doesn't know about name and split.
443 2
                    identifier = "/".join(
444
                        [
445
                            label_name,
446
                            dataloader.dataset.name,  # type: ignore
447
                            dataloader.dataset.split,  # type: ignore
448
                            metric_name,
449
                        ]
450
                    )
451 2
                    metric_score_dict[identifier] = metric_value
452

453 2
        if as_dataframe:
454 2
            return metrics_dict_to_dataframe(metric_score_dict)
455

456 2
        return metric_score_dict
457

458 2
    def _get_labels_to_tasks(
459
        self, label_names: Iterable[str], remap_labels: Dict[str, Optional[str]] = {}
460
    ) -> Dict[str, str]:
461
        """Map each label to its corresponding task outputs based on whether the task is available.
462

463
        If remap_labels specified, overrides specific label -> task mappings.
464
        If a label is mappied to `None`, that key is removed from the mapping.
465
        """
466 2
        labels_to_tasks = {}
467 2
        for label in label_names:
468
            # Override any existing label -> task mappings
469 2
            if label in remap_labels:
470 2
                task = remap_labels.get(label)
471
                # Note: task might be manually remapped to None to remove it from the labels_to_tasks
472 2
                if task is not None:
473 2
                    labels_to_tasks[label] = task
474

475
            # If available in task flows, label should map to task of same name
476 2
            elif label in self.op_sequences:
477 2
                labels_to_tasks[label] = label
478

479 2
        return labels_to_tasks
480

481
    def _move_to_device(self) -> None:  # pragma: no cover
482
        """Move the model to the device specified in the model config."""
483
        device = self.config.device
484
        if device >= 0:
485
            if torch.cuda.is_available():
486
                logging.info(f"Moving model to GPU (cuda:{device}).")
487
                self.to(torch.device(f"cuda:{device}"))
488
            else:
489
                logging.info("No cuda device available. Switch to cpu instead.")
490

491 2
    def save(self, model_path: str) -> None:
492
        """Save the model to the specified file path.
493

494
        Parameters
495
        ----------
496
        model_path
497
            The path where the model should be saved
498

499
        Raises
500
        ------
501
        BaseException
502
            If the torch.save() method fails
503
        """
504 2
        if not os.path.exists(os.path.dirname(model_path)):
505 0
            os.makedirs(os.path.dirname(model_path))
506

507 2
        try:
508 2
            torch.save(self.state_dict(), model_path)
509
        except BaseException:  # pragma: no cover
510
            logging.warning("Saving failed... continuing anyway.")
511

512 2
        logging.info(f"[{self.name}] Model saved in {model_path}")
513

514 2
    def load(self, model_path: str) -> None:
515
        """Load a saved model from the provided file path and moves it to a device.
516

517
        Parameters
518
        ----------
519
        model_path
520
            The path to a saved model
521
        """
522 2
        try:
523 2
            self.load_state_dict(
524
                torch.load(model_path, map_location=torch.device("cpu"))
525
            )
526
        except BaseException:  # pragma: no cover
527
            if not os.path.exists(model_path):
528
                logging.error("Loading failed... Model does not exist.")
529
            else:
530
                logging.error(f"Loading failed... Cannot load model from {model_path}")
531
            raise
532

533 2
        logging.info(f"[{self.name}] Model loaded from {model_path}")
534 2
        self._move_to_device()

Read our documentation on viewing source code .

Loading