diff --git a/sahi/models/yolov8.py b/sahi/models/yolov8.py index cfa37d9c..b8311329 100644 --- a/sahi/models/yolov8.py +++ b/sahi/models/yolov8.py @@ -130,105 +130,70 @@ def _create_object_prediction_list_from_original_predictions( """ original_predictions = self._original_predictions - # compatilibty for sahi v0.8.15 + # compatibility for sahi v0.8.15 shift_amount_list = fix_shift_amount_list(shift_amount_list) full_shape_list = fix_full_shape_list(full_shape_list) + # handle all predictions object_prediction_list_per_image = [] + for image_ind, image_predictions in enumerate(original_predictions): shift_amount = shift_amount_list[image_ind] full_shape = None if full_shape_list is None else full_shape_list[image_ind] object_prediction_list = [] - if self.has_mask: - image_predictions_in_xyxy_format = image_predictions[0] - image_predictions_masks = image_predictions[1] - for prediction, bool_mask in zip( - image_predictions_in_xyxy_format.cpu().detach().numpy(), - image_predictions_masks.cpu().detach().numpy(), - ): - x1 = prediction[0] - y1 = prediction[1] - x2 = prediction[2] - y2 = prediction[3] - bbox = [x1, y1, x2, y2] - score = prediction[4] - category_id = int(prediction[5]) - category_name = self.category_mapping[str(category_id)] + # Extract boxes and optional masks + if self.has_mask: + boxes = image_predictions[0].cpu().detach().numpy() + masks = image_predictions[1].cpu().detach().numpy() + else: + boxes = image_predictions.data.cpu().detach().numpy() + masks = None + + # Process each prediction + for pred_ind, prediction in enumerate(boxes): + # Get bbox coordinates + bbox = prediction[:4].tolist() + score = prediction[4] + category_id = int(prediction[5]) + category_name = self.category_mapping[str(category_id)] + + # Fix box coordinates + bbox = [max(0, coord) for coord in bbox] + if full_shape is not None: + bbox[0] = min(full_shape[1], bbox[0]) + bbox[1] = min(full_shape[0], bbox[1]) + bbox[2] = min(full_shape[1], bbox[2]) + bbox[3] = min(full_shape[0], bbox[3]) + + # Ignore invalid predictions + if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]): + logger.warning(f"ignoring invalid prediction with bbox: {bbox}") + continue + + # Get segmentation if available + segmentation = None + if masks is not None: + bool_mask = masks[pred_ind] orig_width = self._original_shape[1] orig_height = self._original_shape[0] bool_mask = cv2.resize(bool_mask.astype(np.uint8), (orig_width, orig_height)) segmentation = get_coco_segmentation_from_bool_mask(bool_mask) if len(segmentation) == 0: continue - # fix negative box coords - bbox[0] = max(0, bbox[0]) - bbox[1] = max(0, bbox[1]) - bbox[2] = max(0, bbox[2]) - bbox[3] = max(0, bbox[3]) - - # fix out of image box coords - if full_shape is not None: - bbox[0] = min(full_shape[1], bbox[0]) - bbox[1] = min(full_shape[0], bbox[1]) - bbox[2] = min(full_shape[1], bbox[2]) - bbox[3] = min(full_shape[0], bbox[3]) - - # ignore invalid predictions - if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]): - logger.warning(f"ignoring invalid prediction with bbox: {bbox}") - continue - object_prediction = ObjectPrediction( - bbox=bbox, - category_id=category_id, - score=score, - segmentation=segmentation, - category_name=category_name, - shift_amount=shift_amount, - full_shape=full_shape, - ) - object_prediction_list.append(object_prediction) - object_prediction_list_per_image.append(object_prediction_list) - else: # Only bounding boxes - # process predictions - for prediction in image_predictions.data.cpu().detach().numpy(): - x1 = prediction[0] - y1 = prediction[1] - x2 = prediction[2] - y2 = prediction[3] - bbox = [x1, y1, x2, y2] - score = prediction[4] - category_id = int(prediction[5]) - category_name = self.category_mapping[str(category_id)] - - # fix negative box coords - bbox[0] = max(0, bbox[0]) - bbox[1] = max(0, bbox[1]) - bbox[2] = max(0, bbox[2]) - bbox[3] = max(0, bbox[3]) - - # fix out of image box coords - if full_shape is not None: - bbox[0] = min(full_shape[1], bbox[0]) - bbox[1] = min(full_shape[0], bbox[1]) - bbox[2] = min(full_shape[1], bbox[2]) - bbox[3] = min(full_shape[0], bbox[3]) - - # ignore invalid predictions - if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]): - logger.warning(f"ignoring invalid prediction with bbox: {bbox}") - continue - object_prediction = ObjectPrediction( - bbox=bbox, - category_id=category_id, - score=score, - segmentation=None, - category_name=category_name, - shift_amount=shift_amount, - full_shape=full_shape, - ) - object_prediction_list.append(object_prediction) - object_prediction_list_per_image.append(object_prediction_list) + # Create and append object prediction + object_prediction = ObjectPrediction( + bbox=bbox, + category_id=category_id, + score=score, + segmentation=segmentation, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + object_prediction_list.append(object_prediction) + + object_prediction_list_per_image.append(object_prediction_list) self._object_prediction_list_per_image = object_prediction_list_per_image