-
Notifications
You must be signed in to change notification settings - Fork 300
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
350 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,350 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "4b92a04e-7143-4eb7-8614-00d19c22568e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Copyright 2022 Google LLC\n", | ||
"#\n", | ||
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n", | ||
"# you may not use this file except in compliance with the License.\n", | ||
"# You may obtain a copy of the License at\n", | ||
"#\n", | ||
"# http://www.apache.org/licenses/LICENSE-2.0\n", | ||
"#\n", | ||
"# Unless required by applicable law or agreed to in writing, software\n", | ||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n", | ||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", | ||
"# See the License for the specific language governing permissions and\n", | ||
"# limitations under the License." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "3b301424-607b-45c4-9786-2ab2fdc4ce0f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from typing import Tuple, Union, Optional, List\n", | ||
"\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"from torch.optim.adamw import AdamW\n", | ||
"from torch.optim.sgd import SGD\n", | ||
"from diffusers import StableDiffusionPipeline, UNet2DConditionModel\n", | ||
"import numpy as np\n", | ||
"from PIL import Image\n", | ||
"from tqdm.notebook import tqdm\n", | ||
"from IPython.display import display, clear_output\n", | ||
"\n", | ||
"T = torch.Tensor\n", | ||
"TN = Optional[T]\n", | ||
"TS = Union[Tuple[T, ...], List[T]]\n", | ||
"\n", | ||
"device = torch.device('cuda:0')\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6fa79d33-f5c9-401e-b06a-0976e0ec48c7", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"## Delta Denoising Score: zero-shot image editing" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "205ef294-fb96-4a8f-86ed-14b3bab27058", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def load_512(image_path: str, left=0, right=0, top=0, bottom=0):\n", | ||
" image = np.array(Image.open(image_path))[:, :, :3] \n", | ||
" h, w, c = image.shape\n", | ||
" left = min(left, w-1)\n", | ||
" right = min(right, w - left - 1)\n", | ||
" top = min(top, h - left - 1)\n", | ||
" bottom = min(bottom, h - top - 1)\n", | ||
" image = image[top:h-bottom, left:w-right]\n", | ||
" h, w, c = image.shape\n", | ||
" if h < w:\n", | ||
" offset = (w - h) // 2\n", | ||
" image = image[:, offset:offset + h]\n", | ||
" elif w < h:\n", | ||
" offset = (h - w) // 2\n", | ||
" image = image[offset:offset + w]\n", | ||
" image = np.array(Image.fromarray(image).resize((512, 512)))\n", | ||
" return image\n", | ||
"\n", | ||
"\n", | ||
"@torch.no_grad()\n", | ||
"def get_text_embeddings(pipe: StableDiffusionPipeline, text: str) -> T:\n", | ||
" tokens = pipe.tokenizer([text], padding=\"max_length\", max_length=77, truncation=True,\n", | ||
" return_tensors=\"pt\", return_overflowing_tokens=True).input_ids.to(device)\n", | ||
" return pipe.text_encoder(tokens).last_hidden_state.detach()\n", | ||
"\n", | ||
"@torch.no_grad()\n", | ||
"def denormalize(image):\n", | ||
" image = (image / 2 + 0.5).clamp(0, 1)\n", | ||
" image = image.cpu().permute(0, 2, 3, 1).numpy()\n", | ||
" image = (image * 255).astype(np.uint8)\n", | ||
" return image[0]\n", | ||
"\n", | ||
"\n", | ||
"@torch.no_grad()\n", | ||
"def decode(latent: T, pipe: StableDiffusionPipeline, im_cat: TN = None):\n", | ||
" image = pipeline.vae.decode((1 / 0.18215) * latent, return_dict=False)[0]\n", | ||
" image = denormalize(image)\n", | ||
" if im_cat is not None:\n", | ||
" image = np.concatenate((im_cat, image), axis=1)\n", | ||
" return Image.fromarray(image)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "7f6419af-2759-47ce-9620-0099b6805ec9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"def init_pipe(device, dtype, unet, scheduler) -> Tuple[UNet2DConditionModel, T, T]:\n", | ||
"\n", | ||
" with torch.inference_mode():\n", | ||
" alphas = torch.sqrt(scheduler.alphas_cumprod).to(device, dtype=dtype)\n", | ||
" sigmas = torch.sqrt(1 - scheduler.alphas_cumprod).to(device, dtype=dtype)\n", | ||
" for p in unet.parameters():\n", | ||
" p.requires_grad = False\n", | ||
" return unet, alphas, sigmas\n", | ||
"\n", | ||
"\n", | ||
"class DDSLoss:\n", | ||
" \n", | ||
" def noise_input(self, z, eps=None, timestep: Optional[int] = None):\n", | ||
" if timestep is None:\n", | ||
" b = z.shape[0]\n", | ||
" timestep = torch.randint(\n", | ||
" low=self.t_min,\n", | ||
" high=min(self.t_max, 1000) - 1, # Avoid the highest timestep.\n", | ||
" size=(b,),\n", | ||
" device=z.device, dtype=torch.long)\n", | ||
" if eps is None:\n", | ||
" eps = torch.randn_like(z)\n", | ||
" alpha_t = self.alphas[timestep, None, None, None]\n", | ||
" sigma_t = self.sigmas[timestep, None, None, None]\n", | ||
" z_t = alpha_t * z + sigma_t * eps\n", | ||
" return z_t, eps, timestep, alpha_t, sigma_t\n", | ||
"\n", | ||
" def get_eps_prediction(self, z_t: T, timestep: T, text_embeddings: T, alpha_t: T, sigma_t: T, get_raw=False,\n", | ||
" guidance_scale=7.5):\n", | ||
"\n", | ||
" latent_input = torch.cat([z_t] * 2)\n", | ||
" timestep = torch.cat([timestep] * 2)\n", | ||
" embedd = text_embeddings.permute(1, 0, 2, 3).reshape(-1, *text_embeddings.shape[2:])\n", | ||
" with torch.autocast(device_type=\"cuda\", dtype=torch.float16):\n", | ||
" e_t = self.unet(latent_input, timestep, embedd).sample\n", | ||
" if self.prediction_type == 'v_prediction':\n", | ||
" e_t = torch.cat([alpha_t] * 2) * e_t + torch.cat([sigma_t] * 2) * latent_input\n", | ||
" e_t_uncond, e_t = e_t.chunk(2)\n", | ||
" if get_raw:\n", | ||
" return e_t_uncond, e_t\n", | ||
" e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)\n", | ||
" assert torch.isfinite(e_t).all()\n", | ||
" if get_raw:\n", | ||
" return e_t\n", | ||
" pred_z0 = (z_t - sigma_t * e_t) / alpha_t\n", | ||
" return e_t, pred_z0\n", | ||
"\n", | ||
" def get_sds_loss(self, z: T, text_embeddings: T, eps: TN = None, mask=None, t=None,\n", | ||
" timestep: Optional[int] = None, guidance_scale=7.5) -> TS:\n", | ||
" with torch.inference_mode():\n", | ||
" z_t, eps, timestep, alpha_t, sigma_t = self.noise_input(z, eps=eps, timestep=timestep)\n", | ||
" e_t, _ = self.get_eps_prediction(z_t, timestep, text_embeddings, alpha_t, sigma_t,\n", | ||
" guidance_scale=guidance_scale)\n", | ||
" grad_z = (alpha_t ** self.alpha_exp) * (sigma_t ** self.sigma_exp) * (e_t - eps)\n", | ||
" assert torch.isfinite(grad_z).all()\n", | ||
" grad_z = torch.nan_to_num(grad_z.detach(), 0.0, 0.0, 0.0)\n", | ||
" if mask is not None:\n", | ||
" grad_z = grad_z * mask\n", | ||
" log_loss = (grad_z ** 2).mean()\n", | ||
" sds_loss = grad_z.clone() * z\n", | ||
" del grad_z\n", | ||
" return sds_loss.sum() / (z.shape[2] * z.shape[3]), log_loss\n", | ||
"\n", | ||
" \n", | ||
" def get_dds_loss(self, z_source: T, z_target: T, text_emb_source: T, text_emb_target: T,\n", | ||
" eps=None, reduction='mean', symmetric: bool = False, calibration_grad=None, timestep: Optional[int] = None,\n", | ||
" guidance_scale=7.5, raw_log=False) -> TS:\n", | ||
" with torch.inference_mode():\n", | ||
" z_t_source, eps, timestep, alpha_t, sigma_t = self.noise_input(z_source, eps, timestep)\n", | ||
" z_t_target, _, _, _, _ = self.noise_input(z_target, eps, timestep)\n", | ||
" eps_pred, _ = self.get_eps_prediction(torch.cat((z_t_source, z_t_target)),\n", | ||
" torch.cat((timestep, timestep)),\n", | ||
" torch.cat((text_emb_source, text_emb_target)),\n", | ||
" torch.cat((alpha_t, alpha_t)),\n", | ||
" torch.cat((sigma_t, sigma_t)),\n", | ||
" guidance_scale=guidance_scale)\n", | ||
" eps_pred_source, eps_pred_target = eps_pred.chunk(2)\n", | ||
" grad = (alpha_t ** self.alpha_exp) * (sigma_t ** self.sigma_exp) * (eps_pred_target - eps_pred_source)\n", | ||
" if calibration_grad is not None:\n", | ||
" if calibration_grad.dim() == 4:\n", | ||
" grad = grad - calibration_grad\n", | ||
" else:\n", | ||
" grad = grad - calibration_grad[timestep - self.t_min]\n", | ||
" if raw_log:\n", | ||
" log_loss = eps.detach().cpu(), eps_pred_target.detach().cpu(), eps_pred_source.detach().cpu()\n", | ||
" else:\n", | ||
" log_loss = (grad ** 2).mean()\n", | ||
" loss = z_target * grad.clone()\n", | ||
" if symmetric:\n", | ||
" loss = loss.sum() / (z_target.shape[2] * z_target.shape[3])\n", | ||
" loss_symm = self.rescale * z_source * (-grad.clone())\n", | ||
" loss += loss_symm.sum() / (z_target.shape[2] * z_target.shape[3])\n", | ||
" elif reduction == 'mean':\n", | ||
" loss = loss.sum() / (z_target.shape[2] * z_target.shape[3])\n", | ||
" return loss, log_loss\n", | ||
"\n", | ||
" def __init__(self, device, pipe: StableDiffusionPipeline, dtype=torch.float32):\n", | ||
" self.t_min = 50\n", | ||
" self.t_max = 950\n", | ||
" self.alpha_exp = 0\n", | ||
" self.sigma_exp = 0\n", | ||
" self.dtype = dtype\n", | ||
" self.unet, self.alphas, self.sigmas = init_pipe(device, dtype, pipe.unet, pipe.scheduler)\n", | ||
" self.prediction_type = pipe.scheduler.prediction_type\n", | ||
"\n", | ||
"\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "63e6569a-549f-4b27-920f-532459f35249", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_id = [\"runwayml/stable-diffusion-v1-5\", \"stabilityai/stable-diffusion-2-1\"][0]\n", | ||
"pipeline = StableDiffusionPipeline.from_pretrained(model_id,).to(device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "9567c4ac-21fe-44ce-8ae1-fb00a10855e1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def image_optimization(pipe: StableDiffusionPipeline, image: np.ndarray, text_source: str, text_target: str, num_iters=200, use_dds=True) -> None:\n", | ||
" dds_loss = DDSLoss(device, pipe)\n", | ||
" image_source = torch.from_numpy(image).float().permute(2, 0, 1) / 127.5 - 1\n", | ||
" image_source = image_source.unsqueeze(0).to(device)\n", | ||
" with torch.no_grad():\n", | ||
" z_source = pipeline.vae.encode(image_source)['latent_dist'].mean * 0.18215\n", | ||
" image_target = image_source.clone()\n", | ||
" embedding_null = get_text_embeddings(pipeline, \"\")\n", | ||
" embedding_text = get_text_embeddings(pipeline, text_source)\n", | ||
" embedding_text_target = get_text_embeddings(pipeline, text_target)\n", | ||
" embedding_source = torch.stack([embedding_null, embedding_text], dim=1)\n", | ||
" embedding_target = torch.stack([embedding_null, embedding_text_target], dim=1)\n", | ||
"\n", | ||
" guidance_scale = 7.5\n", | ||
" image_target.requires_grad = True\n", | ||
" use_dds_loss = True\n", | ||
"\n", | ||
" z_taregt = z_source.clone()\n", | ||
" z_taregt.requires_grad = True\n", | ||
" optimizer = SGD(params=[z_taregt], lr=1e-1)\n", | ||
"\n", | ||
" for i in range(num_iters):\n", | ||
" if use_dds:\n", | ||
" loss, log_loss = dds_loss.get_dds_loss(z_source, z_taregt, embedding_source, embedding_target)\n", | ||
" else:\n", | ||
" loss, log_loss = dds_loss.get_sds_loss(z_taregt, embedding_target)\n", | ||
" optimizer.zero_grad()\n", | ||
" (2000 * loss).backward()\n", | ||
" optimizer.step()\n", | ||
" if (i + 1) % 10 == 0:\n", | ||
" out = decode(z_taregt, pipeline, im_cat=image)\n", | ||
" clear_output(wait=True)\n", | ||
" display(out)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "636ed780-86d6-42d3-9f8c-4e4f1d670cae", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"### SDS image optimization" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c9419231-57e0-48f2-a837-7cebed997af4", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"image = load_512(f\"./example_images/gnochi_mirror.jpeg\")\n", | ||
"image_optimization(pipeline, image, \"a photo of a cat.\", \"a photo of a tiger.\", use_dds=False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "2e1d00d3-d295-4da4-95d8-ecda0c359a0b", | ||
"metadata": {}, | ||
"source": [ | ||
"### DDS image optimization" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b53596e6-32f7-4a6b-b06f-a87cc42662e8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"image = load_512(f\"./example_images/gnochi_mirror.jpeg\")\n", | ||
"image_optimization(pipeline, image, \"a photo of a cat.\", \"a photo of a tiger.\", use_dds=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "90dd901d-8a7a-4f91-be05-3b6e2729cc34", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |