diff --git a/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py index f754c38b44..5c4a07670a 100644 --- a/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py @@ -48,6 +48,7 @@ dict( type='InstanceCrop', config_file='mmdet::mask_rcnn/mask-rcnn_x101-32x8d_fpn_ms-poly-3x_coco.py', # noqa + from_pretrained=None, finesize=256, box_num_upbound=5), dict( diff --git a/mmagic/datasets/transforms/crop.py b/mmagic/datasets/transforms/crop.py index 8af6ced157..22dac70ccc 100644 --- a/mmagic/datasets/transforms/crop.py +++ b/mmagic/datasets/transforms/crop.py @@ -958,6 +958,7 @@ class InstanceCrop(BaseTransform): def __init__(self, config_file, + from_pretrained=None, key='img', box_num_upbound=-1, finesize=256): @@ -967,6 +968,11 @@ def __init__(self, "\"mim install 'mmdet >= 3.0.0'\".") cfg = get_config(config_file, pretrained=True) + + # loading checkpoint from local path + if from_pretrained is not None: + cfg.model.backbone.init_cfg.checkpoint = from_pretrained + with DefaultScope.overwrite_default_scope('mmdet'): self.predictor = mmdet_apis.init_detector(cfg, cfg.model_path)