Spaces:
Runtime error
Runtime error
File size: 6,813 Bytes
3470339 86b3270 3470339 cf24bfb 3470339 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# flux_kontext_helpers.py (ADUC: O Especialista Pintor - com suporte a callback)
# Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
import torch
from PIL import Image, ImageOps
import gc
from diffusers import FluxKontextPipeline
import huggingface_hub
import os
import threading
import yaml
import logging
from hardware_manager import hardware_manager
logger = logging.getLogger(__name__)
class FluxWorker:
"""Representa uma única instância do pipeline FluxKontext em um dispositivo."""
def __init__(self, device_id='cuda:0'):
self.cpu_device = torch.device('cpu')
self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
self.pipe = None
self._load_pipe_to_cpu()
def _load_pipe_to_cpu(self):
if self.pipe is None:
logger.info(f"FLUX Worker ({self.device}): Carregando modelo para a CPU...")
self.pipe = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
).to(self.cpu_device)
logger.info(f"FLUX Worker ({self.device}): Modelo pronto na CPU.")
def to_gpu(self):
if self.device.type == 'cpu': return
logger.info(f"FLUX Worker: Movendo modelo para a GPU {self.device}...")
self.pipe.to(self.device)
def to_cpu(self):
if self.device.type == 'cpu': return
logger.info(f"FLUX Worker: Descarregando modelo da GPU {self.device}...")
self.pipe.to(self.cpu_device)
gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()
def _create_composite_reference(self, images: list[Image.Image], target_width: int, target_height: int) -> Image.Image:
if not images: return None
valid_images = [img.convert("RGB") for img in images if img is not None]
if not valid_images: return None
if len(valid_images) == 1:
if valid_images[0].size != (target_width, target_height):
return ImageOps.fit(valid_images[0], (target_width, target_height), Image.Resampling.LANCZOS)
return valid_images[0]
base_height = valid_images[0].height
resized_for_concat = []
for img in valid_images:
if img.height != base_height:
aspect_ratio = img.width / img.height
new_width = int(base_height * aspect_ratio)
resized_for_concat.append(img.resize((new_width, base_height), Image.Resampling.LANCZOS))
else:
resized_for_concat.append(img)
total_width = sum(img.width for img in resized_for_concat)
concatenated = Image.new('RGB', (total_width, base_height))
x_offset = 0
for img in resized_for_concat:
concatenated.paste(img, (x_offset, 0))
x_offset += img.width
#final_reference = ImageOps.fit(concatenated, (target_width, target_height), Image.Resampling.LANCZOS)
return concatenated
@torch.inference_mode()
def generate_image_internal(self, reference_images: list[Image.Image], prompt: str, target_width: int, target_height: int, seed: int, callback: callable = None):
composite_reference = self._create_composite_reference(reference_images, target_width, target_height)
num_steps = 12 # Valor fixo otimizado
logger.info(f"\n===== [CHAMADA AO PIPELINE FLUX em {self.device}] =====\n"
f" - Prompt: '{prompt}'\n"
f" - Resolução: {target_width}x{target_height}, Seed: {seed}, Passos: {num_steps}\n"
f" - Nº de Imagens na Composição: {len(reference_images)}\n"
f"==========================================")
generated_image = self.pipe(
image=composite_reference,
prompt=prompt,
guidance_scale=2.5,
width=target_width,
height=target_height,
num_inference_steps=num_steps,
generator=torch.Generator(device="cpu").manual_seed(seed),
callback_on_step_end=callback,
callback_on_step_end_tensor_inputs=["latents"] if callback else None
).images[0]
return generated_image
class FluxPoolManager:
def __init__(self, device_ids):
logger.info(f"FLUX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
self.workers = [FluxWorker(device_id) for device_id in device_ids]
self.current_worker_index = 0
self.lock = threading.Lock()
self.last_cleanup_thread = None
def _cleanup_worker_thread(self, worker):
logger.info(f"FLUX CLEANUP THREAD: Iniciando limpeza de {worker.device} em background...")
worker.to_cpu()
def generate_image(self, reference_images, prompt, width, height, seed=42, callback=None):
worker_to_use = None
try:
with self.lock:
if self.last_cleanup_thread and self.last_cleanup_thread.is_alive():
self.last_cleanup_thread.join()
worker_to_use = self.workers[self.current_worker_index]
previous_worker_index = (self.current_worker_index - 1 + len(self.workers)) % len(self.workers)
worker_to_cleanup = self.workers[previous_worker_index]
cleanup_thread = threading.Thread(target=self._cleanup_worker_thread, args=(worker_to_cleanup,))
cleanup_thread.start()
self.last_cleanup_thread = cleanup_thread
worker_to_use.to_gpu()
self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
logger.info(f"FLUX POOL MANAGER: Gerando imagem em {worker_to_use.device}...")
return worker_to_use.generate_image_internal(
reference_images=reference_images,
prompt=prompt,
target_width=width,
target_height=height,
seed=seed,
callback=callback
)
except Exception as e:
logger.error(f"FLUX POOL MANAGER: Erro durante a geração: {e}", exc_info=True)
raise e
finally:
pass
# --- Instanciação Singleton Dinâmica ---
logger.info("Lendo config.yaml para inicializar o FluxKontext Pool Manager...")
with open("config.yaml", 'r') as f: config = yaml.safe_load(f)
hf_token = os.getenv('HF_TOKEN');
if hf_token: huggingface_hub.login(token=hf_token)
flux_gpus_required = config['specialists']['flux']['gpus_required']
flux_device_ids = hardware_manager.allocate_gpus('Flux', flux_gpus_required)
flux_kontext_singleton = FluxPoolManager(device_ids=flux_device_ids)
logger.info("Especialista de Imagem (Flux) pronto.") |