121 |
126 |
|
If a file object was used instead of a filename, this parameter should always be used. |
122 |
127 |
|
**kwargs: Other arguments are documented in ``make_grid``. |
123 |
128 |
|
""" |
124 |
|
- |
from PIL import Image |
125 |
129 |
|
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value, |
126 |
130 |
|
normalize=normalize, range=range, scale_each=scale_each) |
127 |
131 |
|
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer |
128 |
132 |
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() |
129 |
133 |
|
im = Image.fromarray(ndarr) |
130 |
134 |
|
im.save(fp, format=format) |
|
135 |
+ |
|
|
136 |
+ |
|
|
137 |
+ |
@torch.no_grad() |
|
138 |
+ |
def draw_bounding_boxes( |
|
139 |
+ |
image: torch.Tensor, |
|
140 |
+ |
boxes: torch.Tensor, |
|
141 |
+ |
labels: Optional[List[str]] = None, |
|
142 |
+ |
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, |
|
143 |
+ |
width: int = 1, |
|
144 |
+ |
font: Optional[str] = None, |
|
145 |
+ |
font_size: int = 10 |
|
146 |
+ |
) -> torch.Tensor: |
|
147 |
+ |
|
|
148 |
+ |
""" |
|
149 |
+ |
Draws bounding boxes on given image. |
|
150 |
+ |
The values of the input image should be uint8 between 0 and 255. |
|
151 |
+ |
|
|
152 |
+ |
Args: |
|
153 |
+ |
image (Tensor): Tensor of shape (C x H x W) |
|
154 |
+ |
bboxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that |
|
155 |
+ |
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and |
|
156 |
+ |
`0 <= ymin < ymax < H`. |
|
157 |
+ |
labels (List[str]): List containing the labels of bounding boxes. |
|
158 |
+ |
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes. The colors can |
|
159 |
+ |
be represented as `str` or `Tuple[int, int, int]`. |
|
160 |
+ |
width (int): Width of bounding box. |
|
161 |
+ |
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may |
|
162 |
+ |
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, |
|
163 |
+ |
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. |
|
164 |
+ |
font_size (int): The requested font size in points. |
|
165 |
+ |
""" |
|
166 |
+ |
|
|
167 |
+ |
if not isinstance(image, torch.Tensor): |
|
168 |
+ |
raise TypeError(f"Tensor expected, got {type(image)}") |
|
169 |
+ |
elif image.dtype != torch.uint8: |
|
170 |
+ |
raise ValueError(f"Tensor uint8 expected, got {image.dtype}") |
|
171 |
+ |
elif image.dim() != 3: |
|
172 |
+ |
raise ValueError("Pass individual images, not batches") |
|
173 |
+ |
|
|
174 |
+ |
ndarr = image.permute(1, 2, 0).numpy() |
|
175 |
+ |
img_to_draw = Image.fromarray(ndarr) |
|
176 |
+ |
|
|
177 |
+ |
img_boxes = boxes.to(torch.int64).tolist() |
|
178 |
+ |
|
|
179 |
+ |
draw = ImageDraw.Draw(img_to_draw) |
|
180 |
+ |
|
|
181 |
+ |
for i, bbox in enumerate(img_boxes): |
|
182 |
+ |
color = None if colors is None else colors[i] |
|
183 |
+ |
draw.rectangle(bbox, width=width, outline=color) |
|
184 |
+ |
|
|
185 |
+ |
if labels is not None: |
|
186 |
+ |
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) |
|
187 |
+ |
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) |
|
188 |
+ |
|
|
189 |
+ |
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) |