akhaliq's picture
akhaliq HF Staff
Update app.py
ccb37d8 verified
raw
history blame
19.2 kB
import gradio as gr
import numpy as np
import torch, random, json, spaces
from ulid import ULID
from diffsynth.pipelines.qwen_image import (
QwenImagePipeline, ModelConfig,
QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode
)
from safetensors.torch import save_file
from PIL import Image
# from utils import repo_utils, image_utils, prompt_utils
# repo_utils.clone_repo_if_not_exists("git clone https://huggingface.co/DiffSynth-Studio/General-Image-Encoders", "app/repos")
# repo_utils.clone_repo_if_not_exists("https://huggingface.co/apple/starflow", "app/models")
URL_PUBLIC = "https://huggingface.co/spaces/AiSudo/Qwen-Image-to-LoRA/blob/main"
DTYPE = torch.bfloat16
MAX_SEED = np.iinfo(np.int32).max
vram_config_disk_offload = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
# Load models (LoRA encoder/decoder)
pipe_lora = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(
download_source="huggingface",
model_id="DiffSynth-Studio/General-Image-Encoders",
origin_file_pattern="SigLIP2-G384/model.safetensors",
**vram_config_disk_offload,
),
ModelConfig(
download_source="huggingface",
model_id="DiffSynth-Studio/General-Image-Encoders",
origin_file_pattern="DINOv3-7B/model.safetensors",
**vram_config_disk_offload,
),
ModelConfig(
download_source="huggingface",
model_id="DiffSynth-Studio/Qwen-Image-i2L",
origin_file_pattern="Qwen-Image-i2L-Style.safetensors",
**vram_config_disk_offload,
),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.bfloat16,
"onload_device": "cuda",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
# Load image generation pipeline
pipe_imagen = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(download_source="huggingface", model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(download_source="huggingface", model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
ModelConfig(download_source="huggingface", model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(download_source="huggingface", model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
def read_file(path: str) -> str:
with open(path, "r", encoding="utf-8") as f:
return f.read()
def show_user(profile: gr.OAuthProfile | None):
"""
Displays who is logged in (or nothing if not logged in).
Works in Spaces OAuth; locally uses hf auth login if available.
"""
if profile is None:
return ""
username = getattr(profile, "username", None) or "unknown"
return f"✅ Signed in as **{username}**"
@spaces.GPU
def generate_lora(
input_images,
profile: gr.OAuthProfile | None = None,
oauth_token: gr.OAuthToken | None = None,
progress=gr.Progress(track_tqdm=True),
):
"""
- Always generates and saves LoRA locally under ./loras/
- If user is signed in (OAuth), also uploads to the user's *own* Hub repo.
"""
import os
from huggingface_hub import HfApi
ulid = str(ULID()).lower()[:12]
print(f"ulid: {ulid}")
if not input_images:
return (
"",
gr.update(value="⚠️ Please upload at least 1 image."),
gr.update(interactive=False),
gr.update(interactive=False, link=""),
)
# Gradio Gallery returns list of (filepath, metadata)
input_images = [Image.open(filepath).convert("RGB") for filepath, _ in input_images]
# Model inference
with torch.no_grad():
embs = QwenImageUnit_Image2LoRAEncode().process(pipe_lora, image2lora_images=input_images)
lora = QwenImageUnit_Image2LoRADecode().process(pipe_lora, **embs)["lora"]
lora_name = f"{ulid}.safetensors"
os.makedirs("loras", exist_ok=True)
lora_path = f"loras/{lora_name}"
save_file(lora, lora_path)
# Default: local-only message (still lets user generate images from local LoRA)
hub_url = ""
hub_markdown = "✅ LoRA generated locally. Sign in to upload it to your Hugging Face account."
# Upload to the signed-in user's own account if available
if profile is not None and oauth_token is not None and getattr(oauth_token, "token", None):
try:
username = getattr(profile, "username", None) or ""
if not username:
raise ValueError("Could not read username from OAuth profile.")
api = HfApi(token=oauth_token.token)
# Create / reuse a user repo (model repo recommended for LoRAs)
# Change name if you want:
repo_id = f"{username}/qwen-image-loras"
api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
api.upload_file(
path_or_fileobj=lora_path,
path_in_repo=f"loras/{lora_name}",
repo_id=repo_id,
repo_type="model",
commit_message=f"Add LoRA: {lora_name}",
)
hub_url = f"https://huggingface.co/{repo_id}/blob/main/loras/{lora_name}"
hub_markdown = f"✅ **Uploaded to your account:** {hub_url}"
except Exception as e:
print(f"Error uploading to user repo: {e}")
hub_markdown = f"⚠️ Upload failed (still saved locally): `{str(e)}`"
# Enable image generation button regardless
return (
lora_name,
gr.update(value=hub_markdown),
gr.update(interactive=True),
gr.update(interactive=bool(hub_url), link=hub_url),
)
@spaces.GPU
def generate_image(
lora_name,
prompt,
negative_prompt="blurry ugly bad",
width=1024,
height=1024,
seed=42,
randomize_seed=True,
guidance_scale=3.5,
num_inference_steps=8,
progress=gr.Progress(track_tqdm=True),
):
if not lora_name:
return None, seed
lora_path = f"loras/{lora_name}"
pipe_imagen.clear_lora()
pipe_imagen.load_lora(pipe_imagen.dit, lora_path)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# generator = torch.Generator().manual_seed(seed)
output_image = pipe_imagen(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
width=width,
height=height,
# generator=generator,
# true_cfg_scale=guidance_scale,
# guidance_scale=1.0 # Use a fixed default for distilled guidance
)
return output_image, seed
# Enhanced Apple-style CSS - more minimalist and clean
css = """
/* Pure Apple Design System */
.gradio-container {
font-family: -apple-system, BlinkMacSystemFont, "SF Pro Display", "SF Pro Text", "Helvetica Neue", Helvetica, Arial, sans-serif !important;
background: #ffffff !important;
color: #1d1d1f !important;
line-height: 1.47059 !important;
font-weight: 400 !important;
letter-spacing: -.022em !important;
}
#col-container {
margin: 0 auto;
max-width: 980px;
padding: 40px 20px;
}
/* Ultra-minimalist header */
.gradio-container .gr-block-header {
background: transparent !important;
border: none !important;
padding: 0 !important;
margin-bottom: 60px !important;
box-shadow: none !important;
}
.gradio-container h1 {
font-weight: 600 !important;
font-size: 3rem !important;
color: #1d1d1f !important;
text-align: center !important;
margin-bottom: 16px !important;
letter-spacing: -.003em !important;
}
.gradio-container .subtitle {
font-size: 1.25rem !important;
font-weight: 400 !important;
color: #6e6e73 !important;
text-align: center !important;
margin-bottom: 8px !important;
line-height: 1.4 !important;
}
/* Clean card sections */
.section-card {
background: #f2f2f7 !important;
border-radius: 18px !important;
padding: 32px !important;
margin-bottom: 32px !important;
border: none !important;
box-shadow: none !important;
}
/* Apple-style buttons */
.gradio-container .gr-button {
background: #007aff !important;
border: none !important;
border-radius: 8px !important;
color: white !important;
font-weight: 500 !important;
font-size: 17px !important;
padding: 16px 32px !important;
min-height: 44px !important;
transition: all 0.15s ease !important;
box-shadow: none !important;
letter-spacing: -.022em !important;
}
.gradio-container .gr-button:hover {
background: #0051d5 !important;
transform: none !important;
box-shadow: none !important;
}
.gradio-container .gr-button:active {
background: #004bb8 !important;
transform: scale(0.98) !important;
}
/* Clean input fields */
.gradio-container .gr-textbox,
.gradio-container .gr-slider {
background: #ffffff !important;
border: 1px solid #d2d2d7 !important;
border-radius: 10px !important;
padding: 12px 16px !important;
font-size: 17px !important;
color: #1d1d1f !important;
transition: all 0.15s ease !important;
min-height: 44px !important;
}
.gradio-container .gr-textbox:focus,
.gradio-container .gr-slider:focus {
border-color: #007aff !important;
box-shadow: 0 0 0 3px rgba(0, 122, 255, 0.1) !important;
outline: none !important;
}
/* Gallery styling */
.gradio-container .gr-gallery {
border-radius: 12px !important;
border: 1px solid #d2d2d7 !important;
background: #ffffff !important;
overflow: hidden !important;
}
/* Image output */
.gradio-container .gr-image {
border-radius: 12px !important;
border: 1px solid #d2d2d7 !important;
background: #ffffff !important;
overflow: hidden !important;
}
/* Accordion - Apple style */
.gradio-container .gr-accordion {
background: #f2f2f7 !important;
border: none !important;
border-radius: 12px !important;
padding: 0 !important;
margin-top: 24px !important;
}
.gradio-container .gr-accordion .gr-accordion-button {
background: transparent !important;
border: none !important;
padding: 16px !important;
font-weight: 500 !important;
color: #1d1d1f !important;
}
/* Download button */
.gradio-container .gr-download-button {
background: #34c759 !important;
border: none !important;
border-radius: 8px !important;
color: white !important;
font-weight: 500 !important;
font-size: 17px !important;
padding: 16px 32px !important;
min-height: 44px !important;
}
.gradio-container .gr-download-button:hover {
background: #30a14a !important;
}
/* Examples section */
.gradio-container .gr-examples {
background: #f2f2f7 !important;
border-radius: 18px !important;
padding: 24px !important;
border: none !important;
}
/* Mobile responsiveness */
@media (max-width: 768px) {
#col-container { padding: 20px 16px !important; max-width: 100% !important; }
.gradio-container h1 { font-size: 2rem !important; margin-bottom: 12px !important; }
.gradio-container .subtitle { font-size: 1.1rem !important; }
.section-card { padding: 24px !important; margin-bottom: 24px !important; }
.gradio-container .gr-button { padding: 14px 28px !important; font-size: 16px !important; }
.gradio-container .gr-gallery { height: 200px !important; columns: 2 !important; }
.gradio-container .gr-row { flex-direction: column !important; gap: 20px !important; }
}
@media (max-width: 480px) {
.gradio-container h1 { font-size: 1.75rem !important; }
.section-card { padding: 20px !important; }
.gradio-container .gr-gallery { height: 180px !important; columns: 1 !important; }
}
/* Hide gradio header/footer */
.gradio-container .gr-footer,
.gradio-container .gr-header {
display: none !important;
}
"""
# Load examples
with open("examples/0_examples.json", "r") as file:
examples = json.load(file)
print(examples)
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
# Header
gr.HTML(
"""
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<h1>Qwen Image to LoRA</h1>
<p class="subtitle">Generate custom LoRA models from your images</p>
<p style="font-size: 14px; color: #86868b; margin-top: 16px;">
Demo by <a href="https://aisudo.com/" target="_blank" style="color: #007aff; text-decoration: none;">AiSudo</a> •
<a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank" style="color: #007aff; text-decoration: none;">Built with anycoder</a>
</p>
</div>
"""
)
# ✅ Hugging Face Login Button
with gr.Row():
with gr.Column(scale=1):
login_btn = gr.LoginButton(
value="Sign in with Hugging Face",
logout_value="Logout ({})",
variant="huggingface",
size="lg",
)
with gr.Column(scale=3):
whoami = gr.Markdown(value="", elem_id="whoami")
with gr.Row():
with gr.Column(elem_classes=["section-card"]):
input_images = gr.Gallery(
label="Input Images",
file_types=["image"],
show_label=True,
columns=2,
object_fit="cover",
height=250,
)
lora_button = gr.Button("Generate LoRA", size="lg")
with gr.Column(elem_classes=["section-card"]):
lora_name = gr.Textbox(
label="Generated LoRA",
lines=2,
interactive=False,
placeholder="Your LoRA will appear here...",
)
hub_link = gr.Markdown(value="", label="Hub Link")
# This becomes clickable only after upload
lora_download = gr.Button(
value="View on Hub",
interactive=False,
size="lg",
link="",
)
with gr.Column(elem_classes=["section-card"]) as imagen_container:
gr.Markdown("### Generate Images")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
lines=2,
placeholder="Describe what you want to generate...",
value="a person in a fishing boat.",
)
imagen_button = gr.Button("Generate Image", interactive=False, size="lg")
with gr.Accordion("Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative Prompt",
lines=1,
placeholder="What to avoid...",
value="blurry, low quality",
)
num_inference_steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50,
step=1,
value=25,
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=10.0,
step=0.1,
value=3.5,
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=1280,
step=32,
value=768,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=1280,
step=32,
value=1024,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
with gr.Column():
output_image = gr.Image(label="Generated Image", height=350)
gr.Examples(examples=examples, inputs=[input_images], label="Examples")
gr.Markdown(read_file("static/footer.md"))
# Login click shows the user
login_btn.click(fn=show_user, inputs=[login_btn], outputs=[whoami], api_visibility="public")
# Generate LoRA (auto-uploads to user account if signed in)
lora_button.click(
fn=generate_lora,
inputs=[input_images],
outputs=[lora_name, hub_link, imagen_button, lora_download],
api_visibility="public",
)
# Generate Image
imagen_button.click(
fn=generate_image,
inputs=[
lora_name,
prompt,
negative_prompt,
width,
height,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
],
outputs=[output_image, seed],
api_visibility="public",
)
if __name__ == "__main__":
demo.launch(
css=css, # Gradio 6: pass css here
mcp_server=True,
theme=gr.themes.Base(
primary_hue="blue",
secondary_hue="gray",
neutral_hue="gray",
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "Consolas", "monospace"],
),
footer_links=[{"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"}],
)