chibuzordevAlternative's picture
Create app.py
e98a755 verified
raw
history blame
4.39 kB
# app.py
import gradio as gr
from rag import RAGPipeline
from attack import AdversarialAttackPipeline
import time
# --- 1. Load Models and Pipelines (Singleton Pattern) ---
# This section runs only ONCE when the application starts up.
# This is the key to low latency for subsequent user requests.
print("Starting application setup...")
start_time = time.time()
# Instantiate the RAG pipeline, which will now load from the pre-computed cache
rag_pipeline = RAGPipeline(
json_path="calebdata.json",
defense_model_path="./defense_model",
cache_dir="cache"
)
# Instantiate the Attack pipeline
attack_pipeline = AdversarialAttackPipeline(rag_pipeline_instance=rag_pipeline)
end_time = time.time()
print(f"Application setup complete. Model loading took {end_time - start_time:.2f} seconds.")
# --- 2. Define the Core Function for Gradio ---
# This function will be called every time a user clicks "Submit"
def run_adversarial_test(query, attack_method, perturbation_level):
"""
Runs a single attack and returns the results formatted for the Gradio UI.
"""
print(f"Received query: '{query}', Method: {attack_method}, Level: {perturbation_level}")
# The attack_pipeline object is already loaded in memory
result = attack_pipeline.run_attack(
original_query=query,
perturbation_method=attack_method,
perturbation_level=perturbation_level
)
# Format the output for Gradio components
return (
result["normal_query"],
result["perturbed_query"],
result["normal_response"],
result["perturbed_response"],
f"Defense Triggered: {result['defense_triggered']} | Reason: {result['reason']}",
f"Response Similarity: {result['cos_sim']['response_sim']}%",
f"Adversarial Risk Index (ARI): {result['ari']}"
)
# --- 3. Build the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="RAG Adversarial Tester") as demo:
gr.Markdown("# Adversarial Robustness Testing for RAG Systems")
gr.Markdown(
"This interface allows you to test the resilience of a RAG pipeline against various adversarial query perturbations. "
"Enter a query, select an attack method and perturbation level, and observe the impact."
)
with gr.Row():
with gr.Column(scale=1):
query_input = gr.Textbox(label="Original Query", placeholder="e.g., Who is the Vice-Chancellor?")
attack_method_input = gr.Dropdown(
label="Attack Method",
choices=["random_deletion", "synonym_replacement", "contextual_word_embedding"],
value="random_deletion"
)
perturbation_level_input = gr.Slider(
label="Perturbation Level",
minimum=0, maximum=1,
step=0.1, value=0.2, # Mapping low/medium/high to a slider
# Note: The attack.py code will need a slight modification to handle a float level
# Or, use a dropdown: choices=["low", "medium", "high"]
)
submit_btn = gr.Button("Run Attack", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Attack Results")
original_query_output = gr.Textbox(label="Original Query (from input)", interactive=False)
perturbed_query_output = gr.Textbox(label="Adversarial (Perturbed) Query", interactive=False)
original_response_output = gr.Textbox(label="βœ… Normal Response", interactive=False)
perturbed_response_output = gr.Textbox(label="πŸ”΄ Perturbed Response", interactive=False)
with gr.Row():
defense_status_output = gr.Textbox(label="Defense Status", scale=2)
response_sim_output = gr.Textbox(label="Response Similarity")
ari_output = gr.Textbox(label="ARI Score")
# Connect the button to the function
submit_btn.click(
fn=run_adversarial_test,
inputs=[query_input, attack_method_input, perturbation_level_input],
outputs=[
original_query_output,
perturbed_query_output,
original_response_output,
perturbed_response_output,
defense_status_output,
response_sim_output,
ari_output
]
)
# Launch the interface
if __name__ == "__main__":
demo.launch()