linoyts HF Staff commited on
Commit
2525839
·
verified ·
1 Parent(s): e155fc5

initial commit

Browse files
Files changed (1) hide show
  1. app.py +166 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import spaces
4
+ import torch
5
+ import time
6
+ import json
7
+ from diffusers import BriaFiboPipeline
8
+ from diffusers.modular_pipelines import ModularPipeline
9
+
10
+ # resolutions=[
11
+ # "832 1248",
12
+ # "896 1152",
13
+ # "960 1088",
14
+ # "1024 1024",
15
+ # "1088 960",
16
+ # "1152 896",
17
+ # "1216 832",
18
+ # "1280 800",
19
+ # "1344 768",
20
+ # ]
21
+
22
+ dtype = torch.bfloat16
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ torch.set_grad_enabled(False)
26
+ vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True).to(device)
27
+
28
+ pipe = BriaFiboPipeline.from_pretrained(
29
+ "briaai/FIBO",
30
+ trust_remote_code=True,
31
+ torch_dtype=dtype).to(device)
32
+
33
+
34
+ def handle_json(text):
35
+ try:
36
+ json.loads(text)
37
+ return text
38
+ except:
39
+ return "Error"
40
+
41
+
42
+ @spaces.GPU()
43
+ def infer(prompt,
44
+ negative_prompt="",
45
+ seed=42,
46
+ randomize_seed=False,
47
+ width=1024,
48
+ height=1024,
49
+ guidance_scale=5,
50
+ num_inference_steps=50,
51
+ ):
52
+
53
+ if randomize_seed:
54
+ seed = random.randint(0, MAX_SEED)
55
+
56
+ t=time.time()
57
+
58
+
59
+ with torch.inference_mode():
60
+ # 1. Create a prompt to generate an initial image
61
+ output = vlm_pipe(prompt=prompt)
62
+ json_prompt_generate = output.values["json_prompt"]
63
+
64
+
65
+ image = pipe(prompt=json_prompt,
66
+ num_inference_steps=num_inference_steps,
67
+ negative_prompt=negative_prompt,
68
+ generator=generator,
69
+ width=width,height=height,
70
+ guidance_scale=guidance_scale).images[0]
71
+
72
+
73
+
74
+ return image
75
+
76
+ css = """
77
+ #col-container{
78
+ margin: 0 auto;
79
+ max-width: 750px;
80
+ }
81
+ """
82
+ with gr.Blocks(css=css) as demo:
83
+ with gr.Column(elem_id="col-container"):
84
+ gr.Markdown("## FOBI")
85
+
86
+ with gr.Group():
87
+ with gr.Column():
88
+ with gr.Row():
89
+ prompt_in = gr.Textbox(label="Prompt")
90
+ prompt_in_json = gr.JSON(label="Json")
91
+
92
+ with gr.Accordion("Advanced Settings", open=False):
93
+ with gr.Row():
94
+ seed = gr.Slider(
95
+ label="Seed",
96
+ minimum=0,
97
+ maximum=MAX_SEED,
98
+ step=1,
99
+ value=0,
100
+ )
101
+
102
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
103
+
104
+ with gr.Row():
105
+ guidance_scale = gr.Slider(
106
+ label="guidance scale",
107
+ minimum=1.0,
108
+ maximum=10.0,
109
+ step=0.1,
110
+ value=5.0
111
+ )
112
+ num_inference_steps = gr.Slider(
113
+ label="number of inference steps",
114
+ minimum=1,
115
+ maximum=60,
116
+ step=1,
117
+ value=50,
118
+ )
119
+ height = gr.Slider(
120
+ label="Height",
121
+ minimum=768,
122
+ maximum=1248,
123
+ step=32,
124
+ value=1024,
125
+ )
126
+
127
+ width = gr.Slider(
128
+ label="Width",
129
+ minimum=832,
130
+ maximum=1344,
131
+ step=64,
132
+ value=1024,
133
+ )
134
+ with gr.Row():
135
+ negative_prompt = gr.Textbox(label="negative prompt", value=json.dumps(''))
136
+ negative_prompt_json = gr.JSON(label="json negative prompt", value=json.dumps(''))
137
+
138
+
139
+
140
+ submit_btn = gr.Button("Generate")
141
+ result = gr.Image(label="output")
142
+
143
+ prompt_in.change(
144
+ handle_json,
145
+ inputs=prompt_in,
146
+ outputs=prompt_in_json)
147
+
148
+ negative_prompt.change(handle_json, inputs=negative_prompt, outputs=negative_prompt_json)
149
+
150
+ submit_btn.click(
151
+ fn = infer,
152
+ inputs = [
153
+ prompt_in,
154
+ negative_prompt,
155
+ seed,
156
+ randomize_seed,
157
+ width,
158
+ height,
159
+ guidance_scale,
160
+ num_inference_steps,
161
+ ],
162
+ outputs = [
163
+ result
164
+ ]
165
+ )
166
+ demo.queue().launch()