#1396 Add cleaner defaults to SliceClassifier

Open Vincent Chen vincentschen
Coverage Reach
classification/training/loggers/checkpointer.py classification/training/loggers/log_manager.py classification/training/loggers/log_writer.py classification/training/loggers/tensorboard_writer.py classification/training/trainer.py classification/training/schedulers/shuffled_scheduler.py classification/training/schedulers/scheduler.py classification/training/schedulers/sequential_scheduler.py classification/training/schedulers/__init__.py classification/multitask_classifier.py classification/data.py classification/utils.py classification/task.py classification/loss.py labeling/model/label_model.py labeling/model/baselines.py labeling/model/logger.py labeling/model/graph_utils.py labeling/lf/nlp.py labeling/lf/core.py labeling/lf/__init__.py labeling/apply/core.py labeling/apply/pandas.py labeling/apply/dask.py labeling/analysis.py labeling/utils.py slicing/utils.py slicing/slicing_classifier.py slicing/monitor.py slicing/modules/slice_combiner.py slicing/sf/core.py slicing/sf/nlp.py slicing/apply/core.py augmentation/apply/core.py augmentation/apply/pandas.py augmentation/policy/core.py augmentation/policy/sampling.py augmentation/tf.py utils/core.py utils/optimizers.py utils/lr_schedulers.py utils/config_utils.py utils/data_operators.py map/core.py analysis/metrics.py analysis/scorer.py analysis/error_analysis.py analysis/__init__.py preprocess/nlp.py preprocess/core.py synthetic/synthetic_data.py types/data.py types/__init__.py types/classifier.py version.py

No flags found

Use flags to group coverage reports by test type, project and/or folders.
Then setup custom commit statuses and notifications for each flag.

e.g., #unittest #integration

#production #enterprise

#frontend #backend

Learn more about Codecov Flags here.


@@ -7,6 +7,7 @@
Loading
7 7
8 8
from snorkel.analysis import Scorer
9 9
from snorkel.classification import DictDataLoader, DictDataset, Operation, Task
10 +
from snorkel.classification.data import DEFAULT_INPUT_DATA_KEY, DEFAULT_TASK_NAME
10 11
from snorkel.classification.multitask_classifier import MultitaskClassifier
11 12
12 13
from .utils import add_slice_labels, convert_to_slice_tasks
@@ -47,8 +48,8 @@
Loading
47 48
        base_architecture: nn.Module,
48 49
        head_dim: int,
49 50
        slice_names: List[str],
50 -
        input_data_key: str,
51 -
        task_name: str,
51 +
        input_data_key: str = DEFAULT_INPUT_DATA_KEY,
52 +
        task_name: str = DEFAULT_TASK_NAME,
52 53
        scorer: Scorer = Scorer(metrics=["accuracy", "f1"]),
53 54
        **multitask_kwargs: Any,
54 55
    ) -> None:

@@ -10,6 +10,11 @@
Loading
10 10
YDict = Dict[str, Tensor]
11 11
Batch = Tuple[XDict, YDict]
12 12
13 +
# Default string names for initializing a DictDataset
14 +
DEFAULT_INPUT_DATA_KEY = "input_data"
15 +
DEFAULT_DATASET_NAME = "SnorkelDataset"
16 +
DEFAULT_TASK_NAME = "task"
17 +
13 18
14 19
class DictDataset(Dataset):
15 20
    """A dataset where both the data fields and labels are stored in as dictionaries.
@@ -29,7 +34,7 @@
Loading
29 34
    Raises
30 35
    ------
31 36
    ValueError
32 -
        All values in the Y_dict must be of type torch.Tensor
37 +
        All values in the ``Y_dict`` must be of type torch.Tensor
33 38
34 39
    Attributes
35 40
    ----------
@@ -74,6 +79,45 @@
Loading
74 79
            f"Y_keys={list(self.Y_dict.keys())})"
75 80
        )
76 81
82 +
    @classmethod
83 +
    def from_tensors(
84 +
        cls,
85 +
        X_tensor: Tensor,
86 +
        Y_tensor: Tensor,
87 +
        split: str,
88 +
        input_data_key: str = DEFAULT_INPUT_DATA_KEY,
89 +
        task_name: str = DEFAULT_TASK_NAME,
90 +
        dataset_name: str = DEFAULT_DATASET_NAME,
91 +
    ) -> "DictDataset":
92 +
        """Initialize a ``DictDataset`` from PyTorch Tensors.
93 +
94 +
        Parameters
95 +
        ----------
96 +
        X_tensor
97 +
            Input data of shape [num_examples, feature_dim]
98 +
        Y_tensor
99 +
            Labels of shape [num_samples, num_classes]
100 +
        split
101 +
            Name of data split corresponding to this dataset.
102 +
        input_data_key
103 +
            Name of data field to initialize in ``X_dict``
104 +
        task_name
105 +
            Name of task and corresponding label key in ``Y_dict``
106 +
        dataset_name
107 +
            Name of DictDataset to be initialized; See ``__init__`` above.
108 +
109 +
        Returns
110 +
        -------
111 +
        DictDataset
112 +
            Class initialized with single task and label corresponding to input data
113 +
        """
114 +
        return cls(
115 +
            name=dataset_name,
116 +
            split=split,
117 +
            X_dict={input_data_key: X_tensor},
118 +
            Y_dict={task_name: Y_tensor},
119 +
        )
120 +
77 121
78 122
def collate_dicts(batch: List[Batch]) -> Batch:
79 123
    """Combine many one-element dicts into a single many-element dict for both X and Y.

Everything is accounted for!

No changes detected that need to be reviewed.
What changes does Codecov check for?
Lines, not adjusted in diff, that have changed coverage data.
Files that introduced coverage data that had none before.
Files that have missing coverage data that once were tracked.
Files Coverage
snorkel 0.01% 97.55%
Project Totals (55 files) 97.55%
Loading