#!/usr/bin/env python3 """ LoRA Trainer Funcional para Hugging Face Baseado no kohya-ss sd-scripts """ import gradio as gr import os import sys import json import subprocess import shutil import zipfile import tempfile import toml import logging from pathlib import Path from typing import Optional, Tuple, List, Dict, Any import time import threading import queue # Adicionar o diretório sd-scripts ao path sys.path.insert(0, str(Path(__file__).parent / "sd-scripts")) # Configurar logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class LoRATrainerHF: def __init__(self): self.base_dir = Path("/tmp/lora_training") self.base_dir.mkdir(exist_ok=True) self.models_dir = self.base_dir / "models" self.models_dir.mkdir(exist_ok=True) self.projects_dir = self.base_dir / "projects" self.projects_dir.mkdir(exist_ok=True) self.sd_scripts_dir = Path(__file__).parent / "sd-scripts" # URLs dos modelos self.model_urls = { "Anime (animefull-final-pruned)": "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/animefull-final-pruned-fp16.safetensors", "AnyLoRA": "https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16-pruned.ckpt", "Stable Diffusion 1.5": "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors", "Waifu Diffusion 1.4": "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt" } self.training_process = None self.training_output_queue = queue.Queue() def install_dependencies(self) -> str: """Instala as dependências necessárias""" try: logger.info("Instalando dependências...") # Lista de pacotes necessários packages = [ "torch>=2.0.0", "torchvision>=0.15.0", "diffusers>=0.21.0", "transformers>=4.25.0", "accelerate>=0.20.0", "safetensors>=0.3.0", "huggingface-hub>=0.16.0", "xformers>=0.0.20", "bitsandbytes>=0.41.0", "opencv-python>=4.7.0", "Pillow>=9.0.0", "numpy>=1.21.0", "tqdm>=4.64.0", "toml>=0.10.0", "tensorboard>=2.13.0", "wandb>=0.15.0", "scipy>=1.9.0", "matplotlib>=3.5.0", "datasets>=2.14.0", "peft>=0.5.0", "omegaconf>=2.3.0" ] # Instalar pacotes for package in packages: try: subprocess.run([ sys.executable, "-m", "pip", "install", package, "--quiet" ], check=True, capture_output=True, text=True) logger.info(f"✓ {package} instalado") except subprocess.CalledProcessError as e: logger.warning(f"⚠ Erro ao instalar {package}: {e}") return "✅ Dependências instaladas com sucesso!" except Exception as e: logger.error(f"Erro ao instalar dependências: {e}") return f"❌ Erro ao instalar dependências: {e}" def download_model(self, model_choice: str, custom_url: str = "") -> str: """Download do modelo base""" try: if custom_url.strip(): model_url = custom_url.strip() model_name = model_url.split("/")[-1] else: if model_choice not in self.model_urls: return f"❌ Modelo '{model_choice}' não encontrado" model_url = self.model_urls[model_choice] model_name = model_url.split("/")[-1] model_path = self.models_dir / model_name if model_path.exists(): return f"✅ Modelo já existe: {model_name}" logger.info(f"Baixando modelo: {model_url}") # Download usando wget result = subprocess.run([ "wget", "-O", str(model_path), model_url, "--progress=bar:force" ], capture_output=True, text=True) if result.returncode == 0: return f"✅ Modelo baixado: {model_name} ({model_path.stat().st_size // (1024*1024)} MB)" else: return f"❌ Erro no download: {result.stderr}" except Exception as e: logger.error(f"Erro ao baixar modelo: {e}") return f"❌ Erro ao baixar modelo: {e}" def process_dataset(self, dataset_zip, project_name: str) -> Tuple[str, str]: """Processa o dataset enviado""" try: if not dataset_zip: return "❌ Nenhum dataset foi enviado", "" if not project_name.strip(): return "❌ Nome do projeto é obrigatório", "" project_name = project_name.strip().replace(" ", "_") project_dir = self.projects_dir / project_name project_dir.mkdir(exist_ok=True) dataset_dir = project_dir / "dataset" if dataset_dir.exists(): shutil.rmtree(dataset_dir) dataset_dir.mkdir() # Extrair ZIP with zipfile.ZipFile(dataset_zip.name, 'r') as zip_ref: zip_ref.extractall(dataset_dir) # Analisar dataset image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'} images = [] captions = [] for file_path in dataset_dir.rglob("*"): if file_path.suffix.lower() in image_extensions: images.append(file_path) # Procurar caption caption_path = file_path.with_suffix('.txt') if caption_path.exists(): captions.append(caption_path) info = f"✅ Dataset processado!\n" info += f"📁 Projeto: {project_name}\n" info += f"🖼️ Imagens: {len(images)}\n" info += f"📝 Captions: {len(captions)}\n" info += f"📂 Diretório: {dataset_dir}" return info, str(dataset_dir) except Exception as e: logger.error(f"Erro ao processar dataset: {e}") return f"❌ Erro ao processar dataset: {e}", "" def create_training_config(self, project_name: str, dataset_dir: str, model_choice: str, custom_model_url: str, resolution: int, batch_size: int, epochs: int, learning_rate: float, text_encoder_lr: float, network_dim: int, network_alpha: int, lora_type: str, optimizer: str, scheduler: str, flip_aug: bool, shuffle_caption: bool, keep_tokens: int, clip_skip: int, mixed_precision: str, save_every_n_epochs: int, max_train_steps: int) -> str: """Cria configuração de treinamento""" try: if not project_name.strip(): return "❌ Nome do projeto é obrigatório" project_name = project_name.strip().replace(" ", "_") project_dir = self.projects_dir / project_name project_dir.mkdir(exist_ok=True) output_dir = project_dir / "output" output_dir.mkdir(exist_ok=True) log_dir = project_dir / "logs" log_dir.mkdir(exist_ok=True) # Determinar modelo if custom_model_url.strip(): model_name = custom_model_url.strip().split("/")[-1] else: model_name = self.model_urls[model_choice].split("/")[-1] model_path = self.models_dir / model_name if not model_path.exists(): return f"❌ Modelo não encontrado: {model_name}. Faça o download primeiro." # Configuração do dataset dataset_config = { "general": { "shuffle_caption": shuffle_caption, "caption_extension": ".txt", "keep_tokens": keep_tokens, "flip_aug": flip_aug, "color_aug": False, "face_crop_aug_range": None, "random_crop": False, "debug_dataset": False }, "datasets": [{ "resolution": resolution, "batch_size": batch_size, "subsets": [{ "image_dir": str(dataset_dir), "num_repeats": 1 }] }] } # Configuração de treinamento training_config = { "model_arguments": { "pretrained_model_name_or_path": str(model_path), "v2": False, "v_parameterization": False, "clip_skip": clip_skip }, "dataset_arguments": { "dataset_config": str(project_dir / "dataset_config.toml") }, "training_arguments": { "output_dir": str(output_dir), "output_name": project_name, "save_precision": "fp16", "save_every_n_epochs": save_every_n_epochs, "max_train_epochs": epochs if max_train_steps == 0 else None, "max_train_steps": max_train_steps if max_train_steps > 0 else None, "train_batch_size": batch_size, "gradient_accumulation_steps": 1, "learning_rate": learning_rate, "text_encoder_lr": text_encoder_lr, "lr_scheduler": scheduler, "lr_warmup_steps": 0, "optimizer_type": optimizer, "mixed_precision": mixed_precision, "save_model_as": "safetensors", "seed": 42, "max_data_loader_n_workers": 2, "persistent_data_loader_workers": True, "gradient_checkpointing": True, "xformers": True, "lowram": True, "cache_latents": True, "cache_latents_to_disk": True, "logging_dir": str(log_dir), "log_with": "tensorboard" }, "network_arguments": { "network_module": "networks.lora" if lora_type == "LoRA" else "networks.dylora", "network_dim": network_dim, "network_alpha": network_alpha, "network_train_unet_only": False, "network_train_text_encoder_only": False } } # Adicionar argumentos específicos para LoCon if lora_type == "LoCon": training_config["network_arguments"]["network_module"] = "networks.lora" training_config["network_arguments"]["conv_dim"] = max(1, network_dim // 2) training_config["network_arguments"]["conv_alpha"] = max(1, network_alpha // 2) # Salvar configurações dataset_config_path = project_dir / "dataset_config.toml" training_config_path = project_dir / "training_config.toml" with open(dataset_config_path, 'w') as f: toml.dump(dataset_config, f) with open(training_config_path, 'w') as f: toml.dump(training_config, f) return f"✅ Configuração criada!\n📁 Dataset: {dataset_config_path}\n⚙️ Treinamento: {training_config_path}" except Exception as e: logger.error(f"Erro ao criar configuração: {e}") return f"❌ Erro ao criar configuração: {e}" def start_training(self, project_name: str) -> str: """Inicia o treinamento""" try: if not project_name.strip(): return "❌ Nome do projeto é obrigatório" project_name = project_name.strip().replace(" ", "_") project_dir = self.projects_dir / project_name training_config_path = project_dir / "training_config.toml" if not training_config_path.exists(): return "❌ Configuração não encontrada. Crie a configuração primeiro." # Script de treinamento train_script = self.sd_scripts_dir / "train_network.py" if not train_script.exists(): return "❌ Script de treinamento não encontrado" # Comando de treinamento cmd = [ sys.executable, str(train_script), "--config_file", str(training_config_path) ] logger.info(f"Iniciando treinamento: {' '.join(cmd)}") # Executar em thread separada def run_training(): try: process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True, cwd=str(self.sd_scripts_dir) ) self.training_process = process for line in process.stdout: self.training_output_queue.put(line.strip()) logger.info(line.strip()) process.wait() if process.returncode == 0: self.training_output_queue.put("✅ TREINAMENTO CONCLUÍDO COM SUCESSO!") else: self.training_output_queue.put(f"❌ TREINAMENTO FALHOU (código {process.returncode})") except Exception as e: self.training_output_queue.put(f"❌ ERRO NO TREINAMENTO: {e}") finally: self.training_process = None # Iniciar thread training_thread = threading.Thread(target=run_training) training_thread.daemon = True training_thread.start() return "🚀 Treinamento iniciado! Acompanhe o progresso abaixo." except Exception as e: logger.error(f"Erro ao iniciar treinamento: {e}") return f"❌ Erro ao iniciar treinamento: {e}" def get_training_output(self) -> str: """Obtém output do treinamento""" output_lines = [] try: while not self.training_output_queue.empty(): line = self.training_output_queue.get_nowait() output_lines.append(line) except queue.Empty: pass if output_lines: return "\n".join(output_lines) elif self.training_process and self.training_process.poll() is None: return "🔄 Treinamento em andamento..." else: return "⏸️ Nenhum treinamento ativo" def stop_training(self) -> str: """Para o treinamento""" try: if self.training_process and self.training_process.poll() is None: self.training_process.terminate() self.training_process.wait(timeout=10) return "⏹️ Treinamento interrompido" else: return "ℹ️ Nenhum treinamento ativo para parar" except Exception as e: return f"❌ Erro ao parar treinamento: {e}" def list_output_files(self, project_name: str) -> List[str]: """Lista arquivos de saída""" try: if not project_name.strip(): return [] project_name = project_name.strip().replace(" ", "_") project_dir = self.projects_dir / project_name output_dir = project_dir / "output" if not output_dir.exists(): return [] files = [] for file_path in output_dir.rglob("*.safetensors"): size_mb = file_path.stat().st_size // (1024 * 1024) files.append(f"{file_path.name} ({size_mb} MB)") return sorted(files, reverse=True) # Mais recentes primeiro except Exception as e: logger.error(f"Erro ao listar arquivos: {e}") return [] # Instância global trainer = LoRATrainerHF() def create_interface(): """Cria a interface Gradio""" with gr.Blocks(title="LoRA Trainer Funcional - Hugging Face", theme=gr.themes.Soft()) as interface: gr.Markdown(""" # 🎨 LoRA Trainer Funcional para Hugging Face **Treine seus próprios modelos LoRA para Stable Diffusion de forma profissional!** Esta ferramenta é baseada no kohya-ss sd-scripts e oferece treinamento real e funcional de modelos LoRA. """) # Estado para armazenar informações dataset_dir_state = gr.State("") with gr.Tab("🔧 Instalação"): gr.Markdown("### Primeiro, instale as dependências necessárias:") install_btn = gr.Button("📦 Instalar Dependências", variant="primary", size="lg") install_status = gr.Textbox(label="Status da Instalação", lines=3, interactive=False) install_btn.click( fn=trainer.install_dependencies, outputs=install_status ) with gr.Tab("📁 Configuração do Projeto"): with gr.Row(): project_name = gr.Textbox( label="Nome do Projeto", placeholder="meu_lora_anime", info="Nome único para seu projeto (sem espaços especiais)" ) gr.Markdown("### 📥 Download do Modelo Base") with gr.Row(): model_choice = gr.Dropdown( choices=list(trainer.model_urls.keys()), label="Modelo Base Pré-definido", value="Anime (animefull-final-pruned)", info="Escolha um modelo base ou use URL personalizada" ) custom_model_url = gr.Textbox( label="URL Personalizada (opcional)", placeholder="https://huggingface.co/...", info="URL direta para download de modelo personalizado" ) download_btn = gr.Button("📥 Baixar Modelo", variant="primary") download_status = gr.Textbox(label="Status do Download", lines=2, interactive=False) gr.Markdown("### 📊 Upload do Dataset") gr.Markdown(""" **Formato do Dataset:** - Crie um arquivo ZIP contendo suas imagens - Para cada imagem, inclua um arquivo .txt com o mesmo nome contendo as tags/descrições - Exemplo: `imagem1.jpg` + `imagem1.txt` """) dataset_upload = gr.File( label="Upload do Dataset (ZIP)", file_types=[".zip"] ) process_btn = gr.Button("📊 Processar Dataset", variant="primary") dataset_status = gr.Textbox(label="Status do Dataset", lines=4, interactive=False) with gr.Tab("⚙️ Parâmetros de Treinamento"): with gr.Row(): with gr.Column(): gr.Markdown("#### 🖼️ Configurações de Imagem") resolution = gr.Slider( minimum=512, maximum=1024, step=64, value=512, label="Resolução", info="Resolução das imagens (512 = mais rápido, 1024 = melhor qualidade)" ) batch_size = gr.Slider( minimum=1, maximum=8, step=1, value=1, label="Batch Size", info="Imagens por lote (aumente se tiver GPU potente)" ) flip_aug = gr.Checkbox( label="Flip Augmentation", info="Espelhar imagens para aumentar dataset" ) shuffle_caption = gr.Checkbox( value=True, label="Shuffle Caption", info="Embaralhar ordem das tags" ) keep_tokens = gr.Slider( minimum=0, maximum=5, step=1, value=1, label="Keep Tokens", info="Número de tokens iniciais que não serão embaralhados" ) with gr.Column(): gr.Markdown("#### 🎯 Configurações de Treinamento") epochs = gr.Slider( minimum=1, maximum=100, step=1, value=10, label="Épocas", info="Número de épocas de treinamento" ) max_train_steps = gr.Number( value=0, label="Max Train Steps (0 = usar épocas)", info="Número máximo de steps (deixe 0 para usar épocas)" ) save_every_n_epochs = gr.Slider( minimum=1, maximum=10, step=1, value=1, label="Salvar a cada N épocas", info="Frequência de salvamento dos checkpoints" ) mixed_precision = gr.Dropdown( choices=["fp16", "bf16", "no"], value="fp16", label="Mixed Precision", info="fp16 = mais rápido, bf16 = mais estável" ) clip_skip = gr.Slider( minimum=1, maximum=12, step=1, value=2, label="CLIP Skip", info="Camadas CLIP a pular (2 para anime, 1 para realista)" ) with gr.Row(): with gr.Column(): gr.Markdown("#### 📚 Learning Rate") learning_rate = gr.Number( value=1e-4, label="Learning Rate (UNet)", info="Taxa de aprendizado principal" ) text_encoder_lr = gr.Number( value=5e-5, label="Learning Rate (Text Encoder)", info="Taxa de aprendizado do text encoder" ) scheduler = gr.Dropdown( choices=["cosine", "cosine_with_restarts", "constant", "constant_with_warmup", "linear"], value="cosine_with_restarts", label="LR Scheduler", info="Algoritmo de ajuste da learning rate" ) optimizer = gr.Dropdown( choices=["AdamW8bit", "AdamW", "Lion", "SGD"], value="AdamW8bit", label="Otimizador", info="AdamW8bit = menos memória" ) with gr.Column(): gr.Markdown("#### 🧠 Arquitetura LoRA") lora_type = gr.Radio( choices=["LoRA", "LoCon"], value="LoRA", label="Tipo de LoRA", info="LoRA = geral, LoCon = estilos artísticos" ) network_dim = gr.Slider( minimum=4, maximum=128, step=4, value=32, label="Network Dimension", info="Dimensão da rede (maior = mais detalhes, mais memória)" ) network_alpha = gr.Slider( minimum=1, maximum=128, step=1, value=16, label="Network Alpha", info="Controla a força do LoRA (geralmente dim/2)" ) with gr.Tab("🚀 Treinamento"): create_config_btn = gr.Button("📝 Criar Configuração de Treinamento", variant="primary", size="lg") config_status = gr.Textbox(label="Status da Configuração", lines=3, interactive=False) with gr.Row(): start_training_btn = gr.Button("🎯 Iniciar Treinamento", variant="primary", size="lg") stop_training_btn = gr.Button("⏹️ Parar Treinamento", variant="stop") training_output = gr.Textbox( label="Output do Treinamento", lines=15, interactive=False, info="Acompanhe o progresso do treinamento em tempo real" ) # Auto-refresh do output def update_output(): return trainer.get_training_output() with gr.Tab("📥 Download dos Resultados"): refresh_files_btn = gr.Button("🔄 Atualizar Lista de Arquivos", variant="secondary") output_files = gr.Dropdown( label="Arquivos LoRA Gerados", choices=[], info="Selecione um arquivo para download" ) download_info = gr.Markdown("ℹ️ Os arquivos LoRA estarão disponíveis após o treinamento") # Event handlers download_btn.click( fn=trainer.download_model, inputs=[model_choice, custom_model_url], outputs=download_status ) process_btn.click( fn=trainer.process_dataset, inputs=[dataset_upload, project_name], outputs=[dataset_status, dataset_dir_state] ) create_config_btn.click( fn=trainer.create_training_config, inputs=[ project_name, dataset_dir_state, model_choice, custom_model_url, resolution, batch_size, epochs, learning_rate, text_encoder_lr, network_dim, network_alpha, lora_type, optimizer, scheduler, flip_aug, shuffle_caption, keep_tokens, clip_skip, mixed_precision, save_every_n_epochs, max_train_steps ], outputs=config_status ) start_training_btn.click( fn=trainer.start_training, inputs=project_name, outputs=training_output ) stop_training_btn.click( fn=trainer.stop_training, outputs=training_output ) refresh_files_btn.click( fn=trainer.list_output_files, inputs=project_name, outputs=output_files ) return interface if __name__ == "__main__": print("🚀 Iniciando LoRA Trainer Funcional...") interface = create_interface() interface.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )