Skip to content

Commit

Permalink
Feature/sg 1442 sliding window inference for yolonas (#1979)
Browse files Browse the repository at this point in the history
* wip

* wip

* wip2

* working version, hard coded nms params

* moved post prediction callback to utils

* moved back to wrapper

* added abstract class, small refactoring for pipeline

* rolled back customizable detector, solved pretrained weights setting of proccessing for the wrapper

* temp cleanup

* support for fuse model in predict

* example added for predict

* added support for forward wrappers in trainer

* added test for validation forward wrapper

* added option for None as post prediction callback in DetectionMetrics

* wip adding set_model before using wrapper

* commit changes before removal of validation during training support

* refined docs

* removed old test for forward wrapper, fixed defaults

* fixed test and added clarifications

* forward wrapper test removed

* updated wrong threshold extraction and test result

* fixed docstring format
  • Loading branch information
shaydeci authored May 22, 2024
1 parent dd56e79 commit 217353a
Show file tree
Hide file tree
Showing 6 changed files with 506 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from super_gradients.common.object_names import Models
from super_gradients.training import models


# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
from super_gradients.training.models.detection_models.sliding_window_detection_forward_wrapper import SlidingWindowInferenceDetectionWrapper

model = models.get(Models.YOLO_NAS_S, pretrained_weights="coco")

# We want to use cuda if available to speed up inference.
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

model = SlidingWindowInferenceDetectionWrapper(model=model, tile_size=640, tile_step=160, tile_nms_conf=0.35)

predictions = model.predict(
"https://images.pexels.com/photos/7968254/pexels-photo-7968254.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2", skip_image_resizing=True
)
predictions.show()
predictions.save(output_path="2.jpg") # Save in working directory
5 changes: 4 additions & 1 deletion src/super_gradients/training/metrics/detection_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class DetectionMetrics(Metric):
:param num_cls: Number of classes.
:param post_prediction_callback: DetectionPostPredictionCallback to be applied on net's output prior to the metric computation (NMS).
When None, the direct outputs of the model will be used.
:param normalize_targets: Whether to normalize bbox coordinates by image size.
:param iou_thres: IoU threshold to compute the mAP.
Could be either instance of IouThreshold, a tuple (lower bound, upper_bound) or single scalar.
Expand Down Expand Up @@ -179,7 +181,8 @@ def update(self, preds, target: torch.Tensor, device: str, inputs: torch.tensor,
targets = target.clone()
crowd_targets = torch.zeros(size=(0, 6), device=device) if crowd_targets is None else crowd_targets.clone()

preds = self.post_prediction_callback(preds, device=device)
if self.post_prediction_callback is not None:
preds = self.post_prediction_callback(preds, device=device)

new_matching_info = compute_detection_matching(
preds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,24 @@ def set_dataset_processing_params(
if class_agnostic_nms is not None:
self._default_class_agnostic_nms = bool(class_agnostic_nms)

def get_dataset_processing_params(self):
return dict(
class_names=self._class_names,
image_processor=self._image_processor,
iou=self._default_nms_iou,
conf=self._default_nms_iou,
nms_top_k=self._default_nms_top_k,
max_predictions=self._default_max_predictions,
multi_label_per_box=self._default_multi_label_per_box,
class_agnostic_nms=self._default_class_agnostic_nms,
)

def get_processing_params(self) -> Optional[Processing]:
return self._image_processor

def get_class_names(self) -> Optional[List[str]]:
return self._class_names

@lru_cache(maxsize=1)
def _get_pipeline(
self,
Expand Down
Loading

0 comments on commit 217353a

Please sign in to comment.