diff --git a/demo/Diffusion/models.py b/demo/Diffusion/models.py index ffeb6cfd..763e96d0 100644 --- a/demo/Diffusion/models.py +++ b/demo/Diffusion/models.py @@ -232,9 +232,23 @@ def get_dicts(self, } else: - # Otherwise, we're dealing with the old format. - warn_message = "You have saved the LoRA weights using the old format. To convert LoRA weights to the new format, first load them in a dictionary and then create a new dictionary as follows: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." - print(warn_message) + # Otherwise, convert old LoRA format to the new format. + self.state_dict[path] = {f'unet.{module_name}': params for module_name, params in self.state_dict[path].items()} + keys = list(self.state_dict[path].keys()) + if all(key.startswith(('unet', 'text_encoder')) for key in keys): + keys = [k for k in keys if k.startswith(prefix)] + if keys: + print(f"Processing {prefix} LoRA: {path}") + state_dict[path] = {k.replace(f"{prefix}.", ""): v for k, v in self.state_dict[path].items() if k in keys} + + if path in self.network_alphas: + if self.network_alphas[path]: + alpha_keys = [k for k in self.network_alphas[path].keys() if k.startswith(prefix)] + network_alphas[path] = { + k.replace(f"{prefix}.", ""): v for k, v in self.network_alphas[path].items() if k in alpha_keys + } + else: + network_alphas[path] = None return state_dict, network_alphas