Skip to content

Commit

Permalink
feat(detections): ✨ paligemma segmentation support added
Browse files Browse the repository at this point in the history
Signed-off-by: Onuralp SEZER <[email protected]>
  • Loading branch information
onuralpszr committed Nov 8, 2024
1 parent a6e1f03 commit 1aeb573
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 17 deletions.
5 changes: 3 additions & 2 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,9 +840,10 @@ def from_lmm(

if lmm == LMM.PALIGEMMA:
assert isinstance(result, str)
xyxy, class_id, class_name = from_paligemma(result, **kwargs)
xyxy, class_id, class_name, mask = from_paligemma(result, **kwargs)
data = {CLASS_NAME_DATA_FIELD: class_name}
return cls(xyxy=xyxy, class_id=class_id, data=data)
mask = mask if mask is not None else None
return cls(xyxy=xyxy, class_id=class_id, mask=mask, data=data)

if lmm == LMM.FLORENCE_2:
assert isinstance(result, dict)
Expand Down
72 changes: 57 additions & 15 deletions supervision/detection/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,67 @@ def validate_lmm_parameters(

def from_paligemma(
result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]] = None
) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray, Optional[np.ndarray]]:
"""
Parse results from Paligemma model which can contain object detection and segmentation.
Args:
result (str): Model output string containing loc and optional seg tokens
resolution_wh (Tuple[int, int]): Target resolution (width, height)
classes (Optional[List[str]]): List of class names to filter results
Returns:
xyxy (np.ndarray): Bounding box coordinates
class_id (Optional[np.ndarray]): Class IDs if classes provided
class_name (np.ndarray): Class names
mask (Optional[np.ndarray]): Segmentation masks if available
""" # noqa: E501
w, h = resolution_wh
pattern = re.compile(
r"(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s\-]+)"
)
matches = pattern.findall(result)
matches = np.array(matches) if matches else np.empty((0, 5))

xyxy, class_name = matches[:, [1, 0, 3, 2]], matches[:, 4]
xyxy = xyxy.astype(int) / 1024 * np.array([w, h, w, h])
class_name = np.char.strip(class_name.astype(str))
class_id = None
segmentation_pattern = re.compile(
r"<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})>\s*"
+ "".join(r"<seg(\d{3})>" for _ in range(16))
+ r"\s+([\w\s\-]+)"
)

if classes is not None:
mask = np.array([name in classes for name in class_name]).astype(bool)
xyxy, class_name = xyxy[mask], class_name[mask]
class_id = np.array([classes.index(name) for name in class_name])
detection_pattern = re.compile(
r"(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s\-]+)"
)

return xyxy, class_id, class_name
segmentation_matches = segmentation_pattern.findall(result)
if segmentation_matches:
matches = np.array(segmentation_matches)
xyxy = matches[:, [1, 0, 3, 2]].astype(int) / 1024 * np.array([w, h, w, h])
class_name = np.char.strip(matches[:, -1].astype(str))
seg_tokens = matches[:, 4:-1].astype(int)
masks = [np.zeros((h, w), dtype=bool) for tokens in seg_tokens]
masks = np.array(masks)

class_id = None
if classes is not None:
mask = np.array([name in classes for name in class_name]).astype(bool)
xyxy = xyxy[mask]
class_name = class_name[mask]
masks = masks[mask]
class_id = np.array([classes.index(name) for name in class_name])

return xyxy, class_id, class_name, masks

detection_matches = detection_pattern.findall(result)
if detection_matches:
matches = np.array(detection_matches)
xyxy = matches[:, [1, 0, 3, 2]].astype(int) / 1024 * np.array([w, h, w, h])
class_name = np.char.strip(matches[:, 4].astype(str))

class_id = None
if classes is not None:
mask = np.array([name in classes for name in class_name]).astype(bool)
xyxy, class_name = xyxy[mask], class_name[mask]
class_id = np.array([classes.index(name) for name in class_name])

return xyxy, class_id, class_name, None

return np.empty((0, 4)), None, np.array([]), None


def from_florence_2(
Expand Down

0 comments on commit 1aeb573

Please sign in to comment.