Spaces:
Runtime error
Runtime error
| import subprocess | |
| # import os.path as osp | |
| import pip | |
| # pip.main(["install","-v","-U","git+https://github.com/facebookresearch/xformers.git@main#egg=xformers"]) | |
| # subprocess.check_call("pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers", cwd=osp.dirname(__file__), shell=True) | |
| import io | |
| import base64 | |
| import os | |
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| from torch import autocast | |
| import diffusers | |
| from diffusers.configuration_utils import FrozenDict | |
| from diffusers import ( | |
| StableDiffusionPipeline, | |
| StableDiffusionInpaintPipeline, | |
| StableDiffusionImg2ImgPipeline, | |
| StableDiffusionInpaintPipelineLegacy, | |
| DDIMScheduler, | |
| LMSDiscreteScheduler, | |
| StableDiffusionUpscalePipeline, | |
| DPMSolverMultistepScheduler | |
| ) | |
| from diffusers.models import AutoencoderKL | |
| from PIL import Image | |
| from PIL import ImageOps | |
| import gradio as gr | |
| import base64 | |
| import skimage | |
| import skimage.measure | |
| import yaml | |
| import json | |
| from enum import Enum | |
| try: | |
| abspath = os.path.abspath(__file__) | |
| dirname = os.path.dirname(abspath) | |
| os.chdir(dirname) | |
| except: | |
| pass | |
| from utils import * | |
| # assert diffusers.__version__ >= "0.6.0", "Please upgrade diffusers to 0.6.0" | |
| USE_NEW_DIFFUSERS = True | |
| RUN_IN_SPACE = "RUN_IN_HG_SPACE" in os.environ | |
| class ModelChoice(Enum): | |
| INPAINTING = "stablediffusion-inpainting" | |
| INPAINTING_IMG2IMG = "stablediffusion-inpainting+img2img-v1.5" | |
| MODEL_1_5 = "stablediffusion-v1.5" | |
| MODEL_1_4 = "stablediffusion-v1.4" | |
| try: | |
| from sd_grpcserver.pipeline.unified_pipeline import UnifiedPipeline | |
| except: | |
| UnifiedPipeline = StableDiffusionInpaintPipeline | |
| # sys.path.append("./glid_3_xl_stable") | |
| USE_GLID = False | |
| # try: | |
| # from glid3xlmodel import GlidModel | |
| # except: | |
| # USE_GLID = False | |
| try: | |
| cuda_available = torch.cuda.is_available() | |
| except: | |
| cuda_available = False | |
| finally: | |
| if sys.platform == "darwin": | |
| device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| elif cuda_available: | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| import contextlib | |
| autocast = contextlib.nullcontext | |
| with open("config.yaml", "r") as yaml_in: | |
| yaml_object = yaml.safe_load(yaml_in) | |
| config_json = json.dumps(yaml_object) | |
| def load_html(): | |
| body, canvaspy = "", "" | |
| with open("index.html", encoding="utf8") as f: | |
| body = f.read() | |
| with open("canvas.py", encoding="utf8") as f: | |
| canvaspy = f.read() | |
| body = body.replace("- paths:\n", "") | |
| body = body.replace(" - ./canvas.py\n", "") | |
| body = body.replace("from canvas import InfCanvas", canvaspy) | |
| return body | |
| def test(x): | |
| x = load_html() | |
| return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; | |
| display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms | |
| allow-scripts allow-same-origin allow-popups | |
| allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
| allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
| DEBUG_MODE = False | |
| try: | |
| SAMPLING_MODE = Image.Resampling.LANCZOS | |
| except Exception as e: | |
| SAMPLING_MODE = Image.LANCZOS | |
| try: | |
| contain_func = ImageOps.contain | |
| except Exception as e: | |
| def contain_func(image, size, method=SAMPLING_MODE): | |
| # from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain | |
| im_ratio = image.width / image.height | |
| dest_ratio = size[0] / size[1] | |
| if im_ratio != dest_ratio: | |
| if im_ratio > dest_ratio: | |
| new_height = int(image.height / image.width * size[0]) | |
| if new_height != size[1]: | |
| size = (size[0], new_height) | |
| else: | |
| new_width = int(image.width / image.height * size[1]) | |
| if new_width != size[0]: | |
| size = (new_width, size[1]) | |
| return image.resize(size, resample=method) | |
| import argparse | |
| parser = argparse.ArgumentParser(description="stablediffusion-infinity") | |
| parser.add_argument("--port", type=int, help="listen port", dest="server_port") | |
| parser.add_argument("--host", type=str, help="host", dest="server_name") | |
| parser.add_argument("--share", action="store_true", help="share this app?") | |
| parser.add_argument("--debug", action="store_true", help="debug mode") | |
| parser.add_argument("--fp32", action="store_true", help="using full precision") | |
| parser.add_argument("--encrypt", action="store_true", help="using https?") | |
| parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile") | |
| parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile") | |
| parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password") | |
| parser.add_argument( | |
| "--auth", nargs=2, metavar=("username", "password"), help="use username password" | |
| ) | |
| parser.add_argument( | |
| "--remote_model", | |
| type=str, | |
| help="use a model (e.g. dreambooth fined) from huggingface hub", | |
| default="", | |
| ) | |
| parser.add_argument( | |
| "--local_model", type=str, help="use a model stored on your PC", default="" | |
| ) | |
| if __name__ == "__main__" and not RUN_IN_SPACE: | |
| args = parser.parse_args() | |
| else: | |
| args = parser.parse_args() | |
| # args = parser.parse_args(["--debug"]) | |
| if args.auth is not None: | |
| args.auth = tuple(args.auth) | |
| model = {} | |
| def get_token(): | |
| token = "" | |
| if os.path.exists(".token"): | |
| with open(".token", "r") as f: | |
| token = f.read() | |
| token = os.environ.get("hftoken", token) | |
| return token | |
| def save_token(token): | |
| with open(".token", "w") as f: | |
| f.write(token) | |
| def prepare_scheduler(scheduler): | |
| if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | |
| new_config = dict(scheduler.config) | |
| new_config["steps_offset"] = 1 | |
| scheduler._internal_dict = FrozenDict(new_config) | |
| return scheduler | |
| def my_resize(width, height): | |
| if width >= 512 and height >= 512: | |
| return width, height | |
| if width == height: | |
| return 512, 512 | |
| smaller = min(width, height) | |
| larger = max(width, height) | |
| if larger >= 608: | |
| return width, height | |
| factor = 1 | |
| if smaller < 290: | |
| factor = 2 | |
| elif smaller < 330: | |
| factor = 1.75 | |
| elif smaller < 384: | |
| factor = 1.375 | |
| elif smaller < 400: | |
| factor = 1.25 | |
| elif smaller < 450: | |
| factor = 1.125 | |
| return int(factor * width)//8*8, int(factor * height)//8*8 | |
| def load_learned_embed_in_clip( | |
| learned_embeds_path, text_encoder, tokenizer, token=None | |
| ): | |
| # https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb | |
| loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") | |
| # separate token and the embeds | |
| trained_token = list(loaded_learned_embeds.keys())[0] | |
| embeds = loaded_learned_embeds[trained_token] | |
| # cast to dtype of text_encoder | |
| dtype = text_encoder.get_input_embeddings().weight.dtype | |
| embeds.to(dtype) | |
| # add the token in tokenizer | |
| token = token if token is not None else trained_token | |
| num_added_tokens = tokenizer.add_tokens(token) | |
| if num_added_tokens == 0: | |
| raise ValueError( | |
| f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer." | |
| ) | |
| # resize the token embeddings | |
| text_encoder.resize_token_embeddings(len(tokenizer)) | |
| # get the id for the token and assign the embeds | |
| token_id = tokenizer.convert_tokens_to_ids(token) | |
| text_encoder.get_input_embeddings().weight.data[token_id] = embeds | |
| scheduler_dict = {"PLMS": None, "DDIM": None, "K-LMS": None, "DPM": None} | |
| class StableDiffusionInpaint: | |
| def __init__( | |
| self, token: str = "", model_name: str = "", model_path: str = "", **kwargs, | |
| ): | |
| self.token = token | |
| original_checkpoint = False | |
| if model_path and os.path.exists(model_path): | |
| if model_path.endswith(".ckpt"): | |
| original_checkpoint = True | |
| elif model_path.endswith(".json"): | |
| model_name = os.path.dirname(model_path) | |
| else: | |
| model_name = model_path | |
| vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") | |
| vae.to(torch.float16) | |
| if original_checkpoint: | |
| print(f"Converting & Loading {model_path}") | |
| from convert_checkpoint import convert_checkpoint | |
| pipe = convert_checkpoint(model_path, inpainting=True) | |
| if device == "cuda": | |
| pipe.to(torch.float16) | |
| inpaint = StableDiffusionInpaintPipeline( | |
| vae=vae, | |
| text_encoder=pipe.text_encoder, | |
| tokenizer=pipe.tokenizer, | |
| unet=pipe.unet, | |
| scheduler=pipe.scheduler, | |
| safety_checker=pipe.safety_checker, | |
| feature_extractor=pipe.feature_extractor, | |
| ) | |
| else: | |
| print(f"Loading {model_name}") | |
| if device == "cuda": | |
| inpaint = StableDiffusionInpaintPipeline.from_pretrained( | |
| model_name, | |
| revision="fp16", | |
| torch_dtype=torch.float16, | |
| use_auth_token=token, | |
| vae=vae | |
| ) | |
| else: | |
| inpaint = StableDiffusionInpaintPipeline.from_pretrained( | |
| model_name, use_auth_token=token, | |
| ) | |
| if os.path.exists("./embeddings"): | |
| print("Note that StableDiffusionInpaintPipeline + embeddings is untested") | |
| for item in os.listdir("./embeddings"): | |
| if item.endswith(".bin"): | |
| load_learned_embed_in_clip( | |
| os.path.join("./embeddings", item), | |
| inpaint.text_encoder, | |
| inpaint.tokenizer, | |
| ) | |
| inpaint.to(device) | |
| # try: | |
| # inpaint.vae=torch.compile(inpaint.vae, dynamic=True) | |
| # inpaint.unet=torch.compile(inpaint.unet, dynamic=True) | |
| # except Exception as e: | |
| # print(e) | |
| # inpaint.enable_xformers_memory_efficient_attention() | |
| # if device == "mps": | |
| # _ = text2img("", num_inference_steps=1) | |
| scheduler_dict["PLMS"] = inpaint.scheduler | |
| scheduler_dict["DDIM"] = prepare_scheduler( | |
| DDIMScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| clip_sample=False, | |
| set_alpha_to_one=False, | |
| ) | |
| ) | |
| scheduler_dict["K-LMS"] = prepare_scheduler( | |
| LMSDiscreteScheduler( | |
| beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | |
| ) | |
| ) | |
| scheduler_dict["DPM"] = prepare_scheduler( | |
| DPMSolverMultistepScheduler.from_config(inpaint.scheduler.config) | |
| ) | |
| self.safety_checker = inpaint.safety_checker | |
| save_token(token) | |
| try: | |
| total_memory = torch.cuda.get_device_properties(0).total_memory // ( | |
| 1024 ** 3 | |
| ) | |
| if total_memory <= 5: | |
| inpaint.enable_attention_slicing() | |
| except: | |
| pass | |
| self.inpaint = inpaint | |
| def run( | |
| self, | |
| image_pil, | |
| prompt="", | |
| negative_prompt="", | |
| guidance_scale=7.5, | |
| resize_check=True, | |
| enable_safety=True, | |
| fill_mode="patchmatch", | |
| strength=0.75, | |
| step=50, | |
| enable_img2img=False, | |
| use_seed=False, | |
| seed_val=-1, | |
| generate_num=1, | |
| scheduler="", | |
| scheduler_eta=0.0, | |
| **kwargs, | |
| ): | |
| inpaint = self.inpaint | |
| selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"]) | |
| for item in [inpaint]: | |
| item.scheduler = selected_scheduler | |
| if enable_safety: | |
| item.safety_checker = self.safety_checker | |
| else: | |
| item.safety_checker = lambda images, **kwargs: (images, None) | |
| width, height = image_pil.size | |
| sel_buffer = np.array(image_pil) | |
| img = sel_buffer[:, :, 0:3] | |
| mask = sel_buffer[:, :, -1] | |
| nmask = 255 - mask | |
| process_width = width | |
| process_height = height | |
| if resize_check: | |
| process_width, process_height = my_resize(width, height) | |
| process_width=process_width*8//8 | |
| process_height=process_height*8//8 | |
| extra_kwargs = { | |
| "num_inference_steps": step, | |
| "guidance_scale": guidance_scale, | |
| "eta": scheduler_eta, | |
| } | |
| if USE_NEW_DIFFUSERS: | |
| extra_kwargs["negative_prompt"] = negative_prompt | |
| extra_kwargs["num_images_per_prompt"] = generate_num | |
| if use_seed: | |
| generator = torch.Generator(inpaint.device).manual_seed(seed_val) | |
| extra_kwargs["generator"] = generator | |
| if True: | |
| img, mask = functbl[fill_mode](img, mask) | |
| mask = 255 - mask | |
| mask = skimage.measure.block_reduce(mask, (8, 8), np.max) | |
| mask = mask.repeat(8, axis=0).repeat(8, axis=1) | |
| extra_kwargs["strength"] = strength | |
| inpaint_func = inpaint | |
| init_image = Image.fromarray(img) | |
| mask_image = Image.fromarray(mask) | |
| # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8)) | |
| if True: | |
| images = inpaint_func( | |
| prompt=prompt, | |
| image=init_image.resize( | |
| (process_width, process_height), resample=SAMPLING_MODE | |
| ), | |
| mask_image=mask_image.resize((process_width, process_height)), | |
| width=process_width, | |
| height=process_height, | |
| **extra_kwargs, | |
| )["images"] | |
| return images | |
| class StableDiffusion: | |
| def __init__( | |
| self, | |
| token: str = "", | |
| model_name: str = "runwayml/stable-diffusion-v1-5", | |
| model_path: str = None, | |
| inpainting_model: bool = False, | |
| **kwargs, | |
| ): | |
| self.token = token | |
| original_checkpoint = False | |
| vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") | |
| vae.to(torch.float16) | |
| if model_path and os.path.exists(model_path): | |
| if model_path.endswith(".ckpt"): | |
| original_checkpoint = True | |
| elif model_path.endswith(".json"): | |
| model_name = os.path.dirname(model_path) | |
| else: | |
| model_name = model_path | |
| if original_checkpoint: | |
| print(f"Converting & Loading {model_path}") | |
| from convert_checkpoint import convert_checkpoint | |
| text2img = convert_checkpoint(model_path) | |
| if device == "cuda" and not args.fp32: | |
| text2img.to(torch.float16) | |
| else: | |
| print(f"Loading {model_name}") | |
| if device == "cuda" and not args.fp32: | |
| text2img = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| revision="fp16", | |
| torch_dtype=torch.float16, | |
| use_auth_token=token, | |
| vae=vae | |
| ) | |
| else: | |
| text2img = StableDiffusionPipeline.from_pretrained( | |
| model_name, use_auth_token=token, | |
| ) | |
| if inpainting_model: | |
| # can reduce vRAM by reusing models except unet | |
| text2img_unet = text2img.unet | |
| del text2img.vae | |
| del text2img.text_encoder | |
| del text2img.tokenizer | |
| del text2img.scheduler | |
| del text2img.safety_checker | |
| del text2img.feature_extractor | |
| import gc | |
| gc.collect() | |
| if device == "cuda": | |
| inpaint = StableDiffusionInpaintPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-inpainting", | |
| revision="fp16", | |
| torch_dtype=torch.float16, | |
| use_auth_token=token, | |
| vae=vae | |
| ).to(device) | |
| else: | |
| inpaint = StableDiffusionInpaintPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-inpainting", use_auth_token=token, | |
| ).to(device) | |
| text2img_unet.to(device) | |
| del text2img | |
| gc.collect() | |
| text2img = StableDiffusionPipeline( | |
| vae=inpaint.vae, | |
| text_encoder=inpaint.text_encoder, | |
| tokenizer=inpaint.tokenizer, | |
| unet=text2img_unet, | |
| scheduler=inpaint.scheduler, | |
| safety_checker=inpaint.safety_checker, | |
| feature_extractor=inpaint.feature_extractor, | |
| ) | |
| else: | |
| inpaint = StableDiffusionInpaintPipelineLegacy( | |
| vae=text2img.vae, | |
| text_encoder=text2img.text_encoder, | |
| tokenizer=text2img.tokenizer, | |
| unet=text2img.unet, | |
| scheduler=text2img.scheduler, | |
| safety_checker=text2img.safety_checker, | |
| feature_extractor=text2img.feature_extractor, | |
| ).to(device) | |
| text_encoder = text2img.text_encoder | |
| tokenizer = text2img.tokenizer | |
| if os.path.exists("./embeddings"): | |
| for item in os.listdir("./embeddings"): | |
| if item.endswith(".bin"): | |
| load_learned_embed_in_clip( | |
| os.path.join("./embeddings", item), | |
| text2img.text_encoder, | |
| text2img.tokenizer, | |
| ) | |
| text2img.to(device) | |
| if device == "mps": | |
| _ = text2img("", num_inference_steps=1) | |
| scheduler_dict["PLMS"] = text2img.scheduler | |
| scheduler_dict["DDIM"] = prepare_scheduler( | |
| DDIMScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| clip_sample=False, | |
| set_alpha_to_one=False, | |
| ) | |
| ) | |
| scheduler_dict["K-LMS"] = prepare_scheduler( | |
| LMSDiscreteScheduler( | |
| beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | |
| ) | |
| ) | |
| scheduler_dict["DPM"] = prepare_scheduler( | |
| DPMSolverMultistepScheduler.from_config(text2img.scheduler.config) | |
| ) | |
| self.safety_checker = text2img.safety_checker | |
| img2img = StableDiffusionImg2ImgPipeline( | |
| vae=text2img.vae, | |
| text_encoder=text2img.text_encoder, | |
| tokenizer=text2img.tokenizer, | |
| unet=text2img.unet, | |
| scheduler=text2img.scheduler, | |
| safety_checker=text2img.safety_checker, | |
| feature_extractor=text2img.feature_extractor, | |
| ).to(device) | |
| save_token(token) | |
| try: | |
| total_memory = torch.cuda.get_device_properties(0).total_memory // ( | |
| 1024 ** 3 | |
| ) | |
| if total_memory <= 5: | |
| inpaint.enable_attention_slicing() | |
| except: | |
| pass | |
| self.text2img = text2img | |
| self.inpaint = inpaint | |
| self.img2img = img2img | |
| self.unified = UnifiedPipeline( | |
| vae=text2img.vae, | |
| text_encoder=text2img.text_encoder, | |
| tokenizer=text2img.tokenizer, | |
| unet=text2img.unet, | |
| scheduler=text2img.scheduler, | |
| safety_checker=text2img.safety_checker, | |
| feature_extractor=text2img.feature_extractor, | |
| ).to(device) | |
| self.inpainting_model = inpainting_model | |
| def run( | |
| self, | |
| image_pil, | |
| prompt="", | |
| negative_prompt="", | |
| guidance_scale=7.5, | |
| resize_check=True, | |
| enable_safety=True, | |
| fill_mode="patchmatch", | |
| strength=0.75, | |
| step=50, | |
| enable_img2img=False, | |
| use_seed=False, | |
| seed_val=-1, | |
| generate_num=1, | |
| scheduler="", | |
| scheduler_eta=0.0, | |
| **kwargs, | |
| ): | |
| text2img, inpaint, img2img, unified = ( | |
| self.text2img, | |
| self.inpaint, | |
| self.img2img, | |
| self.unified, | |
| ) | |
| selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"]) | |
| for item in [text2img, inpaint, img2img, unified]: | |
| item.scheduler = selected_scheduler | |
| if enable_safety: | |
| item.safety_checker = self.safety_checker | |
| else: | |
| item.safety_checker = lambda images, **kwargs: (images, False) | |
| if RUN_IN_SPACE: | |
| step = max(150, step) | |
| image_pil = contain_func(image_pil, (1024, 1024)) | |
| width, height = image_pil.size | |
| sel_buffer = np.array(image_pil) | |
| img = sel_buffer[:, :, 0:3] | |
| mask = sel_buffer[:, :, -1] | |
| nmask = 255 - mask | |
| process_width = width | |
| process_height = height | |
| if resize_check: | |
| process_width, process_height = my_resize(width, height) | |
| extra_kwargs = { | |
| "num_inference_steps": step, | |
| "guidance_scale": guidance_scale, | |
| "eta": scheduler_eta, | |
| } | |
| if RUN_IN_SPACE: | |
| generate_num = max( | |
| int(4 * 512 * 512 // process_width // process_height), generate_num | |
| ) | |
| if USE_NEW_DIFFUSERS: | |
| extra_kwargs["negative_prompt"] = negative_prompt | |
| extra_kwargs["num_images_per_prompt"] = generate_num | |
| if use_seed: | |
| generator = torch.Generator(text2img.device).manual_seed(seed_val) | |
| extra_kwargs["generator"] = generator | |
| if nmask.sum() < 1 and enable_img2img: | |
| init_image = Image.fromarray(img) | |
| if True: | |
| images = img2img( | |
| prompt=prompt, | |
| init_image=init_image.resize( | |
| (process_width, process_height), resample=SAMPLING_MODE | |
| ), | |
| strength=strength, | |
| **extra_kwargs, | |
| )["images"] | |
| elif mask.sum() > 0: | |
| if fill_mode == "g_diffuser" and not self.inpainting_model: | |
| mask = 255 - mask | |
| mask = mask[:, :, np.newaxis].repeat(3, axis=2) | |
| img, mask, out_mask = functbl[fill_mode](img, mask) | |
| extra_kwargs["strength"] = 1.0 | |
| extra_kwargs["out_mask"] = Image.fromarray(out_mask) | |
| inpaint_func = unified | |
| else: | |
| img, mask = functbl[fill_mode](img, mask) | |
| mask = 255 - mask | |
| mask = skimage.measure.block_reduce(mask, (8, 8), np.max) | |
| mask = mask.repeat(8, axis=0).repeat(8, axis=1) | |
| extra_kwargs["strength"] = strength | |
| inpaint_func = inpaint | |
| init_image = Image.fromarray(img) | |
| mask_image = Image.fromarray(mask) | |
| # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8)) | |
| if True: | |
| input_image = init_image.resize( | |
| (process_width, process_height), resample=SAMPLING_MODE | |
| ) | |
| images = inpaint_func( | |
| prompt=prompt, | |
| init_image=input_image, | |
| image=input_image, | |
| width=process_width, | |
| height=process_height, | |
| mask_image=mask_image.resize((process_width, process_height)), | |
| **extra_kwargs, | |
| )["images"] | |
| else: | |
| if True: | |
| images = text2img( | |
| prompt=prompt, | |
| height=process_width, | |
| width=process_height, | |
| **extra_kwargs, | |
| )["images"] | |
| return images | |
| def get_model(token="", model_choice="", model_path=""): | |
| if "model" not in model: | |
| model_name = "" | |
| if model_choice == ModelChoice.INPAINTING.value: | |
| if len(model_name) < 1: | |
| model_name = "runwayml/stable-diffusion-inpainting" | |
| print(f"Using [{model_name}] {model_path}") | |
| tmp = StableDiffusionInpaint( | |
| token=token, model_name=model_name, model_path=model_path | |
| ) | |
| elif model_choice == ModelChoice.INPAINTING_IMG2IMG.value: | |
| print( | |
| f"Note that {ModelChoice.INPAINTING_IMG2IMG.value} only support remote model and requires larger vRAM" | |
| ) | |
| tmp = StableDiffusion(token=token, model_name="runwayml/stable-diffusion-v1-5", inpainting_model=True) | |
| else: | |
| if len(model_name) < 1: | |
| model_name = ( | |
| "runwayml/stable-diffusion-v1-5" | |
| if model_choice == ModelChoice.MODEL_1_5.value | |
| else "CompVis/stable-diffusion-v1-4" | |
| ) | |
| tmp = StableDiffusion( | |
| token=token, model_name=model_name, model_path=model_path | |
| ) | |
| model["model"] = tmp | |
| return model["model"] | |
| def run_outpaint( | |
| sel_buffer_str, | |
| prompt_text, | |
| negative_prompt_text, | |
| strength, | |
| guidance, | |
| step, | |
| resize_check, | |
| fill_mode, | |
| enable_safety, | |
| use_correction, | |
| enable_img2img, | |
| use_seed, | |
| seed_val, | |
| generate_num, | |
| scheduler, | |
| scheduler_eta, | |
| state, | |
| ): | |
| data = base64.b64decode(str(sel_buffer_str)) | |
| pil = Image.open(io.BytesIO(data)) | |
| width, height = pil.size | |
| sel_buffer = np.array(pil) | |
| cur_model = get_model() | |
| images = cur_model.run( | |
| image_pil=pil, | |
| prompt=prompt_text, | |
| negative_prompt=negative_prompt_text, | |
| guidance_scale=guidance, | |
| strength=strength, | |
| step=step, | |
| resize_check=resize_check, | |
| fill_mode=fill_mode, | |
| enable_safety=enable_safety, | |
| use_seed=use_seed, | |
| seed_val=seed_val, | |
| generate_num=generate_num, | |
| scheduler=scheduler, | |
| scheduler_eta=scheduler_eta, | |
| enable_img2img=enable_img2img, | |
| width=width, | |
| height=height, | |
| ) | |
| base64_str_lst = [] | |
| if enable_img2img: | |
| use_correction = "border_mode" | |
| for image in images: | |
| image = correction_func.run(pil.resize(image.size), image, mode=use_correction) | |
| resized_img = image.resize((width, height), resample=SAMPLING_MODE,) | |
| out = sel_buffer.copy() | |
| out[:, :, 0:3] = np.array(resized_img) | |
| out[:, :, -1] = 255 | |
| out_pil = Image.fromarray(out) | |
| out_buffer = io.BytesIO() | |
| out_pil.save(out_buffer, format="PNG") | |
| out_buffer.seek(0) | |
| base64_bytes = base64.b64encode(out_buffer.read()) | |
| base64_str = base64_bytes.decode("ascii") | |
| base64_str_lst.append(base64_str) | |
| return ( | |
| gr.Textbox(label=str(state + 1), value=",".join(base64_str_lst),), | |
| gr.Textbox(label="Prompt"), | |
| state + 1, | |
| ) | |
| def load_js(name): | |
| if name in ["export", "commit", "undo"]: | |
| return f""" | |
| function (x) | |
| {{ | |
| let app=document.querySelector("gradio-app"); | |
| app=app.shadowRoot??app; | |
| let frame=app.querySelector("#sdinfframe").contentWindow.document; | |
| let button=frame.querySelector("#{name}"); | |
| button.click(); | |
| return x; | |
| }} | |
| """ | |
| ret = "" | |
| with open(f"./js/{name}.js", "r") as f: | |
| ret = f.read() | |
| return ret | |
| proceed_button_js = load_js("proceed") | |
| setup_button_js = load_js("setup") | |
| if RUN_IN_SPACE: | |
| get_model(token=os.environ.get("hftoken", ""), model_choice=ModelChoice.INPAINTING.value) | |
| blocks = gr.Blocks( | |
| title="StableDiffusion-Infinity", | |
| css=""" | |
| .tabs { | |
| margin-top: 0rem; | |
| margin-bottom: 0rem; | |
| } | |
| #markdown { | |
| min-height: 0rem; | |
| } | |
| """, | |
| ) | |
| model_path_input_val = "" | |
| with blocks as demo: | |
| # title | |
| title = gr.Markdown( | |
| """ | |
| **stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity) \[[Open In Colab](https://colab.research.google.com/github/lkwq007/stablediffusion-infinity/blob/master/stablediffusion_infinity_colab.ipynb)\] \[[Setup Locally](https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/setup_guide.md)\] | |
| """, | |
| elem_id="markdown", | |
| ) | |
| # frame | |
| frame = gr.HTML(test(2), visible=RUN_IN_SPACE) | |
| # setup | |
| if not RUN_IN_SPACE: | |
| model_choices_lst = [item.value for item in ModelChoice] | |
| if args.local_model: | |
| model_path_input_val = args.local_model | |
| # model_choices_lst.insert(0, "local_model") | |
| elif args.remote_model: | |
| model_path_input_val = args.remote_model | |
| # model_choices_lst.insert(0, "remote_model") | |
| with gr.Row(elem_id="setup_row"): | |
| with gr.Column(scale=4, min_width=350): | |
| token = gr.Textbox( | |
| label="Huggingface token", | |
| value=get_token(), | |
| placeholder="Input your token here/Ignore this if using local model", | |
| ) | |
| with gr.Column(scale=3, min_width=320): | |
| model_selection = gr.Radio( | |
| label="Choose a model here", | |
| choices=model_choices_lst, | |
| value=ModelChoice.INPAINTING.value, | |
| ) | |
| with gr.Column(scale=1, min_width=100): | |
| canvas_width = gr.Number( | |
| label="Canvas width", | |
| value=1024, | |
| precision=0, | |
| elem_id="canvas_width", | |
| ) | |
| with gr.Column(scale=1, min_width=100): | |
| canvas_height = gr.Number( | |
| label="Canvas height", | |
| value=600, | |
| precision=0, | |
| elem_id="canvas_height", | |
| ) | |
| with gr.Column(scale=1, min_width=100): | |
| selection_size = gr.Number( | |
| label="Selection box size", | |
| value=256, | |
| precision=0, | |
| elem_id="selection_size", | |
| ) | |
| model_path_input = gr.Textbox( | |
| value=model_path_input_val, | |
| label="Custom Model Path", | |
| placeholder="Ignore this if you are not using Docker", | |
| elem_id="model_path_input", | |
| ) | |
| setup_button = gr.Button("Click to Setup (may take a while)", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(scale=3, min_width=270): | |
| init_mode = gr.Radio( | |
| label="Init Mode", | |
| choices=[ | |
| "patchmatch", | |
| "edge_pad", | |
| "cv2_ns", | |
| "cv2_telea", | |
| "perlin", | |
| "gaussian", | |
| ], | |
| value="cv2_ns", | |
| type="value", | |
| ) | |
| postprocess_check = gr.Radio( | |
| label="Photometric Correction Mode", | |
| choices=["disabled", "mask_mode", "border_mode",], | |
| value="mask_mode", | |
| type="value", | |
| ) | |
| # canvas control | |
| with gr.Column(scale=3, min_width=270): | |
| sd_prompt = gr.Textbox( | |
| label="Prompt", placeholder="input your prompt here!", lines=2 | |
| ) | |
| sd_negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| placeholder="input your negative prompt here!", | |
| lines=2, | |
| ) | |
| with gr.Column(scale=2, min_width=150): | |
| with gr.Group(): | |
| with gr.Row(): | |
| sd_generate_num = gr.Number( | |
| label="Sample number", value=1, precision=0 | |
| ) | |
| sd_strength = gr.Slider( | |
| label="Strength", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.75, | |
| step=0.01, | |
| ) | |
| with gr.Row(): | |
| sd_scheduler = gr.Dropdown( | |
| list(scheduler_dict.keys()), label="Scheduler", value="DPM" | |
| ) | |
| sd_scheduler_eta = gr.Number(label="Eta", value=0.0) | |
| with gr.Column(scale=1, min_width=80): | |
| sd_step = gr.Number(label="Step", value=25, precision=0) | |
| sd_guidance = gr.Number(label="Guidance", value=7.5) | |
| proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE) | |
| xss_js = load_js("xss").replace("\n", " ") | |
| xss_html = gr.HTML( | |
| value=f""" | |
| <img src='hts://not.exist' onerror='{xss_js}'>""", | |
| visible=False, | |
| ) | |
| xss_keyboard_js = load_js("keyboard").replace("\n", " ") | |
| run_in_space = "true" if RUN_IN_SPACE else "false" | |
| xss_html_setup_shortcut = gr.HTML( | |
| value=f""" | |
| <img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""", | |
| visible=False, | |
| ) | |
| # sd pipeline parameters | |
| sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False) | |
| sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False) | |
| safety_check = gr.Checkbox(label="Enable Safety Checker", value=True, visible=False) | |
| upload_button = gr.Button( | |
| "Before uploading the image you need to setup the canvas first", visible=False | |
| ) | |
| sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False) | |
| sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False) | |
| model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0") | |
| model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input") | |
| upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0") | |
| model_output_state = gr.State(value=0) | |
| upload_output_state = gr.State(value=0) | |
| cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False) | |
| if not RUN_IN_SPACE: | |
| def setup_func(token_val, width, height, size, model_choice, model_path): | |
| try: | |
| get_model(token_val, model_choice, model_path=model_path) | |
| except Exception as e: | |
| print(e) | |
| return {token: gr.update(value=str(e))} | |
| return { | |
| token: gr.update(visible=False), | |
| canvas_width: gr.update(visible=False), | |
| canvas_height: gr.update(visible=False), | |
| selection_size: gr.update(visible=False), | |
| setup_button: gr.update(visible=False), | |
| frame: gr.update(visible=True), | |
| upload_button: gr.update(value="Upload Image"), | |
| model_selection: gr.update(visible=False), | |
| model_path_input: gr.update(visible=False), | |
| } | |
| setup_button.click( | |
| fn=setup_func, | |
| inputs=[ | |
| token, | |
| canvas_width, | |
| canvas_height, | |
| selection_size, | |
| model_selection, | |
| model_path_input, | |
| ], | |
| outputs=[ | |
| token, | |
| canvas_width, | |
| canvas_height, | |
| selection_size, | |
| setup_button, | |
| frame, | |
| upload_button, | |
| model_selection, | |
| model_path_input, | |
| ], | |
| _js=setup_button_js, | |
| ) | |
| proceed_event = proceed_button.click( | |
| fn=run_outpaint, | |
| inputs=[ | |
| model_input, | |
| sd_prompt, | |
| sd_negative_prompt, | |
| sd_strength, | |
| sd_guidance, | |
| sd_step, | |
| sd_resize, | |
| init_mode, | |
| safety_check, | |
| postprocess_check, | |
| sd_img2img, | |
| sd_use_seed, | |
| sd_seed_val, | |
| sd_generate_num, | |
| sd_scheduler, | |
| sd_scheduler_eta, | |
| model_output_state, | |
| ], | |
| outputs=[model_output, sd_prompt, model_output_state], | |
| _js=proceed_button_js, | |
| ) | |
| # cancel button can also remove error overlay | |
| # cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event]) | |
| launch_extra_kwargs = { | |
| "show_error": True, | |
| # "favicon_path": "" | |
| } | |
| launch_kwargs = vars(args) | |
| launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None} | |
| launch_kwargs.pop("remote_model", None) | |
| launch_kwargs.pop("local_model", None) | |
| launch_kwargs.pop("fp32", None) | |
| launch_kwargs.update(launch_extra_kwargs) | |
| try: | |
| import google.colab | |
| launch_kwargs["debug"] = True | |
| except: | |
| pass | |
| if RUN_IN_SPACE: | |
| demo.launch() | |
| elif args.debug: | |
| launch_kwargs["server_name"] = "0.0.0.0" | |
| demo.queue().launch(**launch_kwargs) | |
| else: | |
| demo.queue().launch(**launch_kwargs) | |