import streamlit as st from PIL import Image import time from dotenv import load_dotenv from image_to_text import LlamaEvaluator from prompt_refiner import LlamaPromptRefiner from weave_prompt import PromptOptimizer from mock_components import MockTextToImageModel, MockImageEvaluator, MockPromptRefiner from lpips_evaluator import LPIPSImageSimilarityMetric from fal_image_generator import FalImageGenerator import io # Load environment variables from .env file load_dotenv() st.set_page_config( page_title="WeavePrompt Demo", page_icon="🎨", layout="wide" ) def main(): st.title("🎨 WeavePrompt: Iterative Prompt Optimization") st.markdown(""" Upload a target image and watch as WeavePrompt iteratively optimizes a text prompt to recreate it. This demo uses mock components for illustration. """) # Initialize session state if 'optimizer' not in st.session_state: st.session_state.optimizer = PromptOptimizer( model=FalImageGenerator(), evaluator=LlamaEvaluator(), refiner=LlamaPromptRefiner(), similarity_metric=LPIPSImageSimilarityMetric(), max_iterations=10, similarity_threshold=0.95 ) if 'optimization_started' not in st.session_state: st.session_state.optimization_started = False if 'current_results' not in st.session_state: st.session_state.current_results = None # File uploader uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg']) if uploaded_file is not None: # Display target image target_image = Image.open(uploaded_file) col1, col2 = st.columns(2) with col1: st.subheader("Target Image") st.image(target_image, width='stretch') # Start button if not st.session_state.optimization_started: if st.button("Start Optimization"): st.session_state.optimization_started = True # Initialize optimization is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image) st.session_state.current_results = (is_completed, prompt, generated_image) # Display optimization progress if st.session_state.optimization_started: with col2: st.subheader("Generated Image") is_completed, prompt, generated_image = st.session_state.current_results st.image(generated_image, width='stretch') # Display prompt and controls st.text_area("Current Prompt", prompt, height=100) # Progress metrics col1, col2, col3 = st.columns(3) with col1: st.metric("Iteration", len(st.session_state.optimizer.history)) with col2: if len(st.session_state.optimizer.history) > 0: similarity = st.session_state.optimizer.history[-1]['similarity'] st.metric("Similarity", f"{similarity:.2%}") with col3: st.metric("Status", "Completed" if is_completed else "In Progress") # Next step button if not is_completed: if st.button("Next Step"): is_completed, prompt, generated_image = st.session_state.optimizer.step() st.session_state.current_results = (is_completed, prompt, generated_image) st.rerun() else: st.success("Optimization completed! Click 'Reset' to try another image.") # Reset button if st.button("Reset"): st.session_state.optimization_started = False st.session_state.current_results = None st.rerun() # Display history if len(st.session_state.optimizer.history) > 0: st.subheader("Optimization History") for idx, hist_entry in enumerate(st.session_state.optimizer.history): st.markdown(f"### Step {idx + 1}") col1, col2 = st.columns([2, 3]) with col1: st.image(hist_entry['image'], width='stretch') with col2: st.text(f"Similarity: {hist_entry['similarity']:.2%}") st.text("Prompt:") st.text(hist_entry['prompt']) st.text("\nAnalysis:") for key, value in hist_entry['analysis'].items(): st.text(f"{key}: {value}") if __name__ == "__main__": main()