1 2
import logging
2 2
import os
3 2
from collections import defaultdict
4 2
from typing import Any, DefaultDict, Dict, List, Optional
5

6 2
import torch
7 2
import torch.nn as nn
8 2
import torch.optim as optim
9 2
from tqdm import tqdm
10

11 2
from snorkel.classification.data import DictDataLoader  # noqa: F401
12 2
from snorkel.classification.multitask_classifier import (
13
    ClassifierConfig,
14
    MultitaskClassifier,
15
)
16 2
from snorkel.types import Config
17 2
from snorkel.utils.config_utils import merge_config
18 2
from snorkel.utils.lr_schedulers import LRSchedulerConfig
19 2
from snorkel.utils.optimizers import OptimizerConfig
20

21 2
from .loggers import (
22
    Checkpointer,
23
    CheckpointerConfig,
24
    LogManager,
25
    LogManagerConfig,
26
    LogWriter,
27
    LogWriterConfig,
28
    TensorBoardWriter,
29
)
30 2
from .schedulers import batch_schedulers
31

32 2
Metrics = Dict[str, float]
33

34

35 2
class TrainerConfig(Config):
36
    """Settings for the Trainer.
37

38
    Parameters
39
    ----------
40
    seed
41
        A random seed to set before training; if None, no seed is set
42
    n_epochs
43
        The number of epochs to train
44
    lr
45
        Base learning rate (will also be affected by lr_scheduler choice and settings)
46
    l2
47
        L2 regularization coefficient (weight decay)
48
    grad_clip
49
        The value that the gradient norm will be clipped to if it exceeds it
50
    train_split
51
        The name of the split to use as the training set
52
    valid_split
53
        The name of the split to use as the validation set
54
    test_split
55
        The name of the split to use as the test set
56
    progress_bar
57
        If True, print a tqdm progress bar during training
58
    model_config
59
        Settings for the MultitaskClassifier
60
    log_manager_config
61
        Settings for the LogManager
62
    checkpointing
63
        If True, use a Checkpointer to save the best model during training
64
    checkpointer_config
65
        Settings for the Checkpointer
66
    logging
67
        If True, log metrics (to file or Tensorboard) during training
68
    log_writer
69
        The type of LogWriter to use (one of ["json", "tensorboard"])
70
    log_writer_config
71
        Settings for the LogWriter
72
    optimizer
73
        Which optimizer to use (one of ["sgd", "adam", "adamax"])
74
    optimizer_config
75
        Settings for the optimizer
76
    lr_scheduler
77
        Which lr_scheduler to use (one of ["constant", "linear", "exponential", "step"])
78
    lr_scheduler_config
79
        Settings for the LRScheduler
80
    batch_scheduler
81
        Which batch scheduler to use (in what order batches will be drawn from multiple
82
        tasks)
83
    """
84

85 2
    seed: Optional[int] = None
86 2
    n_epochs: int = 1
87 2
    lr: float = 0.01
88 2
    l2: float = 0.0
89 2
    grad_clip: float = 1.0
90 2
    train_split: str = "train"
91 2
    valid_split: str = "valid"
92 2
    test_split: str = "test"
93 2
    progress_bar: bool = True
94 2
    model_config: ClassifierConfig = ClassifierConfig()  # type:ignore
95 2
    log_manager_config: LogManagerConfig = LogManagerConfig()  # type:ignore
96 2
    checkpointing: bool = False
97 2
    checkpointer_config: CheckpointerConfig = CheckpointerConfig()  # type:ignore
98 2
    logging: bool = False
99 2
    log_writer: str = "tensorboard"
100 2
    log_writer_config: LogWriterConfig = LogWriterConfig()  # type:ignore
101 2
    optimizer: str = "adam"
102 2
    optimizer_config: OptimizerConfig = OptimizerConfig()  # type:ignore
103 2
    lr_scheduler: str = "constant"
104 2
    lr_scheduler_config: LRSchedulerConfig = LRSchedulerConfig()  # type:ignore
105 2
    batch_scheduler: str = "shuffled"
106

107

108 2
class Trainer:
109
    """A class for training a MultitaskClassifier.
110

111
    Parameters
112
    ----------
113
    name
114
        An optional name for this trainer object
115
    kwargs
116
        Settings to be merged into the default Trainer config dict
117

118
    Attributes
119
    ----------
120
    name
121
        See above
122
    config
123
        The config dict with the settings for the Trainer
124
    checkpointer
125
        Saves the best model seen during training
126
    log_manager
127
        Identifies when its time to log or evaluate on the valid set
128
    log_writer
129
        Writes training statistics to file or TensorBoard
130
    optimizer
131
        Updates model weights based on the loss
132
    lr_scheduler
133
        Adjusts the learning rate over the course of training
134
    batch_scheduler
135
        Returns batches from the DataLoaders in a particular order for training
136
    """
137

138 2
    def __init__(self, name: Optional[str] = None, **kwargs: Any) -> None:
139 2
        self.config: TrainerConfig = merge_config(  # type:ignore
140
            TrainerConfig(), kwargs  # type:ignore
141
        )
142 2
        self.name = name if name is not None else type(self).__name__
143

144 2
    def fit(
145
        self, model: MultitaskClassifier, dataloaders: List["DictDataLoader"]
146
    ) -> None:
147
        """Train a MultitaskClassifier.
148

149
        Parameters
150
        ----------
151
        model
152
            The model to train
153
        dataloaders
154
            A list of DataLoaders. These will split into train, valid, and test splits
155
            based on the ``split`` attribute of the DataLoaders.
156
        """
157 2
        self._check_dataloaders(dataloaders)
158

159
        # Identify the dataloaders to train on
160 2
        train_dataloaders = [
161
            dl
162
            for dl in dataloaders
163
            if dl.dataset.split == self.config.train_split  # type: ignore
164
        ]
165

166
        # Calculate the total number of batches per epoch
167 2
        self.n_batches_per_epoch = sum(
168
            [len(dataloader) for dataloader in train_dataloaders]
169
        )
170

171
        # Set training helpers
172 2
        self._set_log_writer()
173 2
        self._set_checkpointer()
174 2
        self._set_log_manager()
175 2
        self._set_optimizer(model)
176 2
        self._set_lr_scheduler()
177 2
        self._set_batch_scheduler()
178

179
        # Set to training mode
180 2
        model.train()
181

182 2
        logging.info("Start training...")
183

184 2
        self.metrics: Dict[str, float] = dict()
185 2
        self._reset_losses()
186

187 2
        for epoch_num in range(self.config.n_epochs):
188 2
            batches = tqdm(
189
                enumerate(self.batch_scheduler.get_batches(train_dataloaders)),
190
                total=self.n_batches_per_epoch,
191
                disable=(not self.config.progress_bar),
192
                desc=f"Epoch {epoch_num}:",
193
            )
194 2
            for batch_num, (batch, dataloader) in batches:
195 2
                X_dict, Y_dict = batch
196

197 2
                total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num
198 2
                batch_size = len(next(iter(Y_dict.values())))
199

200
                # Set gradients of all model parameters to zero
201 2
                self.optimizer.zero_grad()
202

203
                # Perform forward pass and calcualte the loss and count
204 2
                loss_dict, count_dict = model.calculate_loss(X_dict, Y_dict)
205

206
                # Update running loss and count
207 2
                for task_name in loss_dict.keys():
208 2
                    identifier = "/".join(
209
                        [
210
                            task_name,
211
                            dataloader.dataset.name,
212
                            dataloader.dataset.split,
213
                            "loss",
214
                        ]
215
                    )
216 2
                    self.running_losses[identifier] += (
217
                        loss_dict[task_name].item() * count_dict[task_name]
218
                    )
219 2
                    self.running_counts[task_name] += count_dict[task_name]
220

221
                # Skip the backward pass if no loss is calcuated
222 2
                if not loss_dict:
223 0
                    continue
224

225
                # Calculate the average loss
226 2
                loss = torch.stack(list(loss_dict.values())).sum()
227

228
                # Perform backward pass to calculate gradients
229 2
                loss.backward()
230

231
                # Clip gradient norm
232 2
                if self.config.grad_clip:
233 2
                    torch.nn.utils.clip_grad_norm_(
234
                        model.parameters(), self.config.grad_clip
235
                    )
236

237
                # Update the parameters
238 2
                self.optimizer.step()
239

240
                # Update lr using lr scheduler
241 2
                self._update_lr_scheduler(total_batch_num)
242

243
                # Update metrics
244 2
                self.metrics.update(self._logging(model, dataloaders, batch_size))
245

246 2
                batches.set_postfix(self.metrics)
247

248 2
        model = self.log_manager.cleanup(model)
249

250 2
    def _check_dataloaders(self, dataloaders: List["DictDataLoader"]) -> None:
251
        """Validate the dataloader splits."""
252 2
        train_split = self.config.train_split
253 2
        valid_split = self.config.valid_split
254 2
        test_split = self.config.test_split
255

256 2
        all_splits = [train_split, valid_split, test_split]
257 2
        if not all(d.dataset.split in all_splits for d in dataloaders):  # type: ignore
258 2
            raise ValueError(f"Dataloader splits must be one of {all_splits}")
259

260 2
        if not any(d.dataset.split == train_split for d in dataloaders):  # type: ignore
261 2
            raise ValueError(
262
                f"Cannot find any dataloaders with split matching train split: "
263
                f"{self.config.train_split}."
264
            )
265

266 2
    def _set_log_writer(self) -> None:
267 2
        self.log_writer: Optional[LogWriter] = None
268 2
        if self.config.logging:
269 2
            if self.config.log_writer == "json":
270 2
                self.log_writer = LogWriter(**self.config.log_writer_config._asdict())
271 2
            elif self.config.log_writer == "tensorboard":
272 2
                self.log_writer = TensorBoardWriter(
273
                    **self.config.log_writer_config._asdict()
274
                )
275
            else:
276 2
                raise ValueError(
277
                    f"Unrecognized writer option: {self.config.log_writer}"
278
                )
279

280 2
    def _set_checkpointer(self) -> None:
281 2
        self.checkpointer: Optional[Checkpointer]
282

283 2
        if self.config.checkpointing:
284 2
            checkpointer_config = self.config.checkpointer_config
285 2
            evaluation_freq = self.config.log_manager_config.evaluation_freq
286 2
            counter_unit = self.config.log_manager_config.counter_unit
287 2
            self.checkpointer = Checkpointer(
288
                counter_unit, evaluation_freq, **checkpointer_config._asdict()
289
            )
290
        else:
291 2
            self.checkpointer = None
292

293 2
    def _set_log_manager(self) -> None:
294 2
        self.log_manager = LogManager(
295
            self.n_batches_per_epoch,
296
            log_writer=self.log_writer,
297
            checkpointer=self.checkpointer,
298
            **self.config.log_manager_config._asdict(),
299
        )
300

301 2
    def _set_optimizer(self, model: nn.Module) -> None:
302 2
        optimizer_config = self.config.optimizer_config
303 2
        optimizer_name = self.config.optimizer
304

305 2
        parameters = filter(lambda p: p.requires_grad, model.parameters())
306

307
        optimizer: optim.Optimizer  # type: ignore
308

309 2
        if optimizer_name == "sgd":
310 2
            optimizer = optim.SGD(  # type: ignore
311
                parameters,
312
                lr=self.config.lr,
313
                weight_decay=self.config.l2,
314
                **optimizer_config.sgd_config._asdict(),
315
            )
316 2
        elif optimizer_name == "adam":
317 2
            optimizer = optim.Adam(
318
                parameters,
319
                lr=self.config.lr,
320
                weight_decay=self.config.l2,
321
                **optimizer_config.adam_config._asdict(),
322
            )
323 2
        elif optimizer_name == "adamax":
324 2
            optimizer = optim.Adamax(  # type: ignore
325
                parameters,
326
                lr=self.config.lr,
327
                weight_decay=self.config.l2,
328
                **optimizer_config.adamax_config._asdict(),
329
            )
330
        else:
331 2
            raise ValueError(f"Unrecognized optimizer option '{optimizer_name}'")
332

333 2
        logging.info(f"Using optimizer {optimizer}")
334

335 2
        self.optimizer = optimizer
336

337 2
    def _set_lr_scheduler(self) -> None:
338
        # Set warmup scheduler
339 2
        self._set_warmup_scheduler()
340

341
        # Set lr scheduler
342 2
        lr_scheduler_name = self.config.lr_scheduler
343 2
        lr_scheduler_config = self.config.lr_scheduler_config
344
        lr_scheduler: Optional[optim.lr_scheduler._LRScheduler]
345

346 2
        if lr_scheduler_name == "constant":
347 2
            lr_scheduler = None
348 2
        elif lr_scheduler_name == "linear":
349 2
            total_steps = self.n_batches_per_epoch * self.config.n_epochs
350 2
            linear_decay_func = lambda x: (total_steps - self.warmup_steps - x) / (
351
                total_steps - self.warmup_steps
352
            )
353 2
            lr_scheduler = optim.lr_scheduler.LambdaLR(  # type: ignore
354
                self.optimizer, linear_decay_func
355
            )
356 2
        elif lr_scheduler_name == "exponential":
357 2
            lr_scheduler = optim.lr_scheduler.ExponentialLR(
358
                self.optimizer, **lr_scheduler_config.exponential_config._asdict()
359
            )
360 2
        elif lr_scheduler_name == "step":
361 2
            lr_scheduler = optim.lr_scheduler.StepLR(
362
                self.optimizer, **lr_scheduler_config.step_config._asdict()
363
            )
364
        else:
365 2
            raise ValueError(f"Unrecognized lr scheduler option '{lr_scheduler_name}'")
366

367 2
        self.lr_scheduler = lr_scheduler
368

369 2
    def _set_warmup_scheduler(self) -> None:
370
        warmup_scheduler: Optional[optim.lr_scheduler.LambdaLR]
371

372 2
        if self.config.lr_scheduler_config.warmup_steps:
373 2
            warmup_steps = self.config.lr_scheduler_config.warmup_steps
374 2
            if warmup_steps < 0:
375 0
                raise ValueError("warmup_steps much greater or equal than 0.")
376 2
            warmup_unit = self.config.lr_scheduler_config.warmup_unit
377 2
            if warmup_unit == "epochs":
378 2
                self.warmup_steps = int(warmup_steps * self.n_batches_per_epoch)
379 2
            elif warmup_unit == "batches":
380 2
                self.warmup_steps = int(warmup_steps)
381
            else:
382 0
                raise ValueError(
383
                    f"warmup_unit must be 'batches' or 'epochs', but {warmup_unit} found."
384
                )
385 2
            linear_warmup_func = lambda x: x / self.warmup_steps
386 2
            warmup_scheduler = optim.lr_scheduler.LambdaLR(  # type: ignore
387
                self.optimizer, linear_warmup_func
388
            )
389 2
            logging.info(f"Warmup {self.warmup_steps} batches.")
390 2
        elif self.config.lr_scheduler_config.warmup_percentage:
391 2
            warmup_percentage = self.config.lr_scheduler_config.warmup_percentage
392 2
            self.warmup_steps = int(
393
                warmup_percentage * self.config.n_epochs * self.n_batches_per_epoch
394
            )
395 2
            linear_warmup_func = lambda x: x / self.warmup_steps
396 2
            warmup_scheduler = optim.lr_scheduler.LambdaLR(  # type: ignore
397
                self.optimizer, linear_warmup_func
398
            )
399 2
            logging.info(f"Warmup {self.warmup_steps} batches.")
400
        else:
401 2
            warmup_scheduler = None
402 2
            self.warmup_steps = 0
403

404 2
        self.warmup_scheduler = warmup_scheduler
405

406 2
    def _update_lr_scheduler(self, step: int) -> None:
407 2
        if self.warmup_scheduler and step < self.warmup_steps:
408 2
            self.warmup_scheduler.step()  # type: ignore
409 2
        elif self.lr_scheduler is not None:
410 2
            self.lr_scheduler.step()  # type: ignore
411 2
            min_lr = self.config.lr_scheduler_config.min_lr
412 2
            if min_lr and self.optimizer.param_groups[0]["lr"] < min_lr:
413 0
                self.optimizer.param_groups[0]["lr"] = min_lr
414

415 2
    def _set_batch_scheduler(self) -> None:
416 2
        scheduler_class = batch_schedulers.get(self.config.batch_scheduler)
417 2
        if not scheduler_class:
418 0
            raise ValueError(f"Unrecognized batch scheduler option '{scheduler_class}'")
419

420 2
        self.batch_scheduler = scheduler_class()  # type: ignore
421

422 2
    def _evaluate(
423
        self,
424
        model: MultitaskClassifier,
425
        dataloaders: List["DictDataLoader"],
426
        split: str,
427
    ) -> Metrics:
428
        """Evalute the current quality of the model on data for the requested split."""
429 2
        loaders = [d for d in dataloaders if d.dataset.split in split]  # type: ignore
430 2
        return model.score(loaders)
431

432 2
    def _logging(
433
        self,
434
        model: MultitaskClassifier,
435
        dataloaders: List["DictDataLoader"],
436
        batch_size: int,
437
    ) -> Metrics:
438
        """Log and checkpoint if it is time to do so."""
439

440
        # Switch to eval mode for evaluation
441 2
        model.eval()
442

443 2
        self.log_manager.update(batch_size)
444

445
        # Log the loss and lr
446 2
        metric_dict: Metrics = dict()
447 2
        metric_dict.update(self._aggregate_losses())
448

449
        # Evaluate the model and log the metric
450 2
        if self.log_manager.trigger_evaluation():
451

452
            # Log metrics
453 2
            metric_dict.update(
454
                self._evaluate(model, dataloaders, self.config.valid_split)
455
            )
456 2
            self._log_metrics(metric_dict)
457 2
            self._reset_losses()
458

459
        # Checkpoint the model
460 2
        if self.log_manager.trigger_checkpointing():
461 2
            self._checkpoint_model(model, metric_dict)
462 2
            self._reset_losses()
463

464
        # Switch back to train mode
465 2
        model.train()
466 2
        return metric_dict
467

468 2
    def _log_metrics(self, metric_dict: Metrics) -> None:
469 2
        if self.log_writer is not None:
470 2
            for metric_name, metric_value in metric_dict.items():
471 2
                self.log_writer.add_scalar(
472
                    metric_name, metric_value, self.log_manager.point_total
473
                )
474

475 2
    def _checkpoint_model(
476
        self, model: MultitaskClassifier, metric_dict: Metrics
477
    ) -> None:
478
        """Save the current model."""
479 2
        if self.checkpointer is not None:
480 2
            self.checkpointer.checkpoint(
481
                self.log_manager.unit_total, model, metric_dict
482
            )
483

484 2
    def _aggregate_losses(self) -> Metrics:
485
        """Calculate the task specific loss, average micro loss and learning rate."""
486

487 2
        metric_dict = dict()
488

489
        # Log task specific loss
490 2
        self.running_losses: DefaultDict[str, float]
491 2
        self.running_counts: DefaultDict[str, float]
492 2
        for identifier in self.running_losses.keys():
493 2
            if self.running_counts[identifier] > 0:
494 0
                metric_dict[identifier] = (
495
                    self.running_losses[identifier] / self.running_counts[identifier]
496
                )
497

498
        # Calculate average micro loss
499 2
        total_loss = sum(self.running_losses.values())
500 2
        total_count = sum(self.running_counts.values())
501 2
        if total_count > 0:
502 2
            metric_dict["model/all/train/loss"] = total_loss / total_count
503

504
        # Log the learning rate
505 2
        metric_dict["model/all/train/lr"] = self.optimizer.param_groups[0]["lr"]
506

507 2
        return metric_dict
508

509 2
    def _reset_losses(self) -> None:
510
        """Reset the loss counters."""
511 2
        self.running_losses = defaultdict(float)
512 2
        self.running_counts = defaultdict(int)
513

514 2
    def save(self, trainer_path: str) -> None:
515
        """Save the trainer config to the specified file path in json format.
516

517
        Parameters
518
        ----------
519
        trainer_path
520
            The path where trainer config and optimizer state should be saved.
521
        """
522

523 2
        head, tail = os.path.split(trainer_path)
524

525 2
        if not os.path.exists(head):
526 0
            os.makedirs(os.path.dirname(head))
527 2
        try:
528 2
            torch.save(
529
                {
530
                    "trainer_config": self.config._asdict(),
531
                    "optimizer_state_dict": self.optimizer.state_dict(),
532
                },
533
                trainer_path,
534
            )
535
        except BaseException:  # pragma: no cover
536
            logging.warning("Saving failed... continuing anyway.")
537

538 2
        logging.info(f"[{self.name}] Trainer config saved in {trainer_path}")
539

540 2
    def load(self, trainer_path: str, model: Optional[MultitaskClassifier]) -> None:
541
        """Load trainer config and optimizer state from the specified json file path to the trainer object. The optimizer state is stored, too. However, it only makes sense if loaded with the correct model again.
542

543
        Parameters
544
        ----------
545
        trainer_path
546
            The path to the saved trainer config to be loaded
547
        model
548
            MultitaskClassifier for which the optimizer has been set. Parameters of optimizer must fit to model parameters. This model
549
            shall be the model which was fit by the stored Trainer.
550

551
        Example
552
        -------
553
        Saving model and corresponding trainer:
554
        >>> model.save('./my_saved_model_file') # doctest: +SKIP
555
        >>> trainer.save('./my_saved_trainer_file') # doctest: +SKIP
556
        Now we can resume training and load the saved model and trainer into new model and trainer objects:
557
        >>> new_model.load('./my_saved_model_file') # doctest: +SKIP
558
        >>> new_trainer.load('./my_saved_trainer_file', model=new_model) # doctest: +SKIP
559
        >>> new_trainer.fit(...) # doctest: +SKIP
560
        """
561

562 2
        try:
563 2
            saved_state = torch.load(trainer_path)
564 0
        except BaseException:
565 2
            if not os.path.exists(trainer_path):
566 0
                logging.error("Loading failed... Trainer config does not exist.")
567
            else:
568 0
                logging.error(
569
                    f"Loading failed... Cannot load trainer config from {trainer_path}"
570
                )
571 0
            raise
572

573 2
        self.config = TrainerConfig(**saved_state["trainer_config"])
574 2
        logging.info(f"[{self.name}] Trainer config loaded from {trainer_path}")
575

576 2
        if model is not None:
577 2
            try:
578 2
                self._set_optimizer(model)
579 2
                self.optimizer.load_state_dict(saved_state["optimizer_state_dict"])
580 2
                logging.info(f"[{self.name}] Optimizer loaded from {trainer_path}")
581 0
            except BaseException:
582 0
                logging.error(
583
                    "Loading the optimizer for your model failed. Optimizer state NOT loaded."
584
                )

Read our documentation on viewing source code .

Loading