File size: 4,785 Bytes
c6eb9ce
 
 
e7247e4
c6eb9ce
 
 
 
 
 
 
 
e7247e4
 
 
c6eb9ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st
from PIL import Image
import time
from dotenv import load_dotenv
from image_to_text import LlamaEvaluator
from prompt_refiner import LlamaPromptRefiner
from weave_prompt import PromptOptimizer
from mock_components import MockTextToImageModel, MockImageEvaluator, MockPromptRefiner
from lpips_evaluator import LPIPSImageSimilarityMetric
from fal_image_generator import FalImageGenerator
import io

# Load environment variables from .env file
load_dotenv()

st.set_page_config(
    page_title="WeavePrompt Demo",
    page_icon="🎨",
    layout="wide"
)

def main():
    st.title("🎨 WeavePrompt: Iterative Prompt Optimization")
    st.markdown("""
    Upload a target image and watch as WeavePrompt iteratively optimizes a text prompt to recreate it.
    This demo uses mock components for illustration.
    """)
    
    # Initialize session state
    if 'optimizer' not in st.session_state:
        st.session_state.optimizer = PromptOptimizer(
            model=FalImageGenerator(),
            evaluator=LlamaEvaluator(),
            refiner=LlamaPromptRefiner(),
            similarity_metric=LPIPSImageSimilarityMetric(),
            max_iterations=10,
            similarity_threshold=0.95
        )
    
    if 'optimization_started' not in st.session_state:
        st.session_state.optimization_started = False
    
    if 'current_results' not in st.session_state:
        st.session_state.current_results = None

    # File uploader
    uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg'])
    
    if uploaded_file is not None:
        # Display target image
        target_image = Image.open(uploaded_file)
        
        col1, col2 = st.columns(2)
        with col1:
            st.subheader("Target Image")
            st.image(target_image, width='stretch')
        
        # Start button
        if not st.session_state.optimization_started:
            if st.button("Start Optimization"):
                st.session_state.optimization_started = True
                # Initialize optimization
                is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image)
                st.session_state.current_results = (is_completed, prompt, generated_image)
        
        # Display optimization progress
        if st.session_state.optimization_started:
            with col2:
                st.subheader("Generated Image")
                is_completed, prompt, generated_image = st.session_state.current_results
                st.image(generated_image, width='stretch')
            
            # Display prompt and controls
            st.text_area("Current Prompt", prompt, height=100)
            
            # Progress metrics
            col1, col2, col3 = st.columns(3)
            with col1:
                st.metric("Iteration", len(st.session_state.optimizer.history))
            with col2:
                if len(st.session_state.optimizer.history) > 0:
                    similarity = st.session_state.optimizer.history[-1]['similarity']
                    st.metric("Similarity", f"{similarity:.2%}")
            with col3:
                st.metric("Status", "Completed" if is_completed else "In Progress")
            
            # Next step button
            if not is_completed:
                if st.button("Next Step"):
                    is_completed, prompt, generated_image = st.session_state.optimizer.step()
                    st.session_state.current_results = (is_completed, prompt, generated_image)
                    st.rerun()
            else:
                st.success("Optimization completed! Click 'Reset' to try another image.")
            
            # Reset button
            if st.button("Reset"):
                st.session_state.optimization_started = False
                st.session_state.current_results = None
                st.rerun()
            
            # Display history
            if len(st.session_state.optimizer.history) > 0:
                st.subheader("Optimization History")
                for idx, hist_entry in enumerate(st.session_state.optimizer.history):
                    st.markdown(f"### Step {idx + 1}")
                    col1, col2 = st.columns([2, 3])
                    with col1:
                        st.image(hist_entry['image'], width='stretch')
                    with col2:
                        st.text(f"Similarity: {hist_entry['similarity']:.2%}")
                        st.text("Prompt:")
                        st.text(hist_entry['prompt'])
                        st.text("\nAnalysis:")
                        for key, value in hist_entry['analysis'].items():
                            st.text(f"{key}: {value}")

if __name__ == "__main__":
    main()