1 2
from collections import defaultdict
2 2
from typing import Any, Callable, Dict, List, Tuple
3

4 2
from torch import Tensor
5 2
from torch.utils.data import DataLoader, Dataset
6

7 2
from .utils import list_to_tensor
8

9 2
XDict = Dict[str, Any]
10 2
YDict = Dict[str, Tensor]
11 2
Batch = Tuple[XDict, YDict]
12

13
# Default string names for initializing a DictDataset
14 2
DEFAULT_INPUT_DATA_KEY = "input_data"
15 2
DEFAULT_DATASET_NAME = "SnorkelDataset"
16 2
DEFAULT_TASK_NAME = "task"
17

18

19 2
class DictDataset(Dataset):
20
    """A dataset where both the data fields and labels are stored in as dictionaries.
21

22
    Parameters
23
    ----------
24
    name
25
        The name of the dataset (e.g., this will be used to report metrics on a
26
        per-dataset basis)
27
    split
28
        The name of the split that the data in this object represents
29
    X_dict
30
        A map from field name to values (e.g., {"tokens": ..., "uids": ...})
31
    Y_dict
32
        A map from task name to its corresponding set of labels
33

34
    Raises
35
    ------
36
    ValueError
37
        All values in the ``Y_dict`` must be of type torch.Tensor
38

39
    Attributes
40
    ----------
41
    name
42
        See above
43
    split
44
        See above
45
    X_dict
46
        See above
47
    Y_dict
48
        See above
49
    """
50

51 2
    def __init__(self, name: str, split: str, X_dict: XDict, Y_dict: YDict) -> None:
52 2
        self.name = name
53 2
        self.split = split
54 2
        self.X_dict = X_dict
55 2
        self.Y_dict = Y_dict
56

57 2
        for name, label in self.Y_dict.items():
58 2
            if not isinstance(label, Tensor):
59 0
                raise ValueError(
60
                    f"Label {name} should be torch.Tensor, not {type(label)}."
61
                )
62

63 2
    def __getitem__(self, index: int) -> Tuple[XDict, YDict]:
64 2
        x_dict = {name: feature[index] for name, feature in self.X_dict.items()}
65 2
        y_dict = {name: label[index] for name, label in self.Y_dict.items()}
66 2
        return x_dict, y_dict
67

68 2
    def __len__(self) -> int:
69 2
        try:
70 2
            return len(next(iter(self.Y_dict.values())))  # type: ignore
71 0
        except StopIteration:
72 0
            return 0
73

74
    def __repr__(self) -> str:
75
        return (
76
            f"{type(self).__name__}"
77
            f"(name={self.name}, "
78
            f"X_keys={list(self.X_dict.keys())}, "
79
            f"Y_keys={list(self.Y_dict.keys())})"
80
        )
81

82 2
    @classmethod
83 2
    def from_tensors(
84
        cls,
85
        X_tensor: Tensor,
86
        Y_tensor: Tensor,
87
        split: str,
88
        input_data_key: str = DEFAULT_INPUT_DATA_KEY,
89
        task_name: str = DEFAULT_TASK_NAME,
90
        dataset_name: str = DEFAULT_DATASET_NAME,
91
    ) -> "DictDataset":
92
        """Initialize a ``DictDataset`` from PyTorch Tensors.
93

94
        Parameters
95
        ----------
96
        X_tensor
97
            Input data of shape [num_examples, feature_dim]
98
        Y_tensor
99
            Labels of shape [num_samples, num_classes]
100
        split
101
            Name of data split corresponding to this dataset.
102
        input_data_key
103
            Name of data field to initialize in ``X_dict``
104
        task_name
105
            Name of task and corresponding label key in ``Y_dict``
106
        dataset_name
107
            Name of DictDataset to be initialized; See ``__init__`` above.
108

109
        Returns
110
        -------
111
        DictDataset
112
            Class initialized with single task and label corresponding to input data
113
        """
114 2
        return cls(
115
            name=dataset_name,
116
            split=split,
117
            X_dict={input_data_key: X_tensor},
118
            Y_dict={task_name: Y_tensor},
119
        )
120

121

122 2
def collate_dicts(batch: List[Batch]) -> Batch:
123
    """Combine many one-element dicts into a single many-element dict for both X and Y.
124

125
    Parameters
126
    ----------
127
    batch
128
        A list of (x_dict, y_dict) where the values of each are a single element
129

130
    Returns
131
    -------
132
    Batch
133
        A tuple of X_dict, Y_dict where the values of each are a merged list or tensor
134
    """
135 2
    X_batch: Dict[str, Any] = defaultdict(list)
136 2
    Y_batch: Dict[str, Any] = defaultdict(list)
137

138 2
    for x_dict, y_dict in batch:
139 2
        for field_name, value in x_dict.items():
140 2
            X_batch[field_name].append(value)
141 2
        for label_name, value in y_dict.items():
142 2
            Y_batch[label_name].append(value)
143

144 2
    for field_name, values in X_batch.items():
145
        # Only merge list of tensors
146 2
        if isinstance(values[0], Tensor):
147 2
            X_batch[field_name] = list_to_tensor(values)
148

149 2
    for label_name, values in Y_batch.items():
150 2
        Y_batch[label_name] = list_to_tensor(values)
151

152 2
    return dict(X_batch), dict(Y_batch)
153

154

155 2
class DictDataLoader(DataLoader):
156
    """A DataLoader that uses the appropriate collate_fn for a ``DictDataset``.
157

158
    Parameters
159
    ----------
160
    dataset
161
        A dataset to wrap
162
    collate_fn
163
        The collate function to use when combining multiple indexed examples for a
164
        single batch. Usually the default collate_dicts() method should be used, but
165
        it can be overriden if you want to use different collate logic.
166
    kwargs
167
        Keyword arguments to pass on to DataLoader.__init__()
168
    """
169

170 2
    def __init__(
171
        self,
172
        dataset: DictDataset,
173
        collate_fn: Callable[..., Any] = collate_dicts,
174
        **kwargs: Any,
175
    ) -> None:
176 2
        assert isinstance(dataset, DictDataset)
177 2
        super().__init__(dataset, collate_fn=collate_fn, **kwargs)

Read our documentation on viewing source code .

Loading