1 2
from typing import List
2

3 2
import pandas as pd
4 2
from tqdm import tqdm
5

6 2
from .core import BaseTFApplier
7

8

9 2
class PandasTFApplier(BaseTFApplier):
10
    """TF applier for a Pandas DataFrame.
11

12
    Data points are stored as Series in a DataFrame. The TFs
13
    run on data points obtained via a ``pandas.DataFrame.iterrows``
14
    call, which is single-process and can be slow for large DataFrames.
15
    For large datasets, consider ``DaskTFApplier`` or ``SparkTFApplier``.
16
    """
17

18 2
    def apply_generator(self, df: pd.DataFrame, batch_size: int) -> pd.DataFrame:
19
        """Augment a Pandas DataFrame of data points using TFs and policy in batches.
20

21
        This method acts as a generator, yielding augmented data points for
22
        a given input batch of data points. This can be useful in a training
23
        loop when it is too memory-intensive to pregenerate all transformed
24
        examples.
25

26
        Parameters
27
        ----------
28
        df
29
            Pandas DataFrame containing data points to be transformed
30
        batch_size
31
            Batch size for generator. Yields augmented data points
32
            for the next ``batch_size`` input data points.
33

34
        Returns
35
        -------
36
        pd.DataFrame
37
            Pandas DataFrame of data points in augmented data set
38
        """
39 2
        batch_transformed: List[pd.Series] = []
40 2
        for i, (_, x) in enumerate(df.iterrows()):
41 2
            batch_transformed.extend(self._apply_policy_to_data_point(x))
42 2
            if (i + 1) % batch_size == 0:
43 2
                yield pd.concat(batch_transformed, axis=1).T.infer_objects()
44 2
                batch_transformed = []
45 2
        yield pd.concat(batch_transformed, axis=1).T.infer_objects()
46

47 2
    def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> pd.DataFrame:
48
        """Augment a Pandas DataFrame of data points using TFs and policy.
49

50
        Parameters
51
        ----------
52
        df
53
            Pandas DataFrame containing data points to be transformed
54
        progress_bar
55
            Display a progress bar?
56

57
        Returns
58
        -------
59
        pd.DataFrame
60
            Pandas DataFrame of data points in augmented data set
61
        """
62 2
        x_transformed: List[pd.Series] = []
63 2
        for _, x in tqdm(df.iterrows(), total=len(df), disable=(not progress_bar)):
64 2
            x_transformed.extend(self._apply_policy_to_data_point(x))
65 2
        return pd.concat(x_transformed, axis=1).T.infer_objects()

Read our documentation on viewing source code .

Loading