Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
|
| 3 |
import torch
|
| 4 |
-
import mediapy
|
| 5 |
import sa_handler
|
| 6 |
|
| 7 |
# init models
|
|
@@ -11,10 +10,10 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
|
|
| 11 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
|
| 12 |
scheduler=scheduler
|
| 13 |
).to("cuda")
|
| 14 |
-
#pipeline.enable_sequential_cpu_offload()
|
| 15 |
pipeline.enable_model_cpu_offload()
|
| 16 |
pipeline.enable_vae_slicing()
|
| 17 |
-
|
| 18 |
handler = sa_handler.Handler(pipeline)
|
| 19 |
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
|
| 20 |
share_layer_norm=False,
|
|
@@ -26,43 +25,43 @@ sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
|
|
| 26 |
|
| 27 |
handler.register(sa_args, )
|
| 28 |
|
| 29 |
-
#
|
| 30 |
-
sets_of_prompts = [
|
| 31 |
-
"a toy train. macro photo. 3d game asset",
|
| 32 |
-
"a toy airplane. macro photo. 3d game asset",
|
| 33 |
-
"a toy bicycle. macro photo. 3d game asset",
|
| 34 |
-
"a toy car. macro photo. 3d game asset",
|
| 35 |
-
"a toy boat. macro photo. 3d game asset",
|
| 36 |
-
]
|
| 37 |
-
|
| 38 |
-
# run StyleAligned
|
| 39 |
def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt):
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
with gr.Blocks() as demo:
|
| 47 |
with gr.Group():
|
| 48 |
with gr.Column():
|
| 49 |
with gr.Accordion(label='Enter upto 5 different initial prompts', open=True):
|
| 50 |
with gr.Row(variant='panel'):
|
|
|
|
| 51 |
initial_prompt1 = gr.Textbox(label='Initial prompt 1', value='', show_label=False, container=False, placeholder='a toy train')
|
| 52 |
initial_prompt2 = gr.Textbox(label='Initial prompt 2', value='', show_label=False, container=False, placeholder='a toy airplane')
|
| 53 |
initial_prompt3 = gr.Textbox(label='Initial prompt 3', value='', show_label=False, container=False, placeholder='a toy bicycle')
|
| 54 |
initial_prompt4 = gr.Textbox(label='Initial prompt 4', value='', show_label=False, container=False, placeholder='a toy car')
|
| 55 |
initial_prompt5 = gr.Textbox(label='Initial prompt 5', value='', show_label=False, container=False, placeholder='a toy boat')
|
| 56 |
with gr.Row():
|
|
|
|
| 57 |
style_prompt = gr.Textbox(label="Enter a style prompt", placeholder='macro photo, 3d game asset')
|
|
|
|
| 58 |
btn = gr.Button("Generate a set of Style-aligned SDXL images",)
|
|
|
|
| 59 |
output = gr.Gallery(label="Style-Aligned SDXL Images", elem_id="gallery",columns=5, rows=1, object_fit="contain", height="auto",)
|
| 60 |
-
|
|
|
|
| 61 |
btn.click(fn=style_aligned_sdxl,
|
| 62 |
inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
|
| 63 |
outputs=output,
|
| 64 |
api_name="style_aligned_sdxl")
|
| 65 |
|
|
|
|
| 66 |
gr.Examples(examples=[
|
| 67 |
["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "macro photo. 3d game asset."],
|
| 68 |
["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "BW logo. high contrast."],
|
|
@@ -74,5 +73,5 @@ with gr.Blocks() as demo:
|
|
| 74 |
outputs=[output],
|
| 75 |
fn=style_aligned_sdxl)
|
| 76 |
|
| 77 |
-
demo
|
| 78 |
-
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
|
| 3 |
import torch
|
|
|
|
| 4 |
import sa_handler
|
| 5 |
|
| 6 |
# init models
|
|
|
|
| 10 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
|
| 11 |
scheduler=scheduler
|
| 12 |
).to("cuda")
|
| 13 |
+
# Configure the pipeline for CPU offloading and VAE slicing#pipeline.enable_sequential_cpu_offload()
|
| 14 |
pipeline.enable_model_cpu_offload()
|
| 15 |
pipeline.enable_vae_slicing()
|
| 16 |
+
# Initialize the style-aligned handler
|
| 17 |
handler = sa_handler.Handler(pipeline)
|
| 18 |
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
|
| 19 |
share_layer_norm=False,
|
|
|
|
| 25 |
|
| 26 |
handler.register(sa_args, )
|
| 27 |
|
| 28 |
+
# Define the function to generate style-aligned images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt):
|
| 30 |
+
try:
|
| 31 |
+
# Combine the style prompt with each initial prompt
|
| 32 |
+
sets_of_prompts = [ prompt + ". " + style_prompt for prompt in [initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5,]]
|
| 33 |
+
# Generate images using the pipeline
|
| 34 |
+
images = pipeline(sets_of_prompts,).images
|
| 35 |
+
return images
|
| 36 |
+
except Exception as e:
|
| 37 |
+
raise gr.Error(f"Error in generating images: {e}")
|
| 38 |
|
| 39 |
with gr.Blocks() as demo:
|
| 40 |
with gr.Group():
|
| 41 |
with gr.Column():
|
| 42 |
with gr.Accordion(label='Enter upto 5 different initial prompts', open=True):
|
| 43 |
with gr.Row(variant='panel'):
|
| 44 |
+
# Textboxes for initial prompts
|
| 45 |
initial_prompt1 = gr.Textbox(label='Initial prompt 1', value='', show_label=False, container=False, placeholder='a toy train')
|
| 46 |
initial_prompt2 = gr.Textbox(label='Initial prompt 2', value='', show_label=False, container=False, placeholder='a toy airplane')
|
| 47 |
initial_prompt3 = gr.Textbox(label='Initial prompt 3', value='', show_label=False, container=False, placeholder='a toy bicycle')
|
| 48 |
initial_prompt4 = gr.Textbox(label='Initial prompt 4', value='', show_label=False, container=False, placeholder='a toy car')
|
| 49 |
initial_prompt5 = gr.Textbox(label='Initial prompt 5', value='', show_label=False, container=False, placeholder='a toy boat')
|
| 50 |
with gr.Row():
|
| 51 |
+
# Textbox for the style prompt
|
| 52 |
style_prompt = gr.Textbox(label="Enter a style prompt", placeholder='macro photo, 3d game asset')
|
| 53 |
+
# Button to generate images
|
| 54 |
btn = gr.Button("Generate a set of Style-aligned SDXL images",)
|
| 55 |
+
# Display the generated images
|
| 56 |
output = gr.Gallery(label="Style-Aligned SDXL Images", elem_id="gallery",columns=5, rows=1, object_fit="contain", height="auto",)
|
| 57 |
+
|
| 58 |
+
# Button click event
|
| 59 |
btn.click(fn=style_aligned_sdxl,
|
| 60 |
inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
|
| 61 |
outputs=output,
|
| 62 |
api_name="style_aligned_sdxl")
|
| 63 |
|
| 64 |
+
# Providing Example inputs for the demo
|
| 65 |
gr.Examples(examples=[
|
| 66 |
["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "macro photo. 3d game asset."],
|
| 67 |
["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "BW logo. high contrast."],
|
|
|
|
| 73 |
outputs=[output],
|
| 74 |
fn=style_aligned_sdxl)
|
| 75 |
|
| 76 |
+
# Launch the Gradio demo
|
| 77 |
+
demo.launch()
|