From f523fb37b912566742778ea41ae7af0fdfe5d904 Mon Sep 17 00:00:00 2001 From: thegenerativegeneration <> Date: Mon, 12 Jun 2023 13:48:37 +0200 Subject: [PATCH] fix safety checker --- visual_chatgpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/visual_chatgpt.py b/visual_chatgpt.py index 17dc4def..dad4cb8a 100644 --- a/visual_chatgpt.py +++ b/visual_chatgpt.py @@ -1175,7 +1175,7 @@ def __init__(self, device): self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32 self.inpaint = StableDiffusionInpaintPipeline.from_pretrained( - "runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained('CompVis/stable-diffusion-safety-checker')).to(device) + "runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained('CompVis/stable-diffusion-safety-checker', torch_dtype=self.torch_dtype)).to(device) def __call__(self, prompt, image, mask_image, height=512, width=512, num_inference_steps=50): update_image = self.inpaint(prompt=prompt, image=image.resize((width, height)), mask_image=mask_image.resize((width, height)), height=height, width=width, num_inference_steps=num_inference_steps).images[0]