diff --git a/demo/mmagic_inference_demo.py b/demo/mmagic_inference_demo.py index 26b08a10da..0c3a17ce12 100644 --- a/demo/mmagic_inference_demo.py +++ b/demo/mmagic_inference_demo.py @@ -43,6 +43,11 @@ def parse_args(): type=int, default=None, help='Pretrained mmagic algorithm setting') + parser.add_argument( + '--config-name', + type=str, + default=None, + help='Pretrained mmagic algorithm config name') parser.add_argument( '--model-config', type=str, diff --git a/demo/mmagic_inference_tutorial.ipynb b/demo/mmagic_inference_tutorial.ipynb index 7ca106d4e0..fa29d2654d 100644 --- a/demo/mmagic_inference_tutorial.ipynb +++ b/demo/mmagic_inference_tutorial.ipynb @@ -404,9 +404,13 @@ "\n", "There are some different configs and checkpoints for one model.\n", "\n", + "You could configure different settings by passing 'model_setting' to 'MMagicInferencer'. Every model's default setting is 0.\n", + "\n", "Take conditional GAN model 'biggan' as an example. We have pretrained model for Cifar and Imagenet, and all pretrained models of 'biggan' are listed in its [metafile.yaml](../configs/biggan/metafile.yml)\n", "\n", - "You could configure different settings by passing 'model_setting' to 'MMagicInferencer'. Every model's default setting is 0." + "There are six settings in this metafile. If you choose setting 1, then the config 'configs/biggan/biggan_ajbrock-sn_8xb32-1500kiters_imagenet1k-128x128.py' will be used. If 'model_setting' is not passed to 'MMagicInferencer', the config ‘configs/biggan/biggan_2xb25-500kiters_cifar10-32x32.py’ will be used by default.\n", + "\n", + "And you could also use 'config_name' to replace 'model_setting'. For example, you can init a MMagicInferencer with 'MMagicInferencer('biggan', config_name='biggan_2xb25-500kiters_cifar10-32x32')', which is the same with 'MMagicInferencer('biggan', model_setting=0)'." ] }, { diff --git a/mmagic/apis/mmagic_inferencer.py b/mmagic/apis/mmagic_inferencer.py index bad7e82df0..11392c32f1 100644 --- a/mmagic/apis/mmagic_inferencer.py +++ b/mmagic/apis/mmagic_inferencer.py @@ -130,6 +130,7 @@ class MMagicInferencer: def __init__(self, model_name: str = None, model_setting: int = None, + config_name: int = None, model_config: str = None, model_ckpt: str = None, device: torch.device = None, @@ -141,13 +142,14 @@ def __init__(self, inferencer_kwargs = {} inferencer_kwargs.update( self._get_inferencer_kwargs(model_name, model_setting, - model_config, model_ckpt, - extra_parameters)) + config_name, model_config, + model_ckpt, extra_parameters)) self.inferencer = Inferencers( device=device, seed=seed, **inferencer_kwargs) def _get_inferencer_kwargs(self, model_name: Optional[str], model_setting: Optional[int], + config_name: Optional[int], model_config: Optional[str], model_ckpt: Optional[str], extra_parameters: Optional[Dict]) -> Dict: @@ -161,6 +163,11 @@ def _get_inferencer_kwargs(self, model_name: Optional[str], if model_setting: setting_to_use = model_setting config_dir = cfgs['settings'][setting_to_use]['Config'] + if config_name: + for setting in cfgs['settings']: + if setting['Name'] == config_name: + config_dir = setting['Config'] + break config_dir = config_dir[config_dir.find('configs'):] if osp.exists( osp.join(osp.dirname(__file__), '..', '..', config_dir)):