Fibo-local / app.py
linoyts's picture
linoyts HF Staff
Update app.py
80d42e4 verified
raw
history blame
4.3 kB
import gradio as gr
import random
import os
import spaces
import torch
import time
import json
import numpy as np
from diffusers import BriaFiboPipeline
from diffusers.modular_pipelines import ModularPipeline
MAX_SEED = np.iinfo(np.int32).max
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True).to(device)
pipe = BriaFiboPipeline.from_pretrained("briaai/FIBO", trust_remote_code=True, dtype=dtype).to(device)
@spaces.GPU(duration=300)
def infer(
prompt,
prompt_refine,
prompt_in_json,
negative_prompt="",
seed=42,
randomize_seed=False,
width=1024,
height=768,
guidance_scale=5,
num_inference_steps=50,
mode="generate",
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
with torch.inference_mode():
if negative_prompt:
neg_output = vlm_pipe(prompt=negative_prompt)
neg_json_prompt = neg_output.values["json_prompt"]
else:
neg_json_prompt = ""
if mode == "refine":
json_prompt_str = (
json.dumps(prompt_in_json)
if isinstance(prompt_in_json, (dict, list))
else str(prompt_in_json)
)
output = vlm_pipe(json_prompt=json_prompt_str, prompt=prompt_refine)
else:
output = vlm_pipe(prompt=prompt)
json_prompt = output.values["json_prompt"]
image = pipe(
prompt=json_prompt,
num_inference_steps=num_inference_steps,
negative_prompt=neg_json_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
).images[0]
return image, seed, json_prompt, neg_json_prompt
css = """
#col-container{
margin: 0 auto;
max-width: 768px;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="violet")) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("## FIBO")
with gr.Row():
with gr.Tab("generate") as tab_generate:
with gr.Row():
prompt_generate = gr.Textbox(label="Prompt")
with gr.Tab("refine") as tab_refine:
with gr.Row():
prompt_refine = gr.Textbox(label="Prompt")
submit_btn = gr.Button("Generate")
result = gr.Image(label="output")
with gr.Accordion("Structured Prompt", open=False):
prompt_in_json = gr.JSON(label="json structured prompt")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(label="guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=5.0)
num_inference_steps = gr.Slider(
label="number of inference steps", minimum=1, maximum=60, step=1, value=50
)
height = gr.Slider(label="Height", minimum=768, maximum=1248, step=32, value=768)
width = gr.Slider(label="Width", minimum=832, maximum=1344, step=64, value=1024)
with gr.Row():
negative_prompt = gr.Textbox(label="negative prompt")
negative_prompt_json = gr.JSON(label="json negative prompt")
# Track active tab
current_mode = gr.State("generate")
tab_generate.select(lambda: "generate", outputs=current_mode)
tab_refine.select(lambda: "refine", outputs=current_mode)
submit_btn.click(
fn=infer,
inputs=[
prompt_generate,
prompt_refine,
prompt_in_json,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
current_mode,
],
outputs=[result, seed, prompt_in_json, negative_prompt_json],
)
demo.queue().launch()