-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
202 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# Modified from thirdparty/mmdetection/demo/image_demo.py | ||
import asyncio | ||
import glob | ||
import os | ||
from argparse import ArgumentParser | ||
|
||
from mmcv import Config | ||
from mmdet.apis import async_inference_detector, inference_detector, show_result_pyplot | ||
|
||
from ssod.apis.inference import init_detector, save_result | ||
from ssod.utils import patch_config | ||
|
||
|
||
def parse_args(): | ||
parser = ArgumentParser() | ||
parser.add_argument("img", help="Image file") | ||
parser.add_argument("config", help="Config file") | ||
parser.add_argument("checkpoint", help="Checkpoint file") | ||
parser.add_argument("--device", default="cuda:0", help="Device used for inference") | ||
parser.add_argument( | ||
"--score-thr", type=float, default=0.3, help="bbox score threshold" | ||
) | ||
parser.add_argument( | ||
"--async-test", | ||
action="store_true", | ||
help="whether to set async options for async inference.", | ||
) | ||
parser.add_argument( | ||
"--output", | ||
type=str, | ||
default=None, | ||
help="specify the directory to save visualization results.", | ||
) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(args): | ||
cfg = Config.fromfile(args.config) | ||
# Not affect anything, just avoid index error | ||
cfg.work_dir = "./work_dirs" | ||
cfg = patch_config(cfg) | ||
# build the model from a config file and a checkpoint file | ||
model = init_detector(cfg, args.checkpoint, device=args.device) | ||
imgs = glob.glob(args.img) | ||
for img in imgs: | ||
# test a single image | ||
result = inference_detector(model, img) | ||
# show the results | ||
if args.output is None: | ||
show_result_pyplot(model, img, result, score_thr=args.score_thr) | ||
else: | ||
out_file_path = os.path.join(args.output, os.path.basename(img)) | ||
print(f"Save results to {out_file_path}") | ||
save_result( | ||
model, img, result, score_thr=args.score_thr, out_file=out_file_path | ||
) | ||
|
||
|
||
async def async_main(args): | ||
cfg = Config.fromfile(args.config) | ||
# Not affect anything, just avoid index error | ||
cfg.work_dir = "./work_dirs" | ||
cfg = patch_config(cfg) | ||
# build the model from a config file and a checkpoint file | ||
model = init_detector(cfg, args.checkpoint, device=args.device) | ||
# test a single image | ||
args.img = glob.glob(args.img) | ||
tasks = asyncio.create_task(async_inference_detector(model, args.img)) | ||
result = await asyncio.gather(tasks) | ||
# show the results | ||
for img, pred in zip(args.img, result): | ||
if args.output is None: | ||
show_result_pyplot(model, img, pred, score_thr=args.score_thr) | ||
else: | ||
out_file_path = os.path.join(args.output, os.path.basename(img)) | ||
print(f"Save results to {out_file_path}") | ||
save_result( | ||
model, img, pred, score_thr=args.score_thr, out_file=out_file_path | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
if args.async_test: | ||
asyncio.run(async_main(args)) | ||
else: | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import warnings | ||
|
||
import mmcv | ||
from mmcv.runner import load_checkpoint | ||
|
||
from mmdet.core import get_classes | ||
from mmdet.models import build_detector | ||
|
||
|
||
def init_detector(config, checkpoint=None, device="cuda:0", cfg_options=None): | ||
"""Initialize a detector from config file. | ||
Args: | ||
config (str or :obj:`mmcv.Config`): Config file path or the config | ||
object. | ||
checkpoint (str, optional): Checkpoint path. If left as None, the model | ||
will not load any weights. | ||
cfg_options (dict): Options to override some settings in the used | ||
config. | ||
Returns: | ||
nn.Module: The constructed detector. | ||
""" | ||
if isinstance(config, str): | ||
config = mmcv.Config.fromfile(config) | ||
elif not isinstance(config, mmcv.Config): | ||
raise TypeError( | ||
"config must be a filename or Config object, " f"but got {type(config)}" | ||
) | ||
if cfg_options is not None: | ||
config.merge_from_dict(cfg_options) | ||
config.model.train_cfg = None | ||
|
||
if hasattr(config.model, "model"): | ||
config.model.model.pretrained = None | ||
config.model.model.train_cfg = None | ||
else: | ||
config.model.pretrained = None | ||
|
||
model = build_detector(config.model, test_cfg=config.get("test_cfg")) | ||
if checkpoint is not None: | ||
map_loc = "cpu" if device == "cpu" else None | ||
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) | ||
if "CLASSES" in checkpoint.get("meta", {}): | ||
model.CLASSES = checkpoint["meta"]["CLASSES"] | ||
else: | ||
warnings.simplefilter("once") | ||
warnings.warn( | ||
"Class names are not saved in the checkpoint's " | ||
"meta data, use COCO classes by default." | ||
) | ||
model.CLASSES = get_classes("coco") | ||
model.cfg = config # save the config in the model for convenience | ||
model.to(device) | ||
model.eval() | ||
return model | ||
|
||
|
||
def save_result(model, img, result, score_thr=0.3, out_file="res.png"): | ||
"""Save the detection results on the image. | ||
Args: | ||
model (nn.Module): The loaded detector. | ||
img (str or np.ndarray): Image filename or loaded image. | ||
result (tuple[list] or list): The detection result, can be either | ||
(bbox, segm) or just bbox. | ||
score_thr (float): The threshold to visualize the bboxes and masks. | ||
out_file (str): Specifies where to save the visualization result | ||
""" | ||
if hasattr(model, "module"): | ||
model = model.module | ||
model.show_result( | ||
img, | ||
result, | ||
score_thr=score_thr, | ||
show=False, | ||
out_file=out_file, | ||
bbox_color=(72, 101, 241), | ||
text_color=(72, 101, 241), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters