Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify yolo detection model code #1107

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 49 additions & 84 deletions sahi/models/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,105 +130,70 @@ def _create_object_prediction_list_from_original_predictions(
"""
original_predictions = self._original_predictions

# compatilibty for sahi v0.8.15
# compatibility for sahi v0.8.15
shift_amount_list = fix_shift_amount_list(shift_amount_list)
full_shape_list = fix_full_shape_list(full_shape_list)

# handle all predictions
object_prediction_list_per_image = []

for image_ind, image_predictions in enumerate(original_predictions):
shift_amount = shift_amount_list[image_ind]
full_shape = None if full_shape_list is None else full_shape_list[image_ind]
object_prediction_list = []
if self.has_mask:
image_predictions_in_xyxy_format = image_predictions[0]
image_predictions_masks = image_predictions[1]
for prediction, bool_mask in zip(
image_predictions_in_xyxy_format.cpu().detach().numpy(),
image_predictions_masks.cpu().detach().numpy(),
):
x1 = prediction[0]
y1 = prediction[1]
x2 = prediction[2]
y2 = prediction[3]
bbox = [x1, y1, x2, y2]
score = prediction[4]
category_id = int(prediction[5])
category_name = self.category_mapping[str(category_id)]

# Extract boxes and optional masks
if self.has_mask:
boxes = image_predictions[0].cpu().detach().numpy()
masks = image_predictions[1].cpu().detach().numpy()
else:
boxes = image_predictions.data.cpu().detach().numpy()
masks = None

# Process each prediction
for pred_ind, prediction in enumerate(boxes):
# Get bbox coordinates
bbox = prediction[:4].tolist()
score = prediction[4]
category_id = int(prediction[5])
category_name = self.category_mapping[str(category_id)]

# Fix box coordinates
bbox = [max(0, coord) for coord in bbox]
if full_shape is not None:
bbox[0] = min(full_shape[1], bbox[0])
bbox[1] = min(full_shape[0], bbox[1])
bbox[2] = min(full_shape[1], bbox[2])
bbox[3] = min(full_shape[0], bbox[3])

# Ignore invalid predictions
if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue

# Get segmentation if available
segmentation = None
if masks is not None:
bool_mask = masks[pred_ind]
orig_width = self._original_shape[1]
orig_height = self._original_shape[0]
bool_mask = cv2.resize(bool_mask.astype(np.uint8), (orig_width, orig_height))
segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
if len(segmentation) == 0:
continue
# fix negative box coords
bbox[0] = max(0, bbox[0])
bbox[1] = max(0, bbox[1])
bbox[2] = max(0, bbox[2])
bbox[3] = max(0, bbox[3])

# fix out of image box coords
if full_shape is not None:
bbox[0] = min(full_shape[1], bbox[0])
bbox[1] = min(full_shape[0], bbox[1])
bbox[2] = min(full_shape[1], bbox[2])
bbox[3] = min(full_shape[0], bbox[3])

# ignore invalid predictions
if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue
object_prediction = ObjectPrediction(
bbox=bbox,
category_id=category_id,
score=score,
segmentation=segmentation,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)
object_prediction_list_per_image.append(object_prediction_list)
else: # Only bounding boxes
# process predictions
for prediction in image_predictions.data.cpu().detach().numpy():
x1 = prediction[0]
y1 = prediction[1]
x2 = prediction[2]
y2 = prediction[3]
bbox = [x1, y1, x2, y2]
score = prediction[4]
category_id = int(prediction[5])
category_name = self.category_mapping[str(category_id)]

# fix negative box coords
bbox[0] = max(0, bbox[0])
bbox[1] = max(0, bbox[1])
bbox[2] = max(0, bbox[2])
bbox[3] = max(0, bbox[3])

# fix out of image box coords
if full_shape is not None:
bbox[0] = min(full_shape[1], bbox[0])
bbox[1] = min(full_shape[0], bbox[1])
bbox[2] = min(full_shape[1], bbox[2])
bbox[3] = min(full_shape[0], bbox[3])

# ignore invalid predictions
if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue

object_prediction = ObjectPrediction(
bbox=bbox,
category_id=category_id,
score=score,
segmentation=None,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)
object_prediction_list_per_image.append(object_prediction_list)
# Create and append object prediction
object_prediction = ObjectPrediction(
bbox=bbox,
category_id=category_id,
score=score,
segmentation=segmentation,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)

object_prediction_list_per_image.append(object_prediction_list)

self._object_prediction_list_per_image = object_prediction_list_per_image
Loading