Research_AI_Assistant / src /context_manager.py
JatsTheAIGen's picture
cumulative upgrade - context + safety + response length v1
29048d9
raw
history blame
23.4 kB
# context_manager.py
import sqlite3
import json
import logging
import uuid
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
class EfficientContextManager:
def __init__(self, llm_router=None):
self.session_cache = {} # In-memory for active sessions
self.cache_config = {
"max_session_size": 10, # MB per session
"ttl": 3600, # 1 hour
"compression": "gzip",
"eviction_policy": "LRU"
}
self.db_path = "sessions.db"
self.llm_router = llm_router # For generating context summaries
logger.info(f"Initializing ContextManager with DB path: {self.db_path}")
self._init_database()
def _init_database(self):
"""Initialize database and create tables"""
try:
logger.info("Initializing database...")
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create sessions table if not exists
cursor.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
user_id TEXT DEFAULT 'Test_Any',
created_at TIMESTAMP,
last_activity TIMESTAMP,
context_data TEXT,
user_metadata TEXT
)
""")
# Add user_id column to existing sessions table if it doesn't exist
try:
cursor.execute("ALTER TABLE sessions ADD COLUMN user_id TEXT DEFAULT 'Test_Any'")
logger.info("✓ Added user_id column to sessions table")
except sqlite3.OperationalError:
# Column already exists
pass
logger.info("✓ Sessions table ready")
# Create interactions table
cursor.execute("""
CREATE TABLE IF NOT EXISTS interactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT REFERENCES sessions(session_id),
user_input TEXT,
context_snapshot TEXT,
created_at TIMESTAMP,
FOREIGN KEY(session_id) REFERENCES sessions(session_id)
)
""")
logger.info("✓ Interactions table ready")
# Create user_contexts table (persistent user persona summaries)
cursor.execute("""
CREATE TABLE IF NOT EXISTS user_contexts (
user_id TEXT PRIMARY KEY,
persona_summary TEXT,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
logger.info("✓ User contexts table ready")
# Create session_contexts table (session summaries)
cursor.execute("""
CREATE TABLE IF NOT EXISTS session_contexts (
session_id TEXT PRIMARY KEY,
user_id TEXT,
session_summary TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(session_id) REFERENCES sessions(session_id),
FOREIGN KEY(user_id) REFERENCES user_contexts(user_id)
)
""")
logger.info("✓ Session contexts table ready")
# Create interaction_contexts table (individual interaction summaries)
cursor.execute("""
CREATE TABLE IF NOT EXISTS interaction_contexts (
interaction_id TEXT PRIMARY KEY,
session_id TEXT,
user_input TEXT,
system_response TEXT,
interaction_summary TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(session_id) REFERENCES sessions(session_id)
)
""")
logger.info("✓ Interaction contexts table ready")
conn.commit()
conn.close()
logger.info("Database initialization complete")
except Exception as e:
logger.error(f"Database initialization error: {e}", exc_info=True)
async def manage_context(self, session_id: str, user_input: str, user_id: str = "Test_Any") -> dict:
"""
Efficient context management with user-based context system
STEP 1: Fetch User Context (if available)
STEP 2: Get Previous Interaction Contexts
STEP 3: Combine for workflow use
"""
# Level 1: In-memory session cache
cache_key = f"{session_id}_{user_id}"
context = self._get_from_memory_cache(cache_key)
if not context:
# Level 2: Database retrieval with user context
context = await self._retrieve_from_db(session_id, user_input, user_id)
# STEP 1: Fetch or generate User Context at session start (if first interaction in session)
if not context.get("user_context_loaded"):
user_context = await self.get_user_context(user_id)
context["user_context"] = user_context
context["user_context_loaded"] = True
# Cache warming
self._warm_memory_cache(cache_key, context)
# Update context with new interaction
updated_context = self._update_context(context, user_input, user_id=user_id)
return self._optimize_context(updated_context)
async def get_user_context(self, user_id: str) -> str:
"""
STEP 1: Fetch or generate User Context (500-token persona summary)
Available for all interactions except first time per user
"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Check if user context exists
cursor.execute("""
SELECT persona_summary FROM user_contexts WHERE user_id = ?
""", (user_id,))
row = cursor.fetchone()
if row and row[0]:
# Existing user context found
conn.close()
logger.info(f"✓ User context loaded for {user_id}")
return row[0]
# Generate new user context from all historical data
logger.info(f"Generating new user context for {user_id}")
# Fetch all historical Session and Interaction contexts for this user
all_session_summaries = []
all_interaction_summaries = []
# Get all session contexts
cursor.execute("""
SELECT session_summary FROM session_contexts WHERE user_id = ?
ORDER BY created_at DESC LIMIT 50
""", (user_id,))
for row in cursor.fetchall():
if row[0]:
all_session_summaries.append(row[0])
# Get all interaction contexts
cursor.execute("""
SELECT ic.interaction_summary
FROM interaction_contexts ic
JOIN sessions s ON ic.session_id = s.session_id
WHERE s.user_id = ?
ORDER BY ic.created_at DESC LIMIT 100
""", (user_id,))
for row in cursor.fetchall():
if row[0]:
all_interaction_summaries.append(row[0])
conn.close()
if not all_session_summaries and not all_interaction_summaries:
# First time user - no context to generate
logger.info(f"No historical data for {user_id} - first time user")
return ""
# Generate persona summary using LLM (500 tokens)
historical_data = "\n\n".join(all_session_summaries + all_interaction_summaries[:20])
if self.llm_router:
prompt = f"""Generate a concise 500-token persona summary for user {user_id} based on their interaction history:
Historical Context:
{historical_data}
Create a persona summary that captures:
- Communication style and preferences
- Common topics and interests
- Interaction patterns
- Key information shared across sessions
Keep the summary concise and focused (approximately 500 tokens)."""
try:
persona_summary = await self.llm_router.route_inference(
task_type="general_reasoning",
prompt=prompt,
max_tokens=500,
temperature=0.7
)
if persona_summary and isinstance(persona_summary, str) and persona_summary.strip():
# Store in database
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO user_contexts (user_id, persona_summary, updated_at)
VALUES (?, ?, ?)
""", (user_id, persona_summary.strip(), datetime.now().isoformat()))
conn.commit()
conn.close()
logger.info(f"✓ Generated and stored user context for {user_id}")
return persona_summary.strip()
except Exception as e:
logger.error(f"Error generating user context: {e}", exc_info=True)
# Fallback: Return empty if LLM fails
logger.warning(f"Could not generate user context for {user_id} - using empty")
return ""
except Exception as e:
logger.error(f"Error getting user context: {e}", exc_info=True)
return ""
async def generate_interaction_context(self, interaction_id: str, session_id: str,
user_input: str, system_response: str,
user_id: str = "Test_Any") -> str:
"""
STEP 2: Generate Interaction Context (50-token summary)
Called after each response
"""
try:
if not self.llm_router:
return ""
prompt = f"""Summarize this interaction in approximately 50 tokens:
User Input: {user_input[:200]}
System Response: {system_response[:300]}
Provide a brief summary capturing the key exchange."""
try:
summary = await self.llm_router.route_inference(
task_type="general_reasoning",
prompt=prompt,
max_tokens=50,
temperature=0.7
)
if summary and isinstance(summary, str) and summary.strip():
# Store in database
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO interaction_contexts
(interaction_id, session_id, user_input, system_response, interaction_summary, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""", (
interaction_id,
session_id,
user_input[:500],
system_response[:1000],
summary.strip(),
datetime.now().isoformat()
))
conn.commit()
conn.close()
logger.info(f"✓ Generated interaction context for {interaction_id}")
return summary.strip()
except Exception as e:
logger.error(f"Error generating interaction context: {e}", exc_info=True)
# Fallback on LLM failure
return ""
except Exception as e:
logger.error(f"Error in generate_interaction_context: {e}", exc_info=True)
return ""
async def generate_session_context(self, session_id: str, user_id: str = "Test_Any") -> str:
"""
FINAL STEP: Generate Session Context (100-token summary)
Called at session end
"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Get all interaction contexts for this session
cursor.execute("""
SELECT interaction_summary FROM interaction_contexts
WHERE session_id = ?
ORDER BY created_at ASC
""", (session_id,))
interaction_summaries = [row[0] for row in cursor.fetchall() if row[0]]
conn.close()
if not interaction_summaries:
logger.info(f"No interactions to summarize for session {session_id}")
return ""
# Generate session summary using LLM (100 tokens)
if self.llm_router:
combined_context = "\n".join(interaction_summaries)
prompt = f"""Summarize this session's interactions in approximately 100 tokens:
Interaction Summaries:
{combined_context}
Create a concise session summary capturing:
- Main topics discussed
- Key outcomes or information shared
- User's focus areas
Keep the summary concise (approximately 100 tokens)."""
try:
session_summary = await self.llm_router.route_inference(
task_type="general_reasoning",
prompt=prompt,
max_tokens=100,
temperature=0.7
)
if session_summary and isinstance(session_summary, str) and session_summary.strip():
# Store in database
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO session_contexts
(session_id, user_id, session_summary, created_at)
VALUES (?, ?, ?, ?)
""", (session_id, user_id, session_summary.strip(), datetime.now().isoformat()))
conn.commit()
conn.close()
logger.info(f"✓ Generated session context for {session_id}")
return session_summary.strip()
except Exception as e:
logger.error(f"Error generating session context: {e}", exc_info=True)
# Fallback on LLM failure
return ""
except Exception as e:
logger.error(f"Error in generate_session_context: {e}", exc_info=True)
return ""
async def end_session(self, session_id: str, user_id: str = "Test_Any"):
"""
FINAL STEP: Generate Session Context and clear cache
"""
try:
# Generate session context
await self.generate_session_context(session_id, user_id)
# Clear in-memory cache for this session
cache_key = f"{session_id}_{user_id}"
if cache_key in self.session_cache:
del self.session_cache[cache_key]
logger.info(f"✓ Cleared cache for session {session_id}")
except Exception as e:
logger.error(f"Error ending session: {e}", exc_info=True)
def _optimize_context(self, context: dict) -> dict:
"""
Optimize context for LLM consumption
Format: [Interaction Context #N, #N-1, ...] + User Context
"""
user_context = context.get("user_context", "")
interaction_contexts = context.get("interaction_contexts", [])
# Format interaction contexts as requested
formatted_interactions = []
for idx, ic in enumerate(interaction_contexts[:10]): # Last 10 interactions
formatted_interactions.append(f"[Interaction Context #{len(interaction_contexts) - idx}]\n{ic.get('summary', '')}")
# Combine User Context + Interaction Contexts
combined_context = ""
if user_context:
combined_context += f"[User Context]\n{user_context}\n\n"
if formatted_interactions:
combined_context += "\n\n".join(formatted_interactions)
return {
"session_id": context.get("session_id"),
"user_id": context.get("user_id", "Test_Any"),
"user_context": user_context,
"interaction_contexts": interaction_contexts,
"combined_context": combined_context, # For direct use in prompts
"preferences": context.get("preferences", {}),
"active_tasks": context.get("active_tasks", []),
"last_activity": context.get("last_activity")
}
def _get_from_memory_cache(self, cache_key: str) -> dict:
"""
Retrieve context from in-memory session cache
"""
return self.session_cache.get(cache_key)
async def _retrieve_from_db(self, session_id: str, user_input: str, user_id: str = "Test_Any") -> dict:
"""
Retrieve context from database with semantic search
"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Get session data
cursor.execute("""
SELECT context_data, user_metadata, last_activity, user_id
FROM sessions
WHERE session_id = ?
""", (session_id,))
row = cursor.fetchone()
if row:
context_data = json.loads(row[0]) if row[0] else {}
user_metadata = json.loads(row[1]) if row[1] else {}
last_activity = row[2]
session_user_id = row[3] if len(row) > 3 else user_id
# Update user_id if it changed
if session_user_id != user_id:
cursor.execute("""
UPDATE sessions SET user_id = ? WHERE session_id = ?
""", (user_id, session_id))
conn.commit()
# Get previous interaction contexts for this session
cursor.execute("""
SELECT interaction_summary, created_at
FROM interaction_contexts
WHERE session_id = ?
ORDER BY created_at DESC
LIMIT 20
""", (session_id,))
interaction_contexts = []
for ic_row in cursor.fetchall():
if ic_row[0]:
interaction_contexts.append({
"summary": ic_row[0],
"timestamp": ic_row[1]
})
context = {
"session_id": session_id,
"user_id": user_id,
"interaction_contexts": interaction_contexts,
"preferences": user_metadata.get("preferences", {}),
"active_tasks": user_metadata.get("active_tasks", []),
"last_activity": last_activity,
"user_context_loaded": False # Will be loaded in manage_context
}
conn.close()
return context
else:
# Create new session
cursor.execute("""
INSERT INTO sessions (session_id, user_id, created_at, last_activity, context_data, user_metadata)
VALUES (?, ?, ?, ?, ?, ?)
""", (session_id, user_id, datetime.now().isoformat(), datetime.now().isoformat(), "{}", "{}"))
conn.commit()
conn.close()
return {
"session_id": session_id,
"user_id": user_id,
"interaction_contexts": [],
"preferences": {},
"active_tasks": [],
"user_context_loaded": False
}
except Exception as e:
logger.error(f"Database retrieval error: {e}", exc_info=True)
# Fallback to empty context
return {
"session_id": session_id,
"user_id": user_id,
"interaction_contexts": [],
"preferences": {},
"active_tasks": [],
"user_context_loaded": False
}
def _warm_memory_cache(self, cache_key: str, context: dict):
"""
Warm the in-memory cache with retrieved context
"""
self.session_cache[cache_key] = context
def _update_context(self, context: dict, user_input: str, response: str = None, user_id: str = "Test_Any") -> dict:
"""
Update context with new user interaction and persist to database
Note: Interaction context generation happens separately after response is generated
"""
try:
# Update session activity
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Update session last_activity
cursor.execute("""
UPDATE sessions
SET last_activity = ?, user_id = ?
WHERE session_id = ?
""", (datetime.now().isoformat(), user_id, context["session_id"]))
# Insert basic interaction record (for backward compatibility)
session_context = {
"preferences": context.get("preferences", {}),
"active_tasks": context.get("active_tasks", [])
}
cursor.execute("""
INSERT INTO interactions (session_id, user_input, context_snapshot, created_at)
VALUES (?, ?, ?, ?)
""", (context["session_id"], user_input, json.dumps(session_context), datetime.now().isoformat()))
conn.commit()
conn.close()
except Exception as e:
logger.error(f"Context update error: {e}", exc_info=True)
return context
def _extract_entities(self, context: dict) -> list:
"""
Extract essential entities from context
"""
# TODO: Implement entity extraction
return []
def _generate_summary(self, context: dict) -> str:
"""
Generate conversation summary
"""
# TODO: Implement summary generation
return ""