Spaces:
Sleeping
Sleeping
| import torch | |
| from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler | |
| from tqdm.auto import tqdm | |
| from huggingface_hub import hf_hub_url, login, HfApi, create_repo | |
| import os | |
| import traceback | |
| from peft import PeftModel | |
| import gradio as gr | |
| def display_image(image): | |
| """Display the generated image.""" | |
| return image | |
| def load_and_merge_lora(base_model_id, lora_id, lora_adapter_name): | |
| try: | |
| pipe = DiffusionPipeline.from_pretrained( | |
| base_model_id, | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| use_safetensors=True, | |
| ).to("cuda") | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
| pipe.scheduler.config | |
| ) | |
| # Get the UNet model from the pipeline | |
| unet = pipe.unet | |
| # Apply PEFT to the UNet model | |
| unet = PeftModel.from_pretrained( | |
| unet, | |
| lora_id, | |
| torch_dtype=torch.float16, | |
| adapter_name=lora_adapter_name | |
| ) | |
| # Replace the original UNet in the pipeline with the PEFT-loaded one | |
| pipe.unet = unet | |
| print("LoRA merged successfully!") | |
| return pipe | |
| except Exception as e: | |
| error_msg = traceback.format_exc() | |
| print(f"Error merging LoRA: {e}\n\nFull traceback saved to errors.txt") | |
| with open("errors.txt", "w") as f: | |
| f.write(error_msg) | |
| return None | |
| def save_merged_model(pipe, save_path, push_to_hub=False, hf_token=None): | |
| """Saves and optionally pushes the merged model to Hugging Face Hub.""" | |
| try: | |
| pipe.save_pretrained(save_path) | |
| print(f"Merged model saved successfully to: {save_path}") | |
| if push_to_hub: | |
| if hf_token is None: | |
| hf_token = input("Enter your Hugging Face write token: ") | |
| login(token=hf_token) | |
| repo_name = input("Enter the Hugging Face repository name " | |
| "(e.g., your_username/your_model_name): ") | |
| # Create the repository if it doesn't exist | |
| create_repo(repo_name, token=hf_token, exist_ok=True) | |
| api = HfApi() | |
| api.upload_folder( | |
| folder_path=save_path, | |
| repo_id=repo_name, | |
| token=hf_token, | |
| repo_type="model", | |
| ) | |
| print(f"Model pushed successfully to Hugging Face Hub: {repo_name}") | |
| except Exception as e: | |
| print(f"Error saving/pushing the merged model: {e}") | |
| def generate_and_save(base_model_id, lora_id, lora_adapter_name, prompt, lora_scale, save_path, push_to_hub, hf_token): | |
| pipe = load_and_merge_lora(base_model_id, lora_id, lora_adapter_name) | |
| if pipe: | |
| lora_scale = float(lora_scale) | |
| image = pipe( | |
| prompt, | |
| num_inference_steps=30, | |
| cross_attention_kwargs={"scale": lora_scale}, | |
| generator=torch.manual_seed(0) | |
| ).images[0] | |
| image.save("generated_image.png") | |
| print(f"Image saved to: generated_image.png") | |
| save_merged_model(pipe, save_path, push_to_hub, hf_token) | |
| return image, "Image generated and model saved/pushed (if selected)." | |
| iface = gr.Interface( | |
| fn=generate_and_save, | |
| inputs=[ | |
| gr.Textbox(label="Base Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)"), | |
| gr.Textbox(label="LoRA ID (e.g., your_username/your_lora)"), | |
| gr.Textbox(label="LoRA Adapter Name"), | |
| gr.Textbox(label="Prompt"), | |
| gr.Slider(label="LoRA Scale", minimum=0.0, maximum=1.0, value=0.7, step=0.1), | |
| gr.Textbox(label="Save Path"), | |
| gr.Checkbox(label="Push to Hugging Face Hub"), | |
| gr.Textbox(label="Hugging Face Write Token", type="password") | |
| ], | |
| outputs=[ | |
| gr.Image(label="Generated Image"), | |
| gr.Textbox(label="Status") | |
| ], | |
| title="LoRA Merger and Image Generator", | |
| description="Merge a LoRA with a base Stable Diffusion model and generate images." | |
| ) | |
| iface.launch() |