1

6

from typing import Union, Generator, Optional

2



3

6

import numpy as np

4



5

6

from ...utils import to_tuple

6

6

from ...constants import LOCATION

7

6

from ...data.subject import Subject

8

6

from ...typing import TypePatchSize

9

6

from ...typing import TypeTripletInt

10

6

from .sampler import PatchSampler

11



12



13

6

class GridSampler(PatchSampler):

14


r"""Extract patches across a whole volume.

15



16


Grid samplers are useful to perform inference using all patches from a

17


volume. It is often used with a :class:`~torchio.data.GridAggregator`.

18



19


Args:

20


subject: Instance of :class:`~torchio.data.Subject`

21


from which patches will be extracted. This argument should only be

22


used before instantiating a :class:`~torchio.data.GridAggregator`,

23


or to precompute the number of patches that would be generated from

24


a subject.

25


patch_size: Tuple of integers :math:`(w, h, d)` to generate patches

26


of size :math:`w \times h \times d`.

27


If a single number :math:`n` is provided,

28


:math:`w = h = d = n`.

29


This argument is mandatory (it is a keyword argument for backward

30


compatibility).

31


patch_overlap: Tuple of even integers :math:`(w_o, h_o, d_o)`

32


specifying the overlap between patches for dense inference. If a

33


single number :math:`n` is provided, :math:`w_o = h_o = d_o = n`.

34


padding_mode: Same as :attr:`padding_mode` in

35


:class:`~torchio.transforms.Pad`. If ``None``, the volume will not

36


be padded before sampling and patches at the border will not be

37


cropped by the aggregator.

38


Otherwise, the volume will be padded with

39


:math:`\left(\frac{w_o}{2}, \frac{h_o}{2}, \frac{d_o}{2} \right)`

40


on each side before sampling. If the sampler is passed to a

41


:class:`~torchio.data.GridAggregator`, it will crop the output

42


to its original size.

43



44


Example::

45



46


>>> import torchio as tio

47


>>> sampler = tio.GridSampler(patch_size=88)

48


>>> colin = tio.datasets.Colin27()

49


>>> for i, patch in enumerate(sampler(colin)):

50


... patch.t1.save(f'patch_{i}.nii.gz')

51


...

52


>>> # To figure out the number of patches beforehand:

53


>>> sampler = tio.GridSampler(subject=colin, patch_size=88)

54


>>> len(sampler)

55


8

56



57


.. note:: Adapted from NiftyNet. See `this NiftyNet tutorial

58


<https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more

59


information about patch based sampling. Note that

60


:attr:`patch_overlap` is twice :attr:`border` in NiftyNet

61


tutorial.

62


"""

63

6

def __init__(

64


self,

65


subject: Optional[Subject] = None,

66


patch_size: TypePatchSize = None,

67


patch_overlap: TypePatchSize = (0, 0, 0),

68


padding_mode: Union[str, float, None] = None,

69


):

70

6

if patch_size is None:

71

6

raise ValueError('A value for patch_size must be given')

72

6

super().__init__(patch_size)

73

6

self.patch_overlap = np.array(to_tuple(patch_overlap, length=3))

74

6

self.padding_mode = padding_mode

75

6

if subject is not None and not isinstance(subject, Subject):

76

0

raise ValueError('The subject argument must be None or Subject')

77

6

self.subject = self._pad(subject)

78

6

self.locations = self._compute_locations(self.subject)

79



80

6

def __len__(self):

81

6

return len(self.locations)

82



83

6

def __getitem__(self, index):

84


# Assume 3D

85

6

location = self.locations[index]

86

6

index_ini = location[:3]

87

6

cropped_subject = self.crop(self.subject, index_ini, self.patch_size)

88

6

cropped_subject[LOCATION] = location

89

6

return cropped_subject

90



91

6

def _pad(self, subject: Subject) > Subject:

92

6

if self.padding_mode is not None:

93

6

from ...transforms import Pad

94

6

border = self.patch_overlap // 2

95

6

padding = border.repeat(2)

96

6

pad = Pad(padding, padding_mode=self.padding_mode)

97

6

subject = pad(subject)

98

6

return subject

99



100

6

def _compute_locations(self, subject: Subject):

101

6

if subject is None:

102

0

return None

103

6

sizes = subject.spatial_shape, self.patch_size, self.patch_overlap

104

6

self._parse_sizes(*sizes)

105

6

return self._get_patches_locations(*sizes)

106



107

6

def _generate_patches(

108


self,

109


subject: Subject,

110


) > Generator[Subject, None, None]:

111

0

subject = self._pad(subject)

112

0

sizes = subject.spatial_shape, self.patch_size, self.patch_overlap

113

0

self._parse_sizes(*sizes)

114

0

locations = self._get_patches_locations(*sizes)

115

0

for location in locations:

116

0

index_ini = location[:3]

117

0

yield self.extract_patch(subject, index_ini)

118



119

6

@staticmethod

120

6

def _parse_sizes(

121


image_size: TypeTripletInt,

122


patch_size: TypeTripletInt,

123


patch_overlap: TypeTripletInt,

124


) > None:

125

6

image_size = np.array(image_size)

126

6

patch_size = np.array(patch_size)

127

6

patch_overlap = np.array(patch_overlap)

128

6

if np.any(patch_size > image_size):

129

6

message = (

130


f'Patch size {tuple(patch_size)} cannot be'

131


f' larger than image size {tuple(image_size)}'

132


)

133

6

raise ValueError(message)

134

6

if np.any(patch_overlap >= patch_size):

135

6

message = (

136


f'Patch overlap {tuple(patch_overlap)} must be smaller'

137


f' than patch size {tuple(patch_size)}'

138


)

139

6

raise ValueError(message)

140

6

if np.any(patch_overlap % 2):

141

6

message = (

142


'Patch overlap must be a tuple of even integers,'

143


f' not {tuple(patch_overlap)}'

144


)

145

6

raise ValueError(message)

146



147

6

@staticmethod

148

6

def _get_patches_locations(

149


image_size: TypeTripletInt,

150


patch_size: TypeTripletInt,

151


patch_overlap: TypeTripletInt,

152


) > np.ndarray:

153


# Example with image_size 10, patch_size 5, overlap 2:

154


# [0 1 2 3 4 5 6 7 8 9]

155


# [0 0 0 0 0]

156


# [1 1 1 1 1]

157


# [2 2 2 2 2]

158


# Locations:

159


# [[0, 5],

160


# [3, 8],

161


# [5, 10]]

162

6

indices = []

163

6

zipped = zip(image_size, patch_size, patch_overlap)

164

6

for im_size_dim, patch_size_dim, patch_overlap_dim in zipped:

165

6

end = im_size_dim + 1  patch_size_dim

166

6

step = patch_size_dim  patch_overlap_dim

167

6

indices_dim = list(range(0, end, step))

168

6

if indices_dim[1] != im_size_dim  patch_size_dim:

169

6

indices_dim.append(im_size_dim  patch_size_dim)

170

6

indices.append(indices_dim)

171

6

indices_ini = np.array(np.meshgrid(*indices)).reshape(3, 1).T

172

6

indices_ini = np.unique(indices_ini, axis=0)

173

6

indices_fin = indices_ini + np.array(patch_size)

174

6

locations = np.hstack((indices_ini, indices_fin))

175

6

return np.array(sorted(locations.tolist()))
