Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from PIL import Image | |
| from dotenv import load_dotenv | |
| from image_evaluators import LlamaEvaluator | |
| from prompt_refiners import LlamaPromptRefiner | |
| from weave_prompt import PromptOptimizer | |
| from similarity_metrics import LPIPSImageSimilarityMetric | |
| from image_generators import FalImageGenerator | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| st.set_page_config( | |
| page_title="WeavePrompt", | |
| page_icon="🎨", | |
| layout="wide" | |
| ) | |
| def main(): | |
| st.title("🎨 WeavePrompt: Iterative Prompt Optimization") | |
| st.markdown(""" | |
| Upload a target image and watch as WeavePrompt iteratively optimizes a text prompt to recreate it. | |
| """) | |
| # Initialize session state | |
| if 'optimizer' not in st.session_state: | |
| st.session_state.optimizer = PromptOptimizer( | |
| image_generator=FalImageGenerator(), | |
| evaluator=LlamaEvaluator(), | |
| refiner=LlamaPromptRefiner(), | |
| similarity_metric=LPIPSImageSimilarityMetric(), | |
| max_iterations=10, | |
| similarity_threshold=0.95 | |
| ) | |
| if 'optimization_started' not in st.session_state: | |
| st.session_state.optimization_started = False | |
| if 'current_results' not in st.session_state: | |
| st.session_state.current_results = None | |
| # Auto mode state | |
| if 'auto_mode' not in st.session_state: | |
| st.session_state.auto_mode = False | |
| if 'auto_paused' not in st.session_state: | |
| st.session_state.auto_paused = False | |
| # Auto mode step control - use this to control when to step vs when to display | |
| if 'auto_should_step' not in st.session_state: | |
| st.session_state.auto_should_step = False | |
| # File uploader | |
| uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg']) | |
| if uploaded_file is not None: | |
| # Display target image | |
| target_image = Image.open(uploaded_file) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Target Image") | |
| st.image(target_image, width='stretch') | |
| # Start button | |
| if not st.session_state.optimization_started: | |
| if st.button("Start Optimization"): | |
| st.session_state.optimization_started = True | |
| # Initialize optimization | |
| is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image) | |
| st.session_state.current_results = (is_completed, prompt, generated_image) | |
| st.rerun() | |
| else: | |
| # Show disabled button or status when optimization has started | |
| st.button("Start Optimization", disabled=True, help="Optimization in progress") | |
| # Display optimization progress | |
| if st.session_state.optimization_started: | |
| if st.session_state.current_results is not None: | |
| with col2: | |
| st.subheader("Generated Image") | |
| is_completed, prompt, generated_image = st.session_state.current_results | |
| st.image(generated_image, width='stretch') | |
| # Display prompt and controls | |
| st.text_area("Current Prompt", prompt, height=100) | |
| else: | |
| # Show loading state | |
| with col2: | |
| st.subheader("Generated Image") | |
| st.info("Initializing optimization...") | |
| st.text_area("Current Prompt", "Generating initial prompt...", height=100) | |
| is_completed = False | |
| prompt = "" | |
| # Auto mode controls | |
| st.subheader("Auto Mode Controls") | |
| col_auto1, col_auto2 = st.columns(2) | |
| with col_auto1: | |
| auto_mode = st.checkbox("Auto-progress steps", value=st.session_state.auto_mode) | |
| if auto_mode != st.session_state.auto_mode: | |
| st.session_state.auto_mode = auto_mode | |
| if auto_mode: | |
| st.session_state.auto_paused = False | |
| st.session_state.auto_should_step = True # Start by stepping | |
| st.rerun() | |
| with col_auto2: | |
| if st.session_state.auto_mode: | |
| if st.session_state.auto_paused: | |
| if st.button("▶️ Resume", key="resume_btn"): | |
| st.session_state.auto_paused = False | |
| st.rerun() | |
| else: | |
| if st.button("⏸️ Pause", key="pause_btn"): | |
| st.session_state.auto_paused = True | |
| st.rerun() | |
| # Progress metrics | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| # Show current iteration: completed steps + 1 if still in progress | |
| current_iteration = len(st.session_state.optimizer.history) + (0 if is_completed else 1) | |
| st.metric("Iteration", current_iteration) | |
| with col2: | |
| if len(st.session_state.optimizer.history) > 0: | |
| similarity = st.session_state.optimizer.history[-1]['similarity'] | |
| st.metric("Similarity", f"{similarity:.2%}") | |
| with col3: | |
| status = "Completed" if is_completed else "Paused" if st.session_state.auto_paused else "In Progress" | |
| st.metric("Status", status) | |
| # Progress bar | |
| st.subheader("Progress") | |
| max_iterations = st.session_state.optimizer.max_iterations | |
| progress_value = min(current_iteration / max_iterations, 1.0) if max_iterations > 0 else 0.0 | |
| st.progress(progress_value, text=f"Step {current_iteration} of {max_iterations}") | |
| # Next step logic | |
| if not is_completed: | |
| # Auto mode logic - mimic pause button behavior | |
| if st.session_state.auto_mode and not st.session_state.auto_paused: | |
| if st.session_state.auto_should_step: | |
| # Execute the step | |
| is_completed, prompt, generated_image = st.session_state.optimizer.step() | |
| st.session_state.current_results = (is_completed, prompt, generated_image) | |
| # Set flag to NOT step on next render (let history display) | |
| st.session_state.auto_should_step = False | |
| st.rerun() | |
| else: | |
| # Don't step, just display current state and history, then set flag to step next time | |
| st.session_state.auto_should_step = True | |
| # Use a small delay then rerun to continue auto mode | |
| import time | |
| time.sleep(0.5) # Give user time to see the history | |
| st.rerun() | |
| # Manual mode | |
| elif not st.session_state.auto_mode: | |
| if st.button("Next Step"): | |
| is_completed, prompt, generated_image = st.session_state.optimizer.step() | |
| st.session_state.current_results = (is_completed, prompt, generated_image) | |
| st.rerun() | |
| # Show status when auto mode is paused | |
| elif st.session_state.auto_paused: | |
| st.info("Auto mode is paused. Click Resume to continue or uncheck Auto-progress to use manual mode.") | |
| else: | |
| st.success("Optimization completed! Click 'Reset' to try another image.") | |
| # Turn off auto mode when completed | |
| if st.session_state.auto_mode: | |
| st.session_state.auto_mode = False | |
| st.session_state.auto_paused = False | |
| # Reset button | |
| if st.button("Reset"): | |
| st.session_state.optimization_started = False | |
| st.session_state.current_results = None | |
| st.session_state.auto_mode = False | |
| st.session_state.auto_paused = False | |
| st.rerun() | |
| # Display history - simple approach | |
| if len(st.session_state.optimizer.history) > 0: | |
| st.subheader(f"Optimization History ({len(st.session_state.optimizer.history)} steps)") | |
| for idx, hist_entry in enumerate(st.session_state.optimizer.history): | |
| st.markdown(f"### Step {idx + 1}") | |
| col1, col2 = st.columns([2, 3]) | |
| with col1: | |
| st.image(hist_entry['image'], width='stretch') | |
| with col2: | |
| st.text(f"Similarity: {hist_entry['similarity']:.2%}") | |
| st.text("Prompt:") | |
| st.text(hist_entry['prompt']) | |
| # Toggle analysis view per history entry | |
| expand_key = f"expand_analysis_{idx}" | |
| if 'analysis_expanded' not in st.session_state: | |
| st.session_state['analysis_expanded'] = {} | |
| if expand_key not in st.session_state['analysis_expanded']: | |
| st.session_state['analysis_expanded'][expand_key] = False | |
| if st.session_state['analysis_expanded'][expand_key]: | |
| if st.button("Hide Analysis", key=f"hide_{expand_key}"): | |
| st.session_state['analysis_expanded'][expand_key] = False | |
| st.rerun() | |
| st.text("Analysis:") | |
| for key, value in hist_entry['analysis'].items(): | |
| st.text(f"{key}: {value}") | |
| else: | |
| if st.button("Expand Analysis", key=expand_key): | |
| st.session_state['analysis_expanded'][expand_key] = True | |
| st.rerun() | |
| if __name__ == "__main__": | |
| main() | |