Skip to content

Commit

Permalink
Add bbox (#162)
Browse files Browse the repository at this point in the history
* add bbox

* fix inference

* style

* fix vision

* style vision

* fix engine

* keep all preds

* speed up

* pass dummy loc

* switch to v8

* fix tests

* test

* drop dummy test

* new api version

* unsed import

* install git

* code quality

* use bbox branch

* use apt get

* alert relaxation 3
  • Loading branch information
MateoLostanlen committed Jul 24, 2023
1 parent c103ae5 commit d5a4d8c
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 29 deletions.
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ COPY ./pyproject.toml /tmp/pyproject.toml
COPY ./README.md /tmp/README.md
COPY ./setup.py /tmp/setup.py

# install git
RUN apt-get update && apt-get install git -y

COPY ./src/requirements.txt /tmp/requirements.txt
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y\
&& pip install --upgrade pip setuptools wheel \
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies = [
"Pillow>=8.4.0",
"onnxruntime>=1.10.0,<2.0.0",
"numpy>=1.19.5,<2.0.0",
"pyroclient>=0.1.2",
"pyroclient @ git+https://github.com/pyronear/pyro-api.git@bbox#egg=pkg&subdirectory=client",
"requests>=2.20.0,<3.0.0",
"opencv-python==4.5.5.64",
]
Expand Down
31 changes: 20 additions & 11 deletions pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,20 +248,27 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:

if is_day_time(self._cache, frame, self.day_time_strategy):
# Inference with ONNX
pred = float(self.model(frame.convert("RGB")))
preds = self.model(frame.convert("RGB"))
if len(preds) == 0:
conf = 0
localization = ""
else:
conf = float(np.max(preds[:, -1]))
localization = str(json.dumps(preds.tolist()))

# Log analysis result
device_str = f"Camera '{cam_id}' - " if isinstance(cam_id, str) else ""
pred_str = "Wildfire detected" if pred >= self.conf_thresh else "No wildfire"
logging.info(f"{device_str}{pred_str} (confidence: {pred:.2%})")
pred_str = "Wildfire detected" if conf >= self.conf_thresh else "No wildfire"
logging.info(f"{device_str}{pred_str} (confidence: {conf:.2%})")

# Alert

to_be_staged = self._update_states(pred, cam_key)
to_be_staged = self._update_states(conf, cam_key)
if to_be_staged and len(self.api_client) > 0 and isinstance(cam_id, str):
# Save the alert in cache to avoid connection issues
self._stage_alert(frame_resize, cam_id)
self._stage_alert(frame_resize, cam_id, localization)
else:
pred = 0 # return default value
conf = 0 # return default value

# Uploading pending alerts
if len(self._alerts) > 0:
Expand Down Expand Up @@ -289,7 +296,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
except ConnectionError:
stream.seek(0) # "Rewind" the stream to the beginning so we can read its content

return pred
return float(conf)

def _upload_frame(self, cam_id: str, media_data: bytes) -> Response:
"""Save frame"""
Expand All @@ -303,7 +310,7 @@ def _upload_frame(self, cam_id: str, media_data: bytes) -> Response:

return response

def _stage_alert(self, frame: Image.Image, cam_id: str) -> None:
def _stage_alert(self, frame: Image.Image, cam_id: str, localization: str) -> None:
# Store information in the queue
self._alerts.append(
{
Expand All @@ -312,6 +319,7 @@ def _stage_alert(self, frame: Image.Image, cam_id: str) -> None:
"ts": datetime.utcnow().isoformat(),
"media_id": None,
"alert_id": None,
"localization": localization,
}
)

Expand All @@ -335,9 +343,10 @@ def _process_alerts(self) -> None:
self._alerts[0]["alert_id"] = (
self.api_client[cam_id]
.send_alert_from_device(
self.latitude,
self.longitude,
self._alerts[0]["media_id"],
lat=self.latitude,
lon=self.longitude,
media_id=self._alerts[0]["media_id"],
localization=self._alerts[0]["localization"],
)
.json()["id"]
)
Expand Down
62 changes: 60 additions & 2 deletions pyroengine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,21 @@
import cv2
import numpy as np

__all__ = ["letterbox"]
__all__ = ["letterbox", "nms", "xywh2xyxy"]


def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, stride=32):
def xywh2xyxy(x: np.array):
y = np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
return y


def letterbox(
im: np.array, new_shape: tuple = (640, 640), color: tuple = (114, 114, 114), auto: bool = False, stride: int = 32
):
"""Letterbox image transform for yolo models
Args:
im (np.array): Input image
Expand Down Expand Up @@ -51,3 +62,50 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, stride
im_b[top : top + h, left : left + w, :] = im

return im_b.astype("uint8")


def box_iou(box1: np.array, box2: np.array, eps: float = 1e-7):
"""
Calculate intersection-over-union (IoU) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
Args:
box1 (np.array): A numpy array of shape (N, 4) representing N bounding boxes.
box2 (np.array): A numpy array of shape (M, 4) representing M bounding boxes.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(np.array): An NxM numpy array containing the pairwise IoU values for every element in box1 and box2.
"""

(a1, a2), (b1, b2) = np.split(box1, 2, 1), np.split(box2, 2, 1)
inter = (np.minimum(a2, b2[:, None, :]) - np.maximum(a1, b1[:, None, :])).clip(0).prod(2)

# IoU = inter / (area1 + area2 - inter)
return inter / ((a2 - a1).prod(1) + (b2 - b1).prod(1)[:, None] - inter + eps)


def nms(boxes: np.array, overlapThresh: int = 0):
"""Non maximum suppression
Args:
boxes (np.array): A numpy array of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2, conf) format
overlapThresh (int, optional): iou threshold. Defaults to 0.
Returns:
boxes: Boxes after NMS
"""
# Return an empty list, if no boxes given
boxes = boxes[boxes[:, -1].argsort()]
if len(boxes) == 0:
return []

indices = np.arange(len(boxes))
rr = box_iou(boxes[:, :4], boxes[:, :4])
for i, box in enumerate(boxes):
temp_indices = indices[indices != i]
if np.any(rr[i, temp_indices] > overlapThresh):
indices = indices[indices != i]

return boxes[indices]
31 changes: 21 additions & 10 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import onnxruntime
from PIL import Image

from .utils import letterbox
from .utils import letterbox, nms, xywh2xyxy

__all__ = ["Classifier"]

MODEL_URL = "https://github.com/pyronear/pyro-vision/releases/download/v0.2.0/yolov5s_v002.onnx"
MODEL_URL = "https://github.com/pyronear/pyro-vision/releases/download/v0.2.0/yolov8s_v001.onnx"


class Classifier:
Expand All @@ -29,16 +29,17 @@ class Classifier:
model_path: model path
"""

def __init__(self, model_path: Optional[str] = "data/model.onnx") -> None:
def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tuple = (384, 640)) -> None:
# Download model if not available
if not os.path.isfile(model_path):
os.makedirs(os.path.split(model_path)[0], exist_ok=True)
print(f"Downloading model from {MODEL_URL} ...")
urllib.request.urlretrieve(MODEL_URL, model_path)

self.ort_session = onnxruntime.InferenceSession(model_path)
self.img_size = img_size

def preprocess_image(self, pil_img: Image.Image, img_size=(640, 384)) -> np.ndarray:
def preprocess_image(self, pil_img: Image.Image) -> np.ndarray:
"""Preprocess an image for inference
Args:
Expand All @@ -49,7 +50,7 @@ def preprocess_image(self, pil_img: Image.Image, img_size=(640, 384)) -> np.ndar
the resized and normalized image of shape (1, C, H, W)
"""

np_img = letterbox(np.array(pil_img)) # letterbox
np_img = letterbox(np.array(pil_img), self.img_size) # letterbox
np_img = np.expand_dims(np_img.astype("float"), axis=0)
np_img = np.ascontiguousarray(np_img.transpose((0, 3, 1, 2))) # BHWC to BCHW
np_img = np_img.astype("float32") / 255
Expand All @@ -60,8 +61,18 @@ def __call__(self, pil_img: Image.Image) -> np.ndarray:
np_img = self.preprocess_image(pil_img)

# ONNX inference
y = self.ort_session.run(["output0"], {"images": np_img})[0]
# Non maximum suppression need to be added here when we will use the location information
# let's avoid useless compute for now

return np.max(y[0, :, 4])
y = self.ort_session.run(["output0"], {"images": np_img})[0][0]
# Drop low conf for speed-up
y = y[:, y[-1, :] > 0.05]
# Post processing
y = np.transpose(y)
y = xywh2xyxy(y)
# Sort by confidence
y = y[y[:, 4].argsort()]
y = nms(y)
# Normalize preds
if len(y) > 0:
y[:, :4:2] /= self.img_size[1]
y[:, 1:4:2] /= self.img_size[0]

return y
2 changes: 1 addition & 1 deletion src/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def main(args):
parser.add_argument(
"--alert_relaxation",
type=int,
default=2,
default=3,
help="Number of consecutive positive detections required to send the first alert",
)
parser.add_argument(
Expand Down
7 changes: 4 additions & 3 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):

# Cache saving
_ts = datetime.utcnow().isoformat()
engine._stage_alert(mock_wildfire_image, 0)
engine._stage_alert(mock_wildfire_image, 0, localization="dummy")
assert len(engine._alerts) == 1
assert engine._alerts[0]["ts"] < datetime.utcnow().isoformat() and _ts < engine._alerts[0]["ts"]
assert engine._alerts[0]["media_id"] is None
Expand All @@ -37,16 +37,17 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):
engine._dump_cache()

# Cache dump loading
engine = Engine(cache_folder=folder + "model.onnx")
engine = Engine(cache_folder=folder)
assert len(engine._alerts) == 1
engine.clear_cache()

# inference
engine = Engine(alert_relaxation=3, cache_folder=folder + "model.onnx")
engine = Engine(alert_relaxation=3, cache_folder=folder)
out = engine.predict(mock_forest_image)
assert isinstance(out, float) and 0 <= out <= 1
assert engine._states["-1"]["consec"] == 0
out = engine.predict(mock_wildfire_image)

assert isinstance(out, float) and 0 <= out <= 1
assert engine._states["-1"]["consec"] == 1
# Alert relaxation
Expand Down
4 changes: 3 additions & 1 deletion tests/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ def test_classifier(mock_wildfire_image):
assert out.shape == (1, 3, 384, 640)
# Check inference
out = model(mock_wildfire_image)
assert out >= 0 and out <= 1
assert out.shape == (1, 5)
conf = np.max(out[:, 4])
assert conf >= 0 and conf <= 1

0 comments on commit d5a4d8c

Please sign in to comment.