#!/usr/bin/env python # coding: utf-8 import gradio as gr import random import torch from collections import defaultdict from diffusers import DiffusionPipeline from functools import partial from itertools import zip_longest from typing import List from PIL import Image SELECT_LABEL = "Select as seed" MODEL_ID = "CompVis/ldm-text2im-large-256" STEPS = 50 ETA = 0.3 GUIDANCE_SCALE = 12 ldm = DiffusionPipeline.from_pretrained(MODEL_ID) with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo: state = gr.Variable({ 'selected': -1, 'seeds': [random.randint(0, 2 ** 32 - 1) for _ in range(6)] }) def infer_seeded_image(prompt, seed): print(f"Prompt: {prompt}, seed: {seed}") images, _ = infer_grid(prompt, n=1, seeds=[seed]) return images[0] def infer_grid(prompt, n=6, seeds=[]): # Unfortunately we have to iterate instead requesting all images at once, # because we have no way to get the generation seeds. result = defaultdict(list) for _, seed in zip_longest(range(n), seeds, fillvalue=None): seed = random.randint(0, 2**32 - 1) if seed is None else seed print(f"Setting seed {seed}") _ = torch.manual_seed(seed) images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)["sample"] result["images"].append(images[0]) result["seeds"].append(seed) return result["images"], result["seeds"] def infer(prompt, state): """ Outputs: - Grid images (list) - Seeded Image (Image or None) - Grid Box with updated visibility - Seeded Box with updated visibility """ grid_images = [None] * 6 image_with_seed = None visible = (False, False) if (seed_index := state["selected"]) > -1: seed = state["seeds"][seed_index] image_with_seed = infer_seeded_image(prompt, seed) visible = (False, True) else: grid_images, seeds = infer_grid(prompt) state["seeds"] = seeds visible = (True, False) boxes = [gr.Box.update(visible=v) for v in visible] return grid_images + [image_with_seed] + boxes + [state] def update_state(selected_index: int, value, state): if value == '': others_value = None else: others_value = '' state["selected"] = selected_index others = gr.Radio.update(value=others_value) return [others] * 5 + [state] def clear_seed(state): """Update state of Radio buttons, grid, seeded_box""" state["selected"] = -1 return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state] def image_block(): return gr.Image( interactive=False, show_label=False ).style( # border = (True, True, False, True), rounded = (True, True, False, False), ) def radio_block(): radio = gr.Radio( choices=[SELECT_LABEL], interactive=True, show_label=False, ).style( # border = (False, True, True, True), # rounded = (False, False, True, True) container=False ) return radio gr.Markdown("