1 2
from abc import ABC, abstractmethod
2 2
from typing import Any, Dict, Iterator, Sequence, Tuple
3

4 2
from torch import Tensor
5

6 2
from snorkel.classification.data import DictDataLoader  # noqa: F401
7

8 2
BatchIterator = Iterator[
9
    Tuple[Tuple[Dict[str, Any], Dict[str, Tensor]], "DictDataLoader"]
10
]
11

12

13 2
class Scheduler(ABC):
14
    """Return batches from all dataloaders according to a specified strategy."""
15

16 2
    def __init__(self) -> None:
17 2
        pass
18

19 2
    @abstractmethod
20 2
    def get_batches(self, dataloaders: Sequence["DictDataLoader"]) -> BatchIterator:
21
        """Return batches from dataloaders according to a specified strategy.
22

23
        Parameters
24
        ----------
25
        dataloaders
26
            A sequence of dataloaders to get batches from
27

28
        Yields
29
        ------
30
        (batch, dataloader)
31
            batch is a tuple of (X_dict, Y_dict) and dataloader is the dataloader
32
            that that batch came from. That dataloader will not be accessed by the
33
            model; it is passed primarily so that the model can pull the necessary
34
            metadata to know what to do with the batch it has been given.
35
        """

Read our documentation on viewing source code .

Loading