amirkiarafiei commited on
Commit
8ba5d9d
·
1 Parent(s): e493d29

init branch

Browse files
Files changed (4) hide show
  1. .env.sample +1 -0
  2. gradio_app.py +42 -31
  3. langchain_mcp_client.py +30 -37
  4. memory_store.py +28 -0
.env.sample CHANGED
@@ -11,5 +11,6 @@ GEMINI_MODEL_PROVIDER=
11
  PANDAS_KEY=
12
  PANDAS_EXPORTS_PATH=
13
 
 
14
  OPENAI_MODEL=
15
  OPENAI_API_KEY=
 
11
  PANDAS_KEY=
12
  PANDAS_EXPORTS_PATH=
13
 
14
+ OPENAI_MODEL_PROVIDER=
15
  OPENAI_MODEL=
16
  OPENAI_API_KEY=
gradio_app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import yaml
2
  from pathlib import Path
3
  import gradio as gr
@@ -6,6 +7,9 @@ from langchain_mcp_client import lc_mcp_exec
6
  from dotenv import load_dotenv
7
  import os
8
  import base64
 
 
 
9
 
10
  # ======================================= Load DB configs
11
  def load_db_configs():
@@ -28,34 +32,44 @@ def image_to_base64_markdown(image_path, alt_text="Customer Status"):
28
 
29
 
30
  # ====================================== Async-compatible wrapper
31
- async def run_agent(request, history):
32
- response, messages = await lc_mcp_exec(request, history)
33
-
34
- image_path = ""
35
- load_dotenv()
36
- PANDAS_EXPORTS_PATH = os.getenv("PANDAS_EXPORTS_PATH", "exports/charts")
37
-
38
- # Ensure the exports directory exists
39
- os.makedirs(PANDAS_EXPORTS_PATH, exist_ok=True)
40
-
41
- generated_files = [f for f in os.listdir(PANDAS_EXPORTS_PATH) if f.startswith("temp_chart_") and f.endswith(".png")]
42
- if generated_files:
43
- image_path = os.path.join(PANDAS_EXPORTS_PATH, generated_files[0])
44
- try:
45
- image_markdown = image_to_base64_markdown(image_path)
46
- output = f"{image_markdown}\n\n{response}"
47
- # Remove the image file after reading it
48
- os.remove(image_path)
49
- except Exception as e:
50
- print(f"Error processing image: {e}")
 
 
 
 
 
 
 
 
 
 
 
51
  output = response
52
- else:
53
- output = response
54
 
55
- print(f"Image path: {image_path}")
56
- print(f"Output length: {len(output)}")
57
-
58
- return output
 
59
 
60
 
61
  # ====================================== Gradio UI with history
@@ -86,9 +100,8 @@ custom_css = """
86
 
87
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
88
  with gr.Row(elem_classes="container"):
89
- # with gr.Column(scale=1):
90
- # gr.Image(value=LOGO_PATH, height=200, show_label=False)
91
-
92
  with gr.Column(scale=3):
93
  gr.Markdown(
94
  """
@@ -96,7 +109,6 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
96
  <p style='text-align: center'>Ask questions about your database, analyze and visualize data.</p>
97
  """
98
  )
99
-
100
  with gr.Row(elem_classes="container"):
101
  with gr.Column(scale=3):
102
  chat = gr.ChatInterface(
@@ -128,7 +140,6 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
128
  save_history=True,
129
  type="messages"
130
  )
131
-
132
  with gr.Column(scale=1):
133
  with gr.Accordion("Example Questions", open=True):
134
  gr.Markdown("""
 
1
+ from asyncio.log import logger
2
  import yaml
3
  from pathlib import Path
4
  import gradio as gr
 
7
  from dotenv import load_dotenv
8
  import os
9
  import base64
10
+ from memory_store import MemoryStore
11
+ import logging
12
+
13
 
14
  # ======================================= Load DB configs
15
  def load_db_configs():
 
32
 
33
 
34
  # ====================================== Async-compatible wrapper
35
+ async def run_agent(request, history=None):
36
+ try:
37
+ logger.info(f"Current request: {request}")
38
+ memory = MemoryStore.get_memory()
39
+ logger.info(f"Current memory messages: {memory.messages}")
40
+
41
+ # Process request using existing memory
42
+ response, messages = await lc_mcp_exec(request)
43
+
44
+ # Handle image processing
45
+ image_path = ""
46
+ load_dotenv()
47
+ PANDAS_EXPORTS_PATH = os.getenv("PANDAS_EXPORTS_PATH", "exports/charts")
48
+
49
+ # Ensure the exports directory exists
50
+ os.makedirs(PANDAS_EXPORTS_PATH, exist_ok=True)
51
+
52
+ # Check for generated charts
53
+ generated_files = [f for f in os.listdir(PANDAS_EXPORTS_PATH)
54
+ if f.startswith("temp_chart_") and f.endswith(".png")]
55
+
56
+ if generated_files:
57
+ image_path = os.path.join(PANDAS_EXPORTS_PATH, generated_files[0])
58
+ try:
59
+ image_markdown = image_to_base64_markdown(image_path)
60
+ output = f"{image_markdown}\n\n{response}"
61
+ os.remove(image_path) # Clean up the image file
62
+ except Exception as e:
63
+ logger.error(f"Error processing image: {e}")
64
+ output = response
65
+ else:
66
  output = response
 
 
67
 
68
+ return output
69
+
70
+ except Exception as e:
71
+ logger.error(f"Error in run_agent: {str(e)}", exc_info=True)
72
+ return f"Error: {str(e)}"
73
 
74
 
75
  # ====================================== Gradio UI with history
 
100
 
101
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
102
  with gr.Row(elem_classes="container"):
103
+ with gr.Column(scale=1):
104
+ gr.Image(value=LOGO_PATH, height=200, show_label=False)
 
105
  with gr.Column(scale=3):
106
  gr.Markdown(
107
  """
 
109
  <p style='text-align: center'>Ask questions about your database, analyze and visualize data.</p>
110
  """
111
  )
 
112
  with gr.Row(elem_classes="container"):
113
  with gr.Column(scale=3):
114
  chat = gr.ChatInterface(
 
140
  save_history=True,
141
  type="messages"
142
  )
 
143
  with gr.Column(scale=1):
144
  with gr.Accordion("Example Questions", open=True):
145
  gr.Markdown("""
langchain_mcp_client.py CHANGED
@@ -11,8 +11,12 @@ from langchain.chat_models import init_chat_model
11
  import logging
12
  from dotenv import load_dotenv
13
  from langchain.globals import set_debug
 
 
14
 
15
- set_debug(True)
 
 
16
 
17
 
18
  # Set up logging
@@ -23,85 +27,74 @@ load_dotenv()
23
 
24
  async def lc_mcp_exec(request: str, history=None) -> Tuple[str, list]:
25
  """
26
- Execute the PostgreSQL MCP pipeline with persistent memory in a JSON file.
27
  Returns the response and the updated message history.
28
  """
29
  try:
30
- # Define the path for the chat history JSON file
31
- history_file = os.path.join(os.path.dirname(__file__), "chat_history.json")
32
-
33
- # Initialize FileChatMessageHistory for persistent storage
34
- message_history = FileChatMessageHistory(file_path=history_file)
35
 
36
  # Load table summary and server parameters
37
  table_summary = load_table_summary(os.environ["TABLE_SUMMARY_PATH"])
38
  server_params = get_server_params()
39
 
40
  # Initialize the LLM
41
- llm = init_chat_model(model_provider=os.getenv("GEMINI_MODEL_PROVIDER"),model=os.getenv("GEMINI_MODEL"),api_key=os.getenv("GEMINI_API_KEY"))
42
- #llm = init_chat_model(model_provider="openai", model=os.getenv("OPENAI_MODEL"), api_key=os.getenv("OPENAI_API_KEY"))
 
 
 
 
 
 
 
 
 
43
 
44
  # Initialize the MCP client
45
  async with stdio_client(server_params) as (read, write):
46
  async with ClientSession(read, write) as session:
47
  await session.initialize()
48
-
49
- # Load tools
50
  tools = await load_and_enrich_tools(session)
51
-
52
- # Create the ReAct agent with tools
53
  agent = create_react_agent(llm, tools)
54
 
55
- # Add the new user message to history as HumanMessage
56
  message_history.add_user_message(request)
57
 
58
- # Prepare the system prompt for the agent
59
  system_prompt = await build_prompt(session, tools, table_summary)
 
60
 
61
- # Init the system message
62
- system_message = SystemMessage(
63
- content=system_prompt
64
- )
65
-
66
- # Add the system message to the history
67
  input_messages = [system_message] + message_history.messages
68
 
69
- # Invoke the agent with the message list
70
  agent_response = await agent.ainvoke(
71
  {"messages": input_messages},
72
  config={"configurable": {"thread_id": "conversation_123"}}
73
  )
74
 
75
- # Extract the latest response and save all new messages
76
  response_content = "No response generated"
77
  if "messages" in agent_response and agent_response["messages"]:
78
- # Identify new messages (those not in input_messages)
79
  new_messages = agent_response["messages"][len(input_messages):]
80
-
81
- # Save all new messages to history
82
  for msg in new_messages:
83
- if isinstance(msg, AIMessage):
84
- # Save AIMessage, including tool_calls if present
85
- message_history.add_message(msg)
86
- elif isinstance(msg, ToolMessage):
87
- # Save ToolMessage with content and tool_call_id
88
  message_history.add_message(msg)
89
  else:
90
- # Log unexpected message types for debugging
91
  logger.debug(f"Skipping unexpected message type: {type(msg)}")
92
 
93
- # Use the last message’s content as the response
94
  response_content = agent_response["messages"][-1].content
95
-
96
  else:
97
  message_history.add_ai_message(response_content)
98
 
99
- # Return the response and the updated history
100
  return response_content, message_history.messages
101
-
102
  except Exception as e:
103
  logger.error(f"Error in execution: {str(e)}", exc_info=True)
104
- return f"Error: {str(e)}", message_history.messages if 'message_history' in locals() else []
105
 
106
  # ---------------- Helper Functions ---------------- #
107
 
 
11
  import logging
12
  from dotenv import load_dotenv
13
  from langchain.globals import set_debug
14
+ from langchain.memory import ChatMessageHistory
15
+ from memory_store import MemoryStore
16
 
17
+
18
+
19
+ # set_debug(True)
20
 
21
 
22
  # Set up logging
 
27
 
28
  async def lc_mcp_exec(request: str, history=None) -> Tuple[str, list]:
29
  """
30
+ Execute the PostgreSQL MCP pipeline with in-memory chat history.
31
  Returns the response and the updated message history.
32
  """
33
  try:
34
+ # Get the singleton memory store instance
35
+ message_history = MemoryStore.get_memory()
 
 
 
36
 
37
  # Load table summary and server parameters
38
  table_summary = load_table_summary(os.environ["TABLE_SUMMARY_PATH"])
39
  server_params = get_server_params()
40
 
41
  # Initialize the LLM
42
+ # llm = init_chat_model(
43
+ # model_provider=os.getenv("OPENAI_MODEL_PROVIDER"),
44
+ # model=os.getenv("OPENAI_MODEL"),
45
+ # api_key=os.getenv("OPENAI_API_KEY")
46
+ # )
47
+
48
+ llm = init_chat_model(
49
+ model_provider=os.getenv("GEMINI_MODEL_PROVIDER"),
50
+ model=os.getenv("GEMINI_MODEL"),
51
+ api_key=os.getenv("GEMINI_API_KEY")
52
+ )
53
 
54
  # Initialize the MCP client
55
  async with stdio_client(server_params) as (read, write):
56
  async with ClientSession(read, write) as session:
57
  await session.initialize()
 
 
58
  tools = await load_and_enrich_tools(session)
 
 
59
  agent = create_react_agent(llm, tools)
60
 
61
+ # Add new user message to memory
62
  message_history.add_user_message(request)
63
 
64
+ # Get system prompt and create system message
65
  system_prompt = await build_prompt(session, tools, table_summary)
66
+ system_message = SystemMessage(content=system_prompt)
67
 
68
+ # Combine system message with chat history
 
 
 
 
 
69
  input_messages = [system_message] + message_history.messages
70
 
71
+ # Invoke agent
72
  agent_response = await agent.ainvoke(
73
  {"messages": input_messages},
74
  config={"configurable": {"thread_id": "conversation_123"}}
75
  )
76
 
77
+ # Process agent response
78
  response_content = "No response generated"
79
  if "messages" in agent_response and agent_response["messages"]:
 
80
  new_messages = agent_response["messages"][len(input_messages):]
81
+
82
+ # Save new messages to memory
83
  for msg in new_messages:
84
+ if isinstance(msg, (AIMessage, ToolMessage)):
 
 
 
 
85
  message_history.add_message(msg)
86
  else:
 
87
  logger.debug(f"Skipping unexpected message type: {type(msg)}")
88
 
 
89
  response_content = agent_response["messages"][-1].content
 
90
  else:
91
  message_history.add_ai_message(response_content)
92
 
 
93
  return response_content, message_history.messages
94
+
95
  except Exception as e:
96
  logger.error(f"Error in execution: {str(e)}", exc_info=True)
97
+ return f"Error: {str(e)}", []
98
 
99
  # ---------------- Helper Functions ---------------- #
100
 
memory_store.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.memory import ChatMessageHistory
2
+ from typing import Optional
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ class MemoryStore:
8
+ _instance: Optional['MemoryStore'] = None
9
+ _memory: Optional[ChatMessageHistory] = None
10
+
11
+ def __new__(cls):
12
+ if cls._instance is None:
13
+ cls._instance = super(MemoryStore, cls).__new__(cls)
14
+ cls._memory = ChatMessageHistory()
15
+ logger.info("New MemoryStore instance created")
16
+ return cls._instance
17
+
18
+ @classmethod
19
+ def get_memory(cls) -> ChatMessageHistory:
20
+ if cls._instance is None:
21
+ cls._instance = cls()
22
+ return cls._memory
23
+
24
+ @classmethod
25
+ def clear_memory(cls):
26
+ if cls._memory is not None:
27
+ cls._memory = ChatMessageHistory()
28
+ logger.info("Memory cleared")