# app.py import gradio as gr from rag import RAGPipeline from attack import AdversarialAttackPipeline import time # --- 1. Load Models and Pipelines (Singleton Pattern) --- # This section runs only ONCE when the application starts up. # This is the key to low latency for subsequent user requests. print("Starting application setup...") start_time = time.time() # Instantiate the RAG pipeline, which will now load from the pre-computed cache rag_pipeline = RAGPipeline( json_path="calebdata.json", defense_model_path="./defense_model", cache_dir="cache" ) # Instantiate the Attack pipeline attack_pipeline = AdversarialAttackPipeline(rag_pipeline_instance=rag_pipeline) end_time = time.time() print(f"Application setup complete. Model loading took {end_time - start_time:.2f} seconds.") # --- 2. Define the Core Function for Gradio --- # This function will be called every time a user clicks "Submit" def run_adversarial_test(query, attack_method, perturbation_level): """ Runs a single attack and returns the results formatted for the Gradio UI. """ print(f"Received query: '{query}', Method: {attack_method}, Level: {perturbation_level}") # The attack_pipeline object is already loaded in memory result = attack_pipeline.run_attack( original_query=query, perturbation_method=attack_method, perturbation_level=perturbation_level ) # Format the output for Gradio components return ( result["normal_query"], result["perturbed_query"], result["normal_response"], result["perturbed_response"], f"Defense Triggered: {result['defense_triggered']} | Reason: {result['reason']}", f"Response Similarity: {result['cos_sim']['response_sim']}%", f"Adversarial Risk Index (ARI): {result['ari']}" ) # --- 3. Build the Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft(), title="RAG Adversarial Tester") as demo: gr.Markdown("# Adversarial Robustness Testing for RAG Systems") gr.Markdown( "This interface allows you to test the resilience of a RAG pipeline against various adversarial query perturbations. " "Enter a query, select an attack method and perturbation level, and observe the impact." ) with gr.Row(): with gr.Column(scale=1): query_input = gr.Textbox(label="Original Query", placeholder="e.g., Who is the Vice-Chancellor?") attack_method_input = gr.Dropdown( label="Attack Method", choices=["random_deletion", "synonym_replacement", "contextual_word_embedding"], value="random_deletion" ) perturbation_level_input = gr.Slider( label="Perturbation Level", minimum=0, maximum=1, step=0.1, value=0.2, # Mapping low/medium/high to a slider # Note: The attack.py code will need a slight modification to handle a float level # Or, use a dropdown: choices=["low", "medium", "high"] ) submit_btn = gr.Button("Run Attack", variant="primary") with gr.Column(scale=2): gr.Markdown("### Attack Results") original_query_output = gr.Textbox(label="Original Query (from input)", interactive=False) perturbed_query_output = gr.Textbox(label="Adversarial (Perturbed) Query", interactive=False) original_response_output = gr.Textbox(label="✅ Normal Response", interactive=False) perturbed_response_output = gr.Textbox(label="🔴 Perturbed Response", interactive=False) with gr.Row(): defense_status_output = gr.Textbox(label="Defense Status", scale=2) response_sim_output = gr.Textbox(label="Response Similarity") ari_output = gr.Textbox(label="ARI Score") # Connect the button to the function submit_btn.click( fn=run_adversarial_test, inputs=[query_input, attack_method_input, perturbation_level_input], outputs=[ original_query_output, perturbed_query_output, original_response_output, perturbed_response_output, defense_status_output, response_sim_output, ari_output ] ) # Launch the interface if __name__ == "__main__": demo.launch()