slahlou commited on
Commit
5f9cab2
Β·
1 Parent(s): 093945a

adding db credential

Browse files
Files changed (2) hide show
  1. app.py +138 -130
  2. db_work.py +23 -10
app.py CHANGED
@@ -11,6 +11,10 @@ import time
11
 
12
  BASE_URL = "https://beeguy74--example-fastapi-fastapi-app.modal.run"
13
 
 
 
 
 
14
  class API:
15
  def __init__(self, base_url: str):
16
  self.base_url = base_url
@@ -84,64 +88,102 @@ class API:
84
  except Exception as e:
85
  return None, f"❌ Error downloading file: {str(e)}"
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # Initialize services
88
  api_service = API(BASE_URL)
89
- db_interface = DatabaseInterface()
90
 
91
- # All function definitions (keeping your existing ones)
92
  def get_schemas():
93
- """### `get_schemas()`
94
- - **Purpose**: Retrieve all database schemas
95
- - **Returns**: JSON object containing schema information
96
- - **Use Case**: Initial database exploration"""
97
- return db_interface.list_schemas()
98
 
99
  def get_db_infos():
100
- """### `get_db_infos()`
101
- - **Purpose**: Get comprehensive database information and metadata
102
- - **Returns**: JSON object containing database information
103
- - **Use Case**: Initial database exploration"""
104
- return db_interface.list_database_info()
105
 
106
  def get_list_of_tables_in_schema(schema):
107
- """### `get_list_of_tables_in_schema(schema_name: str)`
108
- - **Purpose**: List all tables within a specific schema
109
- - **Parameters**: `schema_name` - Name of the schema to explore
110
- - **Returns**: JSON object containing table information
111
- - **Use Case**: Initial database exploration"""
112
- return db_interface.list_tables_in_schema(schema)
113
 
114
  def get_list_of_column_in_table(schema, table):
115
- """### `get_list_of_column_in_table(schema_name: str, table_name: str)`
116
- - **Purpose**: Get detailed column information for a specific table
117
- - **Parameters**: `schema_name` - Name of the schema containing the table
118
- - **Returns**: JSON object containing column information
119
- - **Use Case**: Initial database exploration"""
120
- return db_interface.list_columns_in_table(schema, table)
121
 
122
  def run_read_only_query(query: str):
123
- """### `run_read_only_query(query: str)`
124
- - **Purpose**: Execute read-only SQL queries safely
125
- - **Parameters**: `query` - SQL SELECT statement
126
- - **Returns**: Query results as rows
127
- - **Use Case**: Initial database exploration"""
128
- return db_interface.read_only_query(query)
129
 
130
  def create_table_from_query(table_name: str, source_query: str):
131
- """### `create_table_from_query(table_name: str, source_query: str)`
132
- - **Purpose**: Create permanent tables from SELECT queries
133
- - **Parameters**: `table_name` - Name of the new table
134
- - **Returns**: Status message indicating success or failure
135
- - **Use Case**: Create analysis datasets from queries"""
136
- return db_interface.create_table_from_query(table_name, source_query)
137
 
138
  def drop_table(table_name: str):
139
- """### `drop_table(table_name: str)`
140
- - **Purpose**: Remove tables from the database
141
- - **Parameters**: `table_name` - Name of the table to drop
142
- - **Returns**: Status message indicating success or failure
143
- - **Use Case**: Clean up analysis tables"""
144
- return db_interface.drop_table(table_name)
145
 
146
  def create_sample_image():
147
  img_path = "./sample_graph.png"
@@ -155,99 +197,27 @@ def serve_image_from_path():
155
  return create_sample_image()
156
 
157
  def do_annova(table_name, min_sample_size=0):
158
- '''
159
- this function runs the annova on the dataset and render the associated F_score and p_value
160
- table_name is the name of the table on which you want to run the ANOVA
161
- the selected table MUST have the following signature:
162
-
163
- groups | measurement
164
-
165
- exemple with the product_type_age table:
166
-
167
- type | age
168
- ----------
169
- 'Coat', '36'
170
- 'Coat', '36'
171
- 'Hat/beanie', '32'
172
- ...
173
-
174
- min_sample_size is used to exclude categories that does not have enough measurement.
175
- default = 0: all categories are selected
176
-
177
- return type is: dict
178
- {
179
- "F-statistic": round(f_stat, 3),
180
- "p-value": round(p_value, 3)
181
- }
182
- '''
183
  return var_stats.anova(db_interface, table_name=table_name, min_sample_size=int(min_sample_size))
184
 
185
  def do_tukey_test(table_name, min_sample_size=0):
186
- '''
187
- this function runs a Tukey's HSD (Honestly Significant Difference) test β€” a post-hoc analysis following ANOVA.
188
- It tells you which specific pairs of groups differ significantly in their means
189
- IT is meant to be used after you run a successful anova and you obtain sgnificant F-satatistics and p-value
190
- table_name is the name of the table on which you want to run the ANOVA
191
- the selected table MUST have the following signature:
192
-
193
- groups | measurement
194
-
195
- exemple with the product_type_age table:
196
-
197
- type | age
198
- ----------
199
- 'Coat', 36
200
- 'Coat', 36
201
- 'Hat/beanie', 32
202
- ...
203
-
204
- min_sample_size is used to exclude categories that does not have enough measurement.
205
- default = 0: all categories are selected
206
-
207
- the return result is the raw dataframe that correspond to the pair wize categorie that reject the hypothesis of non statistically difference between two group
208
- the signature of the dataframe is the following:
209
- group1 | group2 | meandiff p-adj | lower | upper | reject (only true)
210
-
211
- '''
212
  return var_stats.tukey_test(db_interface, table_name=table_name, min_sample_size=int(min_sample_size))
213
 
214
  def do_tsne_embedding(query):
215
- """
216
-
217
- this tool allow to run a TSNE dimensionality reduction algorythme and a clustering (HDBSCAN) on top of that.
218
-
219
- the input query, is a sql query that MUST return a table with at least the item id and the corresponding embeddding.
220
- FOR COMPUTATIONAL PURPOSE, THE QUERY YOU SEND MUST NOT RETURN A TABLE GREATER THAN 500 OUTPUT ROWS
221
- exemple:
222
- result = db_connection.read_only_query(query)
223
- result shape:
224
- article_id | embedding
225
- 0125456 | [0.3, 0.5 ...]
226
-
227
- the return is a dictionnary that has the following format:
228
-
229
- return {
230
- "ids": ids,
231
- "x_axis": tsne_projection_x_list,
232
- "y_axis": tsne_projection_y_list,
233
- "labels": labels
234
- }
235
- """
236
-
237
  return var_stats.embedding_clustering(db_interface, query)
238
 
239
  def do_vector_centroid(query):
240
- """
241
- this tool allow you to compute the centroid of a list of embedding vectors
242
- the input query, is a sql query that MUST return a table with only 1 column, the embeddings.
243
- exemple:
244
- result = db_connection.read_only_query(query)
245
- result shape:
246
- embedding
247
- [0.3, 0.5 ...]
248
-
249
- the return value is the computed centroid vector, that you can use to work with.
250
- """
251
  return var_stats.vector_centroid(db_interface, query)
252
 
253
  def embed_text_modal_api(text):
@@ -452,6 +422,44 @@ def get_mcp_server_instructions():
452
  6. **Error Handling**: All functions return status indicators - check for errors before proceeding
453
  7. **Data Safety**: Core tables (transactions, customers, articles) are protected from modification"""
454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  # TAB 1: Database Operations
456
  with gr.Blocks(title="Database Operations") as tab1:
457
  gr.Markdown("# πŸ—„οΈ Database Operations")
@@ -614,7 +622,7 @@ with gr.Blocks(title="AI Analytics") as tab2:
614
  with gr.Blocks(title="Statistical Analysis") as tab4:
615
  gr.Markdown("# πŸ“Š Statistical Analysis")
616
  gr.Markdown("*Run statistical tests on your data*")
617
-
618
  with gr.Row():
619
  with gr.Column(scale=1):
620
  gr.Markdown("### enter a dict that comply for annova function")
@@ -656,9 +664,9 @@ with gr.Blocks(title="MCP guidelines") as tab5:
656
 
657
  # Create the TabbedInterface
658
  interface = gr.TabbedInterface(
659
- [tab1, tab2, tab4, tab5],
660
- tab_names=["πŸ—„οΈ Database Operations", "πŸ€– AI Analytics", "πŸ“Š Statistical Analysis", "πŸ“Š MCP guidelines"],
661
- title="Universal E-commerce Database Analytics MCP Server",
662
  theme=gr.themes.Soft()
663
  )
664
 
 
11
 
12
  BASE_URL = "https://beeguy74--example-fastapi-fastapi-app.modal.run"
13
 
14
+ # Global state for database connection
15
+ db_interface = None
16
+ db_connection_status = "❌ Not Connected"
17
+
18
  class API:
19
  def __init__(self, base_url: str):
20
  self.base_url = base_url
 
88
  except Exception as e:
89
  return None, f"❌ Error downloading file: {str(e)}"
90
 
91
+ def setup_database_connection(host: str, port: str, database: str, user: str, password: str):
92
+ """Setup database connection with user-provided configuration"""
93
+ global db_interface, db_connection_status
94
+
95
+ if not all([host.strip(), port.strip(), database.strip(), user.strip(), password.strip()]):
96
+ db_connection_status = "❌ All fields are required"
97
+ return db_connection_status, False
98
+
99
+ try:
100
+ db_config = {
101
+ 'host': host.strip(),
102
+ 'port': int(port.strip()),
103
+ 'database': database.strip(),
104
+ 'user': user.strip(),
105
+ 'password': password.strip()
106
+ }
107
+
108
+ # Test connection
109
+ test_interface = DatabaseInterface(db_config)
110
+ test_connection = test_interface.get_db_connection()
111
+ test_connection.close()
112
+
113
+ # If successful, set global interface
114
+ db_interface = test_interface
115
+ db_connection_status = f"βœ… Connected to {database} at {host}:{port}"
116
+ return db_connection_status, True
117
+
118
+ except ValueError:
119
+ db_connection_status = "❌ Port must be a valid number"
120
+ return db_connection_status, False
121
+ except Exception as e:
122
+ db_connection_status = f"❌ Connection failed: {str(e)}"
123
+ return db_connection_status, False
124
+
125
+ def get_connection_status():
126
+ """Get current database connection status"""
127
+ return db_connection_status
128
+
129
+ def check_db_connection():
130
+ """Check if database is connected before operations"""
131
+ if db_interface is None:
132
+ return False, "❌ Please configure database connection first"
133
+ return True, "βœ… Database connected"
134
+
135
  # Initialize services
136
  api_service = API(BASE_URL)
 
137
 
138
+ # Updated database functions with connection checks
139
  def get_schemas():
140
+ """### `get_schemas()`"""
141
+ connected, status = check_db_connection()
142
+ if not connected:
143
+ return status
144
+ return db_interface.list_schemas()
145
 
146
  def get_db_infos():
147
+ """### `get_db_infos()`"""
148
+ connected, status = check_db_connection()
149
+ if not connected:
150
+ return status
151
+ return db_interface.list_database_info()
152
 
153
  def get_list_of_tables_in_schema(schema):
154
+ """### `get_list_of_tables_in_schema(schema_name: str)`"""
155
+ connected, status = check_db_connection()
156
+ if not connected:
157
+ return status
158
+ return db_interface.list_tables_in_schema(schema)
 
159
 
160
  def get_list_of_column_in_table(schema, table):
161
+ """### `get_list_of_column_in_table(schema_name: str, table_name: str)`"""
162
+ connected, status = check_db_connection()
163
+ if not connected:
164
+ return status
165
+ return db_interface.list_columns_in_table(schema, table)
 
166
 
167
  def run_read_only_query(query: str):
168
+ """### `run_read_only_query(query: str)`"""
169
+ connected, status = check_db_connection()
170
+ if not connected:
171
+ return status
172
+ return db_interface.read_only_query(query)
 
173
 
174
  def create_table_from_query(table_name: str, source_query: str):
175
+ """### `create_table_from_query(table_name: str, source_query: str)`"""
176
+ connected, status = check_db_connection()
177
+ if not connected:
178
+ return status
179
+ return db_interface.create_table_from_query(table_name, source_query)
 
180
 
181
  def drop_table(table_name: str):
182
+ """### `drop_table(table_name: str)`"""
183
+ connected, status = check_db_connection()
184
+ if not connected:
185
+ return status
186
+ return db_interface.drop_table(table_name)
 
187
 
188
  def create_sample_image():
189
  img_path = "./sample_graph.png"
 
197
  return create_sample_image()
198
 
199
  def do_annova(table_name, min_sample_size=0):
200
+ connected, status = check_db_connection()
201
+ if not connected:
202
+ return status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  return var_stats.anova(db_interface, table_name=table_name, min_sample_size=int(min_sample_size))
204
 
205
  def do_tukey_test(table_name, min_sample_size=0):
206
+ connected, status = check_db_connection()
207
+ if not connected:
208
+ return status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  return var_stats.tukey_test(db_interface, table_name=table_name, min_sample_size=int(min_sample_size))
210
 
211
  def do_tsne_embedding(query):
212
+ connected, status = check_db_connection()
213
+ if not connected:
214
+ return status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  return var_stats.embedding_clustering(db_interface, query)
216
 
217
  def do_vector_centroid(query):
218
+ connected, status = check_db_connection()
219
+ if not connected:
220
+ return status
 
 
 
 
 
 
 
 
221
  return var_stats.vector_centroid(db_interface, query)
222
 
223
  def embed_text_modal_api(text):
 
422
  6. **Error Handling**: All functions return status indicators - check for errors before proceeding
423
  7. **Data Safety**: Core tables (transactions, customers, articles) are protected from modification"""
424
 
425
+ # TAB 0: Database Configuration
426
+ with gr.Blocks(title="Database Configuration") as tab0:
427
+ gr.Markdown("# πŸ”Œ Database Configuration")
428
+ gr.Markdown("*Configure your database connection before using the analytics platform*")
429
+
430
+ with gr.Row():
431
+ with gr.Column(scale=1):
432
+ gr.Markdown("### πŸ—„οΈ Database Connection")
433
+ host_input = gr.Textbox(label="Host", placeholder="database.example.com", value="")
434
+ port_input = gr.Textbox(label="Port", placeholder="5432", value="")
435
+ database_input = gr.Textbox(label="Database", placeholder="my_database", value="")
436
+ user_input = gr.Textbox(label="User", placeholder="db_user", value="")
437
+ password_input = gr.Textbox(label="Password", type="password", placeholder="β€’β€’β€’β€’β€’β€’β€’β€’", value="")
438
+
439
+ connect_btn = gr.Button("πŸ”Œ Connect to Database", variant="primary")
440
+
441
+ with gr.Column(scale=1):
442
+ connection_status = gr.Textbox(label="πŸ”Œ Connection Status", value=db_connection_status, interactive=False)
443
+ gr.Markdown("### ℹ️ Instructions")
444
+ gr.Markdown("""
445
+ 1. **Fill in your database credentials**
446
+ 2. **Click 'Connect to Database'**
447
+ 3. **Wait for successful connection**
448
+ 4. **Proceed to other tabs once connected**
449
+
450
+ **Note**: All database operations require a valid connection.
451
+ """)
452
+
453
+ def handle_connection(host, port, database, user, password):
454
+ status, success = setup_database_connection(host, port, database, user, password)
455
+ return status
456
+
457
+ connect_btn.click(
458
+ handle_connection,
459
+ inputs=[host_input, port_input, database_input, user_input, password_input],
460
+ outputs=connection_status
461
+ )
462
+
463
  # TAB 1: Database Operations
464
  with gr.Blocks(title="Database Operations") as tab1:
465
  gr.Markdown("# πŸ—„οΈ Database Operations")
 
622
  with gr.Blocks(title="Statistical Analysis") as tab4:
623
  gr.Markdown("# πŸ“Š Statistical Analysis")
624
  gr.Markdown("*Run statistical tests on your data*")
625
+
626
  with gr.Row():
627
  with gr.Column(scale=1):
628
  gr.Markdown("### enter a dict that comply for annova function")
 
664
 
665
  # Create the TabbedInterface
666
  interface = gr.TabbedInterface(
667
+ [tab0, tab1, tab2, tab4, tab5],
668
+ tab_names=["πŸ”Œ Database Setup", "πŸ—„οΈ Database Operations", "πŸ€– AI Analytics", "πŸ“Š Statistical Analysis", "πŸ“Š MCP guidelines"],
669
+ title="E-commerce Database Analytics Platform",
670
  theme=gr.themes.Soft()
671
  )
672
 
db_work.py CHANGED
@@ -17,21 +17,34 @@ COLUMN_IN_TABLE=os.getenv('COLUMN_IN_TABLE')
17
 
18
 
19
  class DatabaseInterface:
20
- def __init__(self):
21
  # Initialize FastMCP server
22
  self.mcp = FastMCP("ecommerce-mcp-server")
23
 
24
- self.db_config = {
25
- 'host': os.getenv('DB_HOST'),
26
- 'port': os.getenv('DB_PORT'),
27
- 'database': os.getenv('DB_NAME'),
28
- 'user': os.getenv('DB_USER'),
29
- 'password': os.getenv('DB_PASSWORD')
30
- }
 
 
 
 
 
 
 
 
 
 
31
 
32
  def get_db_connection(self):
33
- """Create database connection"""
34
- return psycopg2.connect(**self.db_config)
 
 
 
35
 
36
  def list_schemas(self):
37
  print("=======>", LIST_SCHEMA)
 
17
 
18
 
19
  class DatabaseInterface:
20
+ def __init__(self, db_config: Optional[Dict[str, Any]] = None):
21
  # Initialize FastMCP server
22
  self.mcp = FastMCP("ecommerce-mcp-server")
23
 
24
+ if db_config:
25
+ self.db_config = db_config
26
+ else:
27
+ # Fallback to environment variables
28
+ self.db_config = {
29
+ 'host': os.getenv('DB_HOST'),
30
+ 'port': int(os.getenv('DB_PORT', 5432)),
31
+ 'database': os.getenv('DB_NAME'),
32
+ 'user': os.getenv('DB_USER'),
33
+ 'password': os.getenv('DB_PASSWORD')
34
+ }
35
+
36
+ # Validate configuration
37
+ required_fields = ['host', 'database', 'user', 'password']
38
+ missing_fields = [field for field in required_fields if not self.db_config.get(field)]
39
+ if missing_fields:
40
+ raise ValueError(f"Missing required database configuration: {missing_fields}")
41
 
42
  def get_db_connection(self):
43
+ """Create database connection with error handling"""
44
+ try:
45
+ return psycopg2.connect(**self.db_config)
46
+ except psycopg2.Error as e:
47
+ raise ConnectionError(f"Failed to connect to database: {str(e)}")
48
 
49
  def list_schemas(self):
50
  print("=======>", LIST_SCHEMA)