Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import random | |
| import os | |
| import spaces | |
| import torch | |
| import time | |
| import json | |
| import numpy as np | |
| from diffusers import BriaFiboPipeline | |
| from diffusers.modular_pipelines import ModularPipeline | |
| MAX_SEED = np.iinfo(np.int32).max | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.set_grad_enabled(False) | |
| vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True).to(device) | |
| pipe = BriaFiboPipeline.from_pretrained("briaai/FIBO", trust_remote_code=True, dtype=dtype).to(device) | |
| def infer( | |
| prompt, | |
| prompt_refine, | |
| prompt_in_json, | |
| negative_prompt="", | |
| seed=42, | |
| randomize_seed=False, | |
| width=1024, | |
| height=768, | |
| guidance_scale=5, | |
| num_inference_steps=50, | |
| mode="generate", | |
| ): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| with torch.inference_mode(): | |
| if negative_prompt: | |
| neg_output = vlm_pipe(prompt=negative_prompt) | |
| neg_json_prompt = neg_output.values["json_prompt"] | |
| else: | |
| neg_json_prompt = "" | |
| if mode == "refine": | |
| json_prompt_str = ( | |
| json.dumps(prompt_in_json) | |
| if isinstance(prompt_in_json, (dict, list)) | |
| else str(prompt_in_json) | |
| ) | |
| output = vlm_pipe(json_prompt=json_prompt_str, prompt=prompt_refine) | |
| else: | |
| output = vlm_pipe(prompt=prompt) | |
| json_prompt = output.values["json_prompt"] | |
| image = pipe( | |
| prompt=json_prompt, | |
| num_inference_steps=num_inference_steps, | |
| negative_prompt=neg_json_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| ).images[0] | |
| return image, seed, json_prompt, neg_json_prompt | |
| css = """ | |
| #col-container{ | |
| margin: 0 auto; | |
| max-width: 768px; | |
| } | |
| """ | |
| with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="violet")) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("## FIBO") | |
| with gr.Row(): | |
| with gr.Tab("generate") as tab_generate: | |
| with gr.Row(): | |
| prompt_generate = gr.Textbox(label="Prompt") | |
| with gr.Tab("refine") as tab_refine: | |
| with gr.Row(): | |
| prompt_refine = gr.Textbox(label="Prompt") | |
| submit_btn = gr.Button("Generate") | |
| result = gr.Image(label="output") | |
| with gr.Accordion("Structured Prompt", open=False): | |
| prompt_in_json = gr.JSON(label="json structured prompt") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider(label="guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=5.0) | |
| num_inference_steps = gr.Slider( | |
| label="number of inference steps", minimum=1, maximum=60, step=1, value=50 | |
| ) | |
| height = gr.Slider(label="Height", minimum=768, maximum=1248, step=32, value=768) | |
| width = gr.Slider(label="Width", minimum=832, maximum=1344, step=64, value=1024) | |
| with gr.Row(): | |
| negative_prompt = gr.Textbox(label="negative prompt") | |
| negative_prompt_json = gr.JSON(label="json negative prompt") | |
| # Track active tab | |
| current_mode = gr.State("generate") | |
| tab_generate.select(lambda: "generate", outputs=current_mode) | |
| tab_refine.select(lambda: "refine", outputs=current_mode) | |
| submit_btn.click( | |
| fn=infer, | |
| inputs=[ | |
| prompt_generate, | |
| prompt_refine, | |
| prompt_in_json, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| current_mode, | |
| ], | |
| outputs=[result, seed, prompt_in_json, negative_prompt_json], | |
| ) | |
| demo.queue().launch() | |