fepegar / torchio

Compare 68ffea7 ... +1 ... c3c0ed0

Coverage Reach
torchio/transforms/augmentation/intensity/random_labels_to_image.py torchio/transforms/augmentation/intensity/random_motion.py torchio/transforms/augmentation/intensity/random_ghosting.py torchio/transforms/augmentation/intensity/random_swap.py torchio/transforms/augmentation/intensity/random_bias_field.py torchio/transforms/augmentation/intensity/random_spike.py torchio/transforms/augmentation/intensity/random_gamma.py torchio/transforms/augmentation/intensity/random_blur.py torchio/transforms/augmentation/intensity/random_noise.py torchio/transforms/augmentation/intensity/__init__.py torchio/transforms/augmentation/spatial/random_affine.py torchio/transforms/augmentation/spatial/random_elastic_deformation.py torchio/transforms/augmentation/spatial/random_flip.py torchio/transforms/augmentation/spatial/random_anisotropy.py torchio/transforms/augmentation/spatial/__init__.py torchio/transforms/augmentation/composition.py torchio/transforms/augmentation/random_transform.py torchio/transforms/augmentation/__init__.py torchio/transforms/preprocessing/spatial/resample.py torchio/transforms/preprocessing/spatial/crop_or_pad.py torchio/transforms/preprocessing/spatial/pad.py torchio/transforms/preprocessing/spatial/crop.py torchio/transforms/preprocessing/spatial/ensure_shape_multiple.py torchio/transforms/preprocessing/spatial/to_canonical.py torchio/transforms/preprocessing/spatial/bounds_transform.py torchio/transforms/preprocessing/intensity/histogram_standardization.py torchio/transforms/preprocessing/intensity/rescale.py torchio/transforms/preprocessing/intensity/z_normalization.py torchio/transforms/preprocessing/intensity/normalization_transform.py torchio/transforms/preprocessing/intensity/__init__.py torchio/transforms/preprocessing/label/remap_labels.py torchio/transforms/preprocessing/label/one_hot.py torchio/transforms/preprocessing/label/sequential_labels.py torchio/transforms/preprocessing/label/keep_largest_component.py torchio/transforms/preprocessing/label/remove_labels.py torchio/transforms/preprocessing/label/contour.py torchio/transforms/preprocessing/label/label_transform.py torchio/transforms/preprocessing/__init__.py torchio/transforms/transform.py torchio/transforms/data_parser.py torchio/transforms/__init__.py torchio/transforms/lambda_transform.py torchio/transforms/interpolation.py torchio/transforms/intensity_transform.py torchio/transforms/fourier.py torchio/transforms/spatial_transform.py torchio/data/sampler/weighted.py torchio/data/sampler/grid.py torchio/data/sampler/label.py torchio/data/sampler/sampler.py torchio/data/sampler/uniform.py torchio/data/sampler/__init__.py torchio/data/image.py torchio/data/io.py torchio/data/subject.py torchio/data/queue.py torchio/data/inference/aggregator.py torchio/data/inference/__init__.py torchio/data/dataset.py torchio/data/__init__.py torchio/datasets/mni/icbm.py torchio/datasets/mni/colin.py torchio/datasets/mni/pediatric.py torchio/datasets/mni/sheep.py torchio/datasets/mni/__init__.py torchio/datasets/mni/mni.py torchio/datasets/ixi.py torchio/datasets/episurg.py torchio/datasets/bite.py torchio/datasets/itk_snap/itk_snap.py torchio/datasets/itk_snap/__init__.py torchio/datasets/fpg.py torchio/datasets/slicer.py torchio/datasets/__init__.py torchio/utils.py torchio/visualization.py torchio/download.py torchio/cli/apply_transform.py torchio/cli/print_info.py torchio/typing.py torchio/constants.py torchio/__init__.py torchio/reference.py tests/transforms/augmentation/test_random_labels_to_image.py tests/transforms/augmentation/test_random_affine.py tests/transforms/augmentation/test_random_ghosting.py tests/transforms/augmentation/test_random_elastic_deformation.py tests/transforms/augmentation/test_random_motion.py tests/transforms/augmentation/test_random_spike.py tests/transforms/augmentation/test_random_gamma.py tests/transforms/augmentation/test_random_noise.py tests/transforms/augmentation/test_random_blur.py tests/transforms/augmentation/test_random_anisotropy.py tests/transforms/augmentation/test_random_flip.py tests/transforms/augmentation/test_random_bias_field.py tests/transforms/augmentation/test_oneof.py tests/transforms/augmentation/test_random_swap.py tests/transforms/preprocessing/test_crop_pad.py tests/transforms/preprocessing/test_rescale.py tests/transforms/preprocessing/test_resample.py tests/transforms/preprocessing/test_histogram_standardization.py tests/transforms/preprocessing/test_ensure_shape_multiple.py tests/transforms/preprocessing/test_pad.py tests/transforms/preprocessing/test_z_normalization.py tests/transforms/preprocessing/test_to_canonical.py tests/transforms/preprocessing/test_crop.py tests/transforms/test_transforms.py tests/transforms/test_invertibility.py tests/transforms/label/test_remove_labels.py tests/transforms/label/test_sequential_labels.py tests/transforms/label/test_remap_labels.py tests/transforms/test_lambda_transform.py tests/transforms/test_collate.py tests/transforms/test_reproducibility.py tests/data/test_image.py tests/data/test_io.py tests/data/inference/test_aggregator.py tests/data/inference/test_grid_sampler.py tests/data/inference/test_inference.py tests/data/sampler/test_label_sampler.py tests/data/sampler/test_weighted_sampler.py tests/data/sampler/test_uniform_sampler.py tests/data/sampler/test_patch_sampler.py tests/data/sampler/test_random_sampler.py tests/data/test_subject.py tests/data/test_subjects_dataset.py tests/data/test_queue.py tests/utils.py tests/test_utils.py tests/test_cli.py tests/datasets/test_ixi.py print_system.py

No flags found

Use flags to group coverage reports by test type, project and/or folders.
Then setup custom commit statuses and notifications for each flag.

e.g., #unittest #integration

#production #enterprise

#frontend #backend

Learn more about Codecov Flags here.


@@ -1,7 +1,7 @@
Loading
1 1
#!/usr/bin/env python
2 2
3 3
from copy import copy
4 -
from torchio.data import GridSampler
4 +
import torchio as tio
5 5
from ...utils import TorchioTestCase
6 6
7 7
@@ -11,7 +11,11 @@
Loading
11 11
    def test_locations(self):
12 12
        patch_size = 5, 20, 20
13 13
        patch_overlap = 2, 4, 6
14 -
        sampler = GridSampler(self.sample_subject, patch_size, patch_overlap)
14 +
        sampler = tio.GridSampler(
15 +
            subject=self.sample_subject,
16 +
            patch_size=patch_size,
17 +
            patch_overlap=patch_overlap,
18 +
        )
15 19
        fixture = [
16 20
            [0, 0, 0, 5, 20, 20],
17 21
            [0, 0, 10, 5, 20, 30],
@@ -25,30 +29,35 @@
Loading
25 29
26 30
    def test_large_patch(self):
27 31
        with self.assertRaises(ValueError):
28 -
            GridSampler(self.sample_subject, (5, 21, 5), (0, 2, 0))
32 +
            tio.GridSampler(self.sample_subject, (5, 21, 5), (0, 2, 0))
29 33
30 34
    def test_large_overlap(self):
31 35
        with self.assertRaises(ValueError):
32 -
            GridSampler(self.sample_subject, (5, 20, 5), (2, 4, 6))
36 +
            tio.GridSampler(self.sample_subject, (5, 20, 5), (2, 4, 6))
33 37
34 38
    def test_odd_overlap(self):
35 39
        with self.assertRaises(ValueError):
36 -
            GridSampler(self.sample_subject, (5, 20, 5), (2, 4, 3))
40 +
            tio.GridSampler(self.sample_subject, (5, 20, 5), (2, 4, 3))
37 41
38 42
    def test_single_location(self):
39 -
        sampler = GridSampler(self.sample_subject, (10, 20, 30), 0)
43 +
        sampler = tio.GridSampler(self.sample_subject, (10, 20, 30), 0)
40 44
        fixture = [[0, 0, 0, 10, 20, 30]]
41 45
        self.assertEqual(sampler.locations.tolist(), fixture)
42 46
43 47
    def test_subject_shape(self):
44 48
        patch_size = 5, 20, 20
45 49
        patch_overlap = 2, 4, 6
46 50
        initial_shape = copy(self.sample_subject.shape)
47 -
        GridSampler(
51 +
        tio.GridSampler(
48 52
            self.sample_subject,
49 53
            patch_size,
50 54
            patch_overlap,
51 55
            padding_mode='reflect',
52 56
        )
53 57
        final_shape = self.sample_subject.shape
54 58
        self.assertEqual(initial_shape, final_shape)
59 +
60 +
    def test_bad_subject(self):
61 +
        with self.assertRaises(ValueError):
62 +
            patch_size = 88
63 +
            tio.GridSampler(patch_size)

@@ -1,28 +1,33 @@
Loading
1 -
from typing import Union
1 +
from typing import Union, Generator, Optional
2 2
3 3
import numpy as np
4 -
from torch.utils.data import Dataset
5 4
6 5
from ...utils import to_tuple
7 6
from ...constants import LOCATION
8 -
from ...typing import TypeTuple, TypeTripletInt
9 -
from ..subject import Subject
10 -
from ..sampler.sampler import PatchSampler
7 +
from ...data.subject import Subject
8 +
from ...typing import TypePatchSize
9 +
from ...typing import TypeTripletInt
10 +
from .sampler import PatchSampler
11 11
12 12
13 -
class GridSampler(PatchSampler, Dataset):
13 +
class GridSampler(PatchSampler):
14 14
    r"""Extract patches across a whole volume.
15 15
16 16
    Grid samplers are useful to perform inference using all patches from a
17 17
    volume. It is often used with a :class:`~torchio.data.GridAggregator`.
18 18
19 19
    Args:
20 20
        subject: Instance of :class:`~torchio.data.Subject`
21 -
            from which patches will be extracted.
21 +
            from which patches will be extracted. This argument should only be
22 +
            used before instantiating a :class:`~torchio.data.GridAggregator`,
23 +
            or to precompute the number of patches that would be generated from
24 +
            a subject.
22 25
        patch_size: Tuple of integers :math:`(w, h, d)` to generate patches
23 26
            of size :math:`w \times h \times d`.
24 27
            If a single number :math:`n` is provided,
25 28
            :math:`w = h = d = n`.
29 +
            This argument is mandatory (it is a keyword argument for backward
30 +
            compatibility).
26 31
        patch_overlap: Tuple of even integers :math:`(w_o, h_o, d_o)`
27 32
            specifying the overlap between patches for dense inference. If a
28 33
            single number :math:`n` is provided, :math:`w_o = h_o = d_o = n`.
@@ -36,6 +41,19 @@
Loading
36 41
            :class:`~torchio.data.GridAggregator`, it will crop the output
37 42
            to its original size.
38 43
44 +
    Example::
45 +
46 +
        >>> import torchio as tio
47 +
        >>> sampler = tio.GridSampler(patch_size=88)
48 +
        >>> colin = tio.datasets.Colin27()
49 +
        >>> for i, patch in enumerate(sampler(colin)):
50 +
        ...     patch.t1.save(f'patch_{i}.nii.gz')
51 +
        ...
52 +
        >>> # To figure out the number of patches beforehand:
53 +
        >>> sampler = tio.GridSampler(subject=colin, patch_size=88)
54 +
        >>> len(sampler)
55 +
        8
56 +
39 57
    .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
40 58
        <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
41 59
        information about patch based sampling. Note that
@@ -44,24 +62,20 @@
Loading
44 62
    """
45 63
    def __init__(
46 64
            self,
47 -
            subject: Subject,
48 -
            patch_size: TypeTuple,
49 -
            patch_overlap: TypeTuple = (0, 0, 0),
65 +
            subject: Optional[Subject] = None,
66 +
            patch_size: TypePatchSize = None,
67 +
            patch_overlap: TypePatchSize = (0, 0, 0),
50 68
            padding_mode: Union[str, float, None] = None,
51 69
            ):
52 -
        self.subject = subject
70 +
        if patch_size is None:
71 +
            raise ValueError('A value for patch_size must be given')
72 +
        super().__init__(patch_size)
53 73
        self.patch_overlap = np.array(to_tuple(patch_overlap, length=3))
54 74
        self.padding_mode = padding_mode
55 -
        if padding_mode is not None:
56 -
            from ...transforms import Pad
57 -
            border = self.patch_overlap // 2
58 -
            padding = border.repeat(2)
59 -
            pad = Pad(padding, padding_mode=padding_mode)
60 -
            self.subject = pad(self.subject)
61 -
        PatchSampler.__init__(self, patch_size)
62 -
        sizes = self.subject.spatial_shape, self.patch_size, self.patch_overlap
63 -
        self.parse_sizes(*sizes)
64 -
        self.locations = self.get_patches_locations(*sizes)
75 +
        if subject is not None and not isinstance(subject, Subject):
76 +
            raise ValueError('The subject argument must be None or Subject')
77 +
        self.subject = self._pad(subject)
78 +
        self.locations = self._compute_locations(self.subject)
65 79
66 80
    def __len__(self):
67 81
        return len(self.locations)
@@ -74,8 +88,36 @@
Loading
74 88
        cropped_subject[LOCATION] = location
75 89
        return cropped_subject
76 90
91 +
    def _pad(self, subject: Subject) -> Subject:
92 +
        if self.padding_mode is not None:
93 +
            from ...transforms import Pad
94 +
            border = self.patch_overlap // 2
95 +
            padding = border.repeat(2)
96 +
            pad = Pad(padding, padding_mode=self.padding_mode)
97 +
            subject = pad(subject)
98 +
        return subject
99 +
100 +
    def _compute_locations(self, subject: Subject):
101 +
        if subject is None:
102 +
            return None
103 +
        sizes = subject.spatial_shape, self.patch_size, self.patch_overlap
104 +
        self._parse_sizes(*sizes)
105 +
        return self._get_patches_locations(*sizes)
106 +
107 +
    def _generate_patches(
108 +
            self,
109 +
            subject: Subject,
110 +
            ) -> Generator[Subject, None, None]:
111 +
        subject = self._pad(subject)
112 +
        sizes = subject.spatial_shape, self.patch_size, self.patch_overlap
113 +
        self._parse_sizes(*sizes)
114 +
        locations = self._get_patches_locations(*sizes)
115 +
        for location in locations:
116 +
            index_ini = location[:3]
117 +
            yield self.extract_patch(subject, index_ini)
118 +
77 119
    @staticmethod
78 -
    def parse_sizes(
120 +
    def _parse_sizes(
79 121
            image_size: TypeTripletInt,
80 122
            patch_size: TypeTripletInt,
81 123
            patch_overlap: TypeTripletInt,
@@ -103,7 +145,7 @@
Loading
103 145
            raise ValueError(message)
104 146
105 147
    @staticmethod
106 -
    def get_patches_locations(
148 +
    def _get_patches_locations(
107 149
            image_size: TypeTripletInt,
108 150
            patch_size: TypeTripletInt,
109 151
            patch_overlap: TypeTripletInt,

@@ -1,6 +1,5 @@
Loading
1 1
import torch
2 2
from ...data.subject import Subject
3 -
from ...typing import TypePatchSize
4 3
from .sampler import RandomSampler
5 4
from typing import Generator
6 5
import numpy as np
@@ -16,20 +15,11 @@
Loading
16 15
    def get_probability_map(self, subject: Subject) -> torch.Tensor:
17 16
        return torch.ones(1, *subject.spatial_shape)
18 17
19 -
    def __call__(
18 +
    def _generate_patches(
20 19
            self,
21 20
            subject: Subject,
22 21
            num_patches: int = None,
23 22
            ) -> Generator[Subject, None, None]:
24 -
        subject.check_consistent_spatial_shape()
25 -
26 -
        if np.any(self.patch_size > subject.spatial_shape):
27 -
            message = (
28 -
                f'Patch size {tuple(self.patch_size)} cannot be'
29 -
                f' larger than image size {tuple(subject.spatial_shape)}'
30 -
            )
31 -
            raise RuntimeError(message)
32 -
33 23
        valid_range = subject.spatial_shape - self.patch_size
34 24
        patches_left = num_patches if num_patches is not None else True
35 25
        while patches_left:

@@ -1,9 +1,11 @@
Loading
1 +
from .grid import GridSampler
1 2
from .label import LabelSampler
2 3
from .uniform import UniformSampler
3 4
from .weighted import WeightedSampler
4 5
from .sampler import PatchSampler, RandomSampler
5 6
6 7
__all__ = [
8 +
    'GridSampler',
7 9
    'LabelSampler',
8 10
    'UniformSampler',
9 11
    'WeightedSampler',

@@ -2,8 +2,9 @@
Loading
2 2
from .subject import Subject
3 3
from .dataset import SubjectsDataset
4 4
from .image import Image, ScalarImage, LabelMap
5 -
from .inference import GridSampler, GridAggregator
5 +
from .inference import GridAggregator
6 6
from .sampler import (
7 +
    GridSampler,
7 8
    PatchSampler,
8 9
    LabelSampler,
9 10
    WeightedSampler,

Click to load this diff.
Loading diff...

Click to load this diff.
Loading diff...

Click to load this diff.
Loading diff...

Click to load this diff.
Loading diff...

Learn more Showing 2 files with coverage changes found.

Changes in torchio/utils.py
-1
+1
Loading file...
Changes in torchio/transforms/transform.py
-6
+6
Loading file...
Files Coverage
tests +<.01% 99.76%
torchio 0.06% 86.47%
print_system.py 0.00%
Project Totals (132 files) 90.48%
Loading