Skip to content

Commit

Permalink
use only one frame
Browse files Browse the repository at this point in the history
  • Loading branch information
MateoLostanlen committed Aug 1, 2024
1 parent fc383d2 commit 85ef7d4
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,15 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
# Inference with ONNX
preds = self.model(frame.convert("RGB"), self.occlusion_masks[cam_key])
print(preds)
conf = self._update_states(frame, preds, cam_key)
# conf = self._update_states(frame, preds, cam_key)
conf = np.max(preds[:, -1])
# Limit bbox size for api
output_predictions = np.round(preds, 3) # max 3 digit
output_predictions = output_predictions[:5, :] # max 5 bbox

# Alert
if conf > self.conf_thresh and len(self.api_client) > 0 and isinstance(cam_id, str):
self._stage_alert(frame, cam_id, ts, output_predictions.tolist())

if self.save_captured_frames:
self._local_backup(frame, cam_id, is_alert=False)
Expand All @@ -291,15 +299,15 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
pred_str = "Wildfire detected" if conf > self.conf_thresh else "No wildfire"
logging.info(f"{device_str}{pred_str} (confidence: {conf:.2%})")

# Alert
if conf > self.conf_thresh and len(self.api_client) > 0 and isinstance(cam_id, str):
# Save the alert in cache to avoid connection issues
for idx, (frame, preds, localization, ts, is_staged) in enumerate(
self._states[cam_key]["last_predictions"]
):
if not is_staged:
self._stage_alert(frame, cam_id, ts, localization)
self._states[cam_key]["last_predictions"][idx] = frame, preds, localization, ts, True
# # Alert
# if conf > self.conf_thresh and len(self.api_client) > 0 and isinstance(cam_id, str):
# # Save the alert in cache to avoid connection issues
# for idx, (frame, preds, localization, ts, is_staged) in enumerate(
# self._states[cam_key]["last_predictions"]
# ):
# if not is_staged:
# self._stage_alert(frame, cam_id, ts, localization)
# self._states[cam_key]["last_predictions"][idx] = frame, preds, localization, ts, True

# Check if it's time to backup pending alerts
ts = datetime.now(timezone.utc)
Expand Down

0 comments on commit 85ef7d4

Please sign in to comment.