From 80f912043146ad5383b7c4f7dfb80d88a403a7d8 Mon Sep 17 00:00:00 2001 From: YanxingLiu <42299757+YanxingLiu@users.noreply.github.com> Date: Thu, 19 Oct 2023 06:26:32 -0500 Subject: [PATCH] [CodeCamp2023-645]Add dreambooth new cfg (#2042) * new config of dreambooth * add dreambooth mmagic new_config * fix import name bug --------- Co-authored-by: YanxingLiu Co-authored-by: rangoliu --- .../dreambooth-finetune_text_encoder.py | 95 ++++++++++++++++++ .../dreambooth/dreambooth-prior_pre.py | 8 ++ mmagic/configs/dreambooth/dreambooth.py | 93 ++++++++++++++++++ .../dreambooth/dreambooth_lora-prior_pre.py | 7 ++ mmagic/configs/dreambooth/dreambooth_lora.py | 97 +++++++++++++++++++ 5 files changed, 300 insertions(+) create mode 100644 mmagic/configs/dreambooth/dreambooth-finetune_text_encoder.py create mode 100644 mmagic/configs/dreambooth/dreambooth-prior_pre.py create mode 100644 mmagic/configs/dreambooth/dreambooth.py create mode 100644 mmagic/configs/dreambooth/dreambooth_lora-prior_pre.py create mode 100644 mmagic/configs/dreambooth/dreambooth_lora.py diff --git a/mmagic/configs/dreambooth/dreambooth-finetune_text_encoder.py b/mmagic/configs/dreambooth/dreambooth-finetune_text_encoder.py new file mode 100644 index 000000000..0a09f8e70 --- /dev/null +++ b/mmagic/configs/dreambooth/dreambooth-finetune_text_encoder.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from .._base_.gen_default_runtime import * + +from mmengine.dataset.sampler import InfiniteSampler +from torch.optim import AdamW + +from mmagic.datasets.dreambooth_dataset import DreamBoothDataset +from mmagic.datasets.transforms.aug_shape import Resize +from mmagic.datasets.transforms.formatting import PackInputs +from mmagic.datasets.transforms.loading import LoadImageFromFile +from mmagic.engine import VisualizationHook +from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor +from mmagic.models.editors.disco_diffusion.clip_wrapper import ClipWrapper +from mmagic.models.editors.dreambooth import DreamBooth + +# config for model +stable_diffusion_v15_url = 'runwayml/stable-diffusion-v1-5' + +val_prompts = [ + 'a sks dog in basket', 'a sks dog on the mountain', + 'a sks dog beside a swimming pool', 'a sks dog on the desk', + 'a sleeping sks dog', 'a screaming sks dog', 'a man in the garden' +] + +model = dict( + type=DreamBooth, + vae=dict( + type='AutoencoderKL', + from_pretrained=stable_diffusion_v15_url, + subfolder='vae'), + unet=dict( + type='UNet2DConditionModel', + from_pretrained=stable_diffusion_v15_url, + subfolder='unet', + ), + text_encoder=dict( + type=ClipWrapper, + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_v15_url, + subfolder='text_encoder'), + tokenizer=stable_diffusion_v15_url, + finetune_text_encoder=True, + scheduler=dict( + type='DDPMScheduler', + from_pretrained=stable_diffusion_v15_url, + subfolder='scheduler'), + test_scheduler=dict( + type='DDIMScheduler', + from_pretrained=stable_diffusion_v15_url, + subfolder='scheduler'), + data_preprocessor=dict(type=DataPreprocessor), + val_prompts=val_prompts) + +train_cfg = dict(max_iters=1000) + +optim_wrapper.update( + modules='.*unet', + optimizer=dict(type=AdamW, lr=5e-6), + accumulative_counts=4 # batch size = 4 * 1 = 4 +) + +pipeline = [ + dict(type=LoadImageFromFile, key='img', channel_order='rgb'), + dict(type=Resize, scale=(512, 512)), + dict(type=PackInputs) +] + +dataset = dict( + type=DreamBoothDataset, + data_root='./data/dreambooth', + concept_dir='imgs', + prompt='a photo of sks dog', + pipeline=pipeline) +train_dataloader = dict( + dataset=dataset, + num_workers=16, + sampler=dict(type=InfiniteSampler, shuffle=True), + persistent_workers=True, + batch_size=1) +val_cfg = val_evaluator = val_dataloader = None +test_cfg = test_evaluator = test_dataloader = None + +# hooks +default_hooks.update(dict(logger=dict(interval=10))) +custom_hooks = [ + dict( + type=VisualizationHook, + interval=50, + fixed_input=True, + vis_kwargs_list=dict(type='Data', name='fake_img'), + n_samples=1) +] diff --git a/mmagic/configs/dreambooth/dreambooth-prior_pre.py b/mmagic/configs/dreambooth/dreambooth-prior_pre.py new file mode 100644 index 000000000..3531dee99 --- /dev/null +++ b/mmagic/configs/dreambooth/dreambooth-prior_pre.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from .dreambooth import * + +# config for model +model.update(dict(prior_loss_weight=1, class_prior_prompt='a dog')) diff --git a/mmagic/configs/dreambooth/dreambooth.py b/mmagic/configs/dreambooth/dreambooth.py new file mode 100644 index 000000000..1b3d27e57 --- /dev/null +++ b/mmagic/configs/dreambooth/dreambooth.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from .._base_.gen_default_runtime import * + +from mmengine.dataset.sampler import InfiniteSampler +from torch.optim import AdamW + +from mmagic.datasets.dreambooth_dataset import DreamBoothDataset +from mmagic.datasets.transforms.aug_shape import Resize +from mmagic.datasets.transforms.formatting import PackInputs +from mmagic.datasets.transforms.loading import LoadImageFromFile +from mmagic.engine import VisualizationHook +from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor +from mmagic.models.editors.disco_diffusion.clip_wrapper import ClipWrapper +from mmagic.models.editors.dreambooth import DreamBooth + +stable_diffusion_v15_url = 'runwayml/stable-diffusion-v1-5' + +val_prompts = [ + 'a sks dog in basket', 'a sks dog on the mountain', + 'a sks dog beside a swimming pool', 'a sks dog on the desk', + 'a sleeping sks dog', 'a screaming sks dog', 'a man in the garden' +] + +model = dict( + type=DreamBooth, + vae=dict( + type='AutoencoderKL', + from_pretrained=stable_diffusion_v15_url, + subfolder='vae'), + unet=dict( + type='UNet2DConditionModel', + from_pretrained=stable_diffusion_v15_url, + subfolder='unet', + ), + text_encoder=dict( + type=ClipWrapper, + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_v15_url, + subfolder='text_encoder'), + tokenizer=stable_diffusion_v15_url, + scheduler=dict( + type='DDPMScheduler', + from_pretrained=stable_diffusion_v15_url, + subfolder='scheduler'), + test_scheduler=dict( + type='DDIMScheduler', + from_pretrained=stable_diffusion_v15_url, + subfolder='scheduler'), + data_preprocessor=dict(type=DataPreprocessor), + val_prompts=val_prompts) + +train_cfg = dict(max_iters=1000) + +optim_wrapper.update( + modules='.*unet', + optimizer=dict(type=AdamW, lr=5e-6), + accumulative_counts=4 # batch size = 4 * 1 = 4 +) + +pipeline = [ + dict(type=LoadImageFromFile, key='img', channel_order='rgb'), + dict(type=Resize, scale=(512, 512)), + dict(type=PackInputs) +] + +dataset = dict( + type=DreamBoothDataset, + data_root='./data/dreambooth', + concept_dir='imgs', + prompt='a photo of sks dog', + pipeline=pipeline) +train_dataloader = dict( + dataset=dataset, + num_workers=16, + sampler=dict(type=InfiniteSampler, shuffle=True), + persistent_workers=True, + batch_size=1) +val_cfg = val_evaluator = val_dataloader = None +test_cfg = test_evaluator = test_dataloader = None + +# hooks +default_hooks.update(dict(logger=dict(interval=10))) +custom_hooks = [ + dict( + type=VisualizationHook, + interval=50, + fixed_input=True, + vis_kwargs_list=dict(type='Data', name='fake_img'), + n_samples=1) +] diff --git a/mmagic/configs/dreambooth/dreambooth_lora-prior_pre.py b/mmagic/configs/dreambooth/dreambooth_lora-prior_pre.py new file mode 100644 index 000000000..62c5ce9fc --- /dev/null +++ b/mmagic/configs/dreambooth/dreambooth_lora-prior_pre.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from .dreambooth_lora import * + +model.update(dict(prior_loss_weight=1, class_prior_prompt='a dog')) diff --git a/mmagic/configs/dreambooth/dreambooth_lora.py b/mmagic/configs/dreambooth/dreambooth_lora.py new file mode 100644 index 000000000..6670ea26b --- /dev/null +++ b/mmagic/configs/dreambooth/dreambooth_lora.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from .._base_.gen_default_runtime import * + +from mmengine.dataset.sampler import InfiniteSampler +from torch.optim import AdamW + +from mmagic.datasets.dreambooth_dataset import DreamBoothDataset +from mmagic.datasets.transforms.aug_shape import Resize +from mmagic.datasets.transforms.formatting import PackInputs +from mmagic.datasets.transforms.loading import LoadImageFromFile +from mmagic.engine import VisualizationHook +from mmagic.models.data_preprocessors.data_preprocessor import DataPreprocessor +from mmagic.models.editors.disco_diffusion.clip_wrapper import ClipWrapper +from mmagic.models.editors.dreambooth import DreamBooth + +stable_diffusion_v15_url = 'runwayml/stable-diffusion-v1-5' + +val_prompts = [ + 'a sks dog in basket', 'a sks dog on the mountain', + 'a sks dog beside a swimming pool', 'a sks dog on the desk', + 'a sleeping sks dog', 'a screaming sks dog', 'a man in the garden' +] +lora_config = dict(target_modules=['to_q', 'to_k', 'to_v']) + +model = dict( + type=DreamBooth, + vae=dict( + type='AutoencoderKL', + from_pretrained=stable_diffusion_v15_url, + subfolder='vae'), + unet=dict( + type='UNet2DConditionModel', + from_pretrained=stable_diffusion_v15_url, + subfolder='unet', + ), + text_encoder=dict( + type=ClipWrapper, + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_v15_url, + subfolder='text_encoder'), + tokenizer=stable_diffusion_v15_url, + scheduler=dict( + type='DDPMScheduler', + from_pretrained=stable_diffusion_v15_url, + subfolder='scheduler'), + test_scheduler=dict( + type='DDIMScheduler', + from_pretrained=stable_diffusion_v15_url, + subfolder='scheduler'), + data_preprocessor=dict(type=DataPreprocessor), + prior_loss_weight=0, + val_prompts=val_prompts, + lora_config=lora_config) + +train_cfg = dict(max_iters=1000) + +optim_wrapper = dict( + # Only optimize LoRA mappings + modules='.*.lora_mapping', + # NOTE: lr should be larger than dreambooth finetuning + optimizer=dict(type=AdamW, lr=5e-4), + accumulative_counts=1) + +pipeline = [ + dict(type=LoadImageFromFile, key='img', channel_order='rgb'), + dict(type=Resize, scale=(512, 512)), + dict(type=PackInputs) +] +dataset = dict( + type=DreamBoothDataset, + data_root='./data/dreambooth', + # TODO: rename to instance + concept_dir='imgs', + prompt='a photo of sks dog', + pipeline=pipeline) +train_dataloader = dict( + dataset=dataset, + num_workers=16, + sampler=dict(type=InfiniteSampler, shuffle=True), + persistent_workers=True, + batch_size=1) +val_cfg = val_evaluator = val_dataloader = None +test_cfg = test_evaluator = test_dataloader = None + +# hooks +default_hooks.update(dict(logger=dict(interval=10))) +custom_hooks = [ + dict( + type=VisualizationHook, + interval=50, + fixed_input=True, + vis_kwargs_list=dict(type='Data', name='fake_img'), + n_samples=1) +]