Spaces:
Runtime error
Runtime error
| # app.py | |
| import gradio as gr | |
| from rag import RAGPipeline | |
| from attack import AdversarialAttackPipeline | |
| import time | |
| # --- 1. Load Models and Pipelines (Singleton Pattern) --- | |
| # This section runs only ONCE when the application starts up. | |
| # This is the key to low latency for subsequent user requests. | |
| print("Starting application setup...") | |
| start_time = time.time() | |
| # Instantiate the RAG pipeline, which will now load from the pre-computed cache | |
| rag_pipeline = RAGPipeline( | |
| json_path="calebdata.json", | |
| defense_model_path="./defense_model", | |
| cache_dir="cache" | |
| ) | |
| # Instantiate the Attack pipeline | |
| attack_pipeline = AdversarialAttackPipeline(rag_pipeline_instance=rag_pipeline) | |
| end_time = time.time() | |
| print(f"Application setup complete. Model loading took {end_time - start_time:.2f} seconds.") | |
| # --- 2. Define the Core Function for Gradio --- | |
| # This function will be called every time a user clicks "Submit" | |
| def run_adversarial_test(query, attack_method, perturbation_level): | |
| """ | |
| Runs a single attack and returns the results formatted for the Gradio UI. | |
| """ | |
| print(f"Received query: '{query}', Method: {attack_method}, Level: {perturbation_level}") | |
| # The attack_pipeline object is already loaded in memory | |
| result = attack_pipeline.run_attack( | |
| original_query=query, | |
| perturbation_method=attack_method, | |
| perturbation_level=perturbation_level | |
| ) | |
| # Format the output for Gradio components | |
| return ( | |
| result["normal_query"], | |
| result["perturbed_query"], | |
| result["normal_response"], | |
| result["perturbed_response"], | |
| f"Defense Triggered: {result['defense_triggered']} | Reason: {result['reason']}", | |
| f"Response Similarity: {result['cos_sim']['response_sim']}%", | |
| f"Adversarial Risk Index (ARI): {result['ari']}" | |
| ) | |
| # --- 3. Build the Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft(), title="RAG Adversarial Tester") as demo: | |
| gr.Markdown("# Adversarial Robustness Testing for RAG Systems") | |
| gr.Markdown( | |
| "This interface allows you to test the resilience of a RAG pipeline against various adversarial query perturbations. " | |
| "Enter a query, select an attack method and perturbation level, and observe the impact." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| query_input = gr.Textbox(label="Original Query", placeholder="e.g., Who is the Vice-Chancellor?") | |
| attack_method_input = gr.Dropdown( | |
| label="Attack Method", | |
| choices=["random_deletion", "synonym_replacement", "contextual_word_embedding"], | |
| value="random_deletion" | |
| ) | |
| perturbation_level_input = gr.Slider( | |
| label="Perturbation Level", | |
| minimum=0, maximum=1, | |
| step=0.1, value=0.2, # Mapping low/medium/high to a slider | |
| # Note: The attack.py code will need a slight modification to handle a float level | |
| # Or, use a dropdown: choices=["low", "medium", "high"] | |
| ) | |
| submit_btn = gr.Button("Run Attack", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Attack Results") | |
| original_query_output = gr.Textbox(label="Original Query (from input)", interactive=False) | |
| perturbed_query_output = gr.Textbox(label="Adversarial (Perturbed) Query", interactive=False) | |
| original_response_output = gr.Textbox(label="β Normal Response", interactive=False) | |
| perturbed_response_output = gr.Textbox(label="π΄ Perturbed Response", interactive=False) | |
| with gr.Row(): | |
| defense_status_output = gr.Textbox(label="Defense Status", scale=2) | |
| response_sim_output = gr.Textbox(label="Response Similarity") | |
| ari_output = gr.Textbox(label="ARI Score") | |
| # Connect the button to the function | |
| submit_btn.click( | |
| fn=run_adversarial_test, | |
| inputs=[query_input, attack_method_input, perturbation_level_input], | |
| outputs=[ | |
| original_query_output, | |
| perturbed_query_output, | |
| original_response_output, | |
| perturbed_response_output, | |
| defense_status_output, | |
| response_sim_output, | |
| ari_output | |
| ] | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| demo.launch() |