fepegar / torchio
1 3
import torch.nn.functional as F  # noqa: N812
2

3 3
from .label_transform import LabelTransform
4

5

6 3
class OneHot(LabelTransform):
7
    r"""Reencode label maps using one-hot encoding.
8

9
    Args:
10
        num_classes: See :func:`~torch.nn.functional.one_hot`.
11
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
12
            keyword arguments.
13
    """
14 3
    def __init__(self, num_classes: int = -1, **kwargs):
15 0
        super().__init__(**kwargs)
16 0
        self.num_classes = num_classes
17

18 3
    def apply_transform(self, subject):
19 0
        for image in self.get_images(subject):
20 0
            if image.num_channels > 1:
21 0
                message = (
22
                    'The number of input channels must be 1,'
23
                    f' but it is {image.num_channels}'
24
                )
25 0
                raise RuntimeError(message)
26 0
            data = image.data[0]
27 0
            num_classes = -1 if self.num_classes is None else self.num_classes
28 0
            one_hot = F.one_hot(data.long(), num_classes=num_classes)
29 0
            image.set_data(one_hot.permute(3, 0, 1, 2).type(data.type()))
30 0
        return subject

Read our documentation on viewing source code .

Loading