Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Class configuration file for defining which classes to group during nms #117

Merged
merged 10 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions class_config/colored_robots.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
group_classes:
- robot_red
- robot_blue
- robot_unknown

surrogate_class: robot
2 changes: 2 additions & 0 deletions class_config/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
group_classes:
surrogate_class: ""
89 changes: 52 additions & 37 deletions yoeo/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from torch.utils.data import DataLoader
from torch.autograd import Variable

from typing import Optional, List
from typing import Optional

from imgaug.augmentables.segmaps import SegmentationMapsOnImage

from yoeo.models import load_model
from yoeo.utils.utils import load_classes, rescale_boxes, non_max_suppression, print_environment_info, rescale_segmentation
from yoeo.utils.class_config import ClassConfig
from yoeo.utils.dataclasses import ClassNames, GroupConfig
from yoeo.utils.utils import rescale_boxes, non_max_suppression, print_environment_info, rescale_segmentation
from yoeo.utils.datasets import ImageFolder
from yoeo.utils.transforms import Resize, DEFAULT_TRANSFORMS

Expand All @@ -26,9 +28,9 @@
from matplotlib.ticker import NullLocator


def detect_directory(model_path, weights_path, img_path, classes, output_path,
def detect_directory(model_path, weights_path, img_path, class_config: ClassConfig, output_path,
Flova marked this conversation as resolved.
Show resolved Hide resolved
batch_size=8, img_size=416, n_cpu=8, conf_thres=0.5, nms_thres=0.5,
robot_class_ids: Optional[List[int]] = None):
):
"""Detects objects on all images in specified directory and saves output images with drawn detections.

:param model_path: Path to model definition file (.cfg)
Expand All @@ -37,8 +39,8 @@
:type weights_path: str
:param img_path: Path to directory with images to inference
:type img_path: str
:param classes: List of class names
:type classes: [str]
:param class_config: Class configuration
:type class_config: ClassConfig
:param output_path: Path to output directory
:type output_path: str
:param batch_size: Size of each image batch, defaults to 8
Expand All @@ -51,8 +53,6 @@
:type conf_thres: float, optional
:param nms_thres: IOU threshold for non-maximum suppression, defaults to 0.5
:type nms_thres: float, optional
:param robot_class_ids: List of class IDs of robot classes if multiple robot classes exist.
:type robot_class_ids: List[int], optional
"""
dataloader = _create_data_loader(img_path, batch_size, img_size, n_cpu)
model = load_model(model_path, weights_path)
Expand All @@ -63,30 +63,37 @@
output_path,
conf_thres,
nms_thres,
robot_class_ids=robot_class_ids
class_config.get_group_config()
)
_draw_and_save_output_images(
img_detections, segmentations, imgs, img_size, output_path, classes)
img_detections, segmentations, imgs, img_size, output_path, class_config.get_ungrouped_det_class_names())

print(f"---- Detections were saved to: '{output_path}' ----")


def detect_image(model, image, img_size=416, conf_thres=0.5, nms_thres=0.5, robot_class_ids: Optional[List[int]] = None):
def detect_image(model,
image: np.ndarray,
img_size: int = 416,
conf_thres: float = 0.5,
nms_thres: float = 0.5,
group_config: Optional[GroupConfig] = None
Flova marked this conversation as resolved.
Show resolved Hide resolved
):
"""Inferences one image with model.

:param model: Model for inference
:type model: models.Darknet
:param image: Image to inference
:type image: nd.array
:type image: np.ndarray
:param img_size: Size of each image dimension for yolo, defaults to 416
:type img_size: int, optional
:type img_size: int
:param conf_thres: Object confidence threshold, defaults to 0.5
:type conf_thres: float, optional
:type conf_thres: float
:param nms_thres: IOU threshold for non-maximum suppression, defaults to 0.5
:type nms_thres: float, optional
:param robot_class_ids: List of class IDs of robot classes if multiple robot classes exist.
:type robot_class_ids: List[int], optional
:type nms_thres: float
:param group_config: GroupConfiguration for this model (optional, defaults to None)
:type group_config: Optional[GroupConfig]

:return: Detections on image with each detection in the format: [x1, y1, x2, y2, confidence, class], Segmentation as 2d numpy array with the coresponding class id in each cell

Check failure on line 96 in yoeo/detect.py

View workflow job for this annotation

GitHub Actions / linter

line too long (179 > 150 characters)
:rtype: nd.array, nd.array
"""
model.eval() # Set model to evaluation mode
Expand All @@ -105,13 +112,24 @@
# Get detections
with torch.no_grad():
detections, segmentations = model(input_img)
detections = non_max_suppression(detections, conf_thres, nms_thres, robot_class_ids=robot_class_ids)
detections = non_max_suppression(
prediction=detections,
conf_thres=conf_thres,
iou_thres=nms_thres,
group_config=group_config
)
detections = rescale_boxes(detections[0], img_size, image.shape[0:2])
segmentations = rescale_segmentation(segmentations, image.shape[0:2])
return detections.numpy(), segmentations.cpu().detach().numpy()


def detect(model, dataloader, output_path, conf_thres, nms_thres, robot_class_ids: Optional[List[int]] = None):
def detect(model,
dataloader: DataLoader,
output_path: str,
conf_thres: float = 0.5,
nms_thres: float = 0.5,
group_config: Optional[GroupConfig] = None
):
"""Inferences images with model.

:param model: Model for inference
Expand All @@ -121,11 +139,12 @@
:param output_path: Path to output directory
:type output_path: str
:param conf_thres: Object confidence threshold, defaults to 0.5
:type conf_thres: float, optional
:type conf_thres: float
:param nms_thres: IOU threshold for non-maximum suppression, defaults to 0.5
:type nms_thres: float, optional
:param robot_class_ids: List of class IDs of robot classes if multiple robot classes exist.
:type robot_class_ids: List[int], optional
:type nms_thres: float
:param group_config: GroupConfig for this model (optional, defaults to None)
:type group_config: Optional[GroupConfig]

:return: List of detections. The coordinates are given for the padded image that is provided by the dataloader.
Use `utils.rescale_boxes` to transform them into the desired input image coordinate system before its transformed by the dataloader),
List of input image paths
Expand All @@ -149,7 +168,12 @@
# Get detections
with torch.no_grad():
detections, segmentations = model(input_imgs)
detections = non_max_suppression(detections, conf_thres, nms_thres, robot_class_ids=robot_class_ids)
detections = non_max_suppression(
prediction=detections,
conf_thres=conf_thres,
iou_thres=nms_thres,
group_config=group_config
)

# Store image and detections
img_detections.extend(detections)
Expand Down Expand Up @@ -310,33 +334,24 @@
parser.add_argument("--n_cpu", type=int, default=8, help="Number of cpu threads to use during batch generation")
parser.add_argument("--conf_thres", type=float, default=0.5, help="Object confidence threshold")
parser.add_argument("--nms_thres", type=float, default=0.4, help="IOU threshold for non-maximum suppression")
parser.add_argument("--multiple_robot_classes", action="store_true",
help="If multiple robot classes exist and nms shall be performed across all robot classes")
parser.add_argument("--class_config", type=str, default="class_config/default.yaml", help="Class configuration for evaluation")
args = parser.parse_args()
print(f"Command line arguments: {args}")

# Extract class names from file
classes = load_classes(args.classes)['detection'] # List of class names

robot_class_ids = None
if args.multiple_robot_classes:
robot_class_ids = []
for idx, c in enumerate(classes):
if "robot" in c:
robot_class_ids.append(idx)
class_names = ClassNames.load_from(args.classes)
class_config = ClassConfig.load_from(args.class_config, class_names)

detect_directory(
args.model,
args.weights,
args.images,
classes,
class_config,
args.output,
batch_size=args.batch_size,
img_size=args.img_size,
n_cpu=args.n_cpu,
conf_thres=args.conf_thres,
nms_thres=args.nms_thres,
robot_class_ids=robot_class_ids
)


Expand Down
Loading
Loading