Skip to content

Commit

Permalink
Merge branch 'develop' into rs/update-new-datamodel
Browse files Browse the repository at this point in the history
  • Loading branch information
Ronan committed Sep 18, 2024
2 parents 59a5e1e + a618731 commit 5e23a76
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 81 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dynamic = ["version"]
dependencies = [
"ultralytics==8.2.50",
"opencv-python",
"pyroclient @ git+https://github.com/pyronear/pyro-api.git@92b9e28fe4b329aa28c7833b28f3bc89ab1b756c#egg=pyroclient&subdirectory=client",
"pyroclient @ git+https://github.com/pyronear/pyro-api.git@a46f5a00869049ffd1a8bb920ac685e44f18deb5#egg=pyroclient&subdirectory=client",
"requests>=2.20.0,<3.0.0",
"tqdm>=4.62.0",
"huggingface_hub==0.23.1",
Expand Down
105 changes: 73 additions & 32 deletions pyroengine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,16 @@ def is_day_time(cache, frame, strategy, delta=0):
return is_day


async def capture_camera_image(camera: ReolinkCamera, image_queue: asyncio.Queue) -> None:
async def capture_camera_image(camera: ReolinkCamera, image_queue: asyncio.Queue) -> bool:
"""
Captures an image from the camera and puts it into a queue.
Captures an image from the camera and puts it into a queue. Returns whether it is daytime for this camera.
Args:
camera (ReolinkCamera): The camera instance.
image_queue (asyncio.Queue): The queue to put the captured image.
Returns:
bool: True if it is daytime according to this camera, False otherwise.
"""
cam_id = camera.ip_address
try:
Expand All @@ -75,14 +78,18 @@ async def capture_camera_image(camera: ReolinkCamera, image_queue: asyncio.Queue
if frame is not None:
await image_queue.put((cam_id, frame))
await asyncio.sleep(0) # Yield control
if not is_day_time(None, frame, "ir"):
return False
else:
frame = camera.capture()
if frame is not None:
await image_queue.put((cam_id, frame))
await asyncio.sleep(0) # Yield control
if not is_day_time(None, frame, "ir"):
return False
except Exception as e:
logger.exception(f"Error during image capture from camera {cam_id}: {e}")

return True

class SystemController:
"""
Expand All @@ -103,17 +110,21 @@ def __init__(self, engine: Engine, cameras: List[ReolinkCamera]) -> None:
"""
self.engine = engine
self.cameras = cameras
self.day_time = True
self.is_day = True

async def capture_images(self, image_queue: asyncio.Queue) -> None:
async def capture_images(self, image_queue: asyncio.Queue) -> bool:
"""
Captures images from all cameras using asyncio.
Args:
image_queue (asyncio.Queue): The queue to put the captured images.
Returns:
bool: True if it is daytime according to all cameras, False otherwise.
"""
tasks = [capture_camera_image(camera, image_queue) for camera in self.cameras]
await asyncio.gather(*tasks)
day_times = await asyncio.gather(*tasks)
return all(day_times)

async def analyze_stream(self, image_queue: asyncio.Queue) -> None:
"""
Expand All @@ -134,7 +145,7 @@ async def analyze_stream(self, image_queue: asyncio.Queue) -> None:
finally:
image_queue.task_done() # Mark the task as done

def check_day_time(self) -> None:
async def night_mode(self) -> bool:
"""
Checks and updates the day_time attribute based on the current frame.
"""
Expand All @@ -145,60 +156,90 @@ def check_day_time(self) -> None:
except Exception as e:
logger.exception(f"Exception during initial day time check: {e}")

async def run(self, period: int = 30, send_alerts: bool = True) -> None:
for camera in self.cameras:
cam_id = camera.ip_address
try:
if camera.cam_type == "ptz":
for idx, pose_id in enumerate(camera.cam_poses):
cam_id = f"{camera.ip_address}_{pose_id}"
frame = camera.capture()
# Move camera to the next pose to avoid waiting
next_pos_id = camera.cam_poses[(idx + 1) % len(camera.cam_poses)]
camera.move_camera("ToPos", idx=int(next_pos_id), speed=50)
if frame is not None:
if not is_day_time(None, frame, "ir"):
return False
else:
frame = camera.capture()
if frame is not None:
if not is_day_time(None, frame, "ir"):
return False
except Exception as e:
logger.exception(f"Error during image capture from camera {cam_id}: {e}")
return True

async def run(self, period: int = 30, send_alerts: bool = True) -> bool:
"""
Captures and analyzes all camera streams, then processes alerts.
Args:
period (int): The time period between captures in seconds.
send_alerts (bool): Boolean to activate / deactivate alert sending
send_alerts (bool): Boolean to activate / deactivate alert sending.
Returns:
bool: True if it is daytime according to all cameras, False otherwise.
"""
try:
self.check_day_time()

if self.day_time:
image_queue: asyncio.Queue[Any] = asyncio.Queue()

image_queue: asyncio.Queue[Any] = asyncio.Queue()
# Start the image processor task
processor_task = asyncio.create_task(self.analyze_stream(image_queue))

# Start the image processor task
processor_task = asyncio.create_task(self.analyze_stream(image_queue))
# Capture images concurrently
self.is_day = await self.capture_images(image_queue)

# Capture images concurrently
await self.capture_images(image_queue)
# Wait for the image processor to finish processing
await image_queue.join() # Ensure all tasks are marked as done

# Wait for the image processor to finish processing
await image_queue.join() # Ensure all tasks are marked as done
# Signal the image processor to stop processing
await image_queue.put(None)
await processor_task # Ensure the processor task completes

# Signal the image processor to stop processing
await image_queue.put(None)
await processor_task # Ensure the processor task completes

# Process alerts
# Process alerts
if send_alerts:
try:
if send_alerts:
self.engine._process_alerts(self.cameras)
self.engine._process_alerts(self.cameras)
except Exception as e:
logger.exception(f"Error processing alerts: {e}")

return self.is_day

except Exception as e:
logger.warning(f"Analyze stream error: {e}")
return True

async def main_loop(self, period: int, send_alerts: bool = True) -> None:
"""
Main loop to capture and process images at regular intervals.
Args:
period (int): The time period between captures in seconds.
send_alerts (bool): Boolean to activate / deactivate alert sending
send_alerts (bool): Boolean to activate / deactivate alert sending.
"""
while True:
start_ts = time.time()
await self.run(period, send_alerts)
# Sleep only once all images are processed
loop_time = time.time() - start_ts
sleep_time = max(period - (loop_time), 0)
logger.info(f"Loop run under {loop_time:.2f} seconds, sleeping for {sleep_time:.2f}")
await asyncio.sleep(sleep_time)

if not self.is_day:
while not await self.night_mode():
logger.info("Nighttime detected by at least one camera, sleeping for 1 hour.")
await asyncio.sleep(3600) # Sleep for 1 hour
else:
# Sleep only once all images are processed
loop_time = time.time() - start_ts
sleep_time = max(period - (loop_time), 0)
logger.info(f"Loop run under {loop_time:.2f} seconds, sleeping for {sleep_time:.2f}")
await asyncio.sleep(sleep_time)

def __repr__(self) -> str:
"""
Expand Down
57 changes: 40 additions & 17 deletions pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import os
import shutil
import signal
import time
from collections import deque
from datetime import datetime, timedelta, timezone
Expand All @@ -29,6 +30,23 @@
__all__ = ["Engine"]


def handler(signum, frame):
raise TimeoutError("Heartbeat check timed out")


def heartbeat_with_timeout(api_instance, cam_id, timeout=1):
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
api_instance.heartbeat(cam_id)
except TimeoutError:
logger.warning(f"Heartbeat check timed out for {cam_id}")
except ConnectionError:
logger.warning(f"Unable to reach the pyro-api with {cam_id}")
finally:
signal.alarm(0)


class Engine:
"""This implements an object to manage predictions and API interactions for wildfire alerts.
Expand Down Expand Up @@ -61,7 +79,7 @@ class Engine:
def __init__(
self,
model_path: Optional[str] = None,
conf_thresh: float = 0.25,
conf_thresh: float = 0.15,
api_host: Optional[str] = None,
cam_creds: Optional[Dict[str, Dict[str, str]]] = None,
nb_consecutive_frames: int = 4,
Expand All @@ -78,7 +96,7 @@ def __init__(
) -> None:
"""Init engine"""
# Engine Setup
self.model = Classifier(model_path=model_path)
self.model = Classifier(model_path=model_path, conf=0.05)
self.conf_thresh = conf_thresh

# API Setup
Expand Down Expand Up @@ -211,23 +229,30 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) ->
# Get the best ones
if boxes.shape[0]:
best_boxes = nms(boxes)
ious = box_iou(best_boxes[:, :4], boxes[:, :4])
best_boxes_scores = np.array([sum(boxes[iou > 0, 4]) for iou in ious.T])
combine_predictions = best_boxes[best_boxes_scores > conf_th, :]
conf = np.max(best_boxes_scores) / (self.nb_consecutive_frames + 1) # memory + preds

if len(combine_predictions):

# send only preds boxes that match combine_predictions
ious = box_iou(combine_predictions[:, :4], preds[:, :4])
iou_match = [np.max(iou) > 0 for iou in ious]
output_predictions = preds[iou_match, :]
# We keep only detections with at least two boxes above conf_th
detections = boxes[boxes[:, -1] > self.conf_thresh, :]
ious_detections = box_iou(best_boxes[:, :4], detections[:, :4])
strong_detection = np.sum(ious_detections > 0, 0) > 1
best_boxes = best_boxes[strong_detection, :]
if best_boxes.shape[0]:
ious = box_iou(best_boxes[:, :4], boxes[:, :4])

best_boxes_scores = np.array([sum(boxes[iou > 0, 4]) for iou in ious.T])
combine_predictions = best_boxes[best_boxes_scores > conf_th, :]
conf = np.max(best_boxes_scores) / (self.nb_consecutive_frames + 1) # memory + preds
if len(combine_predictions):

# send only preds boxes that match combine_predictions
ious = box_iou(combine_predictions[:, :4], preds[:, :4])
iou_match = [np.max(iou) > 0 for iou in ious]
output_predictions = preds[iou_match, :]

# Limit bbox size for api
output_predictions = np.round(output_predictions, 3) # max 3 digit
output_predictions = output_predictions[:5, :] # max 5 bbox
output_predictions_tuples = [tuple(row) for row in output_predictions]


self._states[cam_key]["last_predictions"].append(
(frame, preds, output_predictions_tuples, datetime.now(timezone.utc).isoformat(), False)
)
Expand All @@ -253,10 +278,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None, pose_id: Opt

# Heartbeat
if len(self.api_client) > 0 and isinstance(cam_id, str):
try:
self.heartbeat(cam_id)
except ConnectionError:
logger.exception(f"Unable to reach the pyro-api with {cam_id}")
heartbeat_with_timeout(self, cam_id, timeout=1)

cam_key = cam_id or "-1"
# Reduce image size to save bandwidth
Expand All @@ -265,6 +287,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None, pose_id: Opt

# Inference with ONNX
preds = self.model(frame.convert("RGB"), self.occlusion_masks[cam_key])
print(preds)
conf = self._update_states(frame, preds, cam_key)

if self.save_captured_frames:
Expand Down
4 changes: 2 additions & 2 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Classifier:
model_path: model path
"""

def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0.05, format="ncnn", model_path=None) -> None:
def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0, format="ncnn", model_path=None) -> None:
if model_path is None:
if format == "ncnn":
if self.is_arm_architecture():
Expand Down Expand Up @@ -126,7 +126,7 @@ def load_metadata(self, metadata_path):

def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] = None) -> np.ndarray:

results = self.model(pil_img, imgsz=self.imgsz, conf=self.conf, iou=self.iou)
results = self.model(pil_img, imgsz=self.imgsz, conf=self.conf, iou=self.iou, verbose=False)
y = np.concatenate(
(results[0].boxes.xyxyn.cpu().numpy(), results[0].boxes.conf.cpu().numpy().reshape((-1, 1))), axis=1
)
Expand Down
2 changes: 1 addition & 1 deletion src/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = "Apache-2.0"

[tool.poetry.dependencies]
python = "^3.8"
pyroclient = { git = "https://github.com/pyronear/pyro-api.git", rev = "92b9e28fe4b329aa28c7833b28f3bc89ab1b756c", subdirectory = "client" }
pyroclient = { git = "https://github.com/pyronear/pyro-api.git", rev = "a46f5a00869049ffd1a8bb920ac685e44f18deb5", subdirectory = "client" }
pyroengine = "^0.2.0"
python-dotenv = ">=0.15.0"
ultralytics = "8.2.50"
Loading

0 comments on commit 5e23a76

Please sign in to comment.