fepegar / torchio
1 3
import copy
2 3
import shutil
3 3
import random
4 3
import tempfile
5 3
import unittest
6 3
from pathlib import Path
7 3
from random import shuffle
8

9 3
import torch
10 3
import numpy as np
11 3
from numpy.testing import assert_array_equal, assert_array_almost_equal
12 3
import torchio as tio
13

14

15 3
class TorchioTestCase(unittest.TestCase):
16

17 3
    def setUp(self):
18
        """Set up test fixtures, if any."""
19 3
        self.dir = Path(tempfile.gettempdir()) / '.torchio_tests'
20 3
        self.dir.mkdir(exist_ok=True)
21 3
        random.seed(42)
22 3
        np.random.seed(42)
23

24 3
        registration_matrix = np.array([
25
            [1, 0, 0, 10],
26
            [0, 1, 0, 0],
27
            [0, 0, 1.2, 0],
28
            [0, 0, 0, 1]
29
        ])
30

31 3
        subject_a = tio.Subject(
32
            t1=tio.ScalarImage(self.get_image_path('t1_a')),
33
        )
34 3
        subject_b = tio.Subject(
35
            t1=tio.ScalarImage(self.get_image_path('t1_b')),
36
            label=tio.LabelMap(self.get_image_path('label_b', binary=True)),
37
        )
38 3
        subject_c = tio.Subject(
39
            label=tio.LabelMap(self.get_image_path('label_c', binary=True)),
40
        )
41 3
        subject_d = tio.Subject(
42
            t1=tio.ScalarImage(
43
                self.get_image_path('t1_d'),
44
                pre_affine=registration_matrix,
45
            ),
46
            t2=tio.ScalarImage(self.get_image_path('t2_d')),
47
            label=tio.LabelMap(self.get_image_path('label_d', binary=True)),
48
        )
49 3
        subject_a4 = tio.Subject(
50
            t1=tio.ScalarImage(self.get_image_path('t1_a'), components=2),
51
        )
52 3
        self.subjects_list = [
53
            subject_a,
54
            subject_a4,
55
            subject_b,
56
            subject_c,
57
            subject_d,
58
        ]
59 3
        self.dataset = tio.SubjectsDataset(self.subjects_list)
60 3
        self.sample_subject = self.dataset[-1]  # subject_d
61

62 3
    def make_2d(self, subject):
63 3
        subject = copy.deepcopy(subject)
64 3
        for image in subject.get_images(intensity_only=False):
65 3
            image.set_data(image.data[..., :1])
66 3
        return subject
67

68 3
    def make_multichannel(self, subject):
69 3
        subject = copy.deepcopy(subject)
70 3
        for image in subject.get_images(intensity_only=False):
71 3
            image.set_data(torch.cat(4 * (image.data,)))
72 3
        return subject
73

74 3
    def flip_affine_x(self, subject):
75 3
        subject = copy.deepcopy(subject)
76 3
        for image in subject.get_images(intensity_only=False):
77 3
            image.affine = np.diag((-1, 1, 1, 1)) @ image.affine
78 3
        return subject
79

80 3
    def get_inconsistent_shape_subject(self):
81
        """Return a subject containing images of different shape."""
82 3
        subject = tio.Subject(
83
            t1=tio.ScalarImage(self.get_image_path('t1_inc')),
84
            t2=tio.ScalarImage(
85
                self.get_image_path('t2_inc', shape=(10, 20, 31))),
86
            label=tio.LabelMap(
87
                self.get_image_path(
88
                    'label_inc',
89
                    shape=(8, 17, 25),
90
                    binary=True,
91
                ),
92
            ),
93
            label2=tio.LabelMap(
94
                self.get_image_path(
95
                    'label2_inc',
96
                    shape=(18, 17, 25),
97
                    binary=True,
98
                ),
99
            ),
100
        )
101 3
        return subject
102

103 3
    def get_reference_image_and_path(self):
104
        """Return a reference image and its path"""
105 3
        path = self.get_image_path(
106
            'ref',
107
            shape=(10, 20, 31),
108
            spacing=(1, 1, 2),
109
        )
110 3
        image = tio.ScalarImage(path)
111 3
        return image, path
112

113 3
    def get_subject_with_partial_volume_label_map(self, components=1):
114
        """Return a subject with a partial-volume label map."""
115 3
        return tio.Subject(
116
            t1=tio.ScalarImage(
117
                self.get_image_path('t1_d'),
118
            ),
119
            label=tio.LabelMap(
120
                self.get_image_path(
121
                    'label_d2', binary=False, components=components
122
                )
123
            ),
124
        )
125

126 3
    def get_subject_with_labels(self, labels):
127 3
        return tio.Subject(
128
            label=tio.LabelMap(
129
                self.get_image_path(
130
                    'label_multi', labels=labels
131
                )
132
            )
133
        )
134

135 3
    def get_unique_labels(self, label_map):
136 3
        labels = torch.unique(label_map.data)
137 3
        labels = {i.item() for i in labels if i != 0}
138 3
        return labels
139

140 3
    def tearDown(self):
141
        """Tear down test fixtures, if any."""
142 3
        shutil.rmtree(self.dir)
143

144 3
    def get_ixi_tiny(self):
145 0
        root_dir = Path(tempfile.gettempdir()) / 'torchio' / 'ixi_tiny'
146 0
        return tio.datasets.IXITiny(root_dir, download=True)
147

148 3
    def get_image_path(
149
            self,
150
            stem,
151
            binary=False,
152
            labels=None,
153
            shape=(10, 20, 30),
154
            spacing=(1, 1, 1),
155
            components=1,
156
            add_nans=False,
157
            suffix=None,
158
            force_binary_foreground=True,
159
            ):
160 3
        shape = (*shape, 1) if len(shape) == 2 else shape
161 3
        data = np.random.rand(components, *shape)
162 3
        if binary:
163 3
            data = (data > 0.5).astype(np.uint8)
164 3
            if not data.sum() and force_binary_foreground:
165 0
                data[..., 0] = 1
166 3
        elif labels is not None:
167 3
            data = (data * (len(labels) + 1)).astype(np.uint8)
168 3
            new_data = np.zeros_like(data)
169 3
            for i, label in enumerate(labels):
170 3
                new_data[data == (i + 1)] = label
171 3
                if not (new_data == label).sum():
172 0
                    new_data[..., i] = label
173 3
            data = new_data
174 3
        elif self.flip_coin():  # cast some images
175 3
            data *= 100
176 3
            dtype = np.uint8 if self.flip_coin() else np.uint16
177 3
            data = data.astype(dtype)
178 3
        if add_nans:
179 0
            data[:] = np.nan
180 3
        affine = np.diag((*spacing, 1))
181 3
        if suffix is None:
182 3
            extensions = '.nii.gz', '.nii', '.nrrd', '.img', '.mnc'
183 3
            suffix = random.choice(extensions)
184 3
        path = self.dir / f'{stem}{suffix}'
185 3
        if self.flip_coin():
186 3
            path = str(path)
187 3
        image = tio.ScalarImage(
188
            tensor=data,
189
            affine=affine,
190
            check_nans=not add_nans,
191
        )
192 3
        image.save(path)
193 3
        return path
194

195 3
    def flip_coin(self):
196 3
        return np.random.rand() > 0.5
197

198 3
    def get_tests_data_dir(self):
199 3
        return Path(__file__).parent / 'image_data'
200

201 3
    def assertTensorNotEqual(self, *args, **kwargs):  # noqa: N802
202 3
        message_kwarg = {'msg': args[2]} if len(args) == 3 else {}
203 3
        with self.assertRaises(AssertionError, **message_kwarg):
204 3
            self.assertTensorEqual(*args, **kwargs)
205

206 3
    @staticmethod
207 1
    def assertTensorEqual(*args, **kwargs):  # noqa: N802
208 3
        assert_array_equal(*args, **kwargs)
209

210 3
    @staticmethod
211 1
    def assertTensorAlmostEqual(*args, **kwargs):  # noqa: N802
212 3
        assert_array_almost_equal(*args, **kwargs)
213

214 3
    def get_large_composed_transform(self):
215 3
        all_classes = get_all_random_transforms()
216 3
        shuffle(all_classes)
217 3
        transforms = [t() for t in all_classes]
218
        # Hack as default patch size for RandomSwap is 15 and sample_subject
219
        # is (10, 20, 30)
220 3
        for tr in transforms:
221 3
            if tr.name == 'RandomSwap':
222 3
                tr.patch_size = np.array((10, 10, 10))
223 3
        return tio.Compose(transforms)
224

225

226 3
def get_all_random_transforms():
227 3
    transforms_names = [
228
        name
229
        for name in dir(tio.transforms)
230
        if name.startswith('Random')
231
    ]
232 3
    classes = [getattr(tio.transforms, name) for name in transforms_names]
233 3
    return classes

Read our documentation on viewing source code .

Loading