1 2
from typing import Any
2

3 2
from torch.utils.tensorboard import SummaryWriter
4

5 2
from snorkel.types import Config
6

7 2
from .log_writer import LogWriter
8

9

10 2
class TensorBoardWriter(LogWriter):
11
    """A class for logging to Tensorboard during training process.
12

13
    See ``LogWriter`` for more attributes.
14

15
    Parameters
16
    ----------
17
    kwargs
18
        Passed to ``LogWriter`` initializer
19

20
    Attributes
21
    ----------
22
    writer
23
        ``SummaryWriter`` for logging and visualization
24
    """
25

26 2
    def __init__(self, **kwargs: Any) -> None:
27 2
        super().__init__(**kwargs)
28 2
        self.writer = SummaryWriter(self.log_dir)
29

30 2
    def add_scalar(self, name: str, value: float, step: float) -> None:
31
        """Log a scalar variable to TensorBoard.
32

33
        Parameters
34
        ----------
35
        name
36
            Name of the scalar collection
37
        value
38
            Value of scalar
39
        step
40
            Step axis value
41
        """
42 2
        self.writer.add_scalar(name, value, step)
43

44 2
    def write_config(
45
        self, config: Config, config_filename: str = "config.json"
46
    ) -> None:
47
        """Dump the config to file and add it to TensorBoard.
48

49
        Parameters
50
        ----------
51
        config
52
            JSON-compatible config to write to TensorBoard
53
        config_filename
54
            File to write config to
55
        """
56 2
        super().write_config(config, config_filename)
57 2
        self.writer.add_text(tag="config", text_string=str(config))
58

59 2
    def cleanup(self) -> None:
60
        """Close the ``SummaryWriter``."""
61 2
        self.writer.close()

Read our documentation on viewing source code .

Loading