#!/usr/bin/env python # coding: utf-8 import gradio as gr import random import torch from collections import defaultdict from diffusers import DiffusionPipeline from functools import partial from itertools import zip_longest from typing import List from PIL import Image SELECT_LABEL = "Select as seed" MODEL_ID = "CompVis/ldm-text2im-large-256" STEPS = 50 ETA = 0.3 GUIDANCE_SCALE = 12 ldm = DiffusionPipeline.from_pretrained(MODEL_ID) with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo: state = gr.Variable({ 'selected': -1, 'seeds': [random.randint(0, 2 ** 32 - 1) for _ in range(6)] }) def infer_seeded_image(prompt, seed): print(f"Prompt: {prompt}, seed: {seed}") images, _ = infer_grid(prompt, n=1, seeds=[seed]) return images[0] def infer_grid(prompt, n=6, seeds=[]): # Unfortunately we have to iterate instead requesting all images at once, # because we have no way to get the generation seeds. result = defaultdict(list) for _, seed in zip_longest(range(n), seeds, fillvalue=None): seed = random.randint(0, 2**32 - 1) if seed is None else seed print(f"Setting seed {seed}") _ = torch.manual_seed(seed) images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)["sample"] result["images"].append(images[0]) result["seeds"].append(seed) return result["images"], result["seeds"] def infer(prompt, state): """ Outputs: - Grid images (list) - Seeded Image (Image or None) - Grid Box with updated visibility - Seeded Box with updated visibility """ grid_images = [None] * 6 image_with_seed = None visible = (False, False) if (seed_index := state["selected"]) > -1: seed = state["seeds"][seed_index] image_with_seed = infer_seeded_image(prompt, seed) visible = (False, True) else: grid_images, seeds = infer_grid(prompt) state["seeds"] = seeds visible = (True, False) boxes = [gr.Box.update(visible=v) for v in visible] return grid_images + [image_with_seed] + boxes + [state] def update_state(selected_index: int, value, state): if value == '': others_value = None else: others_value = '' state["selected"] = selected_index others = gr.Radio.update(value=others_value) return [others] * 5 + [state] def clear_seed(state): """Update state of Radio buttons, grid, seeded_box""" state["selected"] = -1 return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state] def image_block(): return gr.Image( interactive=False, show_label=False ).style( # border = (True, True, False, True), rounded = (True, True, False, False), ) def radio_block(): radio = gr.Radio( choices=[SELECT_LABEL], interactive=True, show_label=False, ).style( # border = (False, True, True, True), # rounded = (False, False, True, True) container=False ) return radio gr.Markdown("

Latent Diffusion Demo

") with gr.Group(): with gr.Box(): with gr.Row().style(mobile_collapse=False, equal_height=True): text = gr.Textbox( label="Enter your prompt", show_label=False, max_lines=1 ).style( border=(True, False, True, True), # margin=False, rounded=(True, False, False, True), container=False, ) btn = gr.Button("Run").style( margin=False, rounded=(False, True, True, False), ) ## Can we create a Component with these, so it can participate as an output? with (grid := gr.Box()): with gr.Row(): with gr.Box().style(border=None): image1 = image_block() select1 = radio_block() with gr.Box().style(border=None): image2 = image_block() select2 = radio_block() with gr.Box().style(border=None): image3 = image_block() select3 = radio_block() with gr.Row(): with gr.Box().style(border=None): image4 = image_block() select4 = radio_block() with gr.Box().style(border=None): image5 = image_block() select5 = radio_block() with gr.Box().style(border=None): image6 = image_block() select6 = radio_block() images = [image1, image2, image3, image4, image5, image6] selectors = [select1, select2, select3, select4, select5, select6] for i, radio in enumerate(selectors): others = list(filter(lambda s: s != radio, selectors)) radio.change( partial(update_state, i), inputs=[radio, state], outputs=others + [state] ) with (seeded_box := gr.Box()): seeded_image = image_block() clear_seed_button = gr.Button("Clear Seed") seeded_box.visible = False clear_seed_button.click( clear_seed, inputs=[state], outputs=selectors + [grid, seeded_box] + [state] ) all_images = images + [seeded_image] boxes = [grid, seeded_box] infer_outputs = all_images + boxes + [state] text.submit( infer, inputs=[text, state], outputs=infer_outputs ) btn.click( infer, inputs=[text, state], outputs=infer_outputs ) demo.launch(enable_queue=False)