Skip to content

Commit

Permalink
simplify yolo detection model code (#1107)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcakyon authored Dec 16, 2024
1 parent 61fdf6f commit aaeb57c
Showing 1 changed file with 49 additions and 84 deletions.
133 changes: 49 additions & 84 deletions sahi/models/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,105 +130,70 @@ def _create_object_prediction_list_from_original_predictions(
"""
original_predictions = self._original_predictions

# compatilibty for sahi v0.8.15
# compatibility for sahi v0.8.15
shift_amount_list = fix_shift_amount_list(shift_amount_list)
full_shape_list = fix_full_shape_list(full_shape_list)

# handle all predictions
object_prediction_list_per_image = []

for image_ind, image_predictions in enumerate(original_predictions):
shift_amount = shift_amount_list[image_ind]
full_shape = None if full_shape_list is None else full_shape_list[image_ind]
object_prediction_list = []
if self.has_mask:
image_predictions_in_xyxy_format = image_predictions[0]
image_predictions_masks = image_predictions[1]
for prediction, bool_mask in zip(
image_predictions_in_xyxy_format.cpu().detach().numpy(),
image_predictions_masks.cpu().detach().numpy(),
):
x1 = prediction[0]
y1 = prediction[1]
x2 = prediction[2]
y2 = prediction[3]
bbox = [x1, y1, x2, y2]
score = prediction[4]
category_id = int(prediction[5])
category_name = self.category_mapping[str(category_id)]

# Extract boxes and optional masks
if self.has_mask:
boxes = image_predictions[0].cpu().detach().numpy()
masks = image_predictions[1].cpu().detach().numpy()
else:
boxes = image_predictions.data.cpu().detach().numpy()
masks = None

# Process each prediction
for pred_ind, prediction in enumerate(boxes):
# Get bbox coordinates
bbox = prediction[:4].tolist()
score = prediction[4]
category_id = int(prediction[5])
category_name = self.category_mapping[str(category_id)]

# Fix box coordinates
bbox = [max(0, coord) for coord in bbox]
if full_shape is not None:
bbox[0] = min(full_shape[1], bbox[0])
bbox[1] = min(full_shape[0], bbox[1])
bbox[2] = min(full_shape[1], bbox[2])
bbox[3] = min(full_shape[0], bbox[3])

# Ignore invalid predictions
if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue

# Get segmentation if available
segmentation = None
if masks is not None:
bool_mask = masks[pred_ind]
orig_width = self._original_shape[1]
orig_height = self._original_shape[0]
bool_mask = cv2.resize(bool_mask.astype(np.uint8), (orig_width, orig_height))
segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
if len(segmentation) == 0:
continue
# fix negative box coords
bbox[0] = max(0, bbox[0])
bbox[1] = max(0, bbox[1])
bbox[2] = max(0, bbox[2])
bbox[3] = max(0, bbox[3])

# fix out of image box coords
if full_shape is not None:
bbox[0] = min(full_shape[1], bbox[0])
bbox[1] = min(full_shape[0], bbox[1])
bbox[2] = min(full_shape[1], bbox[2])
bbox[3] = min(full_shape[0], bbox[3])

# ignore invalid predictions
if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue
object_prediction = ObjectPrediction(
bbox=bbox,
category_id=category_id,
score=score,
segmentation=segmentation,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)
object_prediction_list_per_image.append(object_prediction_list)
else: # Only bounding boxes
# process predictions
for prediction in image_predictions.data.cpu().detach().numpy():
x1 = prediction[0]
y1 = prediction[1]
x2 = prediction[2]
y2 = prediction[3]
bbox = [x1, y1, x2, y2]
score = prediction[4]
category_id = int(prediction[5])
category_name = self.category_mapping[str(category_id)]

# fix negative box coords
bbox[0] = max(0, bbox[0])
bbox[1] = max(0, bbox[1])
bbox[2] = max(0, bbox[2])
bbox[3] = max(0, bbox[3])

# fix out of image box coords
if full_shape is not None:
bbox[0] = min(full_shape[1], bbox[0])
bbox[1] = min(full_shape[0], bbox[1])
bbox[2] = min(full_shape[1], bbox[2])
bbox[3] = min(full_shape[0], bbox[3])

# ignore invalid predictions
if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue

object_prediction = ObjectPrediction(
bbox=bbox,
category_id=category_id,
score=score,
segmentation=None,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)
object_prediction_list_per_image.append(object_prediction_list)
# Create and append object prediction
object_prediction = ObjectPrediction(
bbox=bbox,
category_id=category_id,
score=score,
segmentation=segmentation,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)

object_prediction_list_per_image.append(object_prediction_list)

self._object_prediction_list_per_image = object_prediction_list_per_image

0 comments on commit aaeb57c

Please sign in to comment.