cevheri commited on
Commit
6a39607
·
1 Parent(s): 685f01f

feat: add openai for client

Browse files
Files changed (2) hide show
  1. langchain_mcp_client.py +10 -10
  2. postgre_mcp_server.py +5 -17
langchain_mcp_client.py CHANGED
@@ -39,19 +39,19 @@ async def lc_mcp_exec(request: str, history=None) -> Tuple[str, list]:
39
  server_params = get_server_params()
40
 
41
  # Initialize the LLM for OpenAI
42
- # llm = init_chat_model(
43
- # model_provider=os.environ["OPENAI_MODEL_PROVIDER"],
44
- # model=os.environ["OPENAI_MODEL"],
45
- # api_key=os.environ["OPENAI_API_KEY"]
46
- # )
47
-
48
- # Initialize the LLM for Gemini
49
  llm = init_chat_model(
50
- model_provider=os.environ["GEMINI_MODEL_PROVIDER"],
51
- model=os.environ["GEMINI_MODEL"],
52
- api_key=os.environ["GEMINI_API_KEY"]
53
  )
54
 
 
 
 
 
 
 
 
55
  # Initialize the MCP client
56
  async with stdio_client(server_params) as (read, write):
57
  async with ClientSession(read, write) as session:
 
39
  server_params = get_server_params()
40
 
41
  # Initialize the LLM for OpenAI
 
 
 
 
 
 
 
42
  llm = init_chat_model(
43
+ model_provider=os.environ["OPENAI_MODEL_PROVIDER"],
44
+ model=os.environ["OPENAI_MODEL"],
45
+ api_key=os.environ["OPENAI_API_KEY"]
46
  )
47
 
48
+ # Initialize the LLM for Gemini
49
+ # llm = init_chat_model(
50
+ # model_provider=os.environ["GEMINI_MODEL_PROVIDER"],
51
+ # model=os.environ["GEMINI_MODEL"],
52
+ # api_key=os.environ["GEMINI_API_KEY"]
53
+ # )
54
+
55
  # Initialize the MCP client
56
  async with stdio_client(server_params) as (read, write):
57
  async with ClientSession(read, write) as session:
postgre_mcp_server.py CHANGED
@@ -2,14 +2,14 @@ import os
2
  from contextlib import asynccontextmanager
3
  from dataclasses import dataclass
4
  from typing import Optional, AsyncIterator
5
- import asyncpg
6
- from dotenv import load_dotenv
7
  from mcp.server.fastmcp import FastMCP, Context
8
  from pydantic import Field
9
  import pandasai as pai
10
  import matplotlib as plt
11
  import pandas as pd
12
  import logging
 
13
 
14
 
15
  # Constants
@@ -21,8 +21,8 @@ DEFAULT_QUERY_LIMIT = 100
21
  # get logger
22
  logger = logging.getLogger(__name__)
23
 
24
- # Load environment variables
25
- load_dotenv()
26
 
27
  # Define our own PromptMessage class if the MCP one isn't available
28
  @dataclass
@@ -37,16 +37,11 @@ class DbContext:
37
  pool: asyncpg.Pool
38
  schema: str
39
 
40
- # Load environment variables
41
- load_dotenv()
42
-
43
 
44
  # Database connection lifecycle manager
45
  @asynccontextmanager
46
  async def db_lifespan(server: FastMCP) -> AsyncIterator[DbContext]:
47
  """Manage database connection lifecycle"""
48
- # Initialize DB connection from environment variables
49
- load_dotenv()
50
  dsn = os.environ["DB_URL"]
51
  schema = os.environ["DB_SCHEMA"]
52
 
@@ -453,8 +448,7 @@ async def generate_analytical_query(table_name: str) -> list[PromptMessage]:
453
  """ Generate analytical queries for a table
454
  Args:
455
  table_name: The name of the table to generate analytical queries for
456
- """
457
- load_dotenv()
458
  schema = os.environ["DB_SCHEMA"]
459
  try:
460
  async with db_lifespan(mcp) as db_ctx:
@@ -627,11 +621,7 @@ async def visualize_results(json_data: dict, vis_prompt: str) -> str:
627
  """
628
  try:
629
  # Debug prints to see what's being received
630
- from pandasai_openai import OpenAI
631
 
632
- pllm = OpenAI(api_token=os.environ["OPENAI_API_KEY"])
633
- pai.config.set({"llm": pllm})
634
-
635
  # Convert JSON to DataFrame
636
  df = pd.DataFrame(json_data["data"], columns=json_data["columns"])
637
 
@@ -642,8 +632,6 @@ async def visualize_results(json_data: dict, vis_prompt: str) -> str:
642
  # Initialize PandasAI
643
  df_ai = pai.DataFrame(df)
644
 
645
- # Load api key
646
- load_dotenv()
647
  api_key = os.environ["PANDAS_KEY"]
648
  pai.api_key.set(api_key)
649
 
 
2
  from contextlib import asynccontextmanager
3
  from dataclasses import dataclass
4
  from typing import Optional, AsyncIterator
5
+ import asyncpg
 
6
  from mcp.server.fastmcp import FastMCP, Context
7
  from pydantic import Field
8
  import pandasai as pai
9
  import matplotlib as plt
10
  import pandas as pd
11
  import logging
12
+ from pandasai_openai import OpenAI
13
 
14
 
15
  # Constants
 
21
  # get logger
22
  logger = logging.getLogger(__name__)
23
 
24
+ pllm = OpenAI(api_token=os.environ["OPENAI_API_KEY"])
25
+ pai.config.set({"llm": pllm})
26
 
27
  # Define our own PromptMessage class if the MCP one isn't available
28
  @dataclass
 
37
  pool: asyncpg.Pool
38
  schema: str
39
 
 
 
 
40
 
41
  # Database connection lifecycle manager
42
  @asynccontextmanager
43
  async def db_lifespan(server: FastMCP) -> AsyncIterator[DbContext]:
44
  """Manage database connection lifecycle"""
 
 
45
  dsn = os.environ["DB_URL"]
46
  schema = os.environ["DB_SCHEMA"]
47
 
 
448
  """ Generate analytical queries for a table
449
  Args:
450
  table_name: The name of the table to generate analytical queries for
451
+ """
 
452
  schema = os.environ["DB_SCHEMA"]
453
  try:
454
  async with db_lifespan(mcp) as db_ctx:
 
621
  """
622
  try:
623
  # Debug prints to see what's being received
 
624
 
 
 
 
625
  # Convert JSON to DataFrame
626
  df = pd.DataFrame(json_data["data"], columns=json_data["columns"])
627
 
 
632
  # Initialize PandasAI
633
  df_ai = pai.DataFrame(df)
634
 
 
 
635
  api_key = os.environ["PANDAS_KEY"]
636
  pai.api_key.set(api_key)
637