amirkiarafiei commited on
Commit
8eaf76a
·
1 Parent(s): 3b91e5c

feat!: add visualizations to UI

Browse files
.env.sample CHANGED
@@ -7,6 +7,7 @@ GEMINI_API_KEY=
7
  GEMINI_MODEL=
8
 
9
  PANDAS_KEY=
 
10
 
11
  OPENAI_MODEL=
12
  OPENAI_API_KEY=
 
7
  GEMINI_MODEL=
8
 
9
  PANDAS_KEY=
10
+ PANDAS_EXPORTS_PATH=
11
 
12
  OPENAI_MODEL=
13
  OPENAI_API_KEY=
exports/charts/temp_chart_5252d436-6dfe-46e3-a348-49308898940d.png DELETED
Binary file (20.4 kB)
 
exports/charts/temp_chart_70308ee4-1919-43a6-b30a-c7024fb0440e.png DELETED
Binary file (39.8 kB)
 
exports/charts/temp_chart_730e3451-365a-4a3f-9c4b-d95f991c4bbd.png DELETED
Binary file (30.9 kB)
 
exports/charts/temp_chart_7ee5d17e-a9a1-40e1-ab90-cc48a82ef03e.png DELETED
Binary file (28.5 kB)
 
exports/charts/temp_chart_91edd69e-ff3b-49e5-834b-399d57898523.png DELETED
Binary file (28.3 kB)
 
gradio_app.py CHANGED
@@ -3,6 +3,9 @@ from pathlib import Path
3
  import gradio as gr
4
  import asyncio
5
  from langchain_mcp_client import lc_mcp_exec
 
 
 
6
 
7
  # ======================================= Load DB configs
8
  def load_db_configs():
@@ -18,23 +21,49 @@ def load_db_configs():
18
  return configs["db_configs"]
19
 
20
 
 
 
 
 
 
 
21
  # ====================================== Async-compatible wrapper
22
  async def run_agent(request, history):
23
  # configs = load_db_configs()
24
  # final_answer, last_tool_answer, = await pg_mcp_exec(request)
25
  # return final_answer, last_tool_answer
26
 
27
- response, message = await lc_mcp_exec(request, history)
28
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  # ====================================== Gradio UI with history
32
  demo = gr.ChatInterface(
33
- run_agent,
34
  type="messages",
35
  flagging_mode="manual",
36
  flagging_options=["Like", "Spam", "Inappropriate", "Other"],
37
- # additional_outputs=[gr.Image(label="Visualization Output")],
38
  examples=[
39
  "List all tables in the database",
40
  "Show me the schema of the dim_customer table",
@@ -42,7 +71,6 @@ demo = gr.ChatInterface(
42
  "Create a pie chart showing the distribution of customer statuses in the dim_customer table.",
43
  "Plot a line chart showing the trend of order quantities over order dates from the dim_product_order_item table.",
44
  "Generate a bar chart displaying the count of products in each product class from the dim_product table. Use beautiful and vivid colors fro visualization.",
45
- "Visualize the relationship between tax_rate and tax_included_amount from the dim_product table using a scatter plot.",
46
  "Visualize the number of product orders over time from the dim_product_order_item table, using the order_date. Show the trend monthly.",
47
  ],
48
  save_history=True,
 
3
  import gradio as gr
4
  import asyncio
5
  from langchain_mcp_client import lc_mcp_exec
6
+ from dotenv import load_dotenv
7
+ import os
8
+ import base64
9
 
10
  # ======================================= Load DB configs
11
  def load_db_configs():
 
21
  return configs["db_configs"]
22
 
23
 
24
+ def image_to_base64_markdown(image_path, alt_text="Customer Status"):
25
+ with open(image_path, "rb") as f:
26
+ encoded = base64.b64encode(f.read()).decode("utf-8")
27
+ return f"![{alt_text}](data:image/png;base64,{encoded})"
28
+
29
+
30
  # ====================================== Async-compatible wrapper
31
  async def run_agent(request, history):
32
  # configs = load_db_configs()
33
  # final_answer, last_tool_answer, = await pg_mcp_exec(request)
34
  # return final_answer, last_tool_answer
35
 
36
+ response, messages = await lc_mcp_exec(request, history)
37
+
38
+ image_path = ""
39
+ load_dotenv()
40
+ PANDAS_EXPORTS_PATH = os.getenv("PANDAS_EXPORTS_PATH")
41
+ generated_files = [f for f in os.listdir(PANDAS_EXPORTS_PATH) if f.beginswith("temp_chart_") and f.endswith(".png")]
42
+ if generated_files:
43
+ image_path = os.path.join(PANDAS_EXPORTS_PATH, generated_files[0])
44
+ image_markdown = image_to_base64_markdown(image_path)
45
+
46
+ # Remove the image file after reading it
47
+ os.remove(image_path)
48
+ else:
49
+ image_path = f"{PANDAS_EXPORTS_PATH}/blank_chart.png"
50
+ image_markdown = image_to_base64_markdown(image_path)
51
+
52
+ print(f"Image path: {image_path}")
53
+
54
+ output = image_markdown + response
55
+
56
+ # print(output)
57
+
58
+ return output
59
 
60
 
61
  # ====================================== Gradio UI with history
62
  demo = gr.ChatInterface(
63
+ fn=run_agent,
64
  type="messages",
65
  flagging_mode="manual",
66
  flagging_options=["Like", "Spam", "Inappropriate", "Other"],
 
67
  examples=[
68
  "List all tables in the database",
69
  "Show me the schema of the dim_customer table",
 
71
  "Create a pie chart showing the distribution of customer statuses in the dim_customer table.",
72
  "Plot a line chart showing the trend of order quantities over order dates from the dim_product_order_item table.",
73
  "Generate a bar chart displaying the count of products in each product class from the dim_product table. Use beautiful and vivid colors fro visualization.",
 
74
  "Visualize the number of product orders over time from the dim_product_order_item table, using the order_date. Show the trend monthly.",
75
  ],
76
  save_history=True,
langchain_mcp_client.py CHANGED
@@ -101,7 +101,7 @@ async def lc_mcp_exec(request: str, history=None) -> Tuple[str, list]:
101
 
102
  # Return the response and the updated history
103
  return response_content, message_history.messages
104
-
105
  except Exception as e:
106
  logger.error(f"Error in execution: {str(e)}", exc_info=True)
107
  return f"Error: {str(e)}", message_history.messages if 'message_history' in locals() else []
 
101
 
102
  # Return the response and the updated history
103
  return response_content, message_history.messages
104
+
105
  except Exception as e:
106
  logger.error(f"Error in execution: {str(e)}", exc_info=True)
107
  return f"Error: {str(e)}", message_history.messages if 'message_history' in locals() else []
postgre_mcp_server.py CHANGED
@@ -150,13 +150,9 @@ You can use the following FastMCP tools to create **read-only** queries (e.g., `
150
 
151
  Present your final answer using the following structure **exactly** in markdown language. When necessary, bold the important parts of your answer or use `` for inline code blocks:
152
 
153
- ```markdown
154
  # Result
155
  {{Take the result from the execute_query tool and format it nicely using Markdown. Use a Markdown table for tabular data (rows and columns) including headers. Use bullet points or items in markdown for answers that include lists of names or descriptions. Use plain text for single values or simple messages. Ensure data alignment and clarity.}}
156
 
157
- # Visualization (if requested)
158
- {{If the user requested a visualization, include the result from the visualize_results tool, e.g., "Visualization saved as visualization_output.png". Otherwise, omit this section.}}
159
-
160
  # Explanation
161
  {{Provide a concise explanation or interpretation of the results (and visualization, if applicable) in 1-3 sentences. Explain what the data and visualization (if any) represent in the context of the user's request.}}
162
 
@@ -165,7 +161,9 @@ Present your final answer using the following structure **exactly** in markdown
165
  {{Display the exact SQL query you generated and executed here to answer the user's request.}}
166
  ```
167
 
168
- **Reminder:**
 
 
169
  - **Every time you generate a SQL query, call `execute_query` immediately and include the result.**
170
  - **If the user requests a visualization (e.g., "create a chart", "visualize", "plot"), call `visualize_results` with:**
171
  - A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
@@ -611,37 +609,30 @@ async def visualize_results(json_data: dict, vis_prompt: str) -> str:
611
  """
612
  try:
613
  # Debug prints to see what's being received
614
- print("\nVisualization Tool Debug:")
615
- print(f"Received json_data: {json_data}")
616
- print(f"Received vis_prompt: {vis_prompt}")
617
 
618
  # Convert JSON to DataFrame
619
  df = pd.DataFrame(json_data["data"], columns=json_data["columns"])
620
- print(f"Created DataFrame:\n{df.head()}")
621
 
622
  # Initialize PandasAI
623
  df_ai = pai.DataFrame(df)
624
- print("Initialized PandasAI DataFrame")
625
 
 
626
  load_dotenv()
627
  api_key = os.environ.get("PANDAS_KEY")
628
- print(f"Using PandasAI API key: {api_key[:5]}...")
629
  pai.api_key.set(api_key)
630
 
631
  # Generate visualization
632
- print(f"Attempting to generate visualization with prompt: '{vis_prompt}'")
633
  df_ai.chat(vis_prompt)
634
 
635
- # Save plot
636
- output_file = "visualization_output.png"
637
- plt.savefig(output_file)
638
- plt.close()
639
- print(f"Saved visualization to {output_file}")
 
 
640
 
641
- return f"Visualization saved as {output_file}"
642
  except Exception as e:
643
- print(f"Visualization error: {str(e)}")
644
- print(f"Error type: {type(e)}")
645
  return f"Visualization error: {str(e)}"
646
 
647
 
 
150
 
151
  Present your final answer using the following structure **exactly** in markdown language. When necessary, bold the important parts of your answer or use `` for inline code blocks:
152
 
 
153
  # Result
154
  {{Take the result from the execute_query tool and format it nicely using Markdown. Use a Markdown table for tabular data (rows and columns) including headers. Use bullet points or items in markdown for answers that include lists of names or descriptions. Use plain text for single values or simple messages. Ensure data alignment and clarity.}}
155
 
 
 
 
156
  # Explanation
157
  {{Provide a concise explanation or interpretation of the results (and visualization, if applicable) in 1-3 sentences. Explain what the data and visualization (if any) represent in the context of the user's request.}}
158
 
 
161
  {{Display the exact SQL query you generated and executed here to answer the user's request.}}
162
  ```
163
 
164
+ ==========================
165
+ # Reminder
166
+ ==========================
167
  - **Every time you generate a SQL query, call `execute_query` immediately and include the result.**
168
  - **If the user requests a visualization (e.g., "create a chart", "visualize", "plot"), call `visualize_results` with:**
169
  - A visualization prompt starting with "plot" (e.g., "plot a bar chart showing sales by region").
 
609
  """
610
  try:
611
  # Debug prints to see what's being received
 
 
 
612
 
613
  # Convert JSON to DataFrame
614
  df = pd.DataFrame(json_data["data"], columns=json_data["columns"])
 
615
 
616
  # Initialize PandasAI
617
  df_ai = pai.DataFrame(df)
 
618
 
619
+ # Load api key
620
  load_dotenv()
621
  api_key = os.environ.get("PANDAS_KEY")
 
622
  pai.api_key.set(api_key)
623
 
624
  # Generate visualization
 
625
  df_ai.chat(vis_prompt)
626
 
627
+ # Get the visualization path
628
+ PANDAS_EXPORTS_PATH = os.getenv("PANDAS_EXPORTS_PATH")
629
+ generated_files = [f for f in os.listdir(PANDAS_EXPORTS_PATH) if f.startswith("temp_chart")]
630
+ if generated_files:
631
+ visualization_path = os.path.join(PANDAS_EXPORTS_PATH, generated_files[0])
632
+
633
+ return f"Visualization saved as {visualization_path}"
634
 
 
635
  except Exception as e:
 
 
636
  return f"Visualization error: {str(e)}"
637
 
638