fepegar / torchio
1
#!/usr/bin/env python
2

3 6
from copy import copy
4 6
import torchio as tio
5 6
from ...utils import TorchioTestCase
6

7

8 6
class TestGridSampler(TorchioTestCase):
9
    """Tests for `GridSampler`."""
10

11 6
    def test_locations(self):
12 6
        patch_size = 5, 20, 20
13 6
        patch_overlap = 2, 4, 6
14 6
        sampler = tio.GridSampler(
15
            subject=self.sample_subject,
16
            patch_size=patch_size,
17
            patch_overlap=patch_overlap,
18
        )
19 6
        fixture = [
20
            [0, 0, 0, 5, 20, 20],
21
            [0, 0, 10, 5, 20, 30],
22
            [3, 0, 0, 8, 20, 20],
23
            [3, 0, 10, 8, 20, 30],
24
            [5, 0, 0, 10, 20, 20],
25
            [5, 0, 10, 10, 20, 30],
26
        ]
27 6
        locations = sampler.locations.tolist()
28 6
        self.assertEqual(locations, fixture)
29

30 6
    def test_large_patch(self):
31 6
        with self.assertRaises(ValueError):
32 6
            tio.GridSampler(self.sample_subject, (5, 21, 5), (0, 2, 0))
33

34 6
    def test_large_overlap(self):
35 6
        with self.assertRaises(ValueError):
36 6
            tio.GridSampler(self.sample_subject, (5, 20, 5), (2, 4, 6))
37

38 6
    def test_odd_overlap(self):
39 6
        with self.assertRaises(ValueError):
40 6
            tio.GridSampler(self.sample_subject, (5, 20, 5), (2, 4, 3))
41

42 6
    def test_single_location(self):
43 6
        sampler = tio.GridSampler(self.sample_subject, (10, 20, 30), 0)
44 6
        fixture = [[0, 0, 0, 10, 20, 30]]
45 6
        self.assertEqual(sampler.locations.tolist(), fixture)
46

47 6
    def test_subject_shape(self):
48 6
        patch_size = 5, 20, 20
49 6
        patch_overlap = 2, 4, 6
50 6
        initial_shape = copy(self.sample_subject.shape)
51 6
        tio.GridSampler(
52
            self.sample_subject,
53
            patch_size,
54
            patch_overlap,
55
            padding_mode='reflect',
56
        )
57 6
        final_shape = self.sample_subject.shape
58 6
        self.assertEqual(initial_shape, final_shape)
59

60 6
    def test_bad_subject(self):
61 6
        with self.assertRaises(ValueError):
62 6
            patch_size = 88
63 6
            tio.GridSampler(patch_size)

Read our documentation on viewing source code .

Loading