File size: 25,806 Bytes
80cfb11 6a39607 80cfb11 a2134eb 09e2bc4 6a39607 80cfb11 feac34e 80cfb11 09e2bc4 0391cfb 80cfb11 feac34e 80cfb11 ef0ee7c 6b3e14e e8d2e8a aba6b7f 80cfb11 6b3e14e 80cfb11 a3f399c 0e5cf1e a3f399c 0e5cf1e a3f399c 19f43ab 6c01c87 19f43ab 09e2bc4 b112622 6c01c87 09e2bc4 a3f399c 09e2bc4 a3f399c 6c01c87 0e5cf1e 6c01c87 0e5cf1e 6c01c87 09e2bc4 6c01c87 09e2bc4 0e5cf1e 6c01c87 e087162 09e2bc4 0e5cf1e b112622 6c01c87 b112622 6c01c87 b112622 6c01c87 b112622 6c01c87 b112622 09e2bc4 b112622 6c01c87 b112622 6c01c87 b112622 6c01c87 b112622 6c01c87 b112622 0e5cf1e 6c01c87 0e5cf1e 0391cfb 0e5cf1e 6c01c87 0391cfb 0e5cf1e 6c01c87 09e2bc4 0e5cf1e 6c01c87 a0d91e2 8eaf76a 09e2bc4 6c01c87 09e2bc4 6c01c87 0e5cf1e 80cfb11 09e2bc4 80cfb11 09e2bc4 80cfb11 6b3e14e 80cfb11 6b3e14e 80cfb11 6b3e14e 80cfb11 6b3e14e 80cfb11 6b3e14e 80cfb11 6b3e14e 80cfb11 6b3e14e 80cfb11 6b3e14e 80cfb11 6b3e14e 80cfb11 6b3e14e 6e75e8f 80cfb11 6e75e8f 6b3e14e 80cfb11 440696e 80cfb11 6b3e14e 80cfb11 6b3e14e 80cfb11 440696e 80cfb11 440696e 80cfb11 440696e 80cfb11 6e75e8f 6a39607 6b3e14e 80cfb11 6e75e8f 80cfb11 6b3e14e 80cfb11 440696e 80cfb11 440696e 80cfb11 440696e 80cfb11 440696e e087162 80cfb11 e087162 80cfb11 6e75e8f 80cfb11 e087162 6e75e8f 80cfb11 6e75e8f 80cfb11 e087162 6e75e8f 80cfb11 e087162 6e75e8f 80cfb11 a2134eb e087162 a2134eb a2d0bfa a2134eb ef0ee7c a2134eb 8eaf76a ef0ee7c 8eaf76a a2134eb 80cfb11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 |
import os
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Optional, AsyncIterator
import asyncpg
from mcp.server.fastmcp import FastMCP, Context
from pydantic import Field
import pandasai as pai
import matplotlib as plt
import pandas as pd
import logging
from pandasai_openai import OpenAI
# Constants
DEFAULT_QUERY_LIMIT = 100
# logging info
# logging.basicConfig(level=logging.INFO)
# get logger
logger = logging.getLogger(__name__)
# Define our own PromptMessage class if the MCP one isn't available
@dataclass
class PromptMessage:
content: str
role: Optional[str] = "user"
# Database context class
@dataclass
class DbContext:
pool: asyncpg.Pool
schema: str
# Database connection lifecycle manager
@asynccontextmanager
async def db_lifespan(server: FastMCP) -> AsyncIterator[DbContext]:
"""Manage database connection lifecycle"""
dsn = os.environ["DB_URL"]
db_schema = os.environ["DB_SCHEMA"]
pool = await asyncpg.create_pool(
dsn,
min_size=1,
max_size=4,
max_inactive_connection_lifetime=300,
timeout=60,
command_timeout=300,
)
try:
yield DbContext(pool=pool, schema=db_schema)
finally:
# Clean up
await pool.close()
# Create server with database lifecycle management
mcp = FastMCP(
"SQL Database Server",
dependencies=["asyncpg", "pydantic"],
lifespan=db_lifespan
)
@mcp.resource(
uri="resource://base_prompt",
name="base_prompt",
description="A base prompt to generate SQL queries and answer questions"
)
async def base_prompt_query() -> str:
"""Returns a base prompt to generate sql queries and answer questions"""
base_prompt = """
==========================
# Your Role
==========================
You are an expert in generating and executing SQL queries, interacting with a PostgreSQL database using **FastMCP tools**, and visualizing results when requested. These tools allow you to:
- List available tables
- Retrieve schema details
- Execute SQL queries
- Visualize query results using PandasAI
Each tool may return previews or summaries of table contents to help you understand the data structure.
---
==========================
# Your Objective
==========================
When a user submits a request, you must:
1. **Analyze the request** to determine the required data, action, or visualization.
2. **Use FastMCP tools** to gather necessary information (e.g., list tables, retrieve schema).
3. **Generate a valid SQL SELECT query**, if needed, and clearly show the full query.
4. **Execute the SQL query** using the `execute_query` tool and return the results.
5. **Visualize results** using the `visualize_results` tool if the user explicitly requests a visualization (e.g., "create a chart", "visualize", "plot"). For visualizations:
- Craft a visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
- Send JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`.
6. **Chain tools logically**, such as: List Tables → Get Schema → Write and Run Query → Visualize Results (if requested).
7. **Explain your reasoning and each step taken** to ensure clarity and transparency.
---
==========================
# Critical Rules
==========================
- Only use **read-only** SQL queries such as **SELECT**, **COUNT**, or queries with **GROUP BY**, **ORDER BY**, etc.
- **Never** use destructive operations like **DELETE**, **UPDATE**, **INSERT**, or **DROP**.
- Always show the SQL query you generate along with the execution result.
- Validate SQL syntax before execution.
- Never assume table or column names. Use tools to confirm structure.
- Use memory efficiently. Don't rerun a tool unless necessary.
- If you generate a SQL query, immediately call the **execute_query** tool.
- If the user requests a visualization, call the **visualize_results** tool with:
- A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
- JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`.
- For non-query or non-visualization requests (e.g., history questions), respond appropriately without forcing a query or visualization.
---
==========================
# Database Description
==========================
{descriptions}
---
==========================
# Tools
==========================
You can use the following FastMCP tools to create **read-only** queries (e.g., `SELECT`, `COUNT`, `GROUP BY`, `ORDER BY`) and visualize results when requested. Chain tools logically to gather information, execute queries, or generate visualizations.
{tools}
---
### Invalid Example — DELETE Operation (Not Allowed):
**User Request:** "Delete all customers from Germany."
**Response Guidance:**
- **Do not generate or execute** destructive queries such as `DELETE`.
- Instead, respond with a message like:
> Destructive operations such as `DELETE` are not permitted. I can help you retrieve the customers from Germany using a `SELECT` query instead:
> ```sql
> SELECT * FROM customers WHERE country = 'Germany';
> ```
==========================
# Output Format
==========================
Present your final answer using the following structure in markdown language:
# Result
{{Take the result from the execute_query tool and format it nicely using Markdown. Use a beautiful Markdown table for tabular data (rows and columns) including headers and show such simple results using a table. Use bullet points or items in markdown for answers that include lists of names or descriptions. Use plain text for single values or simple messages. Ensure data alignment and clarity.}}
# Explanation
{{Provide a concise explanation or interpretation of the results (and visualization, if applicable) in 1-3 sentences. Explain what the data and visualization (if any) represent in the context of the user's request.}}
# Query
```sql
{{Display the exact SQL query you generated and executed here to answer the user's request.}}
```
==========================
# Reminder
==========================
- **Every time you generate a SQL query, call `execute_query` immediately and include the result.**
- **If the user requests a visualization (e.g., "create a chart", "visualize", "plot"), call `visualize_results` with:**
- A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
- JSON data in the format `{{'columns': [list of column names], 'data': [list of rows, each row a list of values]}}`.
- **For non-query or non-visualization requests, respond appropriately without forcing a query or visualization.**
**Conversation History:**
Use the conversation history for context when available to maintain continuity.
"""
return base_prompt
@mcp.tool(description="tests the database connection and returns the PostgreSQL version or an error message.")
async def test_connection(ctx: Context) -> str:
"""Test database connection"""
try:
pool = ctx.request_context.lifespan_context.pool
async with pool.acquire() as conn:
version = await conn.fetchval("SELECT version();")
return f"Connection successful. PostgreSQL version: {version}"
except Exception as e:
return f"Connection failed: {str(e)}"
@mcp.tool(description="Executes a read-only SQL SELECT query and returns formatted results or an error message.")
async def execute_query(
query: str = Field(description="SQL query to execute (SELECT only)"),
limit: Optional[int] = Field(default=DEFAULT_QUERY_LIMIT, description="Maximum number of rows to return"),
ctx: Context = None
) -> str:
"""Execute a read-only SQL query against the database"""
# Validate query - simple check for read-only
query = query.strip()
if not query.lower().startswith("select"):
return "Error: Only SELECT queries are allowed for security reasons."
try:
pool = ctx.request_context.lifespan_context.pool
async with pool.acquire() as conn:
result = await conn.fetch(query)
if not result:
return "Query executed successfully. No rows returned."
# Format results
columns = [k for k in result[0].keys()]
header = " | ".join(columns)
separator = "-" * len(header)
# Format rows with limit
rows = [" | ".join(str(val) for val in row.values())
for row in result[:limit if limit else DEFAULT_QUERY_LIMIT]]
# print(f"{header}\n{separator}\n" + "\n".join(rows))
# print(f"===== Header Type: ======\n {type(header)}")
# print(f"===== Row Type: ======\n {type(rows)}")
#
# # print the data itself
# print(f"===== Header Data: ======\n {header}")
# print(f"===== Row Data: ======\n {rows}")
return f"{header}\n{separator}\n" + "\n".join(rows)
except asyncpg.exceptions.PostgresError as e:
return f"SQL Error: {str(e)}"
except Exception as e:
return f"Error: {str(e)}"
# Database helper functions
async def get_all_tables(pool, db_schema):
"""Get all tables from the database"""
print(f"schema: {db_schema}")
async with pool.acquire() as conn:
result = await conn.fetch("""
SELECT c.relname AS table_name
FROM pg_class AS c
JOIN pg_namespace AS n ON n.oid = c.relnamespace
WHERE NOT EXISTS (
SELECT 1
FROM pg_inherits AS i
WHERE i.inhrelid = c.oid
)
AND c.relkind IN ('r', 'p')
AND n.nspname = $1
AND c.relname NOT LIKE 'pg_%'
ORDER BY c.relname;
""", db_schema)
return result
async def get_table_schema_info(pool, db_schema, table_name):
"""Get schema information for a specific table"""
async with pool.acquire() as conn:
columns = await conn.fetch("""
SELECT
column_name,
data_type,
is_nullable,
column_default,
character_maximum_length
FROM information_schema.columns
WHERE table_schema = $1
AND table_name = $2
ORDER BY ordinal_position;
""", db_schema, table_name)
return columns
def format_table_schema(table_name, columns):
"""Format table schema into readable text"""
if not columns:
return f"Table '{table_name}' not found."
result = [f"Table: {table_name}", "Columns:"]
for col in columns:
nullable = "NULL" if col['is_nullable'] == 'YES' else "NOT NULL"
length = f"({col['character_maximum_length']})" if col['character_maximum_length'] else ""
default = f" DEFAULT {col['column_default']}" if col['column_default'] else ""
result.append(f"- {col['column_name']} ({col['data_type']}{length}) {nullable}{default}")
return "\n".join(result)
@mcp.tool(description="Lists all table names in the configured database schema.")
async def list_tables() -> str:
"""List all tables in the database"""
try:
async with db_lifespan(mcp) as db_ctx:
result = await get_all_tables(db_ctx.pool, db_ctx.schema)
if not result:
return f"No tables found in the {db_ctx.schema} schema."
return "\n".join(row['table_name'] for row in result)
except asyncpg.exceptions.PostgresError as e:
return f"SQL Error: {str(e)}"
except Exception as e:
return f"Error: {str(e)}"
@mcp.tool(description="Retrieves and formats the schema details of a specific table in the database.")
async def get_table_schema(table_name: str) -> str:
"""Get schema information for a specific table"""
try:
db_schema = os.environ["DB_SCHEMA"]
async with db_lifespan(mcp) as db_ctx:
columns = await get_table_schema_info(db_ctx.pool, db_schema, table_name)
if not columns:
return f"Table '{table_name}' not found in {db_schema} schema."
return format_table_schema(table_name, columns)
except asyncpg.exceptions.PostgresError as e:
return f"SQL Error: {str(e)}"
except Exception as e:
return f"Error: {str(e)}"
@mcp.tool(description="Retrieves foreign key relationships for a specified table.")
def get_foreign_keys(table_name: str) -> str:
"""Get foreign key information for a table.
Args:
table_name: The name of the table to get foreign keys from
schema: The schema name (defaults to 'public')
"""
db_schema = os.environ["DB_SCHEMA"]
sql = """
SELECT
tc.constraint_name,
kcu.column_name as fk_column,
ccu.table_schema as referenced_schema,
ccu.table_name as referenced_table,
ccu.column_name as referenced_column
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.referential_constraints rc
ON tc.constraint_name = rc.constraint_name
JOIN information_schema.constraint_column_usage ccu
ON rc.unique_constraint_name = ccu.constraint_name
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = {db_schema}
AND tc.table_name = {table_name}
ORDER BY tc.constraint_name, kcu.ordinal_position
"""
return execute_query(sql, (db_schema, table_name))
@mcp.tool(description="Fetches and formats the schema details for all tables in the configured database schema.")
async def get_all_schemas() -> str:
"""Get schema information for all tables in the database"""
try:
db_schema = os.environ["DB_SCHEMA"]
async with db_lifespan(mcp) as db_ctx:
tables = await get_all_tables(db_ctx.pool, db_ctx.schema)
if not tables:
return f"No tables found in the {db_ctx.schema} schema."
all_schemas = []
for table in tables:
table_name = table['table_name']
columns = await get_table_schema_info(db_ctx.pool, db_schema, table_name)
table_schema = format_table_schema(table_name, columns)
all_schemas.append(table_schema)
all_schemas.append("") # Add empty line between tables
return "\n".join(all_schemas)
except asyncpg.exceptions.PostgresError as e:
return f"SQL Error: {str(e)}"
except Exception as e:
return f"Error: {str(e)}"
@mcp.prompt(description="Generates a prompt message to help craft a best-practice SELECT query for a given table.")
async def generate_select_query(table_name: str) -> list[PromptMessage]:
"""Generate a SELECT query with best practices for a table"""
try:
async with db_lifespan(mcp) as db_ctx:
pool = db_ctx.pool
async with pool.acquire() as conn:
columns = await conn.fetch("""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2
ORDER BY ordinal_position
""", db_ctx.schema, table_name)
if not columns:
return [PromptMessage(f"Table '{table_name}' not found in schema '{db_ctx.schema}'.")]
columns_text = "\n".join([f"- {col['column_name']} ({col['data_type']})" for col in columns])
return [
PromptMessage(
f"""Please help me write a well-structured, efficient SELECT query for the '{table_name}' table.
Table Schema:
{columns_text}
PostgreSQL SQL Best Practices:
- Use explicit column names instead of * when possible
- Include LIMIT clauses to restrict result sets
- Consider adding WHERE clauses to filter results
- Use appropriate indexing considerations
- Format SQL with proper indentation and line breaks
Create a basic SELECT query following these best practices:"""
)
]
except Exception as e:
return [PromptMessage(f"Error generating select query: {str(e)}")]
@mcp.prompt(description="Generates a prompt message to assist in writing analytical queries for a given table.")
async def generate_analytical_query(table_name: str) -> list[PromptMessage]:
""" Generate analytical queries for a table
Args:
table_name: The name of the table to generate analytical queries for
"""
db_schema = os.environ["DB_SCHEMA"]
try:
async with db_lifespan(mcp) as db_ctx:
pool = db_ctx.pool
async with pool.acquire() as conn:
columns = await conn.fetch(f"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = {db_schema} AND table_name = {table_name}
ORDER BY ordinal_position
""", db_ctx.schema, table_name)
if not columns:
return [PromptMessage(f"Table '{table_name}' not found in schema '{db_ctx.schema}'.")]
columns_text = "\n".join([f"- {col['column_name']} ({col['data_type']})" for col in columns])
return [
PromptMessage(
f"""Please help me create analytical queries for the '{table_name}' table.
Table Schema:
{columns_text}
PostgreSQL SQL Best Practices:
- Use aggregation functions (COUNT, SUM, AVG, MIN, MAX) appropriately
- Group data using GROUP BY for meaningful aggregations
- Filter groups with HAVING clauses when needed
- Consider using window functions for advanced analytics
- Format SQL with proper indentation and line breaks
Create a set of analytical queries for this table:"""
)
]
except Exception as e:
return [PromptMessage(f"Error generating analytical query: {str(e)}")]
@mcp.tool(
description="Identifies both explicit and implied foreign key relationships for a given table using schema analysis and naming patterns.")
def find_relationships(table_name: str, db_schema: str = 'public') -> str:
"""Find both explicit and implied relationships for a table.
Args:
table_name: The name of the table to analyze relationships for
db_schema: The schema name (defaults to 'public')
"""
try:
# First get explicit foreign key relationships
fk_sql = f"""
SELECT
kcu.column_name,
ccu.table_name as foreign_table,
ccu.column_name as foreign_column,
'Explicit FK' as relationship_type,
1 as confidence_level
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = {db_schema}
AND tc.table_name = {table_name}
"""
# Then look for implied relationships based on common patterns
implied_sql = f"""
WITH source_columns AS (
-- Get all ID-like columns from our table
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = {db_schema}
AND table_name = {table_name}
AND (
column_name LIKE '%%id'
OR column_name LIKE '%%_id'
OR column_name LIKE '%%_fk'
)
),
potential_references AS (
-- Find tables that might be referenced by our ID columns
SELECT DISTINCT
sc.column_name as source_column,
sc.data_type as source_type,
t.table_name as target_table,
c.column_name as target_column,
c.data_type as target_type,
CASE
-- Highest confidence: column matches table_id pattern and types match
WHEN sc.column_name = t.table_name || '_id'
AND sc.data_type = c.data_type THEN 2
-- High confidence: column ends with _id and types match
WHEN sc.column_name LIKE '%%_id'
AND sc.data_type = c.data_type THEN 3
-- Medium confidence: column contains table name and types match
WHEN sc.column_name LIKE '%%' || t.table_name || '%%'
AND sc.data_type = c.data_type THEN 4
-- Lower confidence: column ends with id and types match
WHEN sc.column_name LIKE '%%id'
AND sc.data_type = c.data_type THEN 5
END as confidence_level
FROM source_columns sc
CROSS JOIN information_schema.tables t
JOIN information_schema.columns c
ON c.table_schema = t.table_schema
AND c.table_name = t.table_name
AND (c.column_name = 'id' OR c.column_name = sc.column_name)
WHERE t.table_schema = {db_schema}
AND t.table_name != {table_name} -- Exclude self-references
)
SELECT
source_column as column_name,
target_table as foreign_table,
target_column as foreign_column,
CASE
WHEN confidence_level = 2 THEN 'Strong implied relationship (exact match)'
WHEN confidence_level = 3 THEN 'Strong implied relationship (_id pattern)'
WHEN confidence_level = 4 THEN 'Likely implied relationship (name match)'
ELSE 'Possible implied relationship'
END as relationship_type,
confidence_level
FROM potential_references
WHERE confidence_level IS NOT NULL
ORDER BY confidence_level, source_column;
"""
# Execute both queries and combine results
fk_results = execute_query(fk_sql)
implied_results = execute_query(implied_sql)
# If both queries returned "No results found", return that
if fk_results == "No results found" and implied_results == "No results found":
return "No relationships found for this table"
# Otherwise, return both sets of results
return f"Explicit Foreign Keys:\n{fk_results}\n\nImplied Relationships:\n{implied_results}"
except Exception as e:
return f"Error finding relationships: {str(e)}"
@mcp.tool(description="Visualizes query results using a prompt and JSON data.")
async def visualize_results(json_data: dict, vis_prompt: str) -> str:
"""
Generates a visualization based on query results using PandasAI.
Args:
json_data (dict): A dictionary containing the query results.
It should have two keys:
- 'columns': A list of column names (strings).
- 'data': A list of lists, where each inner list represents a row of data.
Each element in the inner list corresponds to a column in 'columns'.
Example:
{
'columns': ['Region', 'Product', 'Sales'],
'data': [
['North', 'Widget', 150],
['South', 'Widget', 200]
]
}
vis_prompt (str): A natural language prompt describing the desired visualization
(e.g., "Create a bar chart showing sales by region").
Returns:
str: The path to the saved visualization file (e.g., 'visualization_output.png')
or an error message if the visualization fails.
"""
try:
# Debug prints to see what's being received
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
if OPENAI_API_KEY:
pllm = OpenAI(api_token=OPENAI_API_KEY)
pai.config.set({"llm": pllm})
# Convert JSON to DataFrame
df = pd.DataFrame(json_data["data"], columns=json_data["columns"])
# Shorten long values in text columns
for column in df.select_dtypes(include=['object']).columns:
df[column] = df[column].apply(lambda x: str(x)[:20] + '...' if len(str(x)) > 20 else str(x))
# Initialize PandasAI
df_ai = pai.DataFrame(df)
api_key = os.environ["PANDAS_KEY"]
pai.api_key.set(api_key)
# Generate visualization
df_ai.chat(vis_prompt)
# Get the visualization path
PANDAS_EXPORTS_PATH = os.environ["PANDAS_EXPORTS_PATH"]
generated_files = [f for f in os.listdir(PANDAS_EXPORTS_PATH) if f.startswith("temp_chart")]
if generated_files:
visualization_path = os.path.join(PANDAS_EXPORTS_PATH, generated_files[0])
return f"Visualization saved as {visualization_path}"
except Exception as e:
return f"Visualization error: {str(e)}"
if __name__ == "__main__":
mcp.run()
|