Spaces:
Runtime error
Runtime error
File size: 4,295 Bytes
2525839 7cf4b81 2525839 1bbae28 2525839 1bbae28 2525839 af4203c 2525839 80d42e4 af4203c a7fdf2a af4203c 2525839 af4203c 139be99 2525839 af4203c 2525839 af4203c 2525839 cda033e 2525839 af4203c 2525839 af4203c 2525839 af4203c cda033e af4203c cda033e af4203c a7fdf2a af4203c 2525839 af4203c 2525839 af4203c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import random
import os
import spaces
import torch
import time
import json
import numpy as np
from diffusers import BriaFiboPipeline
from diffusers.modular_pipelines import ModularPipeline
MAX_SEED = np.iinfo(np.int32).max
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True).to(device)
pipe = BriaFiboPipeline.from_pretrained("briaai/FIBO", trust_remote_code=True, dtype=dtype).to(device)
@spaces.GPU(duration=300)
def infer(
prompt,
prompt_refine,
prompt_in_json,
negative_prompt="",
seed=42,
randomize_seed=False,
width=1024,
height=768,
guidance_scale=5,
num_inference_steps=50,
mode="generate",
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
with torch.inference_mode():
if negative_prompt:
neg_output = vlm_pipe(prompt=negative_prompt)
neg_json_prompt = neg_output.values["json_prompt"]
else:
neg_json_prompt = ""
if mode == "refine":
json_prompt_str = (
json.dumps(prompt_in_json)
if isinstance(prompt_in_json, (dict, list))
else str(prompt_in_json)
)
output = vlm_pipe(json_prompt=json_prompt_str, prompt=prompt_refine)
else:
output = vlm_pipe(prompt=prompt)
json_prompt = output.values["json_prompt"]
image = pipe(
prompt=json_prompt,
num_inference_steps=num_inference_steps,
negative_prompt=neg_json_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
).images[0]
return image, seed, json_prompt, neg_json_prompt
css = """
#col-container{
margin: 0 auto;
max-width: 768px;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="violet")) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("## FIBO")
with gr.Row():
with gr.Tab("generate") as tab_generate:
with gr.Row():
prompt_generate = gr.Textbox(label="Prompt")
with gr.Tab("refine") as tab_refine:
with gr.Row():
prompt_refine = gr.Textbox(label="Prompt")
submit_btn = gr.Button("Generate")
result = gr.Image(label="output")
with gr.Accordion("Structured Prompt", open=False):
prompt_in_json = gr.JSON(label="json structured prompt")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(label="guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=5.0)
num_inference_steps = gr.Slider(
label="number of inference steps", minimum=1, maximum=60, step=1, value=50
)
height = gr.Slider(label="Height", minimum=768, maximum=1248, step=32, value=768)
width = gr.Slider(label="Width", minimum=832, maximum=1344, step=64, value=1024)
with gr.Row():
negative_prompt = gr.Textbox(label="negative prompt")
negative_prompt_json = gr.JSON(label="json negative prompt")
# Track active tab
current_mode = gr.State("generate")
tab_generate.select(lambda: "generate", outputs=current_mode)
tab_refine.select(lambda: "refine", outputs=current_mode)
submit_btn.click(
fn=infer,
inputs=[
prompt_generate,
prompt_refine,
prompt_in_json,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
current_mode,
],
outputs=[result, seed, prompt_in_json, negative_prompt_json],
)
demo.queue().launch()
|