chibuzordevAlternative commited on
Commit
e98a755
·
verified ·
1 Parent(s): e57d868

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ from rag import RAGPipeline
4
+ from attack import AdversarialAttackPipeline
5
+ import time
6
+
7
+ # --- 1. Load Models and Pipelines (Singleton Pattern) ---
8
+ # This section runs only ONCE when the application starts up.
9
+ # This is the key to low latency for subsequent user requests.
10
+ print("Starting application setup...")
11
+ start_time = time.time()
12
+
13
+ # Instantiate the RAG pipeline, which will now load from the pre-computed cache
14
+ rag_pipeline = RAGPipeline(
15
+ json_path="calebdata.json",
16
+ defense_model_path="./defense_model",
17
+ cache_dir="cache"
18
+ )
19
+
20
+ # Instantiate the Attack pipeline
21
+ attack_pipeline = AdversarialAttackPipeline(rag_pipeline_instance=rag_pipeline)
22
+
23
+ end_time = time.time()
24
+ print(f"Application setup complete. Model loading took {end_time - start_time:.2f} seconds.")
25
+
26
+ # --- 2. Define the Core Function for Gradio ---
27
+ # This function will be called every time a user clicks "Submit"
28
+ def run_adversarial_test(query, attack_method, perturbation_level):
29
+ """
30
+ Runs a single attack and returns the results formatted for the Gradio UI.
31
+ """
32
+ print(f"Received query: '{query}', Method: {attack_method}, Level: {perturbation_level}")
33
+
34
+ # The attack_pipeline object is already loaded in memory
35
+ result = attack_pipeline.run_attack(
36
+ original_query=query,
37
+ perturbation_method=attack_method,
38
+ perturbation_level=perturbation_level
39
+ )
40
+
41
+ # Format the output for Gradio components
42
+ return (
43
+ result["normal_query"],
44
+ result["perturbed_query"],
45
+ result["normal_response"],
46
+ result["perturbed_response"],
47
+ f"Defense Triggered: {result['defense_triggered']} | Reason: {result['reason']}",
48
+ f"Response Similarity: {result['cos_sim']['response_sim']}%",
49
+ f"Adversarial Risk Index (ARI): {result['ari']}"
50
+ )
51
+
52
+ # --- 3. Build the Gradio Interface ---
53
+ with gr.Blocks(theme=gr.themes.Soft(), title="RAG Adversarial Tester") as demo:
54
+ gr.Markdown("# Adversarial Robustness Testing for RAG Systems")
55
+ gr.Markdown(
56
+ "This interface allows you to test the resilience of a RAG pipeline against various adversarial query perturbations. "
57
+ "Enter a query, select an attack method and perturbation level, and observe the impact."
58
+ )
59
+
60
+ with gr.Row():
61
+ with gr.Column(scale=1):
62
+ query_input = gr.Textbox(label="Original Query", placeholder="e.g., Who is the Vice-Chancellor?")
63
+ attack_method_input = gr.Dropdown(
64
+ label="Attack Method",
65
+ choices=["random_deletion", "synonym_replacement", "contextual_word_embedding"],
66
+ value="random_deletion"
67
+ )
68
+ perturbation_level_input = gr.Slider(
69
+ label="Perturbation Level",
70
+ minimum=0, maximum=1,
71
+ step=0.1, value=0.2, # Mapping low/medium/high to a slider
72
+ # Note: The attack.py code will need a slight modification to handle a float level
73
+ # Or, use a dropdown: choices=["low", "medium", "high"]
74
+ )
75
+ submit_btn = gr.Button("Run Attack", variant="primary")
76
+
77
+ with gr.Column(scale=2):
78
+ gr.Markdown("### Attack Results")
79
+ original_query_output = gr.Textbox(label="Original Query (from input)", interactive=False)
80
+ perturbed_query_output = gr.Textbox(label="Adversarial (Perturbed) Query", interactive=False)
81
+ original_response_output = gr.Textbox(label="✅ Normal Response", interactive=False)
82
+ perturbed_response_output = gr.Textbox(label="🔴 Perturbed Response", interactive=False)
83
+
84
+ with gr.Row():
85
+ defense_status_output = gr.Textbox(label="Defense Status", scale=2)
86
+ response_sim_output = gr.Textbox(label="Response Similarity")
87
+ ari_output = gr.Textbox(label="ARI Score")
88
+
89
+ # Connect the button to the function
90
+ submit_btn.click(
91
+ fn=run_adversarial_test,
92
+ inputs=[query_input, attack_method_input, perturbation_level_input],
93
+ outputs=[
94
+ original_query_output,
95
+ perturbed_query_output,
96
+ original_response_output,
97
+ perturbed_response_output,
98
+ defense_status_output,
99
+ response_sim_output,
100
+ ari_output
101
+ ]
102
+ )
103
+
104
+ # Launch the interface
105
+ if __name__ == "__main__":
106
+ demo.launch()