WeavePrompt / app.py
kevin1kevin1k's picture
Upload folder using huggingface_hub
b3e00f3 verified
raw
history blame
10.2 kB
import streamlit as st
from PIL import Image
from dotenv import load_dotenv
from image_evaluators import LlamaEvaluator
from prompt_refiners import LlamaPromptRefiner
from weave_prompt import PromptOptimizer
from similarity_metrics import LPIPSImageSimilarityMetric
from image_generators import FalImageGenerator
# Load environment variables from .env file
load_dotenv()
st.set_page_config(
page_title="WeavePrompt",
page_icon="🎨",
layout="wide"
)
def main():
st.title("🎨 WeavePrompt: Iterative Prompt Optimization")
st.markdown("""
Upload a target image and watch as WeavePrompt iteratively optimizes a text prompt to recreate it.
""")
# Initialize session state
if 'optimizer' not in st.session_state:
st.session_state.optimizer = PromptOptimizer(
image_generator=FalImageGenerator(),
evaluator=LlamaEvaluator(),
refiner=LlamaPromptRefiner(),
similarity_metric=LPIPSImageSimilarityMetric(),
max_iterations=10,
similarity_threshold=0.95
)
if 'optimization_started' not in st.session_state:
st.session_state.optimization_started = False
if 'current_results' not in st.session_state:
st.session_state.current_results = None
# Auto mode state
if 'auto_mode' not in st.session_state:
st.session_state.auto_mode = False
if 'auto_paused' not in st.session_state:
st.session_state.auto_paused = False
# Auto mode step control - use this to control when to step vs when to display
if 'auto_should_step' not in st.session_state:
st.session_state.auto_should_step = False
# File uploader
uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg'])
if uploaded_file is not None:
# Display target image
target_image = Image.open(uploaded_file)
col1, col2 = st.columns(2)
with col1:
st.subheader("Target Image")
st.image(target_image, width='stretch')
# Start button
if not st.session_state.optimization_started:
if st.button("Start Optimization"):
st.session_state.optimization_started = True
# Initialize optimization
is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image)
st.session_state.current_results = (is_completed, prompt, generated_image)
st.rerun()
else:
# Show disabled button or status when optimization has started
st.button("Start Optimization", disabled=True, help="Optimization in progress")
# Display optimization progress
if st.session_state.optimization_started:
if st.session_state.current_results is not None:
with col2:
st.subheader("Generated Image")
is_completed, prompt, generated_image = st.session_state.current_results
st.image(generated_image, width='stretch')
# Display prompt and controls
st.text_area("Current Prompt", prompt, height=100)
else:
# Show loading state
with col2:
st.subheader("Generated Image")
st.info("Initializing optimization...")
st.text_area("Current Prompt", "Generating initial prompt...", height=100)
is_completed = False
prompt = ""
# Auto mode controls
st.subheader("Auto Mode Controls")
col_auto1, col_auto2 = st.columns(2)
with col_auto1:
auto_mode = st.checkbox("Auto-progress steps", value=st.session_state.auto_mode)
if auto_mode != st.session_state.auto_mode:
st.session_state.auto_mode = auto_mode
if auto_mode:
st.session_state.auto_paused = False
st.session_state.auto_should_step = True # Start by stepping
st.rerun()
with col_auto2:
if st.session_state.auto_mode:
if st.session_state.auto_paused:
if st.button("▶️ Resume", key="resume_btn"):
st.session_state.auto_paused = False
st.rerun()
else:
if st.button("⏸️ Pause", key="pause_btn"):
st.session_state.auto_paused = True
st.rerun()
# Progress metrics
col1, col2, col3 = st.columns(3)
with col1:
# Show current iteration: completed steps + 1 if still in progress
current_iteration = len(st.session_state.optimizer.history) + (0 if is_completed else 1)
st.metric("Iteration", current_iteration)
with col2:
if len(st.session_state.optimizer.history) > 0:
similarity = st.session_state.optimizer.history[-1]['similarity']
st.metric("Similarity", f"{similarity:.2%}")
with col3:
status = "Completed" if is_completed else "Paused" if st.session_state.auto_paused else "In Progress"
st.metric("Status", status)
# Progress bar
st.subheader("Progress")
max_iterations = st.session_state.optimizer.max_iterations
progress_value = min(current_iteration / max_iterations, 1.0) if max_iterations > 0 else 0.0
st.progress(progress_value, text=f"Step {current_iteration} of {max_iterations}")
# Next step logic
if not is_completed:
# Auto mode logic - mimic pause button behavior
if st.session_state.auto_mode and not st.session_state.auto_paused:
if st.session_state.auto_should_step:
# Execute the step
is_completed, prompt, generated_image = st.session_state.optimizer.step()
st.session_state.current_results = (is_completed, prompt, generated_image)
# Set flag to NOT step on next render (let history display)
st.session_state.auto_should_step = False
st.rerun()
else:
# Don't step, just display current state and history, then set flag to step next time
st.session_state.auto_should_step = True
# Use a small delay then rerun to continue auto mode
import time
time.sleep(0.5) # Give user time to see the history
st.rerun()
# Manual mode
elif not st.session_state.auto_mode:
if st.button("Next Step"):
is_completed, prompt, generated_image = st.session_state.optimizer.step()
st.session_state.current_results = (is_completed, prompt, generated_image)
st.rerun()
# Show status when auto mode is paused
elif st.session_state.auto_paused:
st.info("Auto mode is paused. Click Resume to continue or uncheck Auto-progress to use manual mode.")
else:
st.success("Optimization completed! Click 'Reset' to try another image.")
# Turn off auto mode when completed
if st.session_state.auto_mode:
st.session_state.auto_mode = False
st.session_state.auto_paused = False
# Reset button
if st.button("Reset"):
st.session_state.optimization_started = False
st.session_state.current_results = None
st.session_state.auto_mode = False
st.session_state.auto_paused = False
st.rerun()
# Display history - simple approach
if len(st.session_state.optimizer.history) > 0:
st.subheader(f"Optimization History ({len(st.session_state.optimizer.history)} steps)")
for idx, hist_entry in enumerate(st.session_state.optimizer.history):
st.markdown(f"### Step {idx + 1}")
col1, col2 = st.columns([2, 3])
with col1:
st.image(hist_entry['image'], width='stretch')
with col2:
st.text(f"Similarity: {hist_entry['similarity']:.2%}")
st.text("Prompt:")
st.text(hist_entry['prompt'])
# Toggle analysis view per history entry
expand_key = f"expand_analysis_{idx}"
if 'analysis_expanded' not in st.session_state:
st.session_state['analysis_expanded'] = {}
if expand_key not in st.session_state['analysis_expanded']:
st.session_state['analysis_expanded'][expand_key] = False
if st.session_state['analysis_expanded'][expand_key]:
if st.button("Hide Analysis", key=f"hide_{expand_key}"):
st.session_state['analysis_expanded'][expand_key] = False
st.rerun()
st.text("Analysis:")
for key, value in hist_entry['analysis'].items():
st.text(f"{key}: {value}")
else:
if st.button("Expand Analysis", key=expand_key):
st.session_state['analysis_expanded'][expand_key] = True
st.rerun()
if __name__ == "__main__":
main()