1 2 ```from typing import List, Optional, Sequence ``` 2 3 2 ```import numpy as np ``` 4 5 2 ```from .core import Policy ``` 6 7 8 2 ```class MeanFieldPolicy(Policy): ``` 9 ``` """Sample sequences of TFs according to a distribution. ``` 10 11 ``` Samples sequences of indices of a specified length from a ``` 12 ``` user-provided distribution. A distribution over TFs can be ``` 13 ``` learned by a TANDA mean-field model, for example. ``` 14 ``` See https://hazyresearch.github.io/snorkel/blog/tanda.html ``` 15 16 ``` Parameters ``` 17 ``` ---------- ``` 18 ``` n_tfs ``` 19 ``` Total number of TFs ``` 20 ``` sequence_length ``` 21 ``` Number of TFs to run on each data point ``` 22 ``` p ``` 23 ``` Probability distribution from which to sample TF indices. ``` 24 ``` Must have length ``n_tfs`` and be a valid distribution. ``` 25 ``` n_per_original ``` 26 ``` Number of transformed data points per original ``` 27 ``` keep_original ``` 28 ``` Keep untransformed data point in augmented data set? Note that ``` 29 ``` even if in-place modifications are made to the original data ``` 30 ``` point by the TFs being applied, the original data point will ``` 31 ``` remain unchanged. ``` 32 33 ``` Attributes ``` 34 ``` ---------- ``` 35 ``` n ``` 36 ``` Total number of TFs ``` 37 ``` n_per_original ``` 38 ``` See above ``` 39 ``` keep_original ``` 40 ``` See above ``` 41 ``` sequence_length ``` 42 ``` See above ``` 43 ``` """ ``` 44 45 2 ``` def __init__( ``` 46 ``` self, ``` 47 ``` n_tfs: int, ``` 48 ``` sequence_length: int = 1, ``` 49 ``` p: Optional[Sequence[float]] = None, ``` 50 ``` n_per_original: int = 1, ``` 51 ``` keep_original: bool = True, ``` 52 ``` ) -> None: ``` 53 2 ``` self.sequence_length = sequence_length ``` 54 2 ``` self._p = p ``` 55 2 ``` super().__init__( ``` 56 ``` n_tfs, n_per_original=n_per_original, keep_original=keep_original ``` 57 ``` ) ``` 58 59 2 ``` def generate(self) -> List[int]: ``` 60 ``` """Generate a sequence of TF indices by sampling from distribution. ``` 61 62 ``` Returns ``` 63 ``` ------- ``` 64 ``` List[int] ``` 65 ``` Indices of TFs to run on data point in order. ``` 66 ``` """ ``` 67 2 ``` return np.random.choice(self.n, size=self.sequence_length, p=self._p).tolist() ``` 68 69 70 2 ```class RandomPolicy(MeanFieldPolicy): ``` 71 ``` """Naive random augmentation policy. ``` 72 73 ``` Samples sequences of TF indices a specified length at random ``` 74 ``` from the total number of TFs. Sampling uniformly at random is ``` 75 ``` a common baseline approach to data augmentation. ``` 76 77 ``` Parameters ``` 78 ``` ---------- ``` 79 ``` n_tfs ``` 80 ``` Total number of TFs ``` 81 ``` sequence_length ``` 82 ``` Number of TFs to run on each data point ``` 83 ``` n_per_original ``` 84 ``` Number of transformed data points per original ``` 85 ``` keep_original ``` 86 ``` Keep untransformed data point in augmented data set? Note that ``` 87 ``` even if in-place modifications are made to the original data ``` 88 ``` point by the TFs being applied, the original data point will ``` 89 ``` remain unchanged. ``` 90 91 ``` Attributes ``` 92 ``` ---------- ``` 93 ``` n ``` 94 ``` Total number of TFs ``` 95 ``` n_per_original ``` 96 ``` See above ``` 97 ``` keep_original ``` 98 ``` See above ``` 99 ``` sequence_length ``` 100 ``` See above ``` 101 ``` """ ``` 102 103 2 ``` def __init__( ``` 104 ``` self, ``` 105 ``` n_tfs: int, ``` 106 ``` sequence_length: int = 1, ``` 107 ``` n_per_original: int = 1, ``` 108 ``` keep_original: bool = True, ``` 109 ``` ) -> None: ``` 110 2 ``` super().__init__( ``` 111 ``` n_tfs, ``` 112 ``` sequence_length=sequence_length, ``` 113 ``` p=None, ``` 114 ``` n_per_original=n_per_original, ``` 115 ``` keep_original=keep_original, ``` 116 ``` ) ```

Read our documentation on viewing source code .