File size: 4,621 Bytes
c6eb9ce
 
e7247e4
fb2f0a7
 
c6eb9ce
fb2f0a7
 
c6eb9ce
e7247e4
 
 
c6eb9ce
4584c11
c6eb9ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
from PIL import Image
from dotenv import load_dotenv
from image_evaluators import LlamaEvaluator
from prompt_refiners import LlamaPromptRefiner
from weave_prompt import PromptOptimizer
from similarity_metrics import LPIPSImageSimilarityMetric
from image_generators import FalImageGenerator

# Load environment variables from .env file
load_dotenv()

st.set_page_config(
    page_title="WeavePrompt",
    page_icon="🎨",
    layout="wide"
)

def main():
    st.title("🎨 WeavePrompt: Iterative Prompt Optimization")
    st.markdown("""
    Upload a target image and watch as WeavePrompt iteratively optimizes a text prompt to recreate it.
    """)
    
    # Initialize session state
    if 'optimizer' not in st.session_state:
        st.session_state.optimizer = PromptOptimizer(
            model=FalImageGenerator(),
            evaluator=LlamaEvaluator(),
            refiner=LlamaPromptRefiner(),
            similarity_metric=LPIPSImageSimilarityMetric(),
            max_iterations=10,
            similarity_threshold=0.95
        )
    
    if 'optimization_started' not in st.session_state:
        st.session_state.optimization_started = False
    
    if 'current_results' not in st.session_state:
        st.session_state.current_results = None

    # File uploader
    uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg'])
    
    if uploaded_file is not None:
        # Display target image
        target_image = Image.open(uploaded_file)
        
        col1, col2 = st.columns(2)
        with col1:
            st.subheader("Target Image")
            st.image(target_image, width='stretch')
        
        # Start button
        if not st.session_state.optimization_started:
            if st.button("Start Optimization"):
                st.session_state.optimization_started = True
                # Initialize optimization
                is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image)
                st.session_state.current_results = (is_completed, prompt, generated_image)
        
        # Display optimization progress
        if st.session_state.optimization_started:
            with col2:
                st.subheader("Generated Image")
                is_completed, prompt, generated_image = st.session_state.current_results
                st.image(generated_image, width='stretch')
            
            # Display prompt and controls
            st.text_area("Current Prompt", prompt, height=100)
            
            # Progress metrics
            col1, col2, col3 = st.columns(3)
            with col1:
                st.metric("Iteration", len(st.session_state.optimizer.history))
            with col2:
                if len(st.session_state.optimizer.history) > 0:
                    similarity = st.session_state.optimizer.history[-1]['similarity']
                    st.metric("Similarity", f"{similarity:.2%}")
            with col3:
                st.metric("Status", "Completed" if is_completed else "In Progress")
            
            # Next step button
            if not is_completed:
                if st.button("Next Step"):
                    is_completed, prompt, generated_image = st.session_state.optimizer.step()
                    st.session_state.current_results = (is_completed, prompt, generated_image)
                    st.rerun()
            else:
                st.success("Optimization completed! Click 'Reset' to try another image.")
            
            # Reset button
            if st.button("Reset"):
                st.session_state.optimization_started = False
                st.session_state.current_results = None
                st.rerun()
            
            # Display history
            if len(st.session_state.optimizer.history) > 0:
                st.subheader("Optimization History")
                for idx, hist_entry in enumerate(st.session_state.optimizer.history):
                    st.markdown(f"### Step {idx + 1}")
                    col1, col2 = st.columns([2, 3])
                    with col1:
                        st.image(hist_entry['image'], width='stretch')
                    with col2:
                        st.text(f"Similarity: {hist_entry['similarity']:.2%}")
                        st.text("Prompt:")
                        st.text(hist_entry['prompt'])
                        st.text("\nAnalysis:")
                        for key, value in hist_entry['analysis'].items():
                            st.text(f"{key}: {value}")

if __name__ == "__main__":
    main()