talk2data / postgre_mcp_server.py
cevheri's picture
ci:add version
6b3e14e
raw
history blame
25.8 kB
import os
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Optional, AsyncIterator
import asyncpg
from mcp.server.fastmcp import FastMCP, Context
from pydantic import Field
import pandasai as pai
import matplotlib as plt
import pandas as pd
import logging
from pandasai_openai import OpenAI
# Constants
DEFAULT_QUERY_LIMIT = 100
# logging info
# logging.basicConfig(level=logging.INFO)
# get logger
logger = logging.getLogger(__name__)
# Define our own PromptMessage class if the MCP one isn't available
@dataclass
class PromptMessage:
content: str
role: Optional[str] = "user"
# Database context class
@dataclass
class DbContext:
pool: asyncpg.Pool
schema: str
# Database connection lifecycle manager
@asynccontextmanager
async def db_lifespan(server: FastMCP) -> AsyncIterator[DbContext]:
"""Manage database connection lifecycle"""
dsn = os.environ["DB_URL"]
db_schema = os.environ["DB_SCHEMA"]
pool = await asyncpg.create_pool(
dsn,
min_size=1,
max_size=4,
max_inactive_connection_lifetime=300,
timeout=60,
command_timeout=300,
)
try:
yield DbContext(pool=pool, schema=db_schema)
finally:
# Clean up
await pool.close()
# Create server with database lifecycle management
mcp = FastMCP(
"SQL Database Server",
dependencies=["asyncpg", "pydantic"],
lifespan=db_lifespan
)
@mcp.resource(
uri="resource://base_prompt",
name="base_prompt",
description="A base prompt to generate SQL queries and answer questions"
)
async def base_prompt_query() -> str:
"""Returns a base prompt to generate sql queries and answer questions"""
base_prompt = """
==========================
# Your Role
==========================
You are an expert in generating and executing SQL queries, interacting with a PostgreSQL database using **FastMCP tools**, and visualizing results when requested. These tools allow you to:
- List available tables
- Retrieve schema details
- Execute SQL queries
- Visualize query results using PandasAI
Each tool may return previews or summaries of table contents to help you understand the data structure.
---
==========================
# Your Objective
==========================
When a user submits a request, you must:
1. **Analyze the request** to determine the required data, action, or visualization.
2. **Use FastMCP tools** to gather necessary information (e.g., list tables, retrieve schema).
3. **Generate a valid SQL SELECT query**, if needed, and clearly show the full query.
4. **Execute the SQL query** using the `execute_query` tool and return the results.
5. **Visualize results** using the `visualize_results` tool if the user explicitly requests a visualization (e.g., "create a chart", "visualize", "plot"). For visualizations:
- Craft a visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
- Send JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`.
6. **Chain tools logically**, such as: List Tables → Get Schema → Write and Run Query → Visualize Results (if requested).
7. **Explain your reasoning and each step taken** to ensure clarity and transparency.
---
==========================
# Critical Rules
==========================
- Only use **read-only** SQL queries such as **SELECT**, **COUNT**, or queries with **GROUP BY**, **ORDER BY**, etc.
- **Never** use destructive operations like **DELETE**, **UPDATE**, **INSERT**, or **DROP**.
- Always show the SQL query you generate along with the execution result.
- Validate SQL syntax before execution.
- Never assume table or column names. Use tools to confirm structure.
- Use memory efficiently. Don't rerun a tool unless necessary.
- If you generate a SQL query, immediately call the **execute_query** tool.
- If the user requests a visualization, call the **visualize_results** tool with:
- A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
- JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`.
- For non-query or non-visualization requests (e.g., history questions), respond appropriately without forcing a query or visualization.
---
==========================
# Database Description
==========================
{descriptions}
---
==========================
# Tools
==========================
You can use the following FastMCP tools to create **read-only** queries (e.g., `SELECT`, `COUNT`, `GROUP BY`, `ORDER BY`) and visualize results when requested. Chain tools logically to gather information, execute queries, or generate visualizations.
{tools}
---
### Invalid Example — DELETE Operation (Not Allowed):
**User Request:** "Delete all customers from Germany."
**Response Guidance:**
- **Do not generate or execute** destructive queries such as `DELETE`.
- Instead, respond with a message like:
> Destructive operations such as `DELETE` are not permitted. I can help you retrieve the customers from Germany using a `SELECT` query instead:
> ```sql
> SELECT * FROM customers WHERE country = 'Germany';
> ```
==========================
# Output Format
==========================
Present your final answer using the following structure in markdown language:
# Result
{{Take the result from the execute_query tool and format it nicely using Markdown. Use a beautiful Markdown table for tabular data (rows and columns) including headers and show such simple results using a table. Use bullet points or items in markdown for answers that include lists of names or descriptions. Use plain text for single values or simple messages. Ensure data alignment and clarity.}}
# Explanation
{{Provide a concise explanation or interpretation of the results (and visualization, if applicable) in 1-3 sentences. Explain what the data and visualization (if any) represent in the context of the user's request.}}
# Query
```sql
{{Display the exact SQL query you generated and executed here to answer the user's request.}}
```
==========================
# Reminder
==========================
- **Every time you generate a SQL query, call `execute_query` immediately and include the result.**
- **If the user requests a visualization (e.g., "create a chart", "visualize", "plot"), call `visualize_results` with:**
- A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
- JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`.
- **For non-query or non-visualization requests, respond appropriately without forcing a query or visualization.**
**Conversation History:**
Use the conversation history for context when available to maintain continuity.
"""
return base_prompt
@mcp.tool(description="tests the database connection and returns the PostgreSQL version or an error message.")
async def test_connection(ctx: Context) -> str:
"""Test database connection"""
try:
pool = ctx.request_context.lifespan_context.pool
async with pool.acquire() as conn:
version = await conn.fetchval("SELECT version();")
return f"Connection successful. PostgreSQL version: {version}"
except Exception as e:
return f"Connection failed: {str(e)}"
@mcp.tool(description="Executes a read-only SQL SELECT query and returns formatted results or an error message.")
async def execute_query(
query: str = Field(description="SQL query to execute (SELECT only)"),
limit: Optional[int] = Field(default=DEFAULT_QUERY_LIMIT, description="Maximum number of rows to return"),
ctx: Context = None
) -> str:
"""Execute a read-only SQL query against the database"""
# Validate query - simple check for read-only
query = query.strip()
if not query.lower().startswith("select"):
return "Error: Only SELECT queries are allowed for security reasons."
try:
pool = ctx.request_context.lifespan_context.pool
async with pool.acquire() as conn:
result = await conn.fetch(query)
if not result:
return "Query executed successfully. No rows returned."
# Format results
columns = [k for k in result[0].keys()]
header = " | ".join(columns)
separator = "-" * len(header)
# Format rows with limit
rows = [" | ".join(str(val) for val in row.values())
for row in result[:limit if limit else DEFAULT_QUERY_LIMIT]]
# print(f"{header}\n{separator}\n" + "\n".join(rows))
# print(f"===== Header Type: ======\n {type(header)}")
# print(f"===== Row Type: ======\n {type(rows)}")
#
# # print the data itself
# print(f"===== Header Data: ======\n {header}")
# print(f"===== Row Data: ======\n {rows}")
return f"{header}\n{separator}\n" + "\n".join(rows)
except asyncpg.exceptions.PostgresError as e:
return f"SQL Error: {str(e)}"
except Exception as e:
return f"Error: {str(e)}"
# Database helper functions
async def get_all_tables(pool, db_schema):
"""Get all tables from the database"""
print(f"schema: {db_schema}")
async with pool.acquire() as conn:
result = await conn.fetch("""
SELECT c.relname AS table_name
FROM pg_class AS c
JOIN pg_namespace AS n ON n.oid = c.relnamespace
WHERE NOT EXISTS (
SELECT 1
FROM pg_inherits AS i
WHERE i.inhrelid = c.oid
)
AND c.relkind IN ('r', 'p')
AND n.nspname = $1
AND c.relname NOT LIKE 'pg_%'
ORDER BY c.relname;
""", db_schema)
return result
async def get_table_schema_info(pool, db_schema, table_name):
"""Get schema information for a specific table"""
async with pool.acquire() as conn:
columns = await conn.fetch("""
SELECT
column_name,
data_type,
is_nullable,
column_default,
character_maximum_length
FROM information_schema.columns
WHERE table_schema = $1
AND table_name = $2
ORDER BY ordinal_position;
""", db_schema, table_name)
return columns
def format_table_schema(table_name, columns):
"""Format table schema into readable text"""
if not columns:
return f"Table '{table_name}' not found."
result = [f"Table: {table_name}", "Columns:"]
for col in columns:
nullable = "NULL" if col['is_nullable'] == 'YES' else "NOT NULL"
length = f"({col['character_maximum_length']})" if col['character_maximum_length'] else ""
default = f" DEFAULT {col['column_default']}" if col['column_default'] else ""
result.append(f"- {col['column_name']} ({col['data_type']}{length}) {nullable}{default}")
return "\n".join(result)
@mcp.tool(description="Lists all table names in the configured database schema.")
async def list_tables() -> str:
"""List all tables in the database"""
try:
async with db_lifespan(mcp) as db_ctx:
result = await get_all_tables(db_ctx.pool, db_ctx.schema)
if not result:
return f"No tables found in the {db_ctx.schema} schema."
return "\n".join(row['table_name'] for row in result)
except asyncpg.exceptions.PostgresError as e:
return f"SQL Error: {str(e)}"
except Exception as e:
return f"Error: {str(e)}"
@mcp.tool(description="Retrieves and formats the schema details of a specific table in the database.")
async def get_table_schema(table_name: str) -> str:
"""Get schema information for a specific table"""
try:
db_schema = os.environ["DB_SCHEMA"]
async with db_lifespan(mcp) as db_ctx:
columns = await get_table_schema_info(db_ctx.pool, db_schema, table_name)
if not columns:
return f"Table '{table_name}' not found in {db_schema} schema."
return format_table_schema(table_name, columns)
except asyncpg.exceptions.PostgresError as e:
return f"SQL Error: {str(e)}"
except Exception as e:
return f"Error: {str(e)}"
@mcp.tool(description="Retrieves foreign key relationships for a specified table.")
def get_foreign_keys(table_name: str) -> str:
"""Get foreign key information for a table.
Args:
table_name: The name of the table to get foreign keys from
schema: The schema name (defaults to 'public')
"""
db_schema = os.environ["DB_SCHEMA"]
sql = """
SELECT
tc.constraint_name,
kcu.column_name as fk_column,
ccu.table_schema as referenced_schema,
ccu.table_name as referenced_table,
ccu.column_name as referenced_column
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.referential_constraints rc
ON tc.constraint_name = rc.constraint_name
JOIN information_schema.constraint_column_usage ccu
ON rc.unique_constraint_name = ccu.constraint_name
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = {db_schema}
AND tc.table_name = {table_name}
ORDER BY tc.constraint_name, kcu.ordinal_position
"""
return execute_query(sql, (db_schema, table_name))
@mcp.tool(description="Fetches and formats the schema details for all tables in the configured database schema.")
async def get_all_schemas() -> str:
"""Get schema information for all tables in the database"""
try:
db_schema = os.environ["DB_SCHEMA"]
async with db_lifespan(mcp) as db_ctx:
tables = await get_all_tables(db_ctx.pool, db_ctx.schema)
if not tables:
return f"No tables found in the {db_ctx.schema} schema."
all_schemas = []
for table in tables:
table_name = table['table_name']
columns = await get_table_schema_info(db_ctx.pool, db_schema, table_name)
table_schema = format_table_schema(table_name, columns)
all_schemas.append(table_schema)
all_schemas.append("") # Add empty line between tables
return "\n".join(all_schemas)
except asyncpg.exceptions.PostgresError as e:
return f"SQL Error: {str(e)}"
except Exception as e:
return f"Error: {str(e)}"
@mcp.prompt(description="Generates a prompt message to help craft a best-practice SELECT query for a given table.")
async def generate_select_query(table_name: str) -> list[PromptMessage]:
"""Generate a SELECT query with best practices for a table"""
try:
async with db_lifespan(mcp) as db_ctx:
pool = db_ctx.pool
async with pool.acquire() as conn:
columns = await conn.fetch("""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2
ORDER BY ordinal_position
""", db_ctx.schema, table_name)
if not columns:
return [PromptMessage(f"Table '{table_name}' not found in schema '{db_ctx.schema}'.")]
columns_text = "\n".join([f"- {col['column_name']} ({col['data_type']})" for col in columns])
return [
PromptMessage(
f"""Please help me write a well-structured, efficient SELECT query for the '{table_name}' table.
Table Schema:
{columns_text}
PostgreSQL SQL Best Practices:
- Use explicit column names instead of * when possible
- Include LIMIT clauses to restrict result sets
- Consider adding WHERE clauses to filter results
- Use appropriate indexing considerations
- Format SQL with proper indentation and line breaks
Create a basic SELECT query following these best practices:"""
)
]
except Exception as e:
return [PromptMessage(f"Error generating select query: {str(e)}")]
@mcp.prompt(description="Generates a prompt message to assist in writing analytical queries for a given table.")
async def generate_analytical_query(table_name: str) -> list[PromptMessage]:
""" Generate analytical queries for a table
Args:
table_name: The name of the table to generate analytical queries for
"""
db_schema = os.environ["DB_SCHEMA"]
try:
async with db_lifespan(mcp) as db_ctx:
pool = db_ctx.pool
async with pool.acquire() as conn:
columns = await conn.fetch(f"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = {db_schema} AND table_name = {table_name}
ORDER BY ordinal_position
""", db_ctx.schema, table_name)
if not columns:
return [PromptMessage(f"Table '{table_name}' not found in schema '{db_ctx.schema}'.")]
columns_text = "\n".join([f"- {col['column_name']} ({col['data_type']})" for col in columns])
return [
PromptMessage(
f"""Please help me create analytical queries for the '{table_name}' table.
Table Schema:
{columns_text}
PostgreSQL SQL Best Practices:
- Use aggregation functions (COUNT, SUM, AVG, MIN, MAX) appropriately
- Group data using GROUP BY for meaningful aggregations
- Filter groups with HAVING clauses when needed
- Consider using window functions for advanced analytics
- Format SQL with proper indentation and line breaks
Create a set of analytical queries for this table:"""
)
]
except Exception as e:
return [PromptMessage(f"Error generating analytical query: {str(e)}")]
@mcp.tool(
description="Identifies both explicit and implied foreign key relationships for a given table using schema analysis and naming patterns.")
def find_relationships(table_name: str, db_schema: str = 'public') -> str:
"""Find both explicit and implied relationships for a table.
Args:
table_name: The name of the table to analyze relationships for
db_schema: The schema name (defaults to 'public')
"""
try:
# First get explicit foreign key relationships
fk_sql = f"""
SELECT
kcu.column_name,
ccu.table_name as foreign_table,
ccu.column_name as foreign_column,
'Explicit FK' as relationship_type,
1 as confidence_level
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = {db_schema}
AND tc.table_name = {table_name}
"""
# Then look for implied relationships based on common patterns
implied_sql = f"""
WITH source_columns AS (
-- Get all ID-like columns from our table
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = {db_schema}
AND table_name = {table_name}
AND (
column_name LIKE '%%id'
OR column_name LIKE '%%_id'
OR column_name LIKE '%%_fk'
)
),
potential_references AS (
-- Find tables that might be referenced by our ID columns
SELECT DISTINCT
sc.column_name as source_column,
sc.data_type as source_type,
t.table_name as target_table,
c.column_name as target_column,
c.data_type as target_type,
CASE
-- Highest confidence: column matches table_id pattern and types match
WHEN sc.column_name = t.table_name || '_id'
AND sc.data_type = c.data_type THEN 2
-- High confidence: column ends with _id and types match
WHEN sc.column_name LIKE '%%_id'
AND sc.data_type = c.data_type THEN 3
-- Medium confidence: column contains table name and types match
WHEN sc.column_name LIKE '%%' || t.table_name || '%%'
AND sc.data_type = c.data_type THEN 4
-- Lower confidence: column ends with id and types match
WHEN sc.column_name LIKE '%%id'
AND sc.data_type = c.data_type THEN 5
END as confidence_level
FROM source_columns sc
CROSS JOIN information_schema.tables t
JOIN information_schema.columns c
ON c.table_schema = t.table_schema
AND c.table_name = t.table_name
AND (c.column_name = 'id' OR c.column_name = sc.column_name)
WHERE t.table_schema = {db_schema}
AND t.table_name != {table_name} -- Exclude self-references
)
SELECT
source_column as column_name,
target_table as foreign_table,
target_column as foreign_column,
CASE
WHEN confidence_level = 2 THEN 'Strong implied relationship (exact match)'
WHEN confidence_level = 3 THEN 'Strong implied relationship (_id pattern)'
WHEN confidence_level = 4 THEN 'Likely implied relationship (name match)'
ELSE 'Possible implied relationship'
END as relationship_type,
confidence_level
FROM potential_references
WHERE confidence_level IS NOT NULL
ORDER BY confidence_level, source_column;
"""
# Execute both queries and combine results
fk_results = execute_query(fk_sql)
implied_results = execute_query(implied_sql)
# If both queries returned "No results found", return that
if fk_results == "No results found" and implied_results == "No results found":
return "No relationships found for this table"
# Otherwise, return both sets of results
return f"Explicit Foreign Keys:\n{fk_results}\n\nImplied Relationships:\n{implied_results}"
except Exception as e:
return f"Error finding relationships: {str(e)}"
@mcp.tool(description="Visualizes query results using a prompt and JSON data.")
async def visualize_results(json_data: dict, vis_prompt: str) -> str:
"""
Generates a visualization based on query results using PandasAI.
Args:
json_data (dict): A dictionary containing the query results.
It should have two keys:
- 'columns': A list of column names (strings).
- 'data': A list of lists, where each inner list represents a row of data.
Each element in the inner list corresponds to a column in 'columns'.
Example:
{
'columns': ['Region', 'Product', 'Sales'],
'data': [
['North', 'Widget', 150],
['South', 'Widget', 200]
]
}
vis_prompt (str): A natural language prompt describing the desired visualization
(e.g., "Create a bar chart showing sales by region").
Returns:
str: The path to the saved visualization file (e.g., 'visualization_output.png')
or an error message if the visualization fails.
"""
try:
# Debug prints to see what's being received
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
if OPENAI_API_KEY:
pllm = OpenAI(api_token=OPENAI_API_KEY)
pai.config.set({"llm": pllm})
# Convert JSON to DataFrame
df = pd.DataFrame(json_data["data"], columns=json_data["columns"])
# Shorten long values in text columns
for column in df.select_dtypes(include=['object']).columns:
df[column] = df[column].apply(lambda x: str(x)[:20] + '...' if len(str(x)) > 20 else str(x))
# Initialize PandasAI
df_ai = pai.DataFrame(df)
api_key = os.environ["PANDAS_KEY"]
pai.api_key.set(api_key)
# Generate visualization
df_ai.chat(vis_prompt)
# Get the visualization path
PANDAS_EXPORTS_PATH = os.environ["PANDAS_EXPORTS_PATH"]
generated_files = [f for f in os.listdir(PANDAS_EXPORTS_PATH) if f.startswith("temp_chart")]
if generated_files:
visualization_path = os.path.join(PANDAS_EXPORTS_PATH, generated_files[0])
return f"Visualization saved as {visualization_path}"
except Exception as e:
return f"Visualization error: {str(e)}"
if __name__ == "__main__":
mcp.run()