Skip to content

Commit

Permalink
feat: add option to color bounding boxes by category or instance
Browse files Browse the repository at this point in the history
  • Loading branch information
CVHub520 committed Oct 15, 2024
1 parent 0b07dc4 commit ba719c5
Showing 1 changed file with 47 additions and 11 deletions.
58 changes: 47 additions & 11 deletions tools/label_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def draw_polygon_from_custom(
save_box=True,
save_label=True,
keep_ori_fn=False,
color_level="category",
):
"""
Draws masks on images from custom dataset annotations and saves the annotated images.
Expand All @@ -105,6 +106,7 @@ def draw_polygon_from_custom(
save_box (bool): Whether to draw bounding boxes around masks.
save_label (bool): Whether to annotate masks with class labels.
keep_ori_fn (bool): If True, keeps the original filename; otherwise, uses a frame index-based naming.
color_level (str): "category" or "instance", whether to color the boxes by category or by instance.
Raises:
FileNotFoundError: If the specified image or label file does not exist.
Expand Down Expand Up @@ -167,13 +169,16 @@ def draw_polygon_from_custom(

# Collect polygons, XYXY coordinates, and class indices
xyxy_list, mask_list, cind_list = [], [], []
for shape in data["shapes"]:
for i, shape in enumerate(data["shapes"]):
if (
shape["shape_type"] != "polygon"
or shape["label"] not in classes
):
continue
label_id = classes.index(shape["label"])
if color_level == "category":
label_id = classes.index(shape["label"])
else:
label_id = i
cind_list.append(label_id)
points = np.array(shape["points"], dtype=np.int32)
xyxy_list.append(sv.polygon_to_xyxy(polygon=points))
Expand Down Expand Up @@ -226,6 +231,7 @@ def draw_rectangle_from_custom(
classes=[],
save_label=True,
keep_ori_fn=False,
color_level="category",
):
"""
Draws horizontal bounding boxes on images from custom rectangle annotations and saves the annotated images.
Expand All @@ -238,6 +244,7 @@ def draw_rectangle_from_custom(
classes (list[str]): List of class names to consider for annotation.
save_label (bool): Whether to annotate boxes with class labels.
keep_ori_fn (bool): If True, keeps the original filename; otherwise, uses a frame index-based naming.
color_level (str): "category" or "instance", whether to color the boxes by category or by instance.
Raises:
FileNotFoundError: If the specified image or label file does not exist.
Expand Down Expand Up @@ -297,24 +304,39 @@ def draw_rectangle_from_custom(

# Collect bounding box coordinates and class indices
xyxy_list, cind_list = [], []
for shape in data["shapes"]:
for i, shape in enumerate(data["shapes"]):
if (
shape["shape_type"] != "rectangle"
or shape["label"] not in classes
):
continue
label_id = classes.index(shape["label"])
if color_level == "category":
label_id = classes.index(shape["label"])
else:
label_id = i
cind_list.append(label_id)
points = shape["points"]
xyxy = (
np.array(points[:2] + points[-2:], dtype=np.int32)
if len(points) == 4
else np.array(points, dtype=np.int32)
)
if len(points) == 2:
# If there are only two points, assume they are diagonal points
x1, y1 = points[0]
x2, y2 = points[1]
xyxy = np.array([x1, y1, x2, y2], dtype=np.float32)
elif len(points) == 4:
# If there are four points, take the top-left and bottom-right points
xyxy = np.array([
min(p[0] for p in points),
min(p[1] for p in points),
max(p[0] for p in points),
max(p[1] for p in points)
], dtype=np.float32)
else:
print(f"Warning: Skipping invalid rectangle: {points}")
continue
xyxy_list.append(xyxy)

# If no rectangles found, save the original image and continue
if not xyxy_list:
print(f"No rectangles found for image: {image_file}")
cv2.imwrite(os.path.join(save_dir, save_name), image)
continue

Expand Down Expand Up @@ -347,6 +369,7 @@ def draw_rotation_from_custom(
classes=[],
save_label=True,
keep_ori_fn=False,
color_level="category",
):
"""
Draws oriented bounding boxes on images from custom rotation annotations and saves the annotated images.
Expand All @@ -359,6 +382,7 @@ def draw_rotation_from_custom(
classes (list[str]): List of class names to consider for annotation.
save_label (bool): Whether to annotate boxes with class labels.
keep_ori_fn (bool): If True, keeps the original filename; otherwise, uses a frame index-based naming.
color_level (str): "category" or "instance", whether to color the boxes by category or by instance.
Raises:
FileNotFoundError: If the specified image or label file does not exist.
Expand Down Expand Up @@ -418,13 +442,16 @@ def draw_rotation_from_custom(

# Collect bounding box coordinates and class indices
xyxyxyxy_list, xyxy_list, cind_list = [], [], []
for shape in data["shapes"]:
for i, shape in enumerate(data["shapes"]):
if (
shape["shape_type"] != "rotation"
or shape["label"] not in classes
):
continue
label_id = classes.index(shape["label"])
if color_level == "category":
label_id = classes.index(shape["label"])
else:
label_id = i
cind_list.append(label_id)
points = shape["points"]
xyxy = sv.polygon_to_xyxy(polygon=points)
Expand Down Expand Up @@ -499,6 +526,12 @@ def main():
action="store_true",
help="Whether to keep original filename",
)
parser.add_argument(
"--color_level",
choices=["category", "instance"],
default="category",
help="Color level for boxes",
)

args = parser.parse_args()

Expand All @@ -523,6 +556,7 @@ def main():
args.save_box,
args.save_label,
args.keep_ori_fn,
args.color_level,
)
elif args.task == "rectangle":
draw_rectangle_from_custom(
Expand All @@ -532,6 +566,7 @@ def main():
args.classes,
args.save_label,
args.keep_ori_fn,
args.color_level,
)
elif args.task == "rotation":
draw_rotation_from_custom(
Expand All @@ -541,6 +576,7 @@ def main():
args.classes,
args.save_label,
args.keep_ori_fn,
args.color_level,
)


Expand Down

0 comments on commit ba719c5

Please sign in to comment.