from transformers import PretrainedConfig, PreTrainedModel, Pipeline import torch from BeamDiffusionModel.beamInference import beam_inference from BeamDiffusionModel.models.diffusionModel.StableDiffusion import StableDiffusion from BeamDiffusionModel.models.diffusionModel.Flux import Flux # Your custom configuration for the BeamDiffusion model class BeamDiffusionConfig(PretrainedConfig): model_type = "beam_diffusion" def __init__(self, sd="SD-2.1",latents_idx=None, n_seeds=4, seeds=None, steps_back=2, beam_width=4, window_size=2, use_rand=True, **kwargs): super().__init__(**kwargs) self.sd_name = sd self.sd = None self.get_model(sd) self.latents_idx = latents_idx if latents_idx else [0, 1, 2, 3] self.n_seeds = n_seeds self.seeds = seeds if seeds else [] self.steps_back = steps_back self.beam_width = beam_width self.window_size = window_size self.use_rand = use_rand def get_model(self, sd): if self.sd_name == "flux": self.sd = Flux() elif self.sd_name == "SD-2.1": self.sd = StableDiffusion() import torch.nn as nn from huggingface_hub import ModelHubMixin # Custom BeamDiffusionModel that performs inference for each step class BeamDiffusionModel(PreTrainedModel, ModelHubMixin): config_class = BeamDiffusionConfig model_type = "beam_diffusion" def __init__(self, config): super().__init__(config) self.config = config self.dummy_param = nn.Parameter(torch.zeros(1)) # Ensure at least one parameter def forward(self, input_data): images = beam_inference( self.config.sd, steps=input_data.get('steps', []), latents_idx=self.config.latents_idx, n_seeds=self.config.n_seeds, seeds=self.config.seeds, steps_back=self.config.steps_back, beam_width=self.config.beam_width, window_size=self.config.window_size, use_rand=self.config.use_rand, ) return {"images": images} # Custom pipeline to handle inference class BeamDiffusionPipeline(Pipeline, ModelHubMixin): def __init__(self, model, tokenizer=None, device="cuda", framework="pt"): super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework) def __call__(self, inputs): return self._forward(inputs) def preprocess(self, inputs): """Converts raw input data into model-ready format.""" return inputs # Keep as-is def postprocess(self, model_outputs): """Processes model output into a user-friendly format.""" return model_outputs["images"] # Ensure this matches expected output def _sanitize_parameters(self, **kwargs): """Handles unused parameters gracefully.""" return {}, {}, {} def _forward(self, model_inputs): return self.model(model_inputs)