Spaces:
Runtime error
Runtime error
File size: 4,621 Bytes
c6eb9ce e7247e4 fb2f0a7 c6eb9ce fb2f0a7 c6eb9ce e7247e4 c6eb9ce 4584c11 c6eb9ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import streamlit as st
from PIL import Image
from dotenv import load_dotenv
from image_evaluators import LlamaEvaluator
from prompt_refiners import LlamaPromptRefiner
from weave_prompt import PromptOptimizer
from similarity_metrics import LPIPSImageSimilarityMetric
from image_generators import FalImageGenerator
# Load environment variables from .env file
load_dotenv()
st.set_page_config(
page_title="WeavePrompt",
page_icon="🎨",
layout="wide"
)
def main():
st.title("🎨 WeavePrompt: Iterative Prompt Optimization")
st.markdown("""
Upload a target image and watch as WeavePrompt iteratively optimizes a text prompt to recreate it.
""")
# Initialize session state
if 'optimizer' not in st.session_state:
st.session_state.optimizer = PromptOptimizer(
model=FalImageGenerator(),
evaluator=LlamaEvaluator(),
refiner=LlamaPromptRefiner(),
similarity_metric=LPIPSImageSimilarityMetric(),
max_iterations=10,
similarity_threshold=0.95
)
if 'optimization_started' not in st.session_state:
st.session_state.optimization_started = False
if 'current_results' not in st.session_state:
st.session_state.current_results = None
# File uploader
uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg'])
if uploaded_file is not None:
# Display target image
target_image = Image.open(uploaded_file)
col1, col2 = st.columns(2)
with col1:
st.subheader("Target Image")
st.image(target_image, width='stretch')
# Start button
if not st.session_state.optimization_started:
if st.button("Start Optimization"):
st.session_state.optimization_started = True
# Initialize optimization
is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image)
st.session_state.current_results = (is_completed, prompt, generated_image)
# Display optimization progress
if st.session_state.optimization_started:
with col2:
st.subheader("Generated Image")
is_completed, prompt, generated_image = st.session_state.current_results
st.image(generated_image, width='stretch')
# Display prompt and controls
st.text_area("Current Prompt", prompt, height=100)
# Progress metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Iteration", len(st.session_state.optimizer.history))
with col2:
if len(st.session_state.optimizer.history) > 0:
similarity = st.session_state.optimizer.history[-1]['similarity']
st.metric("Similarity", f"{similarity:.2%}")
with col3:
st.metric("Status", "Completed" if is_completed else "In Progress")
# Next step button
if not is_completed:
if st.button("Next Step"):
is_completed, prompt, generated_image = st.session_state.optimizer.step()
st.session_state.current_results = (is_completed, prompt, generated_image)
st.rerun()
else:
st.success("Optimization completed! Click 'Reset' to try another image.")
# Reset button
if st.button("Reset"):
st.session_state.optimization_started = False
st.session_state.current_results = None
st.rerun()
# Display history
if len(st.session_state.optimizer.history) > 0:
st.subheader("Optimization History")
for idx, hist_entry in enumerate(st.session_state.optimizer.history):
st.markdown(f"### Step {idx + 1}")
col1, col2 = st.columns([2, 3])
with col1:
st.image(hist_entry['image'], width='stretch')
with col2:
st.text(f"Similarity: {hist_entry['similarity']:.2%}")
st.text("Prompt:")
st.text(hist_entry['prompt'])
st.text("\nAnalysis:")
for key, value in hist_entry['analysis'].items():
st.text(f"{key}: {value}")
if __name__ == "__main__":
main() |