import gradio as gr import spaces import torch import diffusers import transformers import copy import random import numpy as np import torchvision.transforms as T import math import os import peft from peft import LoraConfig from safetensors import safe_open from omegaconf import OmegaConf from omnitry.models.transformer_flux import FluxTransformer2DModel from omnitry.pipelines.pipeline_flux_fill import FluxFillPipeline from huggingface_hub import snapshot_download snapshot_download(repo_id="Kunbyte/OmniTry", local_dir="./OmniTry") device = torch.device('cuda:0') weight_dtype = torch.bfloat16 args = OmegaConf.load('configs/omnitry_v1_unified.yaml') # init model transformer = FluxTransformer2DModel.from_pretrained( 'black-forest-labs/FLUX.1-Fill-dev', subfolder='transformer' ).requires_grad_(False).to(device, dtype=weight_dtype) pipeline = FluxFillPipeline.from_pretrained( 'black-forest-labs/FLUX.1-Fill-dev', transformer=transformer, torch_dtype=weight_dtype ).to(device) # insert LoRA lora_config = LoraConfig( r=args.lora_rank, lora_alpha=args.lora_alpha, init_lora_weights="gaussian", target_modules=[ 'x_embedder', 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0', 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out', 'ff.net.0.proj', 'ff.net.2', 'ff_context.net.0.proj', 'ff_context.net.2', 'norm1_context.linear', 'norm1.linear', 'norm.linear', 'proj_mlp', 'proj_out' ] ) transformer.add_adapter(lora_config, adapter_name='vtryon_lora') transformer.add_adapter(lora_config, adapter_name='garment_lora') with safe_open('OmniTry/omnitry_v1_unified.safetensors', framework="pt") as f: lora_weights = {k: f.get_tensor(k) for k in f.keys()} transformer.load_state_dict(lora_weights, strict=False) # hack lora forward def create_hacked_forward(module): def lora_forward(self, active_adapter, x, *args, **kwargs): result = self.base_layer(x, *args, **kwargs) if active_adapter is not None: lora_A = self.lora_A[active_adapter] lora_B = self.lora_B[active_adapter] dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] x = x.to(lora_A.weight.dtype) result = result + lora_B(lora_A(dropout(x))) * scaling return result def hacked_lora_forward(self, x, *args, **kwargs): return torch.cat(( lora_forward(self, 'vtryon_lora', x[:1], *args, **kwargs), lora_forward(self, 'garment_lora', x[1:], *args, **kwargs), ), dim=0) return hacked_lora_forward.__get__(module, type(module)) for n, m in transformer.named_modules(): if isinstance(m, peft.tuners.lora.layer.Linear): m.forward = create_hacked_forward(m) def seed_everything(seed=0): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) @spaces.GPU def generate(person_image, object_image, object_class, steps, guidance_scale, seed): if seed == -1: seed = random.randint(0, 2**32 - 1) seed_everything(seed) max_area = 1024 * 1024 oW, oH = person_image.width, person_image.height ratio = math.sqrt(max_area / (oW * oH)) ratio = min(1, ratio) tW, tH = int(oW * ratio) // 16 * 16, int(oH * ratio) // 16 * 16 transform = T.Compose([ T.Resize((tH, tW)), T.ToTensor(), ]) person_image = transform(person_image) ratio = min(tW / object_image.width, tH / object_image.height) transform = T.Compose([ T.Resize((int(object_image.height * ratio), int(object_image.width * ratio))), T.ToTensor(), ]) object_image_padded = torch.ones_like(person_image) object_image = transform(object_image) new_h, new_w = object_image.shape[1], object_image.shape[2] min_x = (tW - new_w) // 2 min_y = (tH - new_h) // 2 object_image_padded[:, min_y: min_y + new_h, min_x: min_x + new_w] = object_image prompts = [args.object_map[object_class]] * 2 img_cond = torch.stack([person_image, object_image_padded]).to(dtype=weight_dtype, device=device) mask = torch.zeros_like(img_cond).to(img_cond) with torch.no_grad(): img = pipeline( prompt=prompts, height=tH, width=tW, img_cond=img_cond, mask=mask, guidance_scale=guidance_scale, num_inference_steps=steps, generator=torch.Generator(device).manual_seed(seed), ).images[0] return img # Custom CSS custom_css = """ /* 전체 배경 */ .gradio-container { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); font-family: 'Inter', sans-serif; } /* === 플레이스홀더 전부 제거 === */ .gr-image svg, .gr-image [data-testid*="placeholder"], .gr-image [class*="placeholder"], .gr-image [aria-label*="placeholder"], .gr-image [class*="svelte"][class*="placeholder"], .gr-image .absolute.inset-0.flex.items-center.justify-center, .gr-image .flex.items-center.justify-center svg { display: none !important; visibility: hidden !important; } .gr-image [class*="overlay"], .gr-image .fixed.inset-0, .gr-image .absolute.inset-0 { pointer-events: none !important; } /* 이미지 업로드 영역 */ .gr-image .wrap { background: transparent !important; min-height: 400px !important; } .gr-image .upload-container { min-height: 400px !important; border: 3px dashed rgba(102, 126, 234, 0.4) !important; border-radius: 12px !important; background: linear-gradient(135deg, rgba(248, 250, 252, 0.5) 0%, rgba(241, 245, 249, 0.5) 100%) !important; position: relative !important; } /* 이미지 있을 때 */ .gr-image:has(img) .upload-container { border: none !important; background: transparent !important; } /* 안내 텍스트 */ .gr-image .upload-container::after { content: "Click or Drag to Upload"; position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); color: rgba(102, 126, 234, 0.7); font-size: 1.05em; font-weight: 500; pointer-events: none; } .gr-image:has(img) .upload-container::after { display: none !important; } /* 업로드 이미지 */ .gr-image img { border-radius: 12px !important; position: relative !important; z-index: 10 !important; } /* 버튼, 라벨 등 나머지는 그대로 */ .gr-button-primary { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; color: white !important; border: none !important; padding: 15px 40px !important; font-size: 1.2em !important; border-radius: 50px !important; cursor: pointer !important; } """ if __name__ == '__main__': with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="header"): gr.HTML("""

✨ CodiFit-AI Virtual Try-On ✨

Experience the future of fashion with AI-powered virtual clothing try-on

""") with gr.Row(equal_height=True): with gr.Column(scale=1): person_image = gr.Image(type="pil", label="Upload Person Photo", height=500, interactive=True) with gr.Column(scale=1): object_image = gr.Image(type="pil", label="Upload Object Image", height=400, interactive=True) object_class = gr.Dropdown(label='Select Object Category', choices=args.object_map.keys()) run_button = gr.Button(value="🚀 Generate Try-On", variant='primary') with gr.Column(scale=1): image_out = gr.Image(type="pil", label="Virtual Try-On Result", height=500, interactive=False) with gr.Accordion("⚙️ Advanced Settings", open=False): with gr.Row(): guidance_scale = gr.Slider(label="🎯 Guidance Scale", minimum=1, maximum=50, value=30, step=0.1) steps = gr.Slider(label="🔄 Inference Steps", minimum=1, maximum=50, value=20, step=1) seed = gr.Number(label="🎲 Random Seed", value=-1, precision=0) run_button.click(generate, inputs=[person_image, object_image, object_class, steps, guidance_scale, seed], outputs=[image_out]) demo.launch(server_name="0.0.0.0")