#!/usr/bin/env python3 import torch from diffusers import FluxPipeline, DPMSolverMultistepScheduler from BeamDiffusionModel.models.diffusionModel.configs.config_loader import CONFIG from functools import partial from BeamDiffusionModel.models.diffusionModel.Latents_Singleton import Latents class Flux: def __init__(self): self.device = "cuda" if CONFIG.get("flux", {}).get("use_cuda", True) and torch.cuda.is_available() else "cpu" self.torch_dtype = torch.bfloat16 if CONFIG.get("flux", {}).get("precision") == "bfloat16" else torch.float16 print(f"Loading model: {CONFIG['flux']['id']} on {self.device}") self.pipe = FluxPipeline.from_pretrained(CONFIG["flux"]["id"], torch_dtype=torch.bfloat16) self.pipe.enable_sequential_cpu_offload() self.pipe.vae.enable_slicing() self.pipe.vae.enable_tiling() self.pipe.tokenizer.truncation_side = 'left' print("Model loaded successfully!") def capture_latents(self, latents_store: Latents, pipe, step, timestep, callback_kwargs): latents = callback_kwargs["latents"] latents_store.add_latents(latents) return callback_kwargs def generate_image(self, prompt: str, latent=None, generator=None): latents = Latents() callback = partial(self.capture_latents, latents) img = self.pipe(prompt, latents=latent, callback_on_step_end=callback, generator=generator, callback_on_step_end_tensor_inputs=["latents"], height=768, width=768, guidance_scale=3.5, max_sequence_length=512, num_inference_steps=CONFIG["flux"]["diffusion_settings"]["steps"]).images[0] return img, latents.dump_and_clear()