#!/usr/bin/env python # coding: utf-8 import random from typing import List import gradio as gr from collections import defaultdict from functools import partial from PIL import Image block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }") SELECT_LABEL = "Select as seed" selectors: List[gr.Radio] = [] def infer_seeded_image(prompt, seed): return Image.open(f"sample_outputs/seeded_1.png") def infer_grid(prompt): response = defaultdict(list) for i in range(1, 7): response["images"].append(Image.open(f"sample_outputs/{i}.png")) response["seeds"].append(random.randint(0, 2 ** 32 -1)) # TODO: only run when a selection exists # TODO: get seed to reuse return response["images"] def infer(prompt): """ Outputs: - Grid images (list) - Seeded Image (Image or None) - Grid Box with updated visibility - Seeded Box with updated visibility """ grid_images = [None] * 6 image_with_seed = None visible = (False, False) if current_selection() > -1: image_with_seed = infer_seeded_image(prompt, 7667) visible = (False, True) else: grid_images = infer_grid(prompt) visible = (True, False) boxes = [gr.Box.update(visible=v) for v in visible] return grid_images + [image_with_seed] + boxes def image_block(): return gr.Image( interactive=False, show_label=False ).style( # border = (True, True, False, True), rounded = (True, True, False, False), ) selectors_state = [''] * 6 def did_select(radio: gr.Radio): new_state = list(map(lambda r: SELECT_LABEL if r == radio else '', selectors)) return new_state def update_state(radio: gr.Radio, *state): global selectors_state if list(state) != selectors_state: selectors_state = did_select(radio) return selectors_state def current_selection(): try: return selectors_state.index(SELECT_LABEL) except: return -1 def clear_seed(): """Update state of Radio buttons, grid, seeded_box""" global selectors_state selectors_state = [''] * 6 return selectors_state + [gr.Box.update(visible=True), gr.Box.update(visible=False)] def radio_block(): radio = gr.Radio( choices=[SELECT_LABEL], interactive=True, show_label=False, ).style( # border = (False, True, True, True), # rounded = (False, False, True, True) container=False ) return radio with block: gr.Markdown("