Skip to content

Commit

Permalink
add config name
Browse files Browse the repository at this point in the history
  • Loading branch information
liuwenran committed Sep 11, 2023
1 parent c4fe495 commit 8928e30
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
5 changes: 5 additions & 0 deletions demo/mmagic_inference_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def parse_args():
type=int,
default=None,
help='Pretrained mmagic algorithm setting')
parser.add_argument(
'--config-name',
type=str,
default=None,
help='Pretrained mmagic algorithm config name')
parser.add_argument(
'--model-config',
type=str,
Expand Down
6 changes: 5 additions & 1 deletion demo/mmagic_inference_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,13 @@
"\n",
"There are some different configs and checkpoints for one model.\n",
"\n",
"You could configure different settings by passing 'model_setting' to 'MMagicInferencer'. Every model's default setting is 0.\n",
"\n",
"Take conditional GAN model 'biggan' as an example. We have pretrained model for Cifar and Imagenet, and all pretrained models of 'biggan' are listed in its [metafile.yaml](../configs/biggan/metafile.yml)\n",
"\n",
"You could configure different settings by passing 'model_setting' to 'MMagicInferencer'. Every model's default setting is 0."
"There are six settings in this metafile. If you choose setting 1, then the config 'configs/biggan/biggan_ajbrock-sn_8xb32-1500kiters_imagenet1k-128x128.py' will be used. If 'model_setting' is not passed to 'MMagicInferencer', the config ‘configs/biggan/biggan_2xb25-500kiters_cifar10-32x32.py’ will be used by default.\n",
"\n",
"And you could also use 'config_name' to replace 'model_setting'. For example, you can init a MMagicInferencer with 'MMagicInferencer('biggan', config_name='biggan_2xb25-500kiters_cifar10-32x32')', which is the same with 'MMagicInferencer('biggan', model_setting=0)'."
]
},
{
Expand Down
11 changes: 9 additions & 2 deletions mmagic/apis/mmagic_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class MMagicInferencer:
def __init__(self,
model_name: str = None,
model_setting: int = None,
config_name: int = None,
model_config: str = None,
model_ckpt: str = None,
device: torch.device = None,
Expand All @@ -141,13 +142,14 @@ def __init__(self,
inferencer_kwargs = {}
inferencer_kwargs.update(
self._get_inferencer_kwargs(model_name, model_setting,
model_config, model_ckpt,
extra_parameters))
config_name, model_config,
model_ckpt, extra_parameters))
self.inferencer = Inferencers(
device=device, seed=seed, **inferencer_kwargs)

def _get_inferencer_kwargs(self, model_name: Optional[str],
model_setting: Optional[int],
config_name: Optional[int],
model_config: Optional[str],
model_ckpt: Optional[str],
extra_parameters: Optional[Dict]) -> Dict:
Expand All @@ -161,6 +163,11 @@ def _get_inferencer_kwargs(self, model_name: Optional[str],
if model_setting:
setting_to_use = model_setting
config_dir = cfgs['settings'][setting_to_use]['Config']
if config_name:
for setting in cfgs['settings']:
if setting['Name'] == config_name:
config_dir = setting['Config']
break
config_dir = config_dir[config_dir.find('configs'):]
if osp.exists(
osp.join(osp.dirname(__file__), '..', '..', config_dir)):
Expand Down

0 comments on commit 8928e30

Please sign in to comment.