Skip to content

Commit

Permalink
fix padding (#179)
Browse files Browse the repository at this point in the history
* fix padding

* style

* fix test

* check pad is a tuple

* fix style

* mypy
  • Loading branch information
MateoLostanlen authored Apr 9, 2024
1 parent 035c5ae commit 2cbe807
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pyroengine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def letterbox(
im_b = np.zeros((h + top + bottom, w + left + right, 3)) + color
im_b[top : top + h, left : left + w, :] = im

return im_b.astype("uint8")
return im_b.astype("uint8"), (left, top)


def box_iou(box1: np.ndarray, box2: np.ndarray, eps: float = 1e-7):
Expand Down
32 changes: 19 additions & 13 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import os
from typing import Optional
from typing import Optional, Tuple
from urllib.request import urlretrieve

import numpy as np
Expand Down Expand Up @@ -41,26 +41,27 @@ def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tupl
self.ort_session = onnxruntime.InferenceSession(model_path)
self.img_size = img_size

def preprocess_image(self, pil_img: Image.Image, mask: Optional[np.ndarray] = None) -> np.ndarray:
def preprocess_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Tuple[int, int]]:
"""Preprocess an image for inference
Args:
pil_img: a valid pillow image
mask: occlusion mask to drop prediction in an area
pil_img: A valid PIL image.
Returns:
the resized and normalized image of shape (1, C, H, W)
A tuple containing:
- The resized and normalized image of shape (1, C, H, W).
- Padding information as a tuple of integers (pad_height, pad_width).
"""

np_img = letterbox(np.array(pil_img), self.img_size) # letterbox
np_img = np.expand_dims(np_img.astype("float"), axis=0)
np_img = np.ascontiguousarray(np_img.transpose((0, 3, 1, 2))) # BHWC to BCHW
np_img = np_img.astype("float32") / 255
np_img, pad = letterbox(np.array(pil_img), self.img_size) # Applies letterbox resize with padding
np_img = np.expand_dims(np_img.astype("float"), axis=0) # Add batch dimension
np_img = np.ascontiguousarray(np_img.transpose((0, 3, 1, 2))) # Convert from BHWC to BCHW format
np_img = np_img.astype("float32") / 255 # Normalize to [0, 1]

return np_img
return np_img, pad

def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] = None) -> np.ndarray:
np_img = self.preprocess_image(pil_img)
np_img, pad = self.preprocess_image(pil_img)

# ONNX inference
y = self.ort_session.run(["output0"], {"images": np_img})[0][0]
Expand All @@ -72,10 +73,15 @@ def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] =
# Sort by confidence
y = y[y[:, 4].argsort()]
y = nms(y)

# Normalize preds
if len(y) > 0:
y[:, :4:2] /= self.img_size[1]
y[:, 1:4:2] /= self.img_size[0]
# Remove padding
left_pad, top_pad = pad
y[:, :4:2] -= left_pad
y[:, 1:4:2] -= top_pad
y[:, :4:2] /= self.img_size[1] - 2 * left_pad
y[:, 1:4:2] /= self.img_size[0] - 2 * top_pad
else:
y = np.zeros((0, 5)) # normalize output

Expand Down
3 changes: 2 additions & 1 deletion tests/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ def test_classifier(mock_wildfire_image):
# Instantiate the ONNX model
model = Classifier()
# Check preprocessing
out = model.preprocess_image(mock_wildfire_image)
out, pad = model.preprocess_image(mock_wildfire_image)
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1, 3, 384, 640)
assert isinstance(pad, tuple)
# Check inference
out = model(mock_wildfire_image)
assert out.shape == (1, 5)
Expand Down

0 comments on commit 2cbe807

Please sign in to comment.