fepegar / torchio
1 6
from typing import Union, Generator, Optional
2

3 6
import numpy as np
4

5 6
from ...utils import to_tuple
6 6
from ...constants import LOCATION
7 6
from ...data.subject import Subject
8 6
from ...typing import TypePatchSize
9 6
from ...typing import TypeTripletInt
10 6
from .sampler import PatchSampler
11

12

13 6
class GridSampler(PatchSampler):
14
    r"""Extract patches across a whole volume.
15

16
    Grid samplers are useful to perform inference using all patches from a
17
    volume. It is often used with a :class:`~torchio.data.GridAggregator`.
18

19
    Args:
20
        subject: Instance of :class:`~torchio.data.Subject`
21
            from which patches will be extracted. This argument should only be
22
            used before instantiating a :class:`~torchio.data.GridAggregator`,
23
            or to precompute the number of patches that would be generated from
24
            a subject.
25
        patch_size: Tuple of integers :math:`(w, h, d)` to generate patches
26
            of size :math:`w \times h \times d`.
27
            If a single number :math:`n` is provided,
28
            :math:`w = h = d = n`.
29
            This argument is mandatory (it is a keyword argument for backward
30
            compatibility).
31
        patch_overlap: Tuple of even integers :math:`(w_o, h_o, d_o)`
32
            specifying the overlap between patches for dense inference. If a
33
            single number :math:`n` is provided, :math:`w_o = h_o = d_o = n`.
34
        padding_mode: Same as :attr:`padding_mode` in
35
            :class:`~torchio.transforms.Pad`. If ``None``, the volume will not
36
            be padded before sampling and patches at the border will not be
37
            cropped by the aggregator.
38
            Otherwise, the volume will be padded with
39
            :math:`\left(\frac{w_o}{2}, \frac{h_o}{2}, \frac{d_o}{2} \right)`
40
            on each side before sampling. If the sampler is passed to a
41
            :class:`~torchio.data.GridAggregator`, it will crop the output
42
            to its original size.
43

44
    Example::
45

46
        >>> import torchio as tio
47
        >>> sampler = tio.GridSampler(patch_size=88)
48
        >>> colin = tio.datasets.Colin27()
49
        >>> for i, patch in enumerate(sampler(colin)):
50
        ...     patch.t1.save(f'patch_{i}.nii.gz')
51
        ...
52
        >>> # To figure out the number of patches beforehand:
53
        >>> sampler = tio.GridSampler(subject=colin, patch_size=88)
54
        >>> len(sampler)
55
        8
56

57
    .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
58
        <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
59
        information about patch based sampling. Note that
60
        :attr:`patch_overlap` is twice :attr:`border` in NiftyNet
61
        tutorial.
62
    """
63 6
    def __init__(
64
            self,
65
            subject: Optional[Subject] = None,
66
            patch_size: TypePatchSize = None,
67
            patch_overlap: TypePatchSize = (0, 0, 0),
68
            padding_mode: Union[str, float, None] = None,
69
            ):
70 6
        if patch_size is None:
71 6
            raise ValueError('A value for patch_size must be given')
72 6
        super().__init__(patch_size)
73 6
        self.patch_overlap = np.array(to_tuple(patch_overlap, length=3))
74 6
        self.padding_mode = padding_mode
75 6
        if subject is not None and not isinstance(subject, Subject):
76 0
            raise ValueError('The subject argument must be None or Subject')
77 6
        self.subject = self._pad(subject)
78 6
        self.locations = self._compute_locations(self.subject)
79

80 6
    def __len__(self):
81 6
        return len(self.locations)
82

83 6
    def __getitem__(self, index):
84
        # Assume 3D
85 6
        location = self.locations[index]
86 6
        index_ini = location[:3]
87 6
        cropped_subject = self.crop(self.subject, index_ini, self.patch_size)
88 6
        cropped_subject[LOCATION] = location
89 6
        return cropped_subject
90

91 6
    def _pad(self, subject: Subject) -> Subject:
92 6
        if self.padding_mode is not None:
93 6
            from ...transforms import Pad
94 6
            border = self.patch_overlap // 2
95 6
            padding = border.repeat(2)
96 6
            pad = Pad(padding, padding_mode=self.padding_mode)
97 6
            subject = pad(subject)
98 6
        return subject
99

100 6
    def _compute_locations(self, subject: Subject):
101 6
        if subject is None:
102 0
            return None
103 6
        sizes = subject.spatial_shape, self.patch_size, self.patch_overlap
104 6
        self._parse_sizes(*sizes)
105 6
        return self._get_patches_locations(*sizes)
106

107 6
    def _generate_patches(
108
            self,
109
            subject: Subject,
110
            ) -> Generator[Subject, None, None]:
111 0
        subject = self._pad(subject)
112 0
        sizes = subject.spatial_shape, self.patch_size, self.patch_overlap
113 0
        self._parse_sizes(*sizes)
114 0
        locations = self._get_patches_locations(*sizes)
115 0
        for location in locations:
116 0
            index_ini = location[:3]
117 0
            yield self.extract_patch(subject, index_ini)
118

119 6
    @staticmethod
120 6
    def _parse_sizes(
121
            image_size: TypeTripletInt,
122
            patch_size: TypeTripletInt,
123
            patch_overlap: TypeTripletInt,
124
            ) -> None:
125 6
        image_size = np.array(image_size)
126 6
        patch_size = np.array(patch_size)
127 6
        patch_overlap = np.array(patch_overlap)
128 6
        if np.any(patch_size > image_size):
129 6
            message = (
130
                f'Patch size {tuple(patch_size)} cannot be'
131
                f' larger than image size {tuple(image_size)}'
132
            )
133 6
            raise ValueError(message)
134 6
        if np.any(patch_overlap >= patch_size):
135 6
            message = (
136
                f'Patch overlap {tuple(patch_overlap)} must be smaller'
137
                f' than patch size {tuple(patch_size)}'
138
            )
139 6
            raise ValueError(message)
140 6
        if np.any(patch_overlap % 2):
141 6
            message = (
142
                'Patch overlap must be a tuple of even integers,'
143
                f' not {tuple(patch_overlap)}'
144
            )
145 6
            raise ValueError(message)
146

147 6
    @staticmethod
148 6
    def _get_patches_locations(
149
            image_size: TypeTripletInt,
150
            patch_size: TypeTripletInt,
151
            patch_overlap: TypeTripletInt,
152
            ) -> np.ndarray:
153
        # Example with image_size 10, patch_size 5, overlap 2:
154
        # [0 1 2 3 4 5 6 7 8 9]
155
        # [0 0 0 0 0]
156
        #       [1 1 1 1 1]
157
        #           [2 2 2 2 2]
158
        # Locations:
159
        # [[0, 5],
160
        #  [3, 8],
161
        #  [5, 10]]
162 6
        indices = []
163 6
        zipped = zip(image_size, patch_size, patch_overlap)
164 6
        for im_size_dim, patch_size_dim, patch_overlap_dim in zipped:
165 6
            end = im_size_dim + 1 - patch_size_dim
166 6
            step = patch_size_dim - patch_overlap_dim
167 6
            indices_dim = list(range(0, end, step))
168 6
            if indices_dim[-1] != im_size_dim - patch_size_dim:
169 6
                indices_dim.append(im_size_dim - patch_size_dim)
170 6
            indices.append(indices_dim)
171 6
        indices_ini = np.array(np.meshgrid(*indices)).reshape(3, -1).T
172 6
        indices_ini = np.unique(indices_ini, axis=0)
173 6
        indices_fin = indices_ini + np.array(patch_size)
174 6
        locations = np.hstack((indices_ini, indices_fin))
175 6
        return np.array(sorted(locations.tolist()))

Read our documentation on viewing source code .

Loading