Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pipeline_stable_diffusion_3_inpaint.py for SD3 Inference #8709

Merged
merged 17 commits into from
Jul 9, 2024

Conversation

IrohXu
Copy link
Contributor

@IrohXu IrohXu commented Jun 26, 2024

What does this PR do?

This PR support inpaint pipeline in stable diffusion 3. It follows the mask inpainting idea of StableDiffusionInpaintPipeline as it does not need specific weight for inpainting.

We hope this PR can be an initial version for StableDiffusion3InpaintPipeline. It can be replaced by finetuned 33 channel input inpainting weight in the later version.

We put the demo here: DEMO.

How to use it?

import torch
from torchvision import transforms

from diffusers import StableDiffusion3InpaintPipeline
from diffusers.utils import load_image

def preprocess_image(image):
    image = image.convert("RGB")
    image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
    image = transforms.ToTensor()(image)
    image = image.unsqueeze(0).to("cuda")
    return image

def preprocess_mask(mask):
    mask = mask.convert("L")
    mask = transforms.CenterCrop((mask.size[1] // 64 * 64, mask.size[0] // 64 * 64))(mask)
    mask = transforms.ToTensor()(mask)
    mask = mask.to("cuda")
    return mask

pipe = StableDiffusion3InpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    torch_dtype=torch.float16,
).to("cuda")

prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
source_image = load_image(
    "./overture-creations-5sI6fQgYIuo.png"
)
source = preprocess_image(source_image)
mask = preprocess_mask(
    load_image(
        "./overture-creations-5sI6fQgYIuo_mask.png"
    )
)

image = pipe(
    prompt=prompt,
    image=source,
    mask_image=mask,
    height=1024,
    width=1024,
    num_inference_steps=28,
    guidance_scale=7.0,
    strength=0.6,
).images[0]

image.save("output.png")

Image Input:

Mask Input:

Prompt: Face of a yellow cat, high resolution, sitting on a park bench

SD3 output:

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@yiyixuxu

@IrohXu IrohXu changed the title Add pipeline_stable_diffusion_3_inpaint.py [Community pipeline] Add pipeline_stable_diffusion_3_inpaint.py Jun 26, 2024
@IrohXu IrohXu changed the title [Community pipeline] Add pipeline_stable_diffusion_3_inpaint.py [SD3 Inference] Add pipeline_stable_diffusion_3_inpaint.py Jun 26, 2024
@IrohXu IrohXu changed the title [SD3 Inference] Add pipeline_stable_diffusion_3_inpaint.py [Community pipeline] Add pipeline_stable_diffusion_3_inpaint.py for SD3 Inference Jun 26, 2024
@IrohXu
Copy link
Contributor Author

IrohXu commented Jun 27, 2024

Our teammate has implemented Inpaint pipeline for SD3. Could you review this PR? @yiyixuxu @sayakpaul Thanks!

@sayakpaul sayakpaul requested a review from yiyixuxu June 27, 2024 03:21
@sayakpaul sayakpaul changed the title [Community pipeline] Add pipeline_stable_diffusion_3_inpaint.py for SD3 Inference Add pipeline_stable_diffusion_3_inpaint.py for SD3 Inference Jun 27, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR!
my main feedback is instead of follow the deprecated inpaint_legacy pipeline, can we match the SD/SDXL inpaint pipeline?
I think it is ok to only implement for use-case when a regular text-to-image transformer checkpoint is used (this means we should match the algorithm in SD inpaint pipeline when unet only have 4 channels, not 9) , we can refactor later when we have a sd3 inpaint checkpoint.

negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
add_predicted_noise: Optional[bool] = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think we support this in any other pipelines - is it ok to remove it? what's the use case for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just removed it.

noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

# get latents
init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I think we should have same behavior as other inpainting pipelines, where when strength=1, init_latent is pure noise

latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I added the is_strength_max in the pipeline.

init_latents_proper = self.scheduler.scale_noise(
init_latents_orig, torch.tensor([t]), noise_pred_uncond
)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we match the SD and SDXL inpainting pipeline when using the regular unet checkpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I updated it with checking the number of transformer channels before running the denoising process.

Copy link
Collaborator

@yiyixuxu yiyixuxu Jun 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I wasn't clear

When I said "can we match the SD and SDXL inpainting pipeline when using the regular unet checkpoint", I didn't mean that we need to check the number of transformer channels, although it is ok if you add the check.

SD/SDXL inpainting pipelines support both inpainting-specific checkpoints (when num_channels_unet==9) and regular text-to-image checkpoint (when num_channels_unet =4); I think the algorithm and overall code structure fo SD3 Inpaiting pipeline should match very closely with SD/SDXL inpainting, but you can ignore the part of logic in these pipelines that only that applies to inpainting-specific checkpoints.

the current implementation of this pipeline matches the inpainting_legacy, which we deprecated and is slightly different from SDXL and SD, both in code structure and the actual algorithm

@IrohXu
Copy link
Contributor Author

IrohXu commented Jun 27, 2024

Hi @yiyixuxu Can you review the new version? Thanks a lot!

"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)

init_latents_proper = self.scheduler.scale_noise(init_latents_orig, torch.tensor([t]), noise)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can see the algorithm here is different from

e.g. init_latents_proper should contain the same level of noise as latents ( latents here is after taking the secheduler.step so technically it is x_t-1, hence we we passed next timestep to add_noise in SD/SDXL inpaint (here in this version we just use current timestep)

                         noise_timestep = timesteps[i + 1]
                        init_latents_proper = self.scheduler.add_noise(
                            init_latents_proper, noise, torch.tensor([noise_timestep])
                        )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. We update it.

init_latents_proper = self.scheduler.scale_noise(
init_latents_orig, torch.tensor([t]), noise_pred_uncond
)
else:
Copy link
Collaborator

@yiyixuxu yiyixuxu Jun 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I wasn't clear

When I said "can we match the SD and SDXL inpainting pipeline when using the regular unet checkpoint", I didn't mean that we need to check the number of transformer channels, although it is ok if you add the check.

SD/SDXL inpainting pipelines support both inpainting-specific checkpoints (when num_channels_unet==9) and regular text-to-image checkpoint (when num_channels_unet =4); I think the algorithm and overall code structure fo SD3 Inpaiting pipeline should match very closely with SD/SDXL inpainting, but you can ignore the part of logic in these pipelines that only that applies to inpainting-specific checkpoints.

the current implementation of this pipeline matches the inpainting_legacy, which we deprecated and is slightly different from SDXL and SD, both in code structure and the actual algorithm

@George0726
Copy link

Thanks for your PR for SD3 inpainting!
I am currently trying to finetune the inpainting version by 33 channels. Your version raised an Error on inpainting-specific checkpoints.

  1. the masked latent and mask should apply classifier-free guidance.
  2. Mask latent needs to apply VAE "normalize" as well.
    Hence, it would be better to check the conv_in dimension and raise ERROR when the channel of conv_in != 16.

@IrohXu
Copy link
Contributor Author

IrohXu commented Jul 1, 2024

Thanks for your PR for SD3 inpainting! I am currently trying to finetune the inpainting version by 33 channels. Your version raised an Error on inpainting-specific checkpoints.

  1. the masked latent and mask should apply classifier-free guidance.
  2. Mask latent needs to apply VAE "normalize" as well.
    Hence, it would be better to check the conv_in dimension and raise ERROR when the channel of conv_in != 16.

Thanks for the comment. I will update it based on your idea tomorrow. It is great that you will provide checkpoint for inpainting version by 33 channels. Would like me to add your ideas into this PR or creat a new PR for the 33 channels version later?

@IrohXu
Copy link
Contributor Author

IrohXu commented Jul 1, 2024

@yiyixuxu @George0726 Could you review the new update? Thanks. We follow your suggestion on using SDXL and SD inpaint pipeline and also consider @George0726 's idea on keeping the if-else check for 33 channels and implement VAE "normalize" for mask input.

@IrohXu
Copy link
Contributor Author

IrohXu commented Jul 1, 2024

@George0726 Could you try your 33 channel checkpoint in this PR? Thanks! Let me known if you meet any new error.

@George0726
Copy link

George0726 commented Jul 1, 2024

@IrohXu Sure. I have tested and modified some parts to make it work.

  1. SD3 VAE added additional self.vae.config.shift_factor to normalize, you can refer to
    init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
  2. masked_latent should also apply the VAE scaling, referred as
    masked_image_latents = self._encode_vae_image(masked_image, generator=generator)

    Here is my git diff:
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
index c8f05fe25..a5434d90c 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
@@ -677,8 +677,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
         else:
             image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
 
-        image_latents = self.vae.config.scaling_factor * image_latents
 
+        image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
         return image_latents
 
     def prepare_mask_latents(
@@ -711,6 +711,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
         else:
             masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
 
+        masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
         # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
         if mask.shape[0] < batch_size:
             if not batch_size % mask.shape[0] == 0:

Here are some results of my current models:
input and mask:
case5_2_ori
case5_2_mask

prompt: red hair

inpainting by SD3
2000_2_0_SD3

inpainting by SD3-inpainting
14000_2_3

inpainting by SD3-inpainting with modified codes
14000_2_1

@IrohXu
Copy link
Contributor Author

IrohXu commented Jul 1, 2024

@George0726 Thanks a lot! I have added your improved codes into this PR.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jul 1, 2024

@a-r-r-o-w can you give this a review too if you have time?

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work, this looks great! Just a few small requests that need addressing

def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
mask_image = torch.ones((1, 1, 32, 32)).to(device)
image = image / 2 + 0.5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu I'm curious why we do this here and in the other SD3 tests. floats_tensors returns values in [0, 1]. This statement makes the image tensors have values in range [0.5, 1]. Shouldn't the input images be in range [-1, 1] (unless I've missed something SD3-specific)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu should I modify it now? or wait for you create a new PR to fix all issues in SD3 test cases?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yea it is a mistake here
maybe you can fix it for this test and then we fix it everywhere else in a separate PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just fixed it.

num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit that can be addressed in another PR for all SD3 and similar torch.FloatTensor hints: #7535.

is_strength_max = strength == 1.0

# 5. Preprocess mask and image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

`tuple`. When returning a tuple, the first element is a list with the generated images.
"""

callback = kwargs.pop("callback", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, these callbacks were deprecated a while back, no? @yiyixuxu. We can remove them here if that's the case

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we don't need to accept "callback" for new pipeline!

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these can be removed

@IrohXu
Copy link
Contributor Author

IrohXu commented Jul 3, 2024

@a-r-r-o-w Thanks for the comments. I have updated the code based on your suggestions. For this first issue, I think we should wait for @yiyixuxu reply to us. It might be solved by another PR I think as all SD3 pipelines have it.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! looking great!
I think we can merge this soon

`tuple`. When returning a tuple, the first element is a list with the generated images.
"""

callback = kwargs.pop("callback", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we don't need to accept "callback" for new pipeline!

callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs,

Comment on lines 917 to 918
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)

# 3. Preprocess image and mask
image = self.image_processor.preprocess(image, height, width)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we processing the image twice? we did it again in line 1008, no?


if not output_type == "latent":
condition_kwargs = {}
if isinstance(self.vae, AsymmetricAutoencoderKL):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this? I don't think SD3 works with this vae, no?

def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
mask_image = torch.ones((1, 1, 32, 32)).to(device)
image = image / 2 + 0.5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yea it is a mistake here
maybe you can fix it for this test and then we fix it everywhere else in a separate PR?

@IrohXu
Copy link
Contributor Author

IrohXu commented Jul 3, 2024

@yiyixuxu @a-r-r-o-w I have updated it based on your comments. Thanks a lot!

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jul 3, 2024

@IrohXu can you run make style and make fix-copies?

@IrohXu
Copy link
Contributor Author

IrohXu commented Jul 3, 2024

@IrohXu can you run make style and make fix-copies?

@yiyixuxu I just run them and push the code again.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jul 3, 2024

@George0726 let us know when your checkpoints are ready :)

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jul 3, 2024

@IrohXu can you make sure the tests pass? the new inpainting tests are failing here

@George0726
Copy link

@George0726 let us know when your checkpoints are ready :)

I have uploaded the alpha version of SD3 inpainting model. I don't have enough data and GPUs for large-scale pre-training.
https://huggingface.co/George0667/SD3_inpaint_alpha/tree/main

@IrohXu
Copy link
Contributor Author

IrohXu commented Jul 8, 2024

@IrohXu can you make sure the tests pass? the new inpainting tests are failing here

@yiyixuxu I tested it locally in my machine today, it seems it can pass all failed cases. Do you know how this error appear differently in different test environment?

RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0

Here is my log:

============================= test session starts ==============================
platform linux -- Python 3.10.0, pytest-8.2.2, pluggy-1.5.0 -- /home/xucao2/miniconda3/envs/sd_train/bin/python
cachedir: .pytest_cache
rootdir: /home/xucao2/diffusers
configfile: pyproject.toml
plugins: requests-mock-1.10.0, xdist-3.6.1, timeout-2.3.1
created: 32/32 workers
32 workers [30 items]

scheduling tests via LoadFileScheduling

tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_StableDiffusionMixin_component 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_StableDiffusionMixin_component 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_attention_slicing_forward_pass 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_attention_slicing_forward_pass 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_callback_cfg 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_callback_cfg 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_callback_inputs 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_callback_inputs 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_cfg 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_cfg 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_components_function 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_components_function 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_cpu_offload_forward_pass_twice 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_cpu_offload_forward_pass_twice 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_dict_tuple_outputs_equivalent 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_dict_tuple_outputs_equivalent 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_float16_inference 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_float16_inference 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_inference_batch_consistent Token indices sequence length is longer than the specified maximum sequence length for this model (402 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>', 'ongvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery long']
Token indices sequence length is longer than the specified maximum sequence length for this model (402 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>', 'ongvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery long']
The following part of your input was truncated because `max_sequence_length` is set to  256 tokens: ['<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', 'longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery longvery long']

[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_inference_batch_consistent 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_inference_batch_single_identical Token indices sequence length is longer than the specified maximum sequence length for this model (402 > 77). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (402 > 77). Running this sequence through the model will result in indexing errors

[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_inference_batch_single_identical 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_latents_input 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_latents_input 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_model_cpu_offload_forward_pass 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_model_cpu_offload_forward_pass 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_multi_vae 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_multi_vae 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_num_images_per_prompt 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_num_images_per_prompt 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_pipeline_call_signature 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_pipeline_call_signature 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_progress_bar 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_progress_bar 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_pt_np_pil_inputs_equivalent 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_pt_np_pil_inputs_equivalent 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_pt_np_pil_outputs_equivalent 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_pt_np_pil_outputs_equivalent 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_save_load_float16 
Loading pipeline components...:   0%|          | 0/9 [00:00<?, ?it/s]
Loading pipeline components...:  11%|█         | 1/9 [00:00<00:01,  7.78it/s]
Loading pipeline components...:  56%|█████▌    | 5/9 [00:00<00:00, 21.23it/s]
Loading pipeline components...:  89%|████████▉ | 8/9 [00:00<00:00, 22.20it/s]
Loading pipeline components...: 100%|██████████| 9/9 [00:00<00:00, 23.09it/s]

[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_save_load_float16 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_save_load_local 
Loading pipeline components...:   0%|          | 0/9 [00:00<?, ?it/s]Loaded text_encoder_2 as CLIPTextModelWithProjection from `text_encoder_2` subfolder of /tmp/tmpdyb_z201.

Loading pipeline components...:  11%|█         | 1/9 [00:00<00:01,  7.61it/s]Loaded scheduler as FlowMatchEulerDiscreteScheduler from `scheduler` subfolder of /tmp/tmpdyb_z201.
Loaded tokenizer_3 as T5TokenizerFast from `tokenizer_3` subfolder of /tmp/tmpdyb_z201.
An error occurred while trying to fetch /tmp/tmpdyb_z201/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /tmp/tmpdyb_z201/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loaded vae as AutoencoderKL from `vae` subfolder of /tmp/tmpdyb_z201.
Loaded text_encoder as CLIPTextModelWithProjection from `text_encoder` subfolder of /tmp/tmpdyb_z201.

Loading pipeline components...:  56%|█████▌    | 5/9 [00:00<00:00, 20.03it/s]Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of /tmp/tmpdyb_z201.
An error occurred while trying to fetch /tmp/tmpdyb_z201/transformer: Error no file named diffusion_pytorch_model.safetensors found in directory /tmp/tmpdyb_z201/transformer.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loaded transformer as SD3Transformer2DModel from `transformer` subfolder of /tmp/tmpdyb_z201.
Loaded text_encoder_3 as T5EncoderModel from `text_encoder_3` subfolder of /tmp/tmpdyb_z201.

Loading pipeline components...:  89%|████████▉ | 8/9 [00:00<00:00, 21.17it/s]Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of /tmp/tmpdyb_z201.

Loading pipeline components...: 100%|██████████| 9/9 [00:00<00:00, 22.05it/s]

[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_save_load_local 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_save_load_optional_components 
Loading pipeline components...:   0%|          | 0/9 [00:00<?, ?it/s]Loaded text_encoder_2 as CLIPTextModelWithProjection from `text_encoder_2` subfolder of /tmp/tmptksec1h_.

Loading pipeline components...:  11%|█         | 1/9 [00:00<00:01,  7.64it/s]Loaded scheduler as FlowMatchEulerDiscreteScheduler from `scheduler` subfolder of /tmp/tmptksec1h_.
Loaded tokenizer_3 as T5TokenizerFast from `tokenizer_3` subfolder of /tmp/tmptksec1h_.
An error occurred while trying to fetch /tmp/tmptksec1h_/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /tmp/tmptksec1h_/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loaded vae as AutoencoderKL from `vae` subfolder of /tmp/tmptksec1h_.
Loaded text_encoder as CLIPTextModelWithProjection from `text_encoder` subfolder of /tmp/tmptksec1h_.

Loading pipeline components...:  56%|█████▌    | 5/9 [00:00<00:00, 20.29it/s]Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of /tmp/tmptksec1h_.
An error occurred while trying to fetch /tmp/tmptksec1h_/transformer: Error no file named diffusion_pytorch_model.safetensors found in directory /tmp/tmptksec1h_/transformer.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loaded transformer as SD3Transformer2DModel from `transformer` subfolder of /tmp/tmptksec1h_.
Loaded text_encoder_3 as T5EncoderModel from `text_encoder_3` subfolder of /tmp/tmptksec1h_.

Loading pipeline components...:  89%|████████▉ | 8/9 [00:00<00:00, 21.24it/s]Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of /tmp/tmptksec1h_.

Loading pipeline components...: 100%|██████████| 9/9 [00:00<00:00, 22.15it/s]

[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_save_load_optional_components 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_sequential_cpu_offload_forward_pass 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_sequential_cpu_offload_forward_pass 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_sequential_offload_forward_pass_twice 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_sequential_offload_forward_pass_twice 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_stable_diffusion_3_inpaint_different_negative_prompts 
  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 260.00it/s]

  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 268.14it/s]

[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_stable_diffusion_3_inpaint_different_negative_prompts 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_stable_diffusion_3_inpaint_different_prompts 
  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 268.14it/s]

  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 269.70it/s]

[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_stable_diffusion_3_inpaint_different_prompts 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_stable_diffusion_3_inpaint_prompt_embeds 
  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 269.07it/s]

  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 262.92it/s]

[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_stable_diffusion_3_inpaint_prompt_embeds 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_to_device 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_to_device 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_to_dtype 
[gw0] PASSED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_to_dtype 
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_xformers_attention_forwardGenerator_pass 
[gw0] SKIPPED tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_xformers_attention_forwardGenerator_pass 

=============================== warnings summary ===============================
src/diffusers/models/transformers/transformer_2d.py:34: 32 warnings
  /home/xucao2/diffusers/src/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.
    deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)

tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py: 22 warnings
  /home/xucao2/diffusers/src/diffusers/configuration_utils.py:140: FutureWarning: Accessing config attribute `vae_latent_channels` directly via 'VaeImageProcessor' object attribute is deprecated. Please access 'vae_latent_channels' over 'VaeImageProcessor's config object instead, e.g. 'scheduler.config.vae_latent_channels'.
    deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)

tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_inference_batch_consistent
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_inference_batch_single_identical
tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py::StableDiffusion3InpaintPipelineFastTests::test_num_images_per_prompt
  /home/xucao2/diffusers/src/diffusers/image_processor.py:528: FutureWarning: Passing `image` as a list of 4d torch.Tensor is deprecated.Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================= 29 passed, 1 skipped, 57 warnings in 21.43s ==================

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rebased the branch, can you git pull and test it again to make sure everything works?
we had this PR that may affect inpaint pipeline #8678

f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(batch_size * num_inference_steps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
latent_timestep = timesteps[:1].repeat(batch_size * num_inference_steps)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

@yiyixuxu yiyixuxu merged commit 35cc66d into huggingface:main Jul 9, 2024
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants