Commit
·
8eaf76a
1
Parent(s):
3b91e5c
feat!: add visualizations to UI
Browse files- .env.sample +1 -0
- exports/charts/temp_chart_5252d436-6dfe-46e3-a348-49308898940d.png +0 -0
- exports/charts/temp_chart_70308ee4-1919-43a6-b30a-c7024fb0440e.png +0 -0
- exports/charts/temp_chart_730e3451-365a-4a3f-9c4b-d95f991c4bbd.png +0 -0
- exports/charts/temp_chart_7ee5d17e-a9a1-40e1-ab90-cc48a82ef03e.png +0 -0
- exports/charts/temp_chart_91edd69e-ff3b-49e5-834b-399d57898523.png +0 -0
- gradio_app.py +33 -5
- langchain_mcp_client.py +1 -1
- postgre_mcp_server.py +11 -20
.env.sample
CHANGED
|
@@ -7,6 +7,7 @@ GEMINI_API_KEY=
|
|
| 7 |
GEMINI_MODEL=
|
| 8 |
|
| 9 |
PANDAS_KEY=
|
|
|
|
| 10 |
|
| 11 |
OPENAI_MODEL=
|
| 12 |
OPENAI_API_KEY=
|
|
|
|
| 7 |
GEMINI_MODEL=
|
| 8 |
|
| 9 |
PANDAS_KEY=
|
| 10 |
+
PANDAS_EXPORTS_PATH=
|
| 11 |
|
| 12 |
OPENAI_MODEL=
|
| 13 |
OPENAI_API_KEY=
|
exports/charts/temp_chart_5252d436-6dfe-46e3-a348-49308898940d.png
DELETED
|
Binary file (20.4 kB)
|
|
|
exports/charts/temp_chart_70308ee4-1919-43a6-b30a-c7024fb0440e.png
DELETED
|
Binary file (39.8 kB)
|
|
|
exports/charts/temp_chart_730e3451-365a-4a3f-9c4b-d95f991c4bbd.png
DELETED
|
Binary file (30.9 kB)
|
|
|
exports/charts/temp_chart_7ee5d17e-a9a1-40e1-ab90-cc48a82ef03e.png
DELETED
|
Binary file (28.5 kB)
|
|
|
exports/charts/temp_chart_91edd69e-ff3b-49e5-834b-399d57898523.png
DELETED
|
Binary file (28.3 kB)
|
|
|
gradio_app.py
CHANGED
|
@@ -3,6 +3,9 @@ from pathlib import Path
|
|
| 3 |
import gradio as gr
|
| 4 |
import asyncio
|
| 5 |
from langchain_mcp_client import lc_mcp_exec
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# ======================================= Load DB configs
|
| 8 |
def load_db_configs():
|
|
@@ -18,23 +21,49 @@ def load_db_configs():
|
|
| 18 |
return configs["db_configs"]
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# ====================================== Async-compatible wrapper
|
| 22 |
async def run_agent(request, history):
|
| 23 |
# configs = load_db_configs()
|
| 24 |
# final_answer, last_tool_answer, = await pg_mcp_exec(request)
|
| 25 |
# return final_answer, last_tool_answer
|
| 26 |
|
| 27 |
-
response,
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
# ====================================== Gradio UI with history
|
| 32 |
demo = gr.ChatInterface(
|
| 33 |
-
run_agent,
|
| 34 |
type="messages",
|
| 35 |
flagging_mode="manual",
|
| 36 |
flagging_options=["Like", "Spam", "Inappropriate", "Other"],
|
| 37 |
-
# additional_outputs=[gr.Image(label="Visualization Output")],
|
| 38 |
examples=[
|
| 39 |
"List all tables in the database",
|
| 40 |
"Show me the schema of the dim_customer table",
|
|
@@ -42,7 +71,6 @@ demo = gr.ChatInterface(
|
|
| 42 |
"Create a pie chart showing the distribution of customer statuses in the dim_customer table.",
|
| 43 |
"Plot a line chart showing the trend of order quantities over order dates from the dim_product_order_item table.",
|
| 44 |
"Generate a bar chart displaying the count of products in each product class from the dim_product table. Use beautiful and vivid colors fro visualization.",
|
| 45 |
-
"Visualize the relationship between tax_rate and tax_included_amount from the dim_product table using a scatter plot.",
|
| 46 |
"Visualize the number of product orders over time from the dim_product_order_item table, using the order_date. Show the trend monthly.",
|
| 47 |
],
|
| 48 |
save_history=True,
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import asyncio
|
| 5 |
from langchain_mcp_client import lc_mcp_exec
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
import os
|
| 8 |
+
import base64
|
| 9 |
|
| 10 |
# ======================================= Load DB configs
|
| 11 |
def load_db_configs():
|
|
|
|
| 21 |
return configs["db_configs"]
|
| 22 |
|
| 23 |
|
| 24 |
+
def image_to_base64_markdown(image_path, alt_text="Customer Status"):
|
| 25 |
+
with open(image_path, "rb") as f:
|
| 26 |
+
encoded = base64.b64encode(f.read()).decode("utf-8")
|
| 27 |
+
return f""
|
| 28 |
+
|
| 29 |
+
|
| 30 |
# ====================================== Async-compatible wrapper
|
| 31 |
async def run_agent(request, history):
|
| 32 |
# configs = load_db_configs()
|
| 33 |
# final_answer, last_tool_answer, = await pg_mcp_exec(request)
|
| 34 |
# return final_answer, last_tool_answer
|
| 35 |
|
| 36 |
+
response, messages = await lc_mcp_exec(request, history)
|
| 37 |
+
|
| 38 |
+
image_path = ""
|
| 39 |
+
load_dotenv()
|
| 40 |
+
PANDAS_EXPORTS_PATH = os.getenv("PANDAS_EXPORTS_PATH")
|
| 41 |
+
generated_files = [f for f in os.listdir(PANDAS_EXPORTS_PATH) if f.beginswith("temp_chart_") and f.endswith(".png")]
|
| 42 |
+
if generated_files:
|
| 43 |
+
image_path = os.path.join(PANDAS_EXPORTS_PATH, generated_files[0])
|
| 44 |
+
image_markdown = image_to_base64_markdown(image_path)
|
| 45 |
+
|
| 46 |
+
# Remove the image file after reading it
|
| 47 |
+
os.remove(image_path)
|
| 48 |
+
else:
|
| 49 |
+
image_path = f"{PANDAS_EXPORTS_PATH}/blank_chart.png"
|
| 50 |
+
image_markdown = image_to_base64_markdown(image_path)
|
| 51 |
+
|
| 52 |
+
print(f"Image path: {image_path}")
|
| 53 |
+
|
| 54 |
+
output = image_markdown + response
|
| 55 |
+
|
| 56 |
+
# print(output)
|
| 57 |
+
|
| 58 |
+
return output
|
| 59 |
|
| 60 |
|
| 61 |
# ====================================== Gradio UI with history
|
| 62 |
demo = gr.ChatInterface(
|
| 63 |
+
fn=run_agent,
|
| 64 |
type="messages",
|
| 65 |
flagging_mode="manual",
|
| 66 |
flagging_options=["Like", "Spam", "Inappropriate", "Other"],
|
|
|
|
| 67 |
examples=[
|
| 68 |
"List all tables in the database",
|
| 69 |
"Show me the schema of the dim_customer table",
|
|
|
|
| 71 |
"Create a pie chart showing the distribution of customer statuses in the dim_customer table.",
|
| 72 |
"Plot a line chart showing the trend of order quantities over order dates from the dim_product_order_item table.",
|
| 73 |
"Generate a bar chart displaying the count of products in each product class from the dim_product table. Use beautiful and vivid colors fro visualization.",
|
|
|
|
| 74 |
"Visualize the number of product orders over time from the dim_product_order_item table, using the order_date. Show the trend monthly.",
|
| 75 |
],
|
| 76 |
save_history=True,
|
langchain_mcp_client.py
CHANGED
|
@@ -101,7 +101,7 @@ async def lc_mcp_exec(request: str, history=None) -> Tuple[str, list]:
|
|
| 101 |
|
| 102 |
# Return the response and the updated history
|
| 103 |
return response_content, message_history.messages
|
| 104 |
-
|
| 105 |
except Exception as e:
|
| 106 |
logger.error(f"Error in execution: {str(e)}", exc_info=True)
|
| 107 |
return f"Error: {str(e)}", message_history.messages if 'message_history' in locals() else []
|
|
|
|
| 101 |
|
| 102 |
# Return the response and the updated history
|
| 103 |
return response_content, message_history.messages
|
| 104 |
+
|
| 105 |
except Exception as e:
|
| 106 |
logger.error(f"Error in execution: {str(e)}", exc_info=True)
|
| 107 |
return f"Error: {str(e)}", message_history.messages if 'message_history' in locals() else []
|
postgre_mcp_server.py
CHANGED
|
@@ -150,13 +150,9 @@ You can use the following FastMCP tools to create **read-only** queries (e.g., `
|
|
| 150 |
|
| 151 |
Present your final answer using the following structure **exactly** in markdown language. When necessary, bold the important parts of your answer or use `` for inline code blocks:
|
| 152 |
|
| 153 |
-
```markdown
|
| 154 |
# Result
|
| 155 |
{{Take the result from the execute_query tool and format it nicely using Markdown. Use a Markdown table for tabular data (rows and columns) including headers. Use bullet points or items in markdown for answers that include lists of names or descriptions. Use plain text for single values or simple messages. Ensure data alignment and clarity.}}
|
| 156 |
|
| 157 |
-
# Visualization (if requested)
|
| 158 |
-
{{If the user requested a visualization, include the result from the visualize_results tool, e.g., "Visualization saved as visualization_output.png". Otherwise, omit this section.}}
|
| 159 |
-
|
| 160 |
# Explanation
|
| 161 |
{{Provide a concise explanation or interpretation of the results (and visualization, if applicable) in 1-3 sentences. Explain what the data and visualization (if any) represent in the context of the user's request.}}
|
| 162 |
|
|
@@ -165,7 +161,9 @@ Present your final answer using the following structure **exactly** in markdown
|
|
| 165 |
{{Display the exact SQL query you generated and executed here to answer the user's request.}}
|
| 166 |
```
|
| 167 |
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
- **Every time you generate a SQL query, call `execute_query` immediately and include the result.**
|
| 170 |
- **If the user requests a visualization (e.g., "create a chart", "visualize", "plot"), call `visualize_results` with:**
|
| 171 |
- A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
|
|
@@ -611,37 +609,30 @@ async def visualize_results(json_data: dict, vis_prompt: str) -> str:
|
|
| 611 |
"""
|
| 612 |
try:
|
| 613 |
# Debug prints to see what's being received
|
| 614 |
-
print("\nVisualization Tool Debug:")
|
| 615 |
-
print(f"Received json_data: {json_data}")
|
| 616 |
-
print(f"Received vis_prompt: {vis_prompt}")
|
| 617 |
|
| 618 |
# Convert JSON to DataFrame
|
| 619 |
df = pd.DataFrame(json_data["data"], columns=json_data["columns"])
|
| 620 |
-
print(f"Created DataFrame:\n{df.head()}")
|
| 621 |
|
| 622 |
# Initialize PandasAI
|
| 623 |
df_ai = pai.DataFrame(df)
|
| 624 |
-
print("Initialized PandasAI DataFrame")
|
| 625 |
|
|
|
|
| 626 |
load_dotenv()
|
| 627 |
api_key = os.environ.get("PANDAS_KEY")
|
| 628 |
-
print(f"Using PandasAI API key: {api_key[:5]}...")
|
| 629 |
pai.api_key.set(api_key)
|
| 630 |
|
| 631 |
# Generate visualization
|
| 632 |
-
print(f"Attempting to generate visualization with prompt: '{vis_prompt}'")
|
| 633 |
df_ai.chat(vis_prompt)
|
| 634 |
|
| 635 |
-
#
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
|
|
|
|
|
|
| 640 |
|
| 641 |
-
return f"Visualization saved as {output_file}"
|
| 642 |
except Exception as e:
|
| 643 |
-
print(f"Visualization error: {str(e)}")
|
| 644 |
-
print(f"Error type: {type(e)}")
|
| 645 |
return f"Visualization error: {str(e)}"
|
| 646 |
|
| 647 |
|
|
|
|
| 150 |
|
| 151 |
Present your final answer using the following structure **exactly** in markdown language. When necessary, bold the important parts of your answer or use `` for inline code blocks:
|
| 152 |
|
|
|
|
| 153 |
# Result
|
| 154 |
{{Take the result from the execute_query tool and format it nicely using Markdown. Use a Markdown table for tabular data (rows and columns) including headers. Use bullet points or items in markdown for answers that include lists of names or descriptions. Use plain text for single values or simple messages. Ensure data alignment and clarity.}}
|
| 155 |
|
|
|
|
|
|
|
|
|
|
| 156 |
# Explanation
|
| 157 |
{{Provide a concise explanation or interpretation of the results (and visualization, if applicable) in 1-3 sentences. Explain what the data and visualization (if any) represent in the context of the user's request.}}
|
| 158 |
|
|
|
|
| 161 |
{{Display the exact SQL query you generated and executed here to answer the user's request.}}
|
| 162 |
```
|
| 163 |
|
| 164 |
+
==========================
|
| 165 |
+
# Reminder
|
| 166 |
+
==========================
|
| 167 |
- **Every time you generate a SQL query, call `execute_query` immediately and include the result.**
|
| 168 |
- **If the user requests a visualization (e.g., "create a chart", "visualize", "plot"), call `visualize_results` with:**
|
| 169 |
- A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
|
|
|
|
| 609 |
"""
|
| 610 |
try:
|
| 611 |
# Debug prints to see what's being received
|
|
|
|
|
|
|
|
|
|
| 612 |
|
| 613 |
# Convert JSON to DataFrame
|
| 614 |
df = pd.DataFrame(json_data["data"], columns=json_data["columns"])
|
|
|
|
| 615 |
|
| 616 |
# Initialize PandasAI
|
| 617 |
df_ai = pai.DataFrame(df)
|
|
|
|
| 618 |
|
| 619 |
+
# Load api key
|
| 620 |
load_dotenv()
|
| 621 |
api_key = os.environ.get("PANDAS_KEY")
|
|
|
|
| 622 |
pai.api_key.set(api_key)
|
| 623 |
|
| 624 |
# Generate visualization
|
|
|
|
| 625 |
df_ai.chat(vis_prompt)
|
| 626 |
|
| 627 |
+
# Get the visualization path
|
| 628 |
+
PANDAS_EXPORTS_PATH = os.getenv("PANDAS_EXPORTS_PATH")
|
| 629 |
+
generated_files = [f for f in os.listdir(PANDAS_EXPORTS_PATH) if f.startswith("temp_chart")]
|
| 630 |
+
if generated_files:
|
| 631 |
+
visualization_path = os.path.join(PANDAS_EXPORTS_PATH, generated_files[0])
|
| 632 |
+
|
| 633 |
+
return f"Visualization saved as {visualization_path}"
|
| 634 |
|
|
|
|
| 635 |
except Exception as e:
|
|
|
|
|
|
|
| 636 |
return f"Visualization error: {str(e)}"
|
| 637 |
|
| 638 |
|