from typing import List, Dict, Any, Tuple from PIL import Image from weave_prompt import PromptOptimizer, ImageEvaluator, PromptRefiner, ImageSimilarityMetric from image_generators import MultiModelFalImageGenerator class MultiModelPromptOptimizer: """Sequential multi-model prompt optimizer that finds the best model-prompt combination.""" def __init__(self, image_generator: MultiModelFalImageGenerator, evaluator: ImageEvaluator, refiner: PromptRefiner, similarity_metric: ImageSimilarityMetric, max_iterations: int = 10, similarity_threshold: float = 0.95): """Initialize the multi-model optimizer. Args: image_generator: Multi-model image generator evaluator: Image evaluator for generating initial prompt and analysis refiner: Prompt refinement strategy similarity_metric: Image similarity metric max_iterations: Maximum number of optimization iterations per model similarity_threshold: Target similarity threshold for early stopping """ self.image_generator = image_generator self.evaluator = evaluator self.refiner = refiner self.similarity_metric = similarity_metric self.max_iterations = max_iterations self.similarity_threshold = similarity_threshold # Multi-model state self.target_img = None self.current_model_index = 0 self.model_results = {} # Results per model self.current_optimizer = None self.best_result = None # Initialize individual optimizers for each model self._create_current_optimizer() def _create_current_optimizer(self): """Create optimizer for the current model.""" if self.current_model_index < len(self.image_generator.selected_models): # Set the image generator to current model self.image_generator.current_model_index = self.current_model_index # Create individual optimizer for current model self.current_optimizer = PromptOptimizer( image_generator=self.image_generator, evaluator=self.evaluator, refiner=self.refiner, similarity_metric=self.similarity_metric, max_iterations=self.max_iterations, similarity_threshold=self.similarity_threshold ) def get_current_model_name(self) -> str: """Get the name of the currently active model.""" # Ensure the image generator index is synchronized self.image_generator.current_model_index = self.current_model_index return self.image_generator.get_current_model_name() def get_progress_info(self) -> Dict[str, Any]: """Get current progress information.""" total_models = len(self.image_generator.selected_models) current_model = self.current_model_index + 1 info = { 'current_model_index': self.current_model_index, 'current_model_name': self.get_current_model_name(), 'total_models': total_models, 'models_completed': self.current_model_index, 'overall_progress': self.current_model_index / total_models if total_models > 0 else 0, 'is_last_model': self.current_model_index >= total_models - 1 } if self.current_optimizer: info['current_iteration'] = len(self.current_optimizer.history) info['max_iterations'] = self.max_iterations info['model_progress'] = len(self.current_optimizer.history) / self.max_iterations return info def initialize(self, target_img: Image.Image) -> Tuple[bool, str, Image.Image]: """Initialize the multi-model optimization process. Args: target_img: Target image to optimize towards Returns: Tuple of (is_completed, current_prompt, current_generated_image) """ self.target_img = target_img self.current_model_index = 0 self.model_results = {} self.best_result = None # Reset image generator to first model self.image_generator.reset_to_first_model() self._create_current_optimizer() # Initialize first model return self.current_optimizer.initialize(target_img) def step(self) -> Tuple[bool, str, Image.Image]: """Perform one optimization step. Returns: Tuple of (is_completed, current_prompt, current_generated_image) """ if not self.current_optimizer: raise RuntimeError("Must call initialize() before step()") # Step the current model optimizer is_model_completed, prompt, generated_image = self.current_optimizer.step() if is_model_completed: # Store results for current model - use data from history to ensure consistency model_name = self.get_current_model_name() if len(self.current_optimizer.history) > 0: # Use the last step from history as the final result (ensures consistency) last_step = self.current_optimizer.history[-1] final_prompt = last_step['prompt'] final_image = last_step['image'] final_similarity = last_step['similarity'] else: # Fallback to step results if no history (shouldn't happen) final_prompt = prompt final_image = generated_image final_similarity = 0.0 self.model_results[model_name] = { 'final_prompt': final_prompt, 'final_image': final_image, 'final_similarity': final_similarity, 'history': self.current_optimizer.history.copy(), 'iterations': len(self.current_optimizer.history) } # Update best result if this is better if self.best_result is None or final_similarity > self.best_result['similarity']: self.best_result = { 'model_name': model_name, 'prompt': final_prompt, 'image': final_image, 'similarity': final_similarity } # Move to next model self.current_model_index += 1 if self.current_model_index < len(self.image_generator.selected_models): # Initialize next model - ensure both indices are synchronized self.image_generator.current_model_index = self.current_model_index self._create_current_optimizer() return self.current_optimizer.initialize(self.target_img) else: # All models completed - return best result return True, self.best_result['prompt'], self.best_result['image'] return is_model_completed, prompt, generated_image def get_all_results(self) -> Dict[str, Dict[str, Any]]: """Get results from all completed models.""" return self.model_results.copy() def get_best_result(self) -> Dict[str, Any]: """Get the best result across all models.""" return self.best_result.copy() if self.best_result else None @property def history(self): """Get history from current optimizer for compatibility.""" if self.current_optimizer: return self.current_optimizer.history return []