|
|
import os |
|
|
from contextlib import asynccontextmanager |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, AsyncIterator |
|
|
import asyncpg |
|
|
from mcp.server.fastmcp import FastMCP, Context |
|
|
from pydantic import Field |
|
|
import pandasai as pai |
|
|
import matplotlib as plt |
|
|
import pandas as pd |
|
|
import logging |
|
|
from pandasai_openai import OpenAI |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_QUERY_LIMIT = 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PromptMessage: |
|
|
content: str |
|
|
role: Optional[str] = "user" |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DbContext: |
|
|
pool: asyncpg.Pool |
|
|
schema: str |
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def db_lifespan(server: FastMCP) -> AsyncIterator[DbContext]: |
|
|
"""Manage database connection lifecycle""" |
|
|
dsn = os.environ["DB_URL"] |
|
|
db_schema = os.environ["DB_SCHEMA"] |
|
|
|
|
|
pool = await asyncpg.create_pool( |
|
|
dsn, |
|
|
min_size=1, |
|
|
max_size=4, |
|
|
max_inactive_connection_lifetime=300, |
|
|
timeout=60, |
|
|
command_timeout=300, |
|
|
) |
|
|
try: |
|
|
yield DbContext(pool=pool, schema=db_schema) |
|
|
finally: |
|
|
|
|
|
await pool.close() |
|
|
|
|
|
|
|
|
|
|
|
mcp = FastMCP( |
|
|
"SQL Database Server", |
|
|
dependencies=["asyncpg", "pydantic"], |
|
|
lifespan=db_lifespan |
|
|
) |
|
|
|
|
|
|
|
|
@mcp.resource( |
|
|
uri="resource://base_prompt", |
|
|
name="base_prompt", |
|
|
description="A base prompt to generate SQL queries and answer questions" |
|
|
) |
|
|
async def base_prompt_query() -> str: |
|
|
"""Returns a base prompt to generate sql queries and answer questions""" |
|
|
|
|
|
base_prompt = """ |
|
|
========================== |
|
|
# Your Role |
|
|
========================== |
|
|
|
|
|
You are an expert in generating and executing SQL queries, interacting with a PostgreSQL database using **FastMCP tools**, and visualizing results when requested. These tools allow you to: |
|
|
|
|
|
- List available tables |
|
|
- Retrieve schema details |
|
|
- Execute SQL queries |
|
|
- Visualize query results using PandasAI |
|
|
|
|
|
Each tool may return previews or summaries of table contents to help you understand the data structure. |
|
|
|
|
|
--- |
|
|
|
|
|
========================== |
|
|
# Your Objective |
|
|
========================== |
|
|
|
|
|
When a user submits a request, you must: |
|
|
|
|
|
1. **Analyze the request** to determine the required data, action, or visualization. |
|
|
2. **Use FastMCP tools** to gather necessary information (e.g., list tables, retrieve schema). |
|
|
3. **Generate a valid SQL SELECT query**, if needed, and clearly show the full query. |
|
|
4. **Execute the SQL query** using the `execute_query` tool and return the results. |
|
|
5. **Visualize results** using the `visualize_results` tool if the user explicitly requests a visualization (e.g., "create a chart", "visualize", "plot"). For visualizations: |
|
|
- Craft a visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region"). |
|
|
- Send JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`. |
|
|
6. **Chain tools logically**, such as: List Tables → Get Schema → Write and Run Query → Visualize Results (if requested). |
|
|
7. **Explain your reasoning and each step taken** to ensure clarity and transparency. |
|
|
|
|
|
--- |
|
|
|
|
|
========================== |
|
|
# Critical Rules |
|
|
========================== |
|
|
|
|
|
- Only use **read-only** SQL queries such as **SELECT**, **COUNT**, or queries with **GROUP BY**, **ORDER BY**, etc. |
|
|
- **Never** use destructive operations like **DELETE**, **UPDATE**, **INSERT**, or **DROP**. |
|
|
- Always show the SQL query you generate along with the execution result. |
|
|
- Validate SQL syntax before execution. |
|
|
- Never assume table or column names. Use tools to confirm structure. |
|
|
- Use memory efficiently. Don't rerun a tool unless necessary. |
|
|
- If you generate a SQL query, immediately call the **execute_query** tool. |
|
|
- If the user requests a visualization, call the **visualize_results** tool with: |
|
|
- A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region"). |
|
|
- JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`. |
|
|
- For non-query or non-visualization requests (e.g., history questions), respond appropriately without forcing a query or visualization. |
|
|
|
|
|
--- |
|
|
|
|
|
========================== |
|
|
# Database Description |
|
|
========================== |
|
|
|
|
|
{descriptions} |
|
|
|
|
|
--- |
|
|
|
|
|
========================== |
|
|
# Tools |
|
|
========================== |
|
|
|
|
|
You can use the following FastMCP tools to create **read-only** queries (e.g., `SELECT`, `COUNT`, `GROUP BY`, `ORDER BY`) and visualize results when requested. Chain tools logically to gather information, execute queries, or generate visualizations. |
|
|
|
|
|
{tools} |
|
|
|
|
|
--- |
|
|
|
|
|
### Invalid Example — DELETE Operation (Not Allowed): |
|
|
**User Request:** "Delete all customers from Germany." |
|
|
|
|
|
**Response Guidance:** |
|
|
- **Do not generate or execute** destructive queries such as `DELETE`. |
|
|
- Instead, respond with a message like: |
|
|
> Destructive operations such as `DELETE` are not permitted. I can help you retrieve the customers from Germany using a `SELECT` query instead: |
|
|
> ```sql |
|
|
> SELECT * FROM customers WHERE country = 'Germany'; |
|
|
> ``` |
|
|
|
|
|
========================== |
|
|
# Output Format |
|
|
========================== |
|
|
|
|
|
Present your final answer using the following structure in markdown language: |
|
|
|
|
|
# Result |
|
|
{{Take the result from the execute_query tool and format it nicely using Markdown. Use a beautiful Markdown table for tabular data (rows and columns) including headers and show such simple results using a table. Use bullet points or items in markdown for answers that include lists of names or descriptions. Use plain text for single values or simple messages. Ensure data alignment and clarity.}} |
|
|
|
|
|
# Explanation |
|
|
{{Provide a concise explanation or interpretation of the results (and visualization, if applicable) in 1-3 sentences. Explain what the data and visualization (if any) represent in the context of the user's request.}} |
|
|
|
|
|
# Query |
|
|
```sql |
|
|
{{Display the exact SQL query you generated and executed here to answer the user's request.}} |
|
|
``` |
|
|
|
|
|
========================== |
|
|
# Reminder |
|
|
========================== |
|
|
- **Every time you generate a SQL query, call `execute_query` immediately and include the result.** |
|
|
- **If the user requests a visualization (e.g., "create a chart", "visualize", "plot"), call `visualize_results` with:** |
|
|
- A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region"). |
|
|
- JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`. |
|
|
- **For non-query or non-visualization requests, respond appropriately without forcing a query or visualization.** |
|
|
|
|
|
**Conversation History:** |
|
|
Use the conversation history for context when available to maintain continuity. |
|
|
""" |
|
|
|
|
|
return base_prompt |
|
|
|
|
|
|
|
|
@mcp.tool(description="tests the database connection and returns the PostgreSQL version or an error message.") |
|
|
async def test_connection(ctx: Context) -> str: |
|
|
"""Test database connection""" |
|
|
try: |
|
|
pool = ctx.request_context.lifespan_context.pool |
|
|
async with pool.acquire() as conn: |
|
|
version = await conn.fetchval("SELECT version();") |
|
|
return f"Connection successful. PostgreSQL version: {version}" |
|
|
except Exception as e: |
|
|
return f"Connection failed: {str(e)}" |
|
|
|
|
|
|
|
|
@mcp.tool(description="Executes a read-only SQL SELECT query and returns formatted results or an error message.") |
|
|
async def execute_query( |
|
|
query: str = Field(description="SQL query to execute (SELECT only)"), |
|
|
limit: Optional[int] = Field(default=DEFAULT_QUERY_LIMIT, description="Maximum number of rows to return"), |
|
|
ctx: Context = None |
|
|
) -> str: |
|
|
"""Execute a read-only SQL query against the database""" |
|
|
|
|
|
|
|
|
query = query.strip() |
|
|
if not query.lower().startswith("select"): |
|
|
return "Error: Only SELECT queries are allowed for security reasons." |
|
|
|
|
|
try: |
|
|
pool = ctx.request_context.lifespan_context.pool |
|
|
async with pool.acquire() as conn: |
|
|
result = await conn.fetch(query) |
|
|
|
|
|
if not result: |
|
|
return "Query executed successfully. No rows returned." |
|
|
|
|
|
|
|
|
columns = [k for k in result[0].keys()] |
|
|
header = " | ".join(columns) |
|
|
separator = "-" * len(header) |
|
|
|
|
|
|
|
|
rows = [" | ".join(str(val) for val in row.values()) |
|
|
for row in result[:limit if limit else DEFAULT_QUERY_LIMIT]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return f"{header}\n{separator}\n" + "\n".join(rows) |
|
|
except asyncpg.exceptions.PostgresError as e: |
|
|
return f"SQL Error: {str(e)}" |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
async def get_all_tables(pool, db_schema): |
|
|
"""Get all tables from the database""" |
|
|
print(f"schema: {db_schema}") |
|
|
async with pool.acquire() as conn: |
|
|
result = await conn.fetch(""" |
|
|
SELECT c.relname AS table_name |
|
|
FROM pg_class AS c |
|
|
JOIN pg_namespace AS n ON n.oid = c.relnamespace |
|
|
WHERE NOT EXISTS ( |
|
|
SELECT 1 |
|
|
FROM pg_inherits AS i |
|
|
WHERE i.inhrelid = c.oid |
|
|
) |
|
|
AND c.relkind IN ('r', 'p') |
|
|
AND n.nspname = $1 |
|
|
AND c.relname NOT LIKE 'pg_%' |
|
|
ORDER BY c.relname; |
|
|
""", db_schema) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
async def get_table_schema_info(pool, db_schema, table_name): |
|
|
"""Get schema information for a specific table""" |
|
|
async with pool.acquire() as conn: |
|
|
columns = await conn.fetch(""" |
|
|
SELECT |
|
|
column_name, |
|
|
data_type, |
|
|
is_nullable, |
|
|
column_default, |
|
|
character_maximum_length |
|
|
FROM information_schema.columns |
|
|
WHERE table_schema = $1 |
|
|
AND table_name = $2 |
|
|
ORDER BY ordinal_position; |
|
|
""", db_schema, table_name) |
|
|
|
|
|
return columns |
|
|
|
|
|
|
|
|
def format_table_schema(table_name, columns): |
|
|
"""Format table schema into readable text""" |
|
|
if not columns: |
|
|
return f"Table '{table_name}' not found." |
|
|
|
|
|
result = [f"Table: {table_name}", "Columns:"] |
|
|
for col in columns: |
|
|
nullable = "NULL" if col['is_nullable'] == 'YES' else "NOT NULL" |
|
|
length = f"({col['character_maximum_length']})" if col['character_maximum_length'] else "" |
|
|
default = f" DEFAULT {col['column_default']}" if col['column_default'] else "" |
|
|
result.append(f"- {col['column_name']} ({col['data_type']}{length}) {nullable}{default}") |
|
|
|
|
|
return "\n".join(result) |
|
|
|
|
|
|
|
|
@mcp.tool(description="Lists all table names in the configured database schema.") |
|
|
async def list_tables() -> str: |
|
|
"""List all tables in the database""" |
|
|
try: |
|
|
async with db_lifespan(mcp) as db_ctx: |
|
|
result = await get_all_tables(db_ctx.pool, db_ctx.schema) |
|
|
|
|
|
if not result: |
|
|
return f"No tables found in the {db_ctx.schema} schema." |
|
|
|
|
|
return "\n".join(row['table_name'] for row in result) |
|
|
except asyncpg.exceptions.PostgresError as e: |
|
|
return f"SQL Error: {str(e)}" |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
@mcp.tool(description="Retrieves and formats the schema details of a specific table in the database.") |
|
|
async def get_table_schema(table_name: str) -> str: |
|
|
"""Get schema information for a specific table""" |
|
|
try: |
|
|
db_schema = os.environ["DB_SCHEMA"] |
|
|
|
|
|
async with db_lifespan(mcp) as db_ctx: |
|
|
columns = await get_table_schema_info(db_ctx.pool, db_schema, table_name) |
|
|
|
|
|
if not columns: |
|
|
return f"Table '{table_name}' not found in {db_schema} schema." |
|
|
|
|
|
return format_table_schema(table_name, columns) |
|
|
except asyncpg.exceptions.PostgresError as e: |
|
|
return f"SQL Error: {str(e)}" |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
@mcp.tool(description="Retrieves foreign key relationships for a specified table.") |
|
|
def get_foreign_keys(table_name: str) -> str: |
|
|
"""Get foreign key information for a table. |
|
|
|
|
|
Args: |
|
|
table_name: The name of the table to get foreign keys from |
|
|
schema: The schema name (defaults to 'public') |
|
|
""" |
|
|
db_schema = os.environ["DB_SCHEMA"] |
|
|
|
|
|
sql = """ |
|
|
SELECT |
|
|
tc.constraint_name, |
|
|
kcu.column_name as fk_column, |
|
|
ccu.table_schema as referenced_schema, |
|
|
ccu.table_name as referenced_table, |
|
|
ccu.column_name as referenced_column |
|
|
FROM information_schema.table_constraints tc |
|
|
JOIN information_schema.key_column_usage kcu |
|
|
ON tc.constraint_name = kcu.constraint_name |
|
|
AND tc.table_schema = kcu.table_schema |
|
|
JOIN information_schema.referential_constraints rc |
|
|
ON tc.constraint_name = rc.constraint_name |
|
|
JOIN information_schema.constraint_column_usage ccu |
|
|
ON rc.unique_constraint_name = ccu.constraint_name |
|
|
WHERE tc.constraint_type = 'FOREIGN KEY' |
|
|
AND tc.table_schema = {db_schema} |
|
|
AND tc.table_name = {table_name} |
|
|
ORDER BY tc.constraint_name, kcu.ordinal_position |
|
|
""" |
|
|
|
|
|
return execute_query(sql, (db_schema, table_name)) |
|
|
|
|
|
|
|
|
@mcp.tool(description="Fetches and formats the schema details for all tables in the configured database schema.") |
|
|
async def get_all_schemas() -> str: |
|
|
"""Get schema information for all tables in the database""" |
|
|
try: |
|
|
db_schema = os.environ["DB_SCHEMA"] |
|
|
|
|
|
async with db_lifespan(mcp) as db_ctx: |
|
|
tables = await get_all_tables(db_ctx.pool, db_ctx.schema) |
|
|
|
|
|
if not tables: |
|
|
return f"No tables found in the {db_ctx.schema} schema." |
|
|
|
|
|
all_schemas = [] |
|
|
for table in tables: |
|
|
table_name = table['table_name'] |
|
|
columns = await get_table_schema_info(db_ctx.pool, db_schema, table_name) |
|
|
table_schema = format_table_schema(table_name, columns) |
|
|
all_schemas.append(table_schema) |
|
|
all_schemas.append("") |
|
|
|
|
|
return "\n".join(all_schemas) |
|
|
except asyncpg.exceptions.PostgresError as e: |
|
|
return f"SQL Error: {str(e)}" |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
@mcp.prompt(description="Generates a prompt message to help craft a best-practice SELECT query for a given table.") |
|
|
async def generate_select_query(table_name: str) -> list[PromptMessage]: |
|
|
"""Generate a SELECT query with best practices for a table""" |
|
|
try: |
|
|
async with db_lifespan(mcp) as db_ctx: |
|
|
pool = db_ctx.pool |
|
|
async with pool.acquire() as conn: |
|
|
columns = await conn.fetch(""" |
|
|
SELECT column_name, data_type |
|
|
FROM information_schema.columns |
|
|
WHERE table_schema = $1 AND table_name = $2 |
|
|
ORDER BY ordinal_position |
|
|
""", db_ctx.schema, table_name) |
|
|
|
|
|
if not columns: |
|
|
return [PromptMessage(f"Table '{table_name}' not found in schema '{db_ctx.schema}'.")] |
|
|
|
|
|
columns_text = "\n".join([f"- {col['column_name']} ({col['data_type']})" for col in columns]) |
|
|
|
|
|
return [ |
|
|
PromptMessage( |
|
|
f"""Please help me write a well-structured, efficient SELECT query for the '{table_name}' table. |
|
|
|
|
|
Table Schema: |
|
|
{columns_text} |
|
|
|
|
|
PostgreSQL SQL Best Practices: |
|
|
- Use explicit column names instead of * when possible |
|
|
- Include LIMIT clauses to restrict result sets |
|
|
- Consider adding WHERE clauses to filter results |
|
|
- Use appropriate indexing considerations |
|
|
- Format SQL with proper indentation and line breaks |
|
|
|
|
|
Create a basic SELECT query following these best practices:""" |
|
|
) |
|
|
] |
|
|
except Exception as e: |
|
|
return [PromptMessage(f"Error generating select query: {str(e)}")] |
|
|
|
|
|
|
|
|
@mcp.prompt(description="Generates a prompt message to assist in writing analytical queries for a given table.") |
|
|
async def generate_analytical_query(table_name: str) -> list[PromptMessage]: |
|
|
""" Generate analytical queries for a table |
|
|
Args: |
|
|
table_name: The name of the table to generate analytical queries for |
|
|
""" |
|
|
db_schema = os.environ["DB_SCHEMA"] |
|
|
try: |
|
|
async with db_lifespan(mcp) as db_ctx: |
|
|
pool = db_ctx.pool |
|
|
async with pool.acquire() as conn: |
|
|
columns = await conn.fetch(f""" |
|
|
SELECT column_name, data_type |
|
|
FROM information_schema.columns |
|
|
WHERE table_schema = {db_schema} AND table_name = {table_name} |
|
|
ORDER BY ordinal_position |
|
|
""", db_ctx.schema, table_name) |
|
|
|
|
|
if not columns: |
|
|
return [PromptMessage(f"Table '{table_name}' not found in schema '{db_ctx.schema}'.")] |
|
|
|
|
|
columns_text = "\n".join([f"- {col['column_name']} ({col['data_type']})" for col in columns]) |
|
|
|
|
|
return [ |
|
|
PromptMessage( |
|
|
f"""Please help me create analytical queries for the '{table_name}' table. |
|
|
|
|
|
Table Schema: |
|
|
{columns_text} |
|
|
|
|
|
PostgreSQL SQL Best Practices: |
|
|
- Use aggregation functions (COUNT, SUM, AVG, MIN, MAX) appropriately |
|
|
- Group data using GROUP BY for meaningful aggregations |
|
|
- Filter groups with HAVING clauses when needed |
|
|
- Consider using window functions for advanced analytics |
|
|
- Format SQL with proper indentation and line breaks |
|
|
|
|
|
Create a set of analytical queries for this table:""" |
|
|
) |
|
|
] |
|
|
except Exception as e: |
|
|
return [PromptMessage(f"Error generating analytical query: {str(e)}")] |
|
|
|
|
|
|
|
|
@mcp.tool( |
|
|
description="Identifies both explicit and implied foreign key relationships for a given table using schema analysis and naming patterns.") |
|
|
def find_relationships(table_name: str, db_schema: str = 'public') -> str: |
|
|
"""Find both explicit and implied relationships for a table. |
|
|
|
|
|
Args: |
|
|
table_name: The name of the table to analyze relationships for |
|
|
db_schema: The schema name (defaults to 'public') |
|
|
""" |
|
|
try: |
|
|
|
|
|
fk_sql = f""" |
|
|
SELECT |
|
|
kcu.column_name, |
|
|
ccu.table_name as foreign_table, |
|
|
ccu.column_name as foreign_column, |
|
|
'Explicit FK' as relationship_type, |
|
|
1 as confidence_level |
|
|
FROM information_schema.table_constraints tc |
|
|
JOIN information_schema.key_column_usage kcu |
|
|
ON tc.constraint_name = kcu.constraint_name |
|
|
AND tc.table_schema = kcu.table_schema |
|
|
JOIN information_schema.constraint_column_usage ccu |
|
|
ON ccu.constraint_name = tc.constraint_name |
|
|
AND ccu.table_schema = tc.table_schema |
|
|
WHERE tc.constraint_type = 'FOREIGN KEY' |
|
|
AND tc.table_schema = {db_schema} |
|
|
AND tc.table_name = {table_name} |
|
|
""" |
|
|
|
|
|
|
|
|
implied_sql = f""" |
|
|
WITH source_columns AS ( |
|
|
-- Get all ID-like columns from our table |
|
|
SELECT column_name, data_type |
|
|
FROM information_schema.columns |
|
|
WHERE table_schema = {db_schema} |
|
|
AND table_name = {table_name} |
|
|
AND ( |
|
|
column_name LIKE '%%id' |
|
|
OR column_name LIKE '%%_id' |
|
|
OR column_name LIKE '%%_fk' |
|
|
) |
|
|
), |
|
|
potential_references AS ( |
|
|
-- Find tables that might be referenced by our ID columns |
|
|
SELECT DISTINCT |
|
|
sc.column_name as source_column, |
|
|
sc.data_type as source_type, |
|
|
t.table_name as target_table, |
|
|
c.column_name as target_column, |
|
|
c.data_type as target_type, |
|
|
CASE |
|
|
-- Highest confidence: column matches table_id pattern and types match |
|
|
WHEN sc.column_name = t.table_name || '_id' |
|
|
AND sc.data_type = c.data_type THEN 2 |
|
|
-- High confidence: column ends with _id and types match |
|
|
WHEN sc.column_name LIKE '%%_id' |
|
|
AND sc.data_type = c.data_type THEN 3 |
|
|
-- Medium confidence: column contains table name and types match |
|
|
WHEN sc.column_name LIKE '%%' || t.table_name || '%%' |
|
|
AND sc.data_type = c.data_type THEN 4 |
|
|
-- Lower confidence: column ends with id and types match |
|
|
WHEN sc.column_name LIKE '%%id' |
|
|
AND sc.data_type = c.data_type THEN 5 |
|
|
END as confidence_level |
|
|
FROM source_columns sc |
|
|
CROSS JOIN information_schema.tables t |
|
|
JOIN information_schema.columns c |
|
|
ON c.table_schema = t.table_schema |
|
|
AND c.table_name = t.table_name |
|
|
AND (c.column_name = 'id' OR c.column_name = sc.column_name) |
|
|
WHERE t.table_schema = {db_schema} |
|
|
AND t.table_name != {table_name} -- Exclude self-references |
|
|
) |
|
|
SELECT |
|
|
source_column as column_name, |
|
|
target_table as foreign_table, |
|
|
target_column as foreign_column, |
|
|
CASE |
|
|
WHEN confidence_level = 2 THEN 'Strong implied relationship (exact match)' |
|
|
WHEN confidence_level = 3 THEN 'Strong implied relationship (_id pattern)' |
|
|
WHEN confidence_level = 4 THEN 'Likely implied relationship (name match)' |
|
|
ELSE 'Possible implied relationship' |
|
|
END as relationship_type, |
|
|
confidence_level |
|
|
FROM potential_references |
|
|
WHERE confidence_level IS NOT NULL |
|
|
ORDER BY confidence_level, source_column; |
|
|
""" |
|
|
|
|
|
|
|
|
fk_results = execute_query(fk_sql) |
|
|
implied_results = execute_query(implied_sql) |
|
|
|
|
|
|
|
|
if fk_results == "No results found" and implied_results == "No results found": |
|
|
return "No relationships found for this table" |
|
|
|
|
|
|
|
|
return f"Explicit Foreign Keys:\n{fk_results}\n\nImplied Relationships:\n{implied_results}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error finding relationships: {str(e)}" |
|
|
|
|
|
|
|
|
@mcp.tool(description="Visualizes query results using a prompt and JSON data.") |
|
|
async def visualize_results(json_data: dict, vis_prompt: str) -> str: |
|
|
""" |
|
|
Generates a visualization based on query results using PandasAI. |
|
|
|
|
|
Args: |
|
|
json_data (dict): A dictionary containing the query results. |
|
|
It should have two keys: |
|
|
- 'columns': A list of column names (strings). |
|
|
- 'data': A list of lists, where each inner list represents a row of data. |
|
|
Each element in the inner list corresponds to a column in 'columns'. |
|
|
Example: |
|
|
{ |
|
|
'columns': ['Region', 'Product', 'Sales'], |
|
|
'data': [ |
|
|
['North', 'Widget', 150], |
|
|
['South', 'Widget', 200] |
|
|
] |
|
|
} |
|
|
vis_prompt (str): A natural language prompt describing the desired visualization |
|
|
(e.g., "Create a bar chart showing sales by region"). |
|
|
|
|
|
Returns: |
|
|
str: The path to the saved visualization file (e.g., 'visualization_output.png') |
|
|
or an error message if the visualization fails. |
|
|
""" |
|
|
try: |
|
|
|
|
|
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") |
|
|
if OPENAI_API_KEY: |
|
|
pllm = OpenAI(api_token=OPENAI_API_KEY) |
|
|
pai.config.set({"llm": pllm}) |
|
|
|
|
|
|
|
|
df = pd.DataFrame(json_data["data"], columns=json_data["columns"]) |
|
|
|
|
|
|
|
|
for column in df.select_dtypes(include=['object']).columns: |
|
|
df[column] = df[column].apply(lambda x: str(x)[:20] + '...' if len(str(x)) > 20 else str(x)) |
|
|
|
|
|
|
|
|
df_ai = pai.DataFrame(df) |
|
|
|
|
|
api_key = os.environ["PANDAS_KEY"] |
|
|
pai.api_key.set(api_key) |
|
|
|
|
|
|
|
|
df_ai.chat(vis_prompt) |
|
|
|
|
|
|
|
|
PANDAS_EXPORTS_PATH = os.environ["PANDAS_EXPORTS_PATH"] |
|
|
generated_files = [f for f in os.listdir(PANDAS_EXPORTS_PATH) if f.startswith("temp_chart")] |
|
|
if generated_files: |
|
|
visualization_path = os.path.join(PANDAS_EXPORTS_PATH, generated_files[0]) |
|
|
|
|
|
return f"Visualization saved as {visualization_path}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"Visualization error: {str(e)}" |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
mcp.run() |
|
|
|