Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1442 sliding window inference for yolonas #1979

Merged
merged 26 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2ca48bf
wip
shaydeci Apr 18, 2024
c3666e9
wip
shaydeci Apr 30, 2024
4b783ea
wip2
shaydeci May 1, 2024
667ed3b
Merge remote-tracking branch 'origin/master' into feature/SG-1442_sli…
shaydeci May 5, 2024
1d8cc8c
working version, hard coded nms params
shaydeci May 5, 2024
fae6d8d
moved post prediction callback to utils
shaydeci May 5, 2024
45aea2a
moved back to wrapper
shaydeci May 7, 2024
684af84
Merge remote-tracking branch 'origin/master' into feature/SG-1442_sli…
shaydeci May 9, 2024
837ffd3
added abstract class, small refactoring for pipeline
shaydeci May 9, 2024
f77616c
rolled back customizable detector, solved pretrained weights setting …
shaydeci May 9, 2024
dce1b4a
temp cleanup
shaydeci May 9, 2024
6c64ddd
support for fuse model in predict
shaydeci May 9, 2024
2cdf4ff
example added for predict
shaydeci May 9, 2024
80d81e9
added support for forward wrappers in trainer
shaydeci May 9, 2024
bf809eb
added test for validation forward wrapper
shaydeci May 9, 2024
877e016
added option for None as post prediction callback in DetectionMetrics
shaydeci May 9, 2024
8192a15
wip adding set_model before using wrapper
shaydeci May 15, 2024
60cf723
Merge remote-tracking branch 'origin/master' into feature/SG-1442_sli…
shaydeci May 16, 2024
ebfefd1
commit changes before removal of validation during training support
shaydeci May 16, 2024
aa7d0cb
refined docs
shaydeci May 16, 2024
7f3a0d4
removed old test for forward wrapper, fixed defaults
shaydeci May 20, 2024
1056b23
fixed test and added clarifications
shaydeci May 20, 2024
2981c23
forward wrapper test removed
shaydeci May 20, 2024
cf169e9
Merge remote-tracking branch 'origin/master' into feature/SG-1442_sli…
shaydeci May 20, 2024
0bcb821
updated wrong threshold extraction and test result
shaydeci May 20, 2024
2d6331a
fixed docstring format
shaydeci May 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from super_gradients.common.object_names import Models
from super_gradients.training import models


# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
from super_gradients.training.models.detection_models.sliding_window_detection_forward_wrapper import SlidingWindowInferenceDetectionWrapper

model = models.get(Models.YOLO_NAS_S, pretrained_weights="coco")

# We want to use cuda if available to speed up inference.
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

model = SlidingWindowInferenceDetectionWrapper(model=model, tile_size=640, tile_step=160, tile_nms_conf=0.35)

predictions = model.predict(
"https://images.pexels.com/photos/7968254/pexels-photo-7968254.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2", skip_image_resizing=True
)
predictions.show()
predictions.save(output_path="2.jpg") # Save in working directory
5 changes: 4 additions & 1 deletion src/super_gradients/training/metrics/detection_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class DetectionMetrics(Metric):

:param num_cls: Number of classes.
:param post_prediction_callback: DetectionPostPredictionCallback to be applied on net's output prior to the metric computation (NMS).
When None, the direct outputs of the model will be used.

:param normalize_targets: Whether to normalize bbox coordinates by image size.
:param iou_thres: IoU threshold to compute the mAP.
Could be either instance of IouThreshold, a tuple (lower bound, upper_bound) or single scalar.
Expand Down Expand Up @@ -179,7 +181,8 @@ def update(self, preds, target: torch.Tensor, device: str, inputs: torch.tensor,
targets = target.clone()
crowd_targets = torch.zeros(size=(0, 6), device=device) if crowd_targets is None else crowd_targets.clone()

preds = self.post_prediction_callback(preds, device=device)
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
if self.post_prediction_callback is not None:
preds = self.post_prediction_callback(preds, device=device)

new_matching_info = compute_detection_matching(
preds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,24 @@ def set_dataset_processing_params(
if class_agnostic_nms is not None:
self._default_class_agnostic_nms = bool(class_agnostic_nms)

def get_dataset_processing_params(self):
return dict(
class_names=self._class_names,
image_processor=self._image_processor,
iou=self._default_nms_iou,
conf=self._default_nms_iou,
nms_top_k=self._default_nms_top_k,
max_predictions=self._default_max_predictions,
multi_label_per_box=self._default_multi_label_per_box,
class_agnostic_nms=self._default_class_agnostic_nms,
)

def get_processing_params(self) -> Optional[Processing]:
return self._image_processor

def get_class_names(self) -> Optional[List[str]]:
return self._class_names

@lru_cache(maxsize=1)
def _get_pipeline(
self,
Expand Down
Loading
Loading