diff --git a/demo/mmagic_inference_demo.py b/demo/mmagic_inference_demo.py index 26b08a10da..0c3a17ce12 100644 --- a/demo/mmagic_inference_demo.py +++ b/demo/mmagic_inference_demo.py @@ -43,6 +43,11 @@ def parse_args(): type=int, default=None, help='Pretrained mmagic algorithm setting') + parser.add_argument( + '--config-name', + type=str, + default=None, + help='Pretrained mmagic algorithm config name') parser.add_argument( '--model-config', type=str, diff --git a/demo/mmagic_inference_tutorial.ipynb b/demo/mmagic_inference_tutorial.ipynb index 7ca106d4e0..fa29d2654d 100644 --- a/demo/mmagic_inference_tutorial.ipynb +++ b/demo/mmagic_inference_tutorial.ipynb @@ -404,9 +404,13 @@ "\n", "There are some different configs and checkpoints for one model.\n", "\n", + "You could configure different settings by passing 'model_setting' to 'MMagicInferencer'. Every model's default setting is 0.\n", + "\n", "Take conditional GAN model 'biggan' as an example. We have pretrained model for Cifar and Imagenet, and all pretrained models of 'biggan' are listed in its [metafile.yaml](../configs/biggan/metafile.yml)\n", "\n", - "You could configure different settings by passing 'model_setting' to 'MMagicInferencer'. Every model's default setting is 0." + "There are six settings in this metafile. If you choose setting 1, then the config 'configs/biggan/biggan_ajbrock-sn_8xb32-1500kiters_imagenet1k-128x128.py' will be used. If 'model_setting' is not passed to 'MMagicInferencer', the config ‘configs/biggan/biggan_2xb25-500kiters_cifar10-32x32.py’ will be used by default.\n", + "\n", + "And you could also use 'config_name' to replace 'model_setting'. For example, you can init a MMagicInferencer with 'MMagicInferencer('biggan', config_name='biggan_2xb25-500kiters_cifar10-32x32')', which is the same with 'MMagicInferencer('biggan', model_setting=0)'." ] }, { diff --git a/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py index 2143cd64db..573c0750f6 100644 --- a/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py +++ b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py @@ -22,7 +22,7 @@ class DiffusersPipelineInferencer(BaseMMagicInferencer): postprocess=[]) def preprocess(self, - text: InputsType, + text: InputsType = None, negative_prompt: InputsType = None, num_inference_steps: int = 20, height=None, @@ -37,7 +37,8 @@ def preprocess(self, result(Dict): Results of preprocess. """ result = self.extra_parameters - result['prompt'] = text + if text: + result['prompt'] = text if negative_prompt: result['negative_prompt'] = negative_prompt if num_inference_steps: diff --git a/mmagic/apis/mmagic_inferencer.py b/mmagic/apis/mmagic_inferencer.py index bad7e82df0..5cb9e4cb50 100644 --- a/mmagic/apis/mmagic_inferencer.py +++ b/mmagic/apis/mmagic_inferencer.py @@ -130,6 +130,7 @@ class MMagicInferencer: def __init__(self, model_name: str = None, model_setting: int = None, + config_name: int = None, model_config: str = None, model_ckpt: str = None, device: torch.device = None, @@ -140,7 +141,7 @@ def __init__(self, MMagicInferencer.init_inference_supported_models_cfg() inferencer_kwargs = {} inferencer_kwargs.update( - self._get_inferencer_kwargs(model_name, model_setting, + self._get_inferencer_kwargs(model_name, model_setting, config_name, model_config, model_ckpt, extra_parameters)) self.inferencer = Inferencers( @@ -148,6 +149,7 @@ def __init__(self, def _get_inferencer_kwargs(self, model_name: Optional[str], model_setting: Optional[int], + config_name: Optional[int], model_config: Optional[str], model_ckpt: Optional[str], extra_parameters: Optional[Dict]) -> Dict: @@ -161,6 +163,11 @@ def _get_inferencer_kwargs(self, model_name: Optional[str], if model_setting: setting_to_use = model_setting config_dir = cfgs['settings'][setting_to_use]['Config'] + if config_name: + for setting in cfgs['settings']: + if setting['Name'] == config_name: + config_dir = setting['Config'] + break config_dir = config_dir[config_dir.find('configs'):] if osp.exists( osp.join(osp.dirname(__file__), '..', '..', config_dir)): diff --git a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py index dbfeb5c6e8..2e9285a7a0 100644 --- a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py +++ b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py @@ -2,7 +2,6 @@ import platform import pytest -import torch from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -21,29 +20,11 @@ def test_diffusers_pipeline_inferencer(): cfg = dict( model=dict( - type='DiffusionPipeline', - from_pretrained='runwayml/stable-diffusion-v1-5')) + type='DiffusionPipeline', from_pretrained='google/ddpm-cat-256')) inferencer_instance = DiffusersPipelineInferencer(cfg, None) - - def mock_encode_prompt(prompt, do_classifier_free_guidance, - num_images_per_prompt, *args, **kwargs): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - batch_size *= num_images_per_prompt - if do_classifier_free_guidance: - batch_size *= 2 - return torch.randn(batch_size, 5, 16) # 2 for cfg - - inferencer_instance.model._encode_prompt = mock_encode_prompt - - text_prompts = 'Japanese anime style, girl' - negative_prompt = 'bad face, bad hands' - result = inferencer_instance( - text=text_prompts, - negative_prompt=negative_prompt, - height=64, - width=64) - assert result[1][0].size == (64, 64) + result = inferencer_instance() + assert result[1][0].size == (256, 256) def teardown_module():