File size: 25,806 Bytes
80cfb11
 
 
 
6a39607
80cfb11
 
a2134eb
 
 
09e2bc4
6a39607
80cfb11
feac34e
80cfb11
 
 
09e2bc4
0391cfb
 
 
 
80cfb11
feac34e
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef0ee7c
6b3e14e
e8d2e8a
aba6b7f
 
 
 
 
 
 
 
80cfb11
6b3e14e
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
a3f399c
0e5cf1e
 
 
a3f399c
0e5cf1e
 
a3f399c
 
19f43ab
6c01c87
19f43ab
 
09e2bc4
b112622
6c01c87
 
 
09e2bc4
a3f399c
09e2bc4
a3f399c
6c01c87
0e5cf1e
 
6c01c87
0e5cf1e
 
6c01c87
 
09e2bc4
 
6c01c87
09e2bc4
 
 
 
 
 
0e5cf1e
 
 
 
 
 
 
6c01c87
 
 
 
 
e087162
09e2bc4
 
 
 
 
0e5cf1e
 
 
b112622
6c01c87
b112622
 
6c01c87
b112622
6c01c87
b112622
6c01c87
 
 
b112622
09e2bc4
b112622
6c01c87
b112622
6c01c87
b112622
6c01c87
 
b112622
6c01c87
 
 
 
 
 
 
b112622
0e5cf1e
6c01c87
0e5cf1e
 
0391cfb
0e5cf1e
6c01c87
0391cfb
0e5cf1e
6c01c87
09e2bc4
0e5cf1e
6c01c87
 
 
a0d91e2
 
8eaf76a
 
 
09e2bc4
 
 
 
 
6c01c87
09e2bc4
 
6c01c87
0e5cf1e
 
 
 
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09e2bc4
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09e2bc4
 
 
 
 
 
 
 
 
 
80cfb11
 
 
 
 
 
 
 
6b3e14e
80cfb11
6b3e14e
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
 
6b3e14e
80cfb11
 
 
 
6b3e14e
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
6b3e14e
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b3e14e
80cfb11
 
6b3e14e
80cfb11
 
6b3e14e
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b3e14e
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b3e14e
6e75e8f
80cfb11
 
6e75e8f
6b3e14e
80cfb11
440696e
80cfb11
 
 
 
6b3e14e
80cfb11
 
 
 
 
 
 
 
 
 
6b3e14e
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440696e
80cfb11
 
 
 
 
 
440696e
80cfb11
440696e
 
80cfb11
 
 
 
 
 
6e75e8f
 
 
6a39607
6b3e14e
80cfb11
 
 
 
6e75e8f
80cfb11
 
6b3e14e
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
 
440696e
80cfb11
 
 
 
 
 
440696e
80cfb11
440696e
 
80cfb11
 
 
 
440696e
 
e087162
80cfb11
 
 
 
e087162
80cfb11
 
 
6e75e8f
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
 
e087162
6e75e8f
80cfb11
 
 
6e75e8f
80cfb11
 
 
 
e087162
6e75e8f
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e087162
6e75e8f
80cfb11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2134eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e087162
 
 
 
a2134eb
 
 
a2d0bfa
 
 
 
a2134eb
 
 
 
ef0ee7c
a2134eb
 
 
 
 
8eaf76a
ef0ee7c
8eaf76a
 
 
 
 
a2134eb
 
 
 
 
80cfb11
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
import os
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Optional, AsyncIterator
import asyncpg 
from mcp.server.fastmcp import FastMCP, Context
from pydantic import Field
import pandasai as pai
import matplotlib as plt
import pandas as pd
import logging
from pandasai_openai import OpenAI


# Constants
DEFAULT_QUERY_LIMIT = 100

# logging info
# logging.basicConfig(level=logging.INFO)

# get logger
logger = logging.getLogger(__name__)


# Define our own PromptMessage class if the MCP one isn't available
@dataclass
class PromptMessage:
    content: str
    role: Optional[str] = "user"


# Database context class
@dataclass
class DbContext:
    pool: asyncpg.Pool
    schema: str


# Database connection lifecycle manager
@asynccontextmanager
async def db_lifespan(server: FastMCP) -> AsyncIterator[DbContext]:
    """Manage database connection lifecycle"""
    dsn = os.environ["DB_URL"]
    db_schema = os.environ["DB_SCHEMA"]

    pool = await asyncpg.create_pool(
        dsn,
        min_size=1,
        max_size=4,
        max_inactive_connection_lifetime=300,
        timeout=60,
        command_timeout=300,
    )
    try:
        yield DbContext(pool=pool, schema=db_schema)
    finally:
        # Clean up
        await pool.close()


# Create server with database lifecycle management
mcp = FastMCP(
    "SQL Database Server",
    dependencies=["asyncpg", "pydantic"],
    lifespan=db_lifespan
)


@mcp.resource(
    uri="resource://base_prompt",
    name="base_prompt",
    description="A base prompt to generate SQL queries and answer questions"
)
async def base_prompt_query() -> str:
    """Returns a base prompt to generate sql queries and answer questions"""

    base_prompt = """
==========================
# Your Role
==========================

You are an expert in generating and executing SQL queries, interacting with a PostgreSQL database using **FastMCP tools**, and visualizing results when requested. These tools allow you to:

- List available tables
- Retrieve schema details
- Execute SQL queries
- Visualize query results using PandasAI

Each tool may return previews or summaries of table contents to help you understand the data structure.

---

==========================
# Your Objective
==========================

When a user submits a request, you must:

1. **Analyze the request** to determine the required data, action, or visualization.
2. **Use FastMCP tools** to gather necessary information (e.g., list tables, retrieve schema).
3. **Generate a valid SQL SELECT query**, if needed, and clearly show the full query.
4. **Execute the SQL query** using the `execute_query` tool and return the results.
5. **Visualize results** using the `visualize_results` tool if the user explicitly requests a visualization (e.g., "create a chart", "visualize", "plot"). For visualizations:
   - Craft a visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
   - Send JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`.
6. **Chain tools logically**, such as: List Tables → Get Schema → Write and Run Query → Visualize Results (if requested).
7. **Explain your reasoning and each step taken** to ensure clarity and transparency.

---

==========================
# Critical Rules
==========================

- Only use **read-only** SQL queries such as **SELECT**, **COUNT**, or queries with **GROUP BY**, **ORDER BY**, etc.
- **Never** use destructive operations like **DELETE**, **UPDATE**, **INSERT**, or **DROP**.
- Always show the SQL query you generate along with the execution result.
- Validate SQL syntax before execution.
- Never assume table or column names. Use tools to confirm structure.
- Use memory efficiently. Don't rerun a tool unless necessary.
- If you generate a SQL query, immediately call the **execute_query** tool.
- If the user requests a visualization, call the **visualize_results** tool with:
  - A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
  - JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`.
- For non-query or non-visualization requests (e.g., history questions), respond appropriately without forcing a query or visualization.

---

==========================
# Database Description
==========================

{descriptions}

---

==========================
# Tools
==========================

You can use the following FastMCP tools to create **read-only** queries (e.g., `SELECT`, `COUNT`, `GROUP BY`, `ORDER BY`) and visualize results when requested. Chain tools logically to gather information, execute queries, or generate visualizations.

{tools}

---

### Invalid Example — DELETE Operation (Not Allowed):
**User Request:** "Delete all customers from Germany."

**Response Guidance:**
- **Do not generate or execute** destructive queries such as `DELETE`.
- Instead, respond with a message like:
  > Destructive operations such as `DELETE` are not permitted. I can help you retrieve the customers from Germany using a `SELECT` query instead:
  > ```sql
  > SELECT * FROM customers WHERE country = 'Germany';
  > ```

==========================
# Output Format
==========================

Present your final answer using the following structure in markdown language:

# Result
{{Take the result from the execute_query tool and format it nicely using Markdown. Use a beautiful Markdown table for tabular data (rows and columns) including headers and show such simple results using a table. Use bullet points or items in markdown for answers that include lists of names or descriptions. Use plain text for single values or simple messages. Ensure data alignment and clarity.}}

# Explanation
{{Provide a concise explanation or interpretation of the results (and visualization, if applicable) in 1-3 sentences. Explain what the data and visualization (if any) represent in the context of the user's request.}}

# Query
```sql
{{Display the exact SQL query you generated and executed here to answer the user's request.}}
```

==========================
# Reminder
==========================
- **Every time you generate a SQL query, call `execute_query` immediately and include the result.**
- **If the user requests a visualization (e.g., "create a chart", "visualize", "plot"), call `visualize_results` with:**
  - A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
  - JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`.
- **For non-query or non-visualization requests, respond appropriately without forcing a query or visualization.**

**Conversation History:**
Use the conversation history for context when available to maintain continuity.
"""

    return base_prompt


@mcp.tool(description="tests the database connection and returns the PostgreSQL version or an error message.")
async def test_connection(ctx: Context) -> str:
    """Test database connection"""
    try:
        pool = ctx.request_context.lifespan_context.pool
        async with pool.acquire() as conn:
            version = await conn.fetchval("SELECT version();")
        return f"Connection successful. PostgreSQL version: {version}"
    except Exception as e:
        return f"Connection failed: {str(e)}"


@mcp.tool(description="Executes a read-only SQL SELECT query and returns formatted results or an error message.")
async def execute_query(
        query: str = Field(description="SQL query to execute (SELECT only)"),
        limit: Optional[int] = Field(default=DEFAULT_QUERY_LIMIT, description="Maximum number of rows to return"),
        ctx: Context = None
) -> str:
    """Execute a read-only SQL query against the database"""
    # Validate query - simple check for read-only

    query = query.strip()
    if not query.lower().startswith("select"):
        return "Error: Only SELECT queries are allowed for security reasons."

    try:
        pool = ctx.request_context.lifespan_context.pool
        async with pool.acquire() as conn:
            result = await conn.fetch(query)

            if not result:
                return "Query executed successfully. No rows returned."

            # Format results
            columns = [k for k in result[0].keys()]
            header = " | ".join(columns)
            separator = "-" * len(header)

            # Format rows with limit
            rows = [" | ".join(str(val) for val in row.values())
                    for row in result[:limit if limit else DEFAULT_QUERY_LIMIT]]

            # print(f"{header}\n{separator}\n" + "\n".join(rows))

            # print(f"===== Header Type: ======\n {type(header)}")
            # print(f"===== Row Type: ======\n {type(rows)}")
            #
            # # print the  data itself
            # print(f"===== Header Data: ======\n {header}")
            # print(f"===== Row Data: ======\n {rows}")


            return f"{header}\n{separator}\n" + "\n".join(rows)
    except asyncpg.exceptions.PostgresError as e:
        return f"SQL Error: {str(e)}"
    except Exception as e:
        return f"Error: {str(e)}"


# Database helper functions
async def get_all_tables(pool, db_schema):
    """Get all tables from the database"""
    print(f"schema: {db_schema}")
    async with pool.acquire() as conn:
        result = await conn.fetch("""
            SELECT c.relname AS table_name
            FROM pg_class AS c
            JOIN pg_namespace AS n ON n.oid = c.relnamespace
            WHERE NOT EXISTS (
                SELECT 1
                FROM pg_inherits AS i
                WHERE i.inhrelid = c.oid
            )
            AND c.relkind IN ('r', 'p')
            AND n.nspname = $1
            AND c.relname NOT LIKE 'pg_%'
            ORDER BY c.relname;
        """, db_schema)

        return result


async def get_table_schema_info(pool, db_schema, table_name):
    """Get schema information for a specific table"""
    async with pool.acquire() as conn:
        columns = await conn.fetch("""
            SELECT
                column_name,
                data_type,
                is_nullable,
                column_default,
                character_maximum_length
            FROM information_schema.columns
            WHERE table_schema = $1
            AND table_name = $2
            ORDER BY ordinal_position;
        """, db_schema, table_name)

        return columns


def format_table_schema(table_name, columns):
    """Format table schema into readable text"""
    if not columns:
        return f"Table '{table_name}' not found."

    result = [f"Table: {table_name}", "Columns:"]
    for col in columns:
        nullable = "NULL" if col['is_nullable'] == 'YES' else "NOT NULL"
        length = f"({col['character_maximum_length']})" if col['character_maximum_length'] else ""
        default = f" DEFAULT {col['column_default']}" if col['column_default'] else ""
        result.append(f"- {col['column_name']} ({col['data_type']}{length}) {nullable}{default}")

    return "\n".join(result)


@mcp.tool(description="Lists all table names in the configured database schema.")
async def list_tables() -> str:
    """List all tables in the database"""
    try:
        async with db_lifespan(mcp) as db_ctx:
            result = await get_all_tables(db_ctx.pool, db_ctx.schema)

            if not result:
                return f"No tables found in the {db_ctx.schema} schema."

            return "\n".join(row['table_name'] for row in result)
    except asyncpg.exceptions.PostgresError as e:
        return f"SQL Error: {str(e)}"
    except Exception as e:
        return f"Error: {str(e)}"


@mcp.tool(description="Retrieves and formats the schema details of a specific table in the database.")
async def get_table_schema(table_name: str) -> str:
    """Get schema information for a specific table"""
    try:
        db_schema = os.environ["DB_SCHEMA"]

        async with db_lifespan(mcp) as db_ctx:
            columns = await get_table_schema_info(db_ctx.pool, db_schema, table_name)

            if not columns:
                return f"Table '{table_name}' not found in {db_schema} schema."

            return format_table_schema(table_name, columns)
    except asyncpg.exceptions.PostgresError as e:
        return f"SQL Error: {str(e)}"
    except Exception as e:
        return f"Error: {str(e)}"


@mcp.tool(description="Retrieves foreign key relationships for a specified table.")
def get_foreign_keys(table_name: str) -> str:
    """Get foreign key information for a table.

    Args:
        table_name: The name of the table to get foreign keys from
        schema: The schema name (defaults to 'public')
    """
    db_schema = os.environ["DB_SCHEMA"]

    sql = """
    SELECT 
        tc.constraint_name,
        kcu.column_name as fk_column,
        ccu.table_schema as referenced_schema,
        ccu.table_name as referenced_table,
        ccu.column_name as referenced_column
    FROM information_schema.table_constraints tc
    JOIN information_schema.key_column_usage kcu
        ON tc.constraint_name = kcu.constraint_name
        AND tc.table_schema = kcu.table_schema
    JOIN information_schema.referential_constraints rc
        ON tc.constraint_name = rc.constraint_name
    JOIN information_schema.constraint_column_usage ccu
        ON rc.unique_constraint_name = ccu.constraint_name
    WHERE tc.constraint_type = 'FOREIGN KEY'
        AND tc.table_schema = {db_schema}
        AND tc.table_name = {table_name}
    ORDER BY tc.constraint_name, kcu.ordinal_position
    """

    return execute_query(sql, (db_schema, table_name))


@mcp.tool(description="Fetches and formats the schema details for all tables in the configured database schema.")
async def get_all_schemas() -> str:
    """Get schema information for all tables in the database"""
    try:
        db_schema = os.environ["DB_SCHEMA"]

        async with db_lifespan(mcp) as db_ctx:
            tables = await get_all_tables(db_ctx.pool, db_ctx.schema)

            if not tables:
                return f"No tables found in the {db_ctx.schema} schema."

            all_schemas = []
            for table in tables:
                table_name = table['table_name']
                columns = await get_table_schema_info(db_ctx.pool, db_schema, table_name)
                table_schema = format_table_schema(table_name, columns)
                all_schemas.append(table_schema)
                all_schemas.append("")  # Add empty line between tables

            return "\n".join(all_schemas)
    except asyncpg.exceptions.PostgresError as e:
        return f"SQL Error: {str(e)}"
    except Exception as e:
        return f"Error: {str(e)}"


@mcp.prompt(description="Generates a prompt message to help craft a best-practice SELECT query for a given table.")
async def generate_select_query(table_name: str) -> list[PromptMessage]:
    """Generate a SELECT query with best practices for a table"""
    try:
        async with db_lifespan(mcp) as db_ctx:
            pool = db_ctx.pool
            async with pool.acquire() as conn:
                columns = await conn.fetch("""
                    SELECT column_name, data_type
                    FROM information_schema.columns 
                    WHERE table_schema = $1 AND table_name = $2
                    ORDER BY ordinal_position
                """, db_ctx.schema, table_name)

            if not columns:
                return [PromptMessage(f"Table '{table_name}' not found in schema '{db_ctx.schema}'.")]

            columns_text = "\n".join([f"- {col['column_name']} ({col['data_type']})" for col in columns])

            return [
                PromptMessage(
                    f"""Please help me write a well-structured, efficient SELECT query for the '{table_name}' table.

                    Table Schema:
                    {columns_text}

                    PostgreSQL SQL Best Practices:
                    - Use explicit column names instead of * when possible
                    - Include LIMIT clauses to restrict result sets
                    - Consider adding WHERE clauses to filter results
                    - Use appropriate indexing considerations
                    - Format SQL with proper indentation and line breaks

                    Create a basic SELECT query following these best practices:"""
                )
            ]
    except Exception as e:
        return [PromptMessage(f"Error generating select query: {str(e)}")]


@mcp.prompt(description="Generates a prompt message to assist in writing analytical queries for a given table.")
async def generate_analytical_query(table_name: str) -> list[PromptMessage]:
    """ Generate analytical queries for a table
    Args:
        table_name: The name of the table to generate analytical queries for
    """ 
    db_schema = os.environ["DB_SCHEMA"]
    try:
        async with db_lifespan(mcp) as db_ctx:
            pool = db_ctx.pool
            async with pool.acquire() as conn:
                columns = await conn.fetch(f"""
                    SELECT column_name, data_type
                    FROM information_schema.columns 
                    WHERE table_schema = {db_schema} AND table_name = {table_name}
                    ORDER BY ordinal_position
                """, db_ctx.schema, table_name)

            if not columns:
                return [PromptMessage(f"Table '{table_name}' not found in schema '{db_ctx.schema}'.")]

            columns_text = "\n".join([f"- {col['column_name']} ({col['data_type']})" for col in columns])

            return [
                PromptMessage(
                    f"""Please help me create analytical queries for the '{table_name}' table.

                    Table Schema:
                    {columns_text}

                    PostgreSQL SQL Best Practices:
                    - Use aggregation functions (COUNT, SUM, AVG, MIN, MAX) appropriately
                    - Group data using GROUP BY for meaningful aggregations
                    - Filter groups with HAVING clauses when needed
                    - Consider using window functions for advanced analytics
                    - Format SQL with proper indentation and line breaks

                    Create a set of analytical queries for this table:"""
                )
            ]
    except Exception as e:
        return [PromptMessage(f"Error generating analytical query: {str(e)}")]


@mcp.tool(
    description="Identifies both explicit and implied foreign key relationships for a given table using schema analysis and naming patterns.")
def find_relationships(table_name: str, db_schema: str = 'public') -> str:
    """Find both explicit and implied relationships for a table.

    Args:
        table_name: The name of the table to analyze relationships for
        db_schema: The schema name (defaults to 'public')
    """
    try:
        # First get explicit foreign key relationships
        fk_sql = f"""
        SELECT 
            kcu.column_name,
            ccu.table_name as foreign_table,
            ccu.column_name as foreign_column,
            'Explicit FK' as relationship_type,
            1 as confidence_level
        FROM information_schema.table_constraints tc
        JOIN information_schema.key_column_usage kcu 
            ON tc.constraint_name = kcu.constraint_name
            AND tc.table_schema = kcu.table_schema
        JOIN information_schema.constraint_column_usage ccu
            ON ccu.constraint_name = tc.constraint_name
            AND ccu.table_schema = tc.table_schema
        WHERE tc.constraint_type = 'FOREIGN KEY'
            AND tc.table_schema = {db_schema}
            AND tc.table_name = {table_name}
        """

        # Then look for implied relationships based on common patterns
        implied_sql = f"""
        WITH source_columns AS (
            -- Get all ID-like columns from our table
            SELECT column_name, data_type
            FROM information_schema.columns
            WHERE table_schema = {db_schema}
            AND table_name = {table_name}
            AND (
                column_name LIKE '%%id' 
                OR column_name LIKE '%%_id'
                OR column_name LIKE '%%_fk'
            )
        ),
        potential_references AS (
            -- Find tables that might be referenced by our ID columns
            SELECT DISTINCT
                sc.column_name as source_column,
                sc.data_type as source_type,
                t.table_name as target_table,
                c.column_name as target_column,
                c.data_type as target_type,
                CASE
                    -- Highest confidence: column matches table_id pattern and types match
                    WHEN sc.column_name = t.table_name || '_id' 
                        AND sc.data_type = c.data_type THEN 2
                    -- High confidence: column ends with _id and types match
                    WHEN sc.column_name LIKE '%%_id' 
                        AND sc.data_type = c.data_type THEN 3
                    -- Medium confidence: column contains table name and types match
                    WHEN sc.column_name LIKE '%%' || t.table_name || '%%'
                        AND sc.data_type = c.data_type THEN 4
                    -- Lower confidence: column ends with id and types match
                    WHEN sc.column_name LIKE '%%id'
                        AND sc.data_type = c.data_type THEN 5
                END as confidence_level
            FROM source_columns sc
            CROSS JOIN information_schema.tables t
            JOIN information_schema.columns c 
                ON c.table_schema = t.table_schema 
                AND c.table_name = t.table_name
                AND (c.column_name = 'id' OR c.column_name = sc.column_name)
            WHERE t.table_schema = {db_schema}
                AND t.table_name != {table_name}  -- Exclude self-references
        )
        SELECT 
            source_column as column_name,
            target_table as foreign_table,
            target_column as foreign_column,
            CASE 
                WHEN confidence_level = 2 THEN 'Strong implied relationship (exact match)'
                WHEN confidence_level = 3 THEN 'Strong implied relationship (_id pattern)'
                WHEN confidence_level = 4 THEN 'Likely implied relationship (name match)'
                ELSE 'Possible implied relationship'
            END as relationship_type,
            confidence_level
        FROM potential_references
        WHERE confidence_level IS NOT NULL
        ORDER BY confidence_level, source_column;
        """

        # Execute both queries and combine results
        fk_results = execute_query(fk_sql)
        implied_results = execute_query(implied_sql)

        # If both queries returned "No results found", return that
        if fk_results == "No results found" and implied_results == "No results found":
            return "No relationships found for this table"

        # Otherwise, return both sets of results
        return f"Explicit Foreign Keys:\n{fk_results}\n\nImplied Relationships:\n{implied_results}"

    except Exception as e:
        return f"Error finding relationships: {str(e)}"


@mcp.tool(description="Visualizes query results using a prompt and JSON data.")
async def visualize_results(json_data: dict, vis_prompt: str) -> str:
    """
    Generates a visualization based on query results using PandasAI.

    Args:
        json_data (dict): A dictionary containing the query results.
            It should have two keys:
                - 'columns': A list of column names (strings).
                - 'data': A list of lists, where each inner list represents a row of data.
                  Each element in the inner list corresponds to a column in 'columns'.
            Example:
            {
                'columns': ['Region', 'Product', 'Sales'],
                'data': [
                    ['North', 'Widget', 150],
                    ['South', 'Widget', 200]
                ]
            }
        vis_prompt (str): A natural language prompt describing the desired visualization
            (e.g., "Create a bar chart showing sales by region").

    Returns:
        str: The path to the saved visualization file (e.g., 'visualization_output.png')
             or an error message if the visualization fails.
    """
    try:
        # Debug prints to see what's being received
        OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
        if OPENAI_API_KEY:
            pllm = OpenAI(api_token=OPENAI_API_KEY)
            pai.config.set({"llm": pllm})

        # Convert JSON to DataFrame
        df = pd.DataFrame(json_data["data"], columns=json_data["columns"])
        
        # Shorten long values in text columns
        for column in df.select_dtypes(include=['object']).columns:
            df[column] = df[column].apply(lambda x: str(x)[:20] + '...' if len(str(x)) > 20 else str(x))

        # Initialize PandasAI
        df_ai = pai.DataFrame(df)

        api_key = os.environ["PANDAS_KEY"]
        pai.api_key.set(api_key)

        # Generate visualization
        df_ai.chat(vis_prompt)

        # Get the visualization path
        PANDAS_EXPORTS_PATH = os.environ["PANDAS_EXPORTS_PATH"]
        generated_files = [f for f in os.listdir(PANDAS_EXPORTS_PATH) if f.startswith("temp_chart")]
        if generated_files:
            visualization_path = os.path.join(PANDAS_EXPORTS_PATH, generated_files[0])

        return f"Visualization saved as {visualization_path}"

    except Exception as e:
        return f"Visualization error: {str(e)}"


if __name__ == "__main__":
    mcp.run()