1 2
from typing import List, Optional, Sequence
2

3 2
import numpy as np
4

5 2
from .core import Policy
6

7

8 2
class MeanFieldPolicy(Policy):
9
    """Sample sequences of TFs according to a distribution.
10

11
    Samples sequences of indices of a specified length from a
12
    user-provided distribution. A distribution over TFs can be
13
    learned by a TANDA mean-field model, for example.
14
    See https://hazyresearch.github.io/snorkel/blog/tanda.html
15

16
    Parameters
17
    ----------
18
    n_tfs
19
        Total number of TFs
20
    sequence_length
21
        Number of TFs to run on each data point
22
    p
23
        Probability distribution from which to sample TF indices.
24
        Must have length ``n_tfs`` and be a valid distribution.
25
    n_per_original
26
        Number of transformed data points per original
27
    keep_original
28
        Keep untransformed data point in augmented data set? Note that
29
        even if in-place modifications are made to the original data
30
        point by the TFs being applied, the original data point will
31
        remain unchanged.
32

33
    Attributes
34
    ----------
35
    n
36
        Total number of TFs
37
    n_per_original
38
        See above
39
    keep_original
40
        See above
41
    sequence_length
42
        See above
43
    """
44

45 2
    def __init__(
46
        self,
47
        n_tfs: int,
48
        sequence_length: int = 1,
49
        p: Optional[Sequence[float]] = None,
50
        n_per_original: int = 1,
51
        keep_original: bool = True,
52
    ) -> None:
53 2
        self.sequence_length = sequence_length
54 2
        self._p = p
55 2
        super().__init__(
56
            n_tfs, n_per_original=n_per_original, keep_original=keep_original
57
        )
58

59 2
    def generate(self) -> List[int]:
60
        """Generate a sequence of TF indices by sampling from distribution.
61

62
        Returns
63
        -------
64
        List[int]
65
            Indices of TFs to run on data point in order.
66
        """
67 2
        return np.random.choice(self.n, size=self.sequence_length, p=self._p).tolist()
68

69

70 2
class RandomPolicy(MeanFieldPolicy):
71
    """Naive random augmentation policy.
72

73
    Samples sequences of TF indices a specified length at random
74
    from the total number of TFs. Sampling uniformly at random is
75
    a common baseline approach to data augmentation.
76

77
    Parameters
78
    ----------
79
    n_tfs
80
        Total number of TFs
81
    sequence_length
82
        Number of TFs to run on each data point
83
    n_per_original
84
        Number of transformed data points per original
85
    keep_original
86
        Keep untransformed data point in augmented data set? Note that
87
        even if in-place modifications are made to the original data
88
        point by the TFs being applied, the original data point will
89
        remain unchanged.
90

91
    Attributes
92
    ----------
93
    n
94
        Total number of TFs
95
    n_per_original
96
        See above
97
    keep_original
98
        See above
99
    sequence_length
100
        See above
101
    """
102

103 2
    def __init__(
104
        self,
105
        n_tfs: int,
106
        sequence_length: int = 1,
107
        n_per_original: int = 1,
108
        keep_original: bool = True,
109
    ) -> None:
110 2
        super().__init__(
111
            n_tfs,
112
            sequence_length=sequence_length,
113
            p=None,
114
            n_per_original=n_per_original,
115
            keep_original=keep_original,
116
        )

Read our documentation on viewing source code .

Loading