Addresses #1602. Added a method to analysis/erroranalysis that wraps getlabel_buckets functionality. Given a bucket, a NumPy array x of your data, and corresponding y label(s), it will return to you x with only the instances corresponding to that bucket.
1  2 
from typing import List, Optional, Sequence 
2  
3  2 
import numpy as np 
4  
5  2 
from .core import Policy 
6  
7  
8  2 
class MeanFieldPolicy(Policy): 
9 
"""Sample sequences of TFs according to a distribution.


10  
11 
Samples sequences of indices of a specified length from a


12 
userprovided distribution. A distribution over TFs can be


13 
learned by a TANDA meanfield model, for example.


14 
See https://hazyresearch.github.io/snorkel/blog/tanda.html


15  
16 
Parameters


17 



18 
n_tfs


19 
Total number of TFs


20 
sequence_length


21 
Number of TFs to run on each data point


22 
p


23 
Probability distribution from which to sample TF indices.


24 
Must have length ``n_tfs`` and be a valid distribution.


25 
n_per_original


26 
Number of transformed data points per original


27 
keep_original


28 
Keep untransformed data point in augmented data set? Note that


29 
even if inplace modifications are made to the original data


30 
point by the TFs being applied, the original data point will


31 
remain unchanged.


32  
33 
Attributes


34 



35 
n


36 
Total number of TFs


37 
n_per_original


38 
See above


39 
keep_original


40 
See above


41 
sequence_length


42 
See above


43 
"""


44  
45  2 
def __init__( 
46 
self, 

47 
n_tfs: int, 

48 
sequence_length: int = 1, 

49 
p: Optional[Sequence[float]] = None, 

50 
n_per_original: int = 1, 

51 
keep_original: bool = True, 

52 
) > None: 

53  2 
self.sequence_length = sequence_length 
54  2 
self._p = p 
55  2 
super().__init__( 
56 
n_tfs, n_per_original=n_per_original, keep_original=keep_original 

57 
)


58  
59  2 
def generate(self) > List[int]: 
60 
"""Generate a sequence of TF indices by sampling from distribution.


61  
62 
Returns


63 



64 
List[int]


65 
Indices of TFs to run on data point in order.


66 
"""


67  2 
return np.random.choice(self.n, size=self.sequence_length, p=self._p).tolist() 
68  
69  
70  2 
class RandomPolicy(MeanFieldPolicy): 
71 
"""Naive random augmentation policy.


72  
73 
Samples sequences of TF indices a specified length at random


74 
from the total number of TFs. Sampling uniformly at random is


75 
a common baseline approach to data augmentation.


76  
77 
Parameters


78 



79 
n_tfs


80 
Total number of TFs


81 
sequence_length


82 
Number of TFs to run on each data point


83 
n_per_original


84 
Number of transformed data points per original


85 
keep_original


86 
Keep untransformed data point in augmented data set? Note that


87 
even if inplace modifications are made to the original data


88 
point by the TFs being applied, the original data point will


89 
remain unchanged.


90  
91 
Attributes


92 



93 
n


94 
Total number of TFs


95 
n_per_original


96 
See above


97 
keep_original


98 
See above


99 
sequence_length


100 
See above


101 
"""


102  
103  2 
def __init__( 
104 
self, 

105 
n_tfs: int, 

106 
sequence_length: int = 1, 

107 
n_per_original: int = 1, 

108 
keep_original: bool = True, 

109 
) > None: 

110  2 
super().__init__( 
111 
n_tfs, 

112 
sequence_length=sequence_length, 

113 
p=None, 

114 
n_per_original=n_per_original, 

115 
keep_original=keep_original, 

116 
)

Read our documentation on viewing source code .