import gradio as gr import random import os import spaces import torch import time import json import numpy as np from diffusers import BriaFiboPipeline from diffusers.modular_pipelines import ModularPipeline MAX_SEED = np.iinfo(np.int32).max dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" torch.set_grad_enabled(False) vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True).to(device) pipe = BriaFiboPipeline.from_pretrained("briaai/FIBO", trust_remote_code=True, dtype=dtype).to(device) @spaces.GPU(duration=300) def infer( prompt, prompt_refine, prompt_in_json, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=768, guidance_scale=5, num_inference_steps=50, mode="generate", ): if randomize_seed: seed = random.randint(0, MAX_SEED) with torch.inference_mode(): if negative_prompt: neg_output = vlm_pipe(prompt=negative_prompt) neg_json_prompt = neg_output.values["json_prompt"] else: neg_json_prompt = "" if mode == "refine": json_prompt_str = ( json.dumps(prompt_in_json) if isinstance(prompt_in_json, (dict, list)) else str(prompt_in_json) ) output = vlm_pipe(json_prompt=json_prompt_str, prompt=prompt_refine) else: output = vlm_pipe(prompt=prompt) json_prompt = output.values["json_prompt"] image = pipe( prompt=json_prompt, num_inference_steps=num_inference_steps, negative_prompt=neg_json_prompt, width=width, height=height, guidance_scale=guidance_scale, ).images[0] return image, seed, json_prompt, neg_json_prompt css = """ #col-container{ margin: 0 auto; max-width: 768px; } """ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="violet")) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("## FIBO") with gr.Row(): with gr.Tab("generate") as tab_generate: with gr.Row(): prompt_generate = gr.Textbox(label="Prompt") with gr.Tab("refine") as tab_refine: with gr.Row(): prompt_refine = gr.Textbox(label="Prompt") submit_btn = gr.Button("Generate") result = gr.Image(label="output") with gr.Accordion("Structured Prompt", open=False): prompt_in_json = gr.JSON(label="json structured prompt") with gr.Accordion("Advanced Settings", open=False): with gr.Row(): seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): guidance_scale = gr.Slider(label="guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=5.0) num_inference_steps = gr.Slider( label="number of inference steps", minimum=1, maximum=60, step=1, value=50 ) height = gr.Slider(label="Height", minimum=768, maximum=1248, step=32, value=768) width = gr.Slider(label="Width", minimum=832, maximum=1344, step=64, value=1024) with gr.Row(): negative_prompt = gr.Textbox(label="negative prompt") negative_prompt_json = gr.JSON(label="json negative prompt") # Track active tab current_mode = gr.State("generate") tab_generate.select(lambda: "generate", outputs=current_mode) tab_refine.select(lambda: "refine", outputs=current_mode) submit_btn.click( fn=infer, inputs=[ prompt_generate, prompt_refine, prompt_in_json, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, current_mode, ], outputs=[result, seed, prompt_in_json, negative_prompt_json], ) demo.queue().launch()