diff --git a/pyproject.toml b/pyproject.toml index c06269d..c4ed0f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dynamic = ["version"] dependencies = [ "ultralytics==8.2.50", "opencv-python", - "pyroclient @ git+https://github.com/pyronear/pyro-api.git@92b9e28fe4b329aa28c7833b28f3bc89ab1b756c#egg=pyroclient&subdirectory=client", + "pyroclient @ git+https://github.com/pyronear/pyro-api.git@a46f5a00869049ffd1a8bb920ac685e44f18deb5#egg=pyroclient&subdirectory=client", "requests>=2.20.0,<3.0.0", "tqdm>=4.62.0", "huggingface_hub==0.23.1", diff --git a/pyroengine/core.py b/pyroengine/core.py index b853f2b..d1041de 100644 --- a/pyroengine/core.py +++ b/pyroengine/core.py @@ -55,13 +55,16 @@ def is_day_time(cache, frame, strategy, delta=0): return is_day -async def capture_camera_image(camera: ReolinkCamera, image_queue: asyncio.Queue) -> None: +async def capture_camera_image(camera: ReolinkCamera, image_queue: asyncio.Queue) -> bool: """ - Captures an image from the camera and puts it into a queue. + Captures an image from the camera and puts it into a queue. Returns whether it is daytime for this camera. Args: camera (ReolinkCamera): The camera instance. image_queue (asyncio.Queue): The queue to put the captured image. + + Returns: + bool: True if it is daytime according to this camera, False otherwise. """ cam_id = camera.ip_address try: @@ -75,14 +78,18 @@ async def capture_camera_image(camera: ReolinkCamera, image_queue: asyncio.Queue if frame is not None: await image_queue.put((cam_id, frame)) await asyncio.sleep(0) # Yield control + if not is_day_time(None, frame, "ir"): + return False else: frame = camera.capture() if frame is not None: await image_queue.put((cam_id, frame)) await asyncio.sleep(0) # Yield control + if not is_day_time(None, frame, "ir"): + return False except Exception as e: logger.exception(f"Error during image capture from camera {cam_id}: {e}") - + return True class SystemController: """ @@ -103,17 +110,21 @@ def __init__(self, engine: Engine, cameras: List[ReolinkCamera]) -> None: """ self.engine = engine self.cameras = cameras - self.day_time = True + self.is_day = True - async def capture_images(self, image_queue: asyncio.Queue) -> None: + async def capture_images(self, image_queue: asyncio.Queue) -> bool: """ Captures images from all cameras using asyncio. Args: image_queue (asyncio.Queue): The queue to put the captured images. + + Returns: + bool: True if it is daytime according to all cameras, False otherwise. """ tasks = [capture_camera_image(camera, image_queue) for camera in self.cameras] - await asyncio.gather(*tasks) + day_times = await asyncio.gather(*tasks) + return all(day_times) async def analyze_stream(self, image_queue: asyncio.Queue) -> None: """ @@ -134,7 +145,7 @@ async def analyze_stream(self, image_queue: asyncio.Queue) -> None: finally: image_queue.task_done() # Mark the task as done - def check_day_time(self) -> None: + async def night_mode(self) -> bool: """ Checks and updates the day_time attribute based on the current frame. """ @@ -145,43 +156,67 @@ def check_day_time(self) -> None: except Exception as e: logger.exception(f"Exception during initial day time check: {e}") - async def run(self, period: int = 30, send_alerts: bool = True) -> None: + for camera in self.cameras: + cam_id = camera.ip_address + try: + if camera.cam_type == "ptz": + for idx, pose_id in enumerate(camera.cam_poses): + cam_id = f"{camera.ip_address}_{pose_id}" + frame = camera.capture() + # Move camera to the next pose to avoid waiting + next_pos_id = camera.cam_poses[(idx + 1) % len(camera.cam_poses)] + camera.move_camera("ToPos", idx=int(next_pos_id), speed=50) + if frame is not None: + if not is_day_time(None, frame, "ir"): + return False + else: + frame = camera.capture() + if frame is not None: + if not is_day_time(None, frame, "ir"): + return False + except Exception as e: + logger.exception(f"Error during image capture from camera {cam_id}: {e}") + return True + + async def run(self, period: int = 30, send_alerts: bool = True) -> bool: """ Captures and analyzes all camera streams, then processes alerts. Args: period (int): The time period between captures in seconds. - send_alerts (bool): Boolean to activate / deactivate alert sending + send_alerts (bool): Boolean to activate / deactivate alert sending. + + Returns: + bool: True if it is daytime according to all cameras, False otherwise. """ try: - self.check_day_time() - - if self.day_time: + image_queue: asyncio.Queue[Any] = asyncio.Queue() - image_queue: asyncio.Queue[Any] = asyncio.Queue() + # Start the image processor task + processor_task = asyncio.create_task(self.analyze_stream(image_queue)) - # Start the image processor task - processor_task = asyncio.create_task(self.analyze_stream(image_queue)) + # Capture images concurrently + self.is_day = await self.capture_images(image_queue) - # Capture images concurrently - await self.capture_images(image_queue) + # Wait for the image processor to finish processing + await image_queue.join() # Ensure all tasks are marked as done - # Wait for the image processor to finish processing - await image_queue.join() # Ensure all tasks are marked as done + # Signal the image processor to stop processing + await image_queue.put(None) + await processor_task # Ensure the processor task completes - # Signal the image processor to stop processing - await image_queue.put(None) - await processor_task # Ensure the processor task completes - - # Process alerts + # Process alerts + if send_alerts: try: - if send_alerts: - self.engine._process_alerts(self.cameras) + self.engine._process_alerts(self.cameras) except Exception as e: logger.exception(f"Error processing alerts: {e}") + return self.is_day + except Exception as e: logger.warning(f"Analyze stream error: {e}") + return True async def main_loop(self, period: int, send_alerts: bool = True) -> None: """ @@ -189,16 +224,22 @@ async def main_loop(self, period: int, send_alerts: bool = True) -> None: Args: period (int): The time period between captures in seconds. - send_alerts (bool): Boolean to activate / deactivate alert sending + send_alerts (bool): Boolean to activate / deactivate alert sending. """ while True: start_ts = time.time() await self.run(period, send_alerts) - # Sleep only once all images are processed - loop_time = time.time() - start_ts - sleep_time = max(period - (loop_time), 0) - logger.info(f"Loop run under {loop_time:.2f} seconds, sleeping for {sleep_time:.2f}") - await asyncio.sleep(sleep_time) + + if not self.is_day: + while not await self.night_mode(): + logger.info("Nighttime detected by at least one camera, sleeping for 1 hour.") + await asyncio.sleep(3600) # Sleep for 1 hour + else: + # Sleep only once all images are processed + loop_time = time.time() - start_ts + sleep_time = max(period - (loop_time), 0) + logger.info(f"Loop run under {loop_time:.2f} seconds, sleeping for {sleep_time:.2f}") + await asyncio.sleep(sleep_time) def __repr__(self) -> str: """ diff --git a/pyroengine/engine.py b/pyroengine/engine.py index 34d4fd0..d0208ea 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -8,6 +8,7 @@ import json import os import shutil +import signal import time from collections import deque from datetime import datetime, timedelta, timezone @@ -29,6 +30,23 @@ __all__ = ["Engine"] +def handler(signum, frame): + raise TimeoutError("Heartbeat check timed out") + + +def heartbeat_with_timeout(api_instance, cam_id, timeout=1): + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout) + try: + api_instance.heartbeat(cam_id) + except TimeoutError: + logger.warning(f"Heartbeat check timed out for {cam_id}") + except ConnectionError: + logger.warning(f"Unable to reach the pyro-api with {cam_id}") + finally: + signal.alarm(0) + + class Engine: """This implements an object to manage predictions and API interactions for wildfire alerts. @@ -61,7 +79,7 @@ class Engine: def __init__( self, model_path: Optional[str] = None, - conf_thresh: float = 0.25, + conf_thresh: float = 0.15, api_host: Optional[str] = None, cam_creds: Optional[Dict[str, Dict[str, str]]] = None, nb_consecutive_frames: int = 4, @@ -78,7 +96,7 @@ def __init__( ) -> None: """Init engine""" # Engine Setup - self.model = Classifier(model_path=model_path) + self.model = Classifier(model_path=model_path, conf=0.05) self.conf_thresh = conf_thresh # API Setup @@ -211,23 +229,30 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> # Get the best ones if boxes.shape[0]: best_boxes = nms(boxes) - ious = box_iou(best_boxes[:, :4], boxes[:, :4]) - best_boxes_scores = np.array([sum(boxes[iou > 0, 4]) for iou in ious.T]) - combine_predictions = best_boxes[best_boxes_scores > conf_th, :] - conf = np.max(best_boxes_scores) / (self.nb_consecutive_frames + 1) # memory + preds - - if len(combine_predictions): - - # send only preds boxes that match combine_predictions - ious = box_iou(combine_predictions[:, :4], preds[:, :4]) - iou_match = [np.max(iou) > 0 for iou in ious] - output_predictions = preds[iou_match, :] + # We keep only detections with at least two boxes above conf_th + detections = boxes[boxes[:, -1] > self.conf_thresh, :] + ious_detections = box_iou(best_boxes[:, :4], detections[:, :4]) + strong_detection = np.sum(ious_detections > 0, 0) > 1 + best_boxes = best_boxes[strong_detection, :] + if best_boxes.shape[0]: + ious = box_iou(best_boxes[:, :4], boxes[:, :4]) + + best_boxes_scores = np.array([sum(boxes[iou > 0, 4]) for iou in ious.T]) + combine_predictions = best_boxes[best_boxes_scores > conf_th, :] + conf = np.max(best_boxes_scores) / (self.nb_consecutive_frames + 1) # memory + preds + if len(combine_predictions): + + # send only preds boxes that match combine_predictions + ious = box_iou(combine_predictions[:, :4], preds[:, :4]) + iou_match = [np.max(iou) > 0 for iou in ious] + output_predictions = preds[iou_match, :] # Limit bbox size for api output_predictions = np.round(output_predictions, 3) # max 3 digit output_predictions = output_predictions[:5, :] # max 5 bbox output_predictions_tuples = [tuple(row) for row in output_predictions] + self._states[cam_key]["last_predictions"].append( (frame, preds, output_predictions_tuples, datetime.now(timezone.utc).isoformat(), False) ) @@ -253,10 +278,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None, pose_id: Opt # Heartbeat if len(self.api_client) > 0 and isinstance(cam_id, str): - try: - self.heartbeat(cam_id) - except ConnectionError: - logger.exception(f"Unable to reach the pyro-api with {cam_id}") + heartbeat_with_timeout(self, cam_id, timeout=1) cam_key = cam_id or "-1" # Reduce image size to save bandwidth @@ -265,6 +287,7 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None, pose_id: Opt # Inference with ONNX preds = self.model(frame.convert("RGB"), self.occlusion_masks[cam_key]) + print(preds) conf = self._update_states(frame, preds, cam_key) if self.save_captured_frames: diff --git a/pyroengine/vision.py b/pyroengine/vision.py index e9b75a1..7b68859 100644 --- a/pyroengine/vision.py +++ b/pyroengine/vision.py @@ -43,7 +43,7 @@ class Classifier: model_path: model path """ - def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0.05, format="ncnn", model_path=None) -> None: + def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0, format="ncnn", model_path=None) -> None: if model_path is None: if format == "ncnn": if self.is_arm_architecture(): @@ -126,7 +126,7 @@ def load_metadata(self, metadata_path): def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] = None) -> np.ndarray: - results = self.model(pil_img, imgsz=self.imgsz, conf=self.conf, iou=self.iou) + results = self.model(pil_img, imgsz=self.imgsz, conf=self.conf, iou=self.iou, verbose=False) y = np.concatenate( (results[0].boxes.xyxyn.cpu().numpy(), results[0].boxes.conf.cpu().numpy().reshape((-1, 1))), axis=1 ) diff --git a/src/pyproject.toml b/src/pyproject.toml index 9801747..af66ae9 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -11,7 +11,7 @@ license = "Apache-2.0" [tool.poetry.dependencies] python = "^3.8" -pyroclient = { git = "https://github.com/pyronear/pyro-api.git", rev = "92b9e28fe4b329aa28c7833b28f3bc89ab1b756c", subdirectory = "client" } +pyroclient = { git = "https://github.com/pyronear/pyro-api.git", rev = "a46f5a00869049ffd1a8bb920ac685e44f18deb5", subdirectory = "client" } pyroengine = "^0.2.0" python-dotenv = ">=0.15.0" ultralytics = "8.2.50" diff --git a/tests/test_core.py b/tests/test_core.py index d63a211..c03fd59 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -16,18 +16,28 @@ def mock_engine(): @pytest.fixture -def mock_cameras(): +def mock_cameras(mock_wildfire_image): camera = MagicMock() - camera.capture.return_value = Image.new("RGB", (100, 100)) # Mock captured image + camera.capture.return_value = mock_wildfire_image # Mock captured image camera.cam_type = "static" camera.ip_address = "192.168.1.1" return [camera] @pytest.fixture -def mock_cameras_ptz(): +def mock_cameras_ptz(mock_wildfire_image): camera = MagicMock() - camera.capture.return_value = Image.new("RGB", (100, 100)) # Mock captured image + camera.capture.return_value = mock_wildfire_image # Mock captured image + camera.cam_type = "ptz" + camera.cam_poses = [1, 2] + camera.ip_address = "192.168.1.1" + return [camera] + + +@pytest.fixture +def mock_cameras_ptz_night(): + camera = MagicMock() + camera.capture.return_value = Image.new("RGB", (100, 100), (255, 255, 255)) # Mock captured image camera.cam_type = "ptz" camera.cam_poses = [1, 2] camera.ip_address = "192.168.1.1" @@ -44,6 +54,21 @@ def system_controller_ptz(mock_engine, mock_cameras_ptz): return SystemController(engine=mock_engine, cameras=mock_cameras_ptz) +@pytest.fixture +def system_controller_ptz_night(mock_engine, mock_cameras_ptz_night): + return SystemController(engine=mock_engine, cameras=mock_cameras_ptz_night) + + +@pytest.mark.asyncio +async def test_night_mode(system_controller): + assert await system_controller.night_mode() + + +@pytest.mark.asyncio +async def test_night_mode_ptz(system_controller_ptz_night): + assert not await system_controller_ptz_night.night_mode() + + def test_is_day_time_ir_strategy(mock_wildfire_image): # Use day image assert is_day_time(None, mock_wildfire_image, "ir") @@ -94,10 +119,9 @@ async def test_capture_images_ptz(system_controller_ptz): @pytest.mark.asyncio -async def test_analyze_stream(system_controller): +async def test_analyze_stream(system_controller, mock_wildfire_image): queue = asyncio.Queue() - mock_frame = Image.new("RGB", (100, 100)) - + mock_frame = mock_wildfire_image await queue.put(("192.168.1.1", mock_frame)) analyze_task = asyncio.create_task(system_controller.analyze_stream(queue)) @@ -119,9 +143,9 @@ async def test_capture_images_method(system_controller): @pytest.mark.asyncio -async def test_analyze_stream_method(system_controller): +async def test_analyze_stream_method(system_controller, mock_wildfire_image): queue = asyncio.Queue() - mock_frame = Image.new("RGB", (100, 100)) + mock_frame = mock_wildfire_image await queue.put(("192.168.1.1", mock_frame)) await queue.put(None) # Signal the end of the stream @@ -129,25 +153,6 @@ async def test_analyze_stream_method(system_controller): system_controller.engine.predict.assert_called_once_with(mock_frame, "192.168.1.1") -def test_check_day_time(system_controller): - with patch("pyroengine.core.is_day_time", return_value=True) as mock_is_day_time: - system_controller.check_day_time() - assert system_controller.day_time is True - mock_is_day_time.assert_called_once() - - with patch("pyroengine.core.is_day_time", return_value=False) as mock_is_day_time: - system_controller.check_day_time() - assert system_controller.day_time is False - mock_is_day_time.assert_called_once() - - with patch("pyroengine.core.is_day_time", side_effect=Exception("Error in is_day_time")) as mock_is_day_time, patch( - "pyroengine.core.logging.exception" - ) as mock_logging_exception: - system_controller.check_day_time() - mock_is_day_time.assert_called_once() - mock_logging_exception.assert_called_once_with("Exception during initial day time check: Error in is_day_time") - - def test_repr_method(system_controller): repr_str = repr(system_controller) # Check if the representation is a string