Skip to content

Commit

Permalink
Merge branch 'main' into fixinf
Browse files Browse the repository at this point in the history
  • Loading branch information
liuwenran committed Sep 11, 2023
2 parents afcb9da + 6fda2cc commit 8870d9f
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 26 deletions.
5 changes: 5 additions & 0 deletions demo/mmagic_inference_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def parse_args():
type=int,
default=None,
help='Pretrained mmagic algorithm setting')
parser.add_argument(
'--config-name',
type=str,
default=None,
help='Pretrained mmagic algorithm config name')
parser.add_argument(
'--model-config',
type=str,
Expand Down
6 changes: 5 additions & 1 deletion demo/mmagic_inference_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,13 @@
"\n",
"There are some different configs and checkpoints for one model.\n",
"\n",
"You could configure different settings by passing 'model_setting' to 'MMagicInferencer'. Every model's default setting is 0.\n",
"\n",
"Take conditional GAN model 'biggan' as an example. We have pretrained model for Cifar and Imagenet, and all pretrained models of 'biggan' are listed in its [metafile.yaml](../configs/biggan/metafile.yml)\n",
"\n",
"You could configure different settings by passing 'model_setting' to 'MMagicInferencer'. Every model's default setting is 0."
"There are six settings in this metafile. If you choose setting 1, then the config 'configs/biggan/biggan_ajbrock-sn_8xb32-1500kiters_imagenet1k-128x128.py' will be used. If 'model_setting' is not passed to 'MMagicInferencer', the config ‘configs/biggan/biggan_2xb25-500kiters_cifar10-32x32.py’ will be used by default.\n",
"\n",
"And you could also use 'config_name' to replace 'model_setting'. For example, you can init a MMagicInferencer with 'MMagicInferencer('biggan', config_name='biggan_2xb25-500kiters_cifar10-32x32')', which is the same with 'MMagicInferencer('biggan', model_setting=0)'."
]
},
{
Expand Down
5 changes: 3 additions & 2 deletions mmagic/apis/inferencers/diffusers_pipeline_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DiffusersPipelineInferencer(BaseMMagicInferencer):
postprocess=[])

def preprocess(self,
text: InputsType,
text: InputsType = None,
negative_prompt: InputsType = None,
num_inference_steps: int = 20,
height=None,
Expand All @@ -37,7 +37,8 @@ def preprocess(self,
result(Dict): Results of preprocess.
"""
result = self.extra_parameters
result['prompt'] = text
if text:
result['prompt'] = text
if negative_prompt:
result['negative_prompt'] = negative_prompt
if num_inference_steps:
Expand Down
9 changes: 8 additions & 1 deletion mmagic/apis/mmagic_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class MMagicInferencer:
def __init__(self,
model_name: str = None,
model_setting: int = None,
config_name: int = None,
model_config: str = None,
model_ckpt: str = None,
device: torch.device = None,
Expand All @@ -140,14 +141,15 @@ def __init__(self,
MMagicInferencer.init_inference_supported_models_cfg()
inferencer_kwargs = {}
inferencer_kwargs.update(
self._get_inferencer_kwargs(model_name, model_setting,
self._get_inferencer_kwargs(model_name, model_setting, config_name,
model_config, model_ckpt,
extra_parameters))
self.inferencer = Inferencers(
device=device, seed=seed, **inferencer_kwargs)

def _get_inferencer_kwargs(self, model_name: Optional[str],
model_setting: Optional[int],
config_name: Optional[int],
model_config: Optional[str],
model_ckpt: Optional[str],
extra_parameters: Optional[Dict]) -> Dict:
Expand All @@ -161,6 +163,11 @@ def _get_inferencer_kwargs(self, model_name: Optional[str],
if model_setting:
setting_to_use = model_setting
config_dir = cfgs['settings'][setting_to_use]['Config']
if config_name:
for setting in cfgs['settings']:
if setting['Name'] == config_name:
config_dir = setting['Config']
break
config_dir = config_dir[config_dir.find('configs'):]
if osp.exists(
osp.join(osp.dirname(__file__), '..', '..', config_dir)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import platform

import pytest
import torch
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

Expand All @@ -21,29 +20,11 @@
def test_diffusers_pipeline_inferencer():
cfg = dict(
model=dict(
type='DiffusionPipeline',
from_pretrained='runwayml/stable-diffusion-v1-5'))
type='DiffusionPipeline', from_pretrained='google/ddpm-cat-256'))

inferencer_instance = DiffusersPipelineInferencer(cfg, None)

def mock_encode_prompt(prompt, do_classifier_free_guidance,
num_images_per_prompt, *args, **kwargs):
batch_size = len(prompt) if isinstance(prompt, list) else 1
batch_size *= num_images_per_prompt
if do_classifier_free_guidance:
batch_size *= 2
return torch.randn(batch_size, 5, 16) # 2 for cfg

inferencer_instance.model._encode_prompt = mock_encode_prompt

text_prompts = 'Japanese anime style, girl'
negative_prompt = 'bad face, bad hands'
result = inferencer_instance(
text=text_prompts,
negative_prompt=negative_prompt,
height=64,
width=64)
assert result[1][0].size == (64, 64)
result = inferencer_instance()
assert result[1][0].size == (256, 256)


def teardown_module():
Expand Down

0 comments on commit 8870d9f

Please sign in to comment.