Research_AI_Assistant / src /context_manager.py
JatsTheAIGen's picture
fix: Resolve database permission errors and OMP_NUM_THREADS warning
f5d3311
# context_manager.py
import sqlite3
import json
import logging
import uuid
import hashlib
import threading
import time
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import Dict, Optional, List
logger = logging.getLogger(__name__)
class TransactionManager:
"""Manage database transactions with proper locking"""
def __init__(self, db_path):
self.db_path = db_path
self._lock = threading.RLock()
self._connections = {}
@contextmanager
def transaction(self, session_id=None):
"""Context manager for database transactions with automatic rollback"""
conn = None
cursor = None
try:
with self._lock:
conn = sqlite3.connect(self.db_path, isolation_level='IMMEDIATE')
conn.execute('PRAGMA journal_mode=WAL') # Write-Ahead Logging for better concurrency
conn.execute('PRAGMA busy_timeout=5000') # 5 second timeout for locks
cursor = conn.cursor()
yield cursor
conn.commit()
logger.debug(f"Transaction committed for session {session_id}")
except Exception as e:
if conn:
conn.rollback()
logger.error(f"Transaction rolled back for session {session_id}: {e}")
raise
finally:
if conn:
conn.close()
class EfficientContextManager:
def __init__(self, llm_router=None, db_path=None):
self.session_cache = {} # In-memory for active sessions
self._session_cache = {} # Enhanced in-memory cache with timestamps
self.cache_config = {
"max_session_size": 10, # MB per session
"ttl": 3600, # 1 hour
"compression": "gzip",
"eviction_policy": "LRU"
}
# Use provided db_path or get from settings, fallback to /tmp/sessions.db
if db_path is None:
try:
from src.config import settings
db_path = settings.db_path
except (ImportError, AttributeError):
# Fallback to writable location in containers
import os
db_path = os.getenv("DB_PATH", "/tmp/sessions.db")
self.db_path = db_path
self.llm_router = llm_router # For generating context summaries
logger.info(f"Initializing ContextManager with DB path: {self.db_path}")
self.transaction_manager = TransactionManager(self.db_path)
self._init_database()
self.optimize_database_indexes()
def _init_database(self):
"""Initialize database and create tables"""
try:
logger.info("Initializing database...")
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create sessions table if not exists
cursor.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
user_id TEXT DEFAULT 'Test_Any',
created_at TIMESTAMP,
last_activity TIMESTAMP,
context_data TEXT,
user_metadata TEXT
)
""")
# Add user_id column to existing sessions table if it doesn't exist
try:
cursor.execute("ALTER TABLE sessions ADD COLUMN user_id TEXT DEFAULT 'Test_Any'")
logger.info("✓ Added user_id column to sessions table")
except sqlite3.OperationalError:
# Column already exists
pass
logger.info("✓ Sessions table ready")
# Create interactions table
cursor.execute("""
CREATE TABLE IF NOT EXISTS interactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT REFERENCES sessions(session_id),
user_input TEXT,
context_snapshot TEXT,
created_at TIMESTAMP,
FOREIGN KEY(session_id) REFERENCES sessions(session_id)
)
""")
logger.info("✓ Interactions table ready")
# Create user_contexts table (persistent user persona summaries)
cursor.execute("""
CREATE TABLE IF NOT EXISTS user_contexts (
user_id TEXT PRIMARY KEY,
persona_summary TEXT,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
logger.info("✓ User contexts table ready")
# Create session_contexts table (session summaries)
cursor.execute("""
CREATE TABLE IF NOT EXISTS session_contexts (
session_id TEXT PRIMARY KEY,
user_id TEXT,
session_summary TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(session_id) REFERENCES sessions(session_id),
FOREIGN KEY(user_id) REFERENCES user_contexts(user_id)
)
""")
logger.info("✓ Session contexts table ready")
# Create interaction_contexts table (individual interaction summaries)
cursor.execute("""
CREATE TABLE IF NOT EXISTS interaction_contexts (
interaction_id TEXT PRIMARY KEY,
session_id TEXT,
user_input TEXT,
system_response TEXT,
interaction_summary TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(session_id) REFERENCES sessions(session_id)
)
""")
logger.info("✓ Interaction contexts table ready")
conn.commit()
conn.close()
# Update schema with new columns and tables for user change tracking
self._update_database_schema()
logger.info("Database initialization complete")
except Exception as e:
logger.error(f"Database initialization error: {e}", exc_info=True)
def _update_database_schema(self):
"""Add missing columns and tables for user change tracking"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Add needs_refresh column to interaction_contexts
try:
cursor.execute("""
ALTER TABLE interaction_contexts
ADD COLUMN needs_refresh INTEGER DEFAULT 0
""")
logger.info("✓ Added needs_refresh column to interaction_contexts")
except sqlite3.OperationalError:
pass # Column already exists
# Create user change log table
cursor.execute("""
CREATE TABLE IF NOT EXISTS user_change_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT,
old_user_id TEXT,
new_user_id TEXT,
timestamp TIMESTAMP,
FOREIGN KEY(session_id) REFERENCES sessions(session_id)
)
""")
conn.commit()
conn.close()
logger.info("✓ Database schema updated successfully for user change tracking")
# Update interactions table for deduplication
self._update_interactions_table()
except Exception as e:
logger.error(f"Schema update error: {e}", exc_info=True)
def _update_interactions_table(self):
"""Add interaction_hash column for deduplication"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Check if column already exists
cursor.execute("PRAGMA table_info(interactions)")
columns = [row[1] for row in cursor.fetchall()]
# Add interaction_hash column if it doesn't exist
if 'interaction_hash' not in columns:
try:
cursor.execute("""
ALTER TABLE interactions
ADD COLUMN interaction_hash TEXT
""")
logger.info("✓ Added interaction_hash column to interactions table")
except sqlite3.OperationalError:
pass # Column already exists
# Create unique index for deduplication (this enforces uniqueness)
try:
cursor.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS idx_interaction_hash_unique
ON interactions(interaction_hash)
""")
logger.info("✓ Created unique index on interaction_hash")
except sqlite3.OperationalError:
# Index might already exist, try non-unique index as fallback
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_interaction_hash
ON interactions(interaction_hash)
""")
conn.commit()
conn.close()
logger.info("✓ Interactions table updated for deduplication")
except Exception as e:
logger.error(f"Error updating interactions table: {e}", exc_info=True)
async def manage_context(self, session_id: str, user_input: str, user_id: str = "Test_Any") -> dict:
"""
Efficient context management with separated session/user caching
STEP 1: Fetch User Context (if available)
STEP 2: Get Previous Interaction Contexts
STEP 3: Combine for workflow use
"""
# Use session-only cache key to prevent user_id conflicts
session_cache_key = f"session_{session_id}"
user_cache_key = f"user_{user_id}"
# Get session context from cache
session_context = self._get_from_memory_cache(session_cache_key)
# Check if cached session context matches current user_id
# Handle both old and new cache formats
cached_entry = self.session_cache.get(session_cache_key)
if cached_entry:
# Extract actual context from cache entry
if isinstance(cached_entry, dict) and 'value' in cached_entry:
actual_context = cached_entry.get('value', {})
else:
actual_context = cached_entry
if actual_context and actual_context.get("user_id") != user_id:
# User changed, invalidate session cache
logger.info(f"User mismatch in cache for session {session_id}, invalidating cache")
session_context = None
if session_cache_key in self.session_cache:
del self.session_cache[session_cache_key]
else:
session_context = actual_context
# Get user context separately
user_context = self._get_from_memory_cache(user_cache_key)
if not session_context:
# Retrieve from database with user context
session_context = await self._retrieve_from_db(session_id, user_input, user_id)
# Step 2: Cache session context with TTL
self.add_context_cache(session_cache_key, session_context, ttl=self.cache_config.get("ttl", 3600))
# Handle user context separately - load only once and cache thereafter
# Cache does not refer to database after initial load
if not user_context or not user_context.get("user_context_loaded"):
user_context_data = await self.get_user_context(user_id)
user_context = {
"user_context": user_context_data,
"user_context_loaded": True,
"user_id": user_id
}
# Cache user context separately - this is the only database query for user context
self._warm_memory_cache(user_cache_key, user_context)
logger.debug(f"User context loaded once for {user_id} and cached")
else:
# User context already cached, use it without database query
logger.debug(f"Using cached user context for {user_id}")
# Merge contexts without duplication
merged_context = {
**session_context,
"user_context": user_context.get("user_context", ""),
"user_context_loaded": True,
"user_id": user_id # Ensure current user_id is used
}
# Update context with new interaction
updated_context = self._update_context(merged_context, user_input, user_id=user_id)
return self._optimize_context(updated_context)
async def get_user_context(self, user_id: str) -> str:
"""
STEP 1: Fetch or generate User Context (500-token persona summary)
Available for all interactions except first time per user
"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Check if user context exists
cursor.execute("""
SELECT persona_summary FROM user_contexts WHERE user_id = ?
""", (user_id,))
row = cursor.fetchone()
if row and row[0]:
# Existing user context found
conn.close()
logger.info(f"✓ User context loaded for {user_id}")
return row[0]
# Generate new user context from all historical data
logger.info(f"Generating new user context for {user_id}")
# Fetch all historical Session and Interaction contexts for this user
all_session_summaries = []
all_interaction_summaries = []
# Get all session contexts
cursor.execute("""
SELECT session_summary FROM session_contexts WHERE user_id = ?
ORDER BY created_at DESC LIMIT 50
""", (user_id,))
for row in cursor.fetchall():
if row[0]:
all_session_summaries.append(row[0])
# Get all interaction contexts
cursor.execute("""
SELECT ic.interaction_summary
FROM interaction_contexts ic
JOIN sessions s ON ic.session_id = s.session_id
WHERE s.user_id = ?
ORDER BY ic.created_at DESC LIMIT 100
""", (user_id,))
for row in cursor.fetchall():
if row[0]:
all_interaction_summaries.append(row[0])
conn.close()
if not all_session_summaries and not all_interaction_summaries:
# First time user - no context to generate
logger.info(f"No historical data for {user_id} - first time user")
return ""
# Generate persona summary using LLM (500 tokens)
historical_data = "\n\n".join(all_session_summaries + all_interaction_summaries[:20])
if self.llm_router:
prompt = f"""Generate a concise 500-token persona summary for user {user_id} based on their interaction history:
Historical Context:
{historical_data}
Create a persona summary that captures:
- Communication style and preferences
- Common topics and interests
- Interaction patterns
- Key information shared across sessions
Keep the summary concise and focused (approximately 500 tokens)."""
try:
persona_summary = await self.llm_router.route_inference(
task_type="general_reasoning",
prompt=prompt,
max_tokens=500,
temperature=0.7
)
if persona_summary and isinstance(persona_summary, str) and persona_summary.strip():
# Store in database
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO user_contexts (user_id, persona_summary, updated_at)
VALUES (?, ?, ?)
""", (user_id, persona_summary.strip(), datetime.now().isoformat()))
conn.commit()
conn.close()
logger.info(f"✓ Generated and stored user context for {user_id}")
return persona_summary.strip()
except Exception as e:
logger.error(f"Error generating user context: {e}", exc_info=True)
# Fallback: Return empty if LLM fails
logger.warning(f"Could not generate user context for {user_id} - using empty")
return ""
except Exception as e:
logger.error(f"Error getting user context: {e}", exc_info=True)
return ""
async def generate_interaction_context(self, interaction_id: str, session_id: str,
user_input: str, system_response: str,
user_id: str = "Test_Any") -> str:
"""
STEP 2: Generate Interaction Context (50-token summary)
Called after each response
"""
try:
if not self.llm_router:
return ""
prompt = f"""Summarize this interaction in approximately 50 tokens:
User Input: {user_input[:200]}
System Response: {system_response[:300]}
Provide a brief summary capturing the key exchange."""
try:
summary = await self.llm_router.route_inference(
task_type="general_reasoning",
prompt=prompt,
max_tokens=50,
temperature=0.7
)
if summary and isinstance(summary, str) and summary.strip():
# Store in database
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
created_at = datetime.now().isoformat()
cursor.execute("""
INSERT OR REPLACE INTO interaction_contexts
(interaction_id, session_id, user_input, system_response, interaction_summary, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""", (
interaction_id,
session_id,
user_input[:500],
system_response[:1000],
summary.strip(),
created_at
))
conn.commit()
conn.close()
# Update cache immediately with new interaction context
# This ensures cache is synchronized with database at the same time
self._update_cache_with_interaction_context(session_id, summary.strip(), created_at)
logger.info(f"✓ Generated interaction context for {interaction_id} and updated cache")
return summary.strip()
except Exception as e:
logger.error(f"Error generating interaction context: {e}", exc_info=True)
# Fallback on LLM failure
return ""
except Exception as e:
logger.error(f"Error in generate_interaction_context: {e}", exc_info=True)
return ""
async def generate_session_context(self, session_id: str, user_id: str = "Test_Any") -> str:
"""
Generate Session Context (100-token summary) at every turn
Uses cached interaction contexts instead of querying database
Updates both database and cache immediately
"""
try:
# Get interaction contexts from cache (no database query)
session_cache_key = f"session_{session_id}"
cached_context = self.session_cache.get(session_cache_key)
if not cached_context:
logger.warning(f"No cached context found for session {session_id}, cannot generate session context")
return ""
interaction_contexts = cached_context.get('interaction_contexts', [])
if not interaction_contexts:
logger.info(f"No interaction contexts available for session {session_id} to summarize")
return ""
# Use cached interaction contexts (from cache, not database)
interaction_summaries = [ic.get('summary', '') for ic in interaction_contexts if ic.get('summary')]
if not interaction_summaries:
logger.info(f"No interaction summaries available for session {session_id}")
return ""
# Generate session summary using LLM (100 tokens)
if self.llm_router:
combined_context = "\n".join(interaction_summaries)
prompt = f"""Summarize this session's interactions in approximately 100 tokens:
Interaction Summaries:
{combined_context}
Create a concise session summary capturing:
- Main topics discussed
- Key outcomes or information shared
- User's focus areas
Keep the summary concise (approximately 100 tokens)."""
try:
session_summary = await self.llm_router.route_inference(
task_type="general_reasoning",
prompt=prompt,
max_tokens=100,
temperature=0.7
)
if session_summary and isinstance(session_summary, str) and session_summary.strip():
# Store in database
created_at = datetime.now().isoformat()
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO session_contexts
(session_id, user_id, session_summary, created_at)
VALUES (?, ?, ?, ?)
""", (session_id, user_id, session_summary.strip(), created_at))
conn.commit()
conn.close()
# Update cache immediately with new session context
# This ensures cache is synchronized with database at the same time
self._update_cache_with_session_context(session_id, session_summary.strip(), created_at)
logger.info(f"✓ Generated session context for {session_id} and updated cache")
return session_summary.strip()
except Exception as e:
logger.error(f"Error generating session context: {e}", exc_info=True)
# Fallback on LLM failure
return ""
except Exception as e:
logger.error(f"Error in generate_session_context: {e}", exc_info=True)
return ""
async def end_session(self, session_id: str, user_id: str = "Test_Any"):
"""
End session and clear cache
Note: Session context is already generated at every turn, so this just clears cache
"""
try:
# Session context is already generated at every turn (no need to regenerate)
# Clear in-memory cache for this session (session-only key)
session_cache_key = f"session_{session_id}"
if session_cache_key in self.session_cache:
del self.session_cache[session_cache_key]
logger.info(f"✓ Cleared cache for session {session_id}")
except Exception as e:
logger.error(f"Error ending session: {e}", exc_info=True)
def _clear_user_cache_on_change(self, session_id: str, new_user_id: str, old_user_id: str):
"""Clear cache entries when user changes"""
if new_user_id != old_user_id:
# Clear old composite cache keys
old_cache_key = f"{session_id}_{old_user_id}"
if old_cache_key in self.session_cache:
del self.session_cache[old_cache_key]
logger.info(f"Cleared old cache for user {old_user_id} on session {session_id}")
def _optimize_context(self, context: dict, relevance_classification: Optional[Dict] = None) -> dict:
"""
Optimize context for LLM consumption with relevance filtering support
Format: [Session Context] + [User Context (conditional)] + [Interaction Context #N, #N-1, ...]
Args:
context: Base context dictionary
relevance_classification: Optional relevance classification results with dynamic user context
Applies smart pruning before formatting.
"""
# Step 4: Prune context if it exceeds token limits
pruned_context = self.prune_context(context, max_tokens=2000)
# Get context mode (fresh or relevant)
session_id = pruned_context.get("session_id")
context_mode = self.get_context_mode(session_id)
interaction_contexts = pruned_context.get("interaction_contexts", [])
session_context = pruned_context.get("session_context", {})
session_summary = session_context.get("summary", "") if isinstance(session_context, dict) else ""
# MODIFIED: Conditional user context inclusion based on mode and relevance
user_context = ""
if context_mode == 'relevant' and relevance_classification:
# Use dynamic relevant summaries from relevance classification
user_context = relevance_classification.get('combined_user_context', '')
if user_context:
logger.info(
f"Using dynamic relevant context: {len(relevance_classification.get('relevant_summaries', []))} "
f"sessions summarized for session {session_id}"
)
elif context_mode == 'relevant' and not relevance_classification:
# Fallback: Use traditional user context if relevance classification unavailable
user_context = pruned_context.get("user_context", "")
logger.debug(f"Relevant mode but no classification, using traditional user context")
# If context_mode == 'fresh', user_context remains empty (no user context)
# Format interaction contexts as requested
formatted_interactions = []
for idx, ic in enumerate(interaction_contexts[:10]): # Last 10 interactions
formatted_interactions.append(f"[Interaction Context #{len(interaction_contexts) - idx}]\n{ic.get('summary', '')}")
# Combine Session Context + (Conditional) User Context + Interaction Contexts
combined_context = ""
if session_summary:
combined_context += f"[Session Context]\n{session_summary}\n\n"
# Include user context only if available and in relevant mode
if user_context:
context_label = "[Relevant User Context]" if context_mode == 'relevant' else "[User Context]"
combined_context += f"{context_label}\n{user_context}\n\n"
if formatted_interactions:
combined_context += "\n\n".join(formatted_interactions)
return {
"session_id": pruned_context.get("session_id"),
"user_id": pruned_context.get("user_id", "Test_Any"),
"user_context": user_context, # Dynamic summaries OR empty
"session_context": session_context,
"interaction_contexts": interaction_contexts,
"combined_context": combined_context,
"context_mode": context_mode, # Include mode for debugging
"relevance_metadata": relevance_classification.get('relevance_scores', {}) if relevance_classification else {},
"preferences": pruned_context.get("preferences", {}),
"active_tasks": pruned_context.get("active_tasks", []),
"last_activity": pruned_context.get("last_activity")
}
def _get_from_memory_cache(self, cache_key: str) -> dict:
"""
Retrieve context from in-memory session cache with expiration check
"""
cached = self.session_cache.get(cache_key)
if not cached:
return None
# Check if it's the new format with expiration
if isinstance(cached, dict) and 'value' in cached:
# New format with TTL
if self._is_cache_expired(cached):
# Remove expired cache entry
del self.session_cache[cache_key]
logger.debug(f"Cache expired for key: {cache_key}")
return None
return cached.get('value')
else:
# Old format (direct value) - return as-is for backward compatibility
return cached
def _is_cache_expired(self, cache_entry: dict) -> bool:
"""
Check if cache entry has expired based on TTL
"""
if not isinstance(cache_entry, dict):
return True
expires = cache_entry.get('expires')
if not expires:
return False # No expiration set, consider valid
return time.time() > expires
def add_context_cache(self, key: str, value: dict, ttl: int = 3600):
"""
Step 2: Implement Context Caching with TTL expiration
Add context to cache with expiration time.
Args:
key: Cache key
value: Value to cache (dict)
ttl: Time to live in seconds (default 3600 = 1 hour)
"""
import time
self.session_cache[key] = {
'value': value,
'expires': time.time() + ttl,
'timestamp': time.time()
}
logger.debug(f"Cached context for key: {key} with TTL: {ttl}s")
def get_token_count(self, text: str) -> int:
"""
Approximate token count for text (4 characters ≈ 1 token)
Args:
text: Text to count tokens for
Returns:
Approximate token count
"""
if not text:
return 0
# Simple approximation: 4 characters per token
return len(text) // 4
def prune_context(self, context: dict, max_tokens: int = 2000) -> dict:
"""
Step 4: Implement Smart Context Pruning
Prune context to stay within token limit while keeping most recent and relevant content.
Args:
context: Context dictionary to prune
max_tokens: Maximum token count (default 2000)
Returns:
Pruned context dictionary
"""
try:
# Calculate current token count
current_tokens = self._calculate_context_tokens(context)
if current_tokens <= max_tokens:
return context # No pruning needed
logger.info(f"Context token count ({current_tokens}) exceeds limit ({max_tokens}), pruning...")
# Create a copy to avoid modifying original
pruned_context = context.copy()
# Priority: Keep most recent interactions + session context + user context
interaction_contexts = pruned_context.get('interaction_contexts', [])
session_context = pruned_context.get('session_context', {})
user_context = pruned_context.get('user_context', '')
# Keep user context and session context (essential)
essential_tokens = (
self.get_token_count(user_context) +
self.get_token_count(str(session_context))
)
# Calculate how many interaction contexts we can keep
available_tokens = max_tokens - essential_tokens
if available_tokens < 0:
# Essential context itself is too large - summarize user context
if self.get_token_count(user_context) > max_tokens // 2:
pruned_context['user_context'] = user_context[:max_tokens * 2] # Rough cut
logger.warning(f"User context too large, truncated")
return pruned_context
# Keep most recent interactions that fit in token budget
kept_interactions = []
current_size = 0
for interaction in interaction_contexts:
summary = interaction.get('summary', '')
interaction_tokens = self.get_token_count(summary)
if current_size + interaction_tokens <= available_tokens:
kept_interactions.append(interaction)
current_size += interaction_tokens
else:
break # Can't fit any more
pruned_context['interaction_contexts'] = kept_interactions
logger.info(f"Pruned context: kept {len(kept_interactions)}/{len(interaction_contexts)} interactions, "
f"reduced from {current_tokens} to {self._calculate_context_tokens(pruned_context)} tokens")
return pruned_context
except Exception as e:
logger.error(f"Error pruning context: {e}", exc_info=True)
return context # Return original on error
def _calculate_context_tokens(self, context: dict) -> int:
"""Calculate total token count for context"""
total = 0
# Count tokens in each component
user_context = context.get('user_context', '')
total += self.get_token_count(str(user_context))
session_context = context.get('session_context', {})
if isinstance(session_context, dict):
total += self.get_token_count(str(session_context.get('summary', '')))
else:
total += self.get_token_count(str(session_context))
interaction_contexts = context.get('interaction_contexts', [])
for interaction in interaction_contexts:
summary = interaction.get('summary', '')
total += self.get_token_count(str(summary))
return total
async def _retrieve_from_db(self, session_id: str, user_input: str, user_id: str = "Test_Any") -> dict:
"""
Retrieve session context with proper user_id synchronization
Uses transactions to ensure atomic updates of database and cache
"""
conn = None
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Use transaction to ensure atomic updates
cursor.execute("BEGIN TRANSACTION")
# Get session data (SQLite doesn't support FOR UPDATE, but transaction ensures consistency)
cursor.execute("""
SELECT context_data, user_metadata, last_activity, user_id
FROM sessions
WHERE session_id = ?
""", (session_id,))
row = cursor.fetchone()
if row:
context_data = json.loads(row[0]) if row[0] else {}
user_metadata = json.loads(row[1]) if row[1] else {}
last_activity = row[2]
session_user_id = row[3] if len(row) > 3 else user_id
# Check for user_id change and update atomically
user_changed = False
if session_user_id != user_id:
logger.info(f"User change detected: {session_user_id} -> {user_id} for session {session_id}")
user_changed = True
# Update session with new user_id
cursor.execute("""
UPDATE sessions
SET user_id = ?, last_activity = ?
WHERE session_id = ?
""", (user_id, datetime.now().isoformat(), session_id))
# Clear any cached interaction contexts for old user by marking for refresh
try:
cursor.execute("""
UPDATE interaction_contexts
SET needs_refresh = 1
WHERE session_id = ?
""", (session_id,))
except sqlite3.OperationalError:
# Column might not exist yet, will be created by schema update
pass
# Log user change event
try:
cursor.execute("""
INSERT INTO user_change_log (session_id, old_user_id, new_user_id, timestamp)
VALUES (?, ?, ?, ?)
""", (session_id, session_user_id, user_id, datetime.now().isoformat()))
except sqlite3.OperationalError:
# Table might not exist yet, will be created by schema update
pass
# Clear old cache entries when user changes
self._clear_user_cache_on_change(session_id, user_id, session_user_id)
cursor.execute("COMMIT")
# Get interaction contexts with refresh flag check
try:
cursor.execute("""
SELECT interaction_summary, created_at, needs_refresh
FROM interaction_contexts
WHERE session_id = ? AND (needs_refresh IS NULL OR needs_refresh = 0)
ORDER BY created_at DESC
LIMIT 20
""", (session_id,))
except sqlite3.OperationalError:
# Column might not exist yet, fall back to query without needs_refresh
cursor.execute("""
SELECT interaction_summary, created_at
FROM interaction_contexts
WHERE session_id = ?
ORDER BY created_at DESC
LIMIT 20
""", (session_id,))
interaction_contexts = []
for ic_row in cursor.fetchall():
# Handle both query formats (with and without needs_refresh)
if len(ic_row) >= 2:
summary = ic_row[0]
timestamp = ic_row[1]
needs_refresh = ic_row[2] if len(ic_row) > 2 else 0
if summary and not needs_refresh:
interaction_contexts.append({
"summary": summary,
"timestamp": timestamp
})
# Get session context from database
session_context_data = None
try:
cursor.execute("""
SELECT session_summary, created_at
FROM session_contexts
WHERE session_id = ?
ORDER BY created_at DESC
LIMIT 1
""", (session_id,))
sc_row = cursor.fetchone()
if sc_row and sc_row[0]:
session_context_data = {
"summary": sc_row[0],
"timestamp": sc_row[1]
}
except sqlite3.OperationalError:
# Table might not exist yet
pass
context = {
"session_id": session_id,
"user_id": user_id,
"interaction_contexts": interaction_contexts,
"session_context": session_context_data,
"preferences": user_metadata.get("preferences", {}),
"active_tasks": user_metadata.get("active_tasks", []),
"last_activity": last_activity,
"user_context_loaded": False,
"user_changed": user_changed
}
conn.close()
return context
else:
# Create new session with transaction
cursor.execute("""
INSERT INTO sessions (session_id, user_id, created_at, last_activity, context_data, user_metadata)
VALUES (?, ?, ?, ?, ?, ?)
""", (session_id, user_id, datetime.now().isoformat(), datetime.now().isoformat(), "{}", "{}"))
cursor.execute("COMMIT")
conn.close()
return {
"session_id": session_id,
"user_id": user_id,
"interaction_contexts": [],
"session_context": None,
"preferences": {},
"active_tasks": [],
"user_context_loaded": False,
"user_changed": False
}
except sqlite3.Error as e:
logger.error(f"Database transaction error: {e}", exc_info=True)
if conn:
try:
conn.rollback()
except:
pass
conn.close()
# Return safe fallback
return {
"session_id": session_id,
"user_id": user_id,
"interaction_contexts": [],
"session_context": None,
"preferences": {},
"active_tasks": [],
"user_context_loaded": False,
"error": str(e),
"user_changed": False
}
except Exception as e:
logger.error(f"Database retrieval error: {e}", exc_info=True)
if conn:
try:
conn.rollback()
except:
pass
conn.close()
# Return safe fallback
return {
"session_id": session_id,
"user_id": user_id,
"interaction_contexts": [],
"session_context": None,
"preferences": {},
"active_tasks": [],
"user_context_loaded": False,
"error": str(e),
"user_changed": False
}
def _warm_memory_cache(self, cache_key: str, context: dict):
"""
Warm the in-memory cache with retrieved context
Note: Use add_context_cache() instead for TTL support
"""
# Use add_context_cache for consistency with TTL
self.add_context_cache(cache_key, context, ttl=self.cache_config.get("ttl", 3600))
def _update_cache_with_interaction_context(self, session_id: str, interaction_summary: str, created_at: str):
"""
Update cache with new interaction context immediately after database update
This keeps cache synchronized with database without requiring database queries
"""
session_cache_key = f"session_{session_id}"
# Get current cached context if it exists
cached_context = self.session_cache.get(session_cache_key)
if cached_context:
# Add new interaction context to the beginning of the list (most recent first)
interaction_contexts = cached_context.get('interaction_contexts', [])
new_interaction = {
"summary": interaction_summary,
"timestamp": created_at
}
# Insert at beginning and keep only last 20 (matches DB query limit)
interaction_contexts.insert(0, new_interaction)
interaction_contexts = interaction_contexts[:20]
# Update cached context with new interaction contexts
cached_context['interaction_contexts'] = interaction_contexts
self.session_cache[session_cache_key] = cached_context
logger.debug(f"Cache updated with new interaction context for session {session_id} (total: {len(interaction_contexts)})")
else:
# If cache doesn't exist, create new entry
new_context = {
"session_id": session_id,
"interaction_contexts": [{
"summary": interaction_summary,
"timestamp": created_at
}],
"preferences": {},
"active_tasks": [],
"user_context_loaded": False
}
self.session_cache[session_cache_key] = new_context
logger.debug(f"Created new cache entry with interaction context for session {session_id}")
def _update_cache_with_session_context(self, session_id: str, session_summary: str, created_at: str):
"""
Update cache with new session context immediately after database update
This keeps cache synchronized with database without requiring database queries
"""
session_cache_key = f"session_{session_id}"
# Get current cached context if it exists
cached_context = self.session_cache.get(session_cache_key)
if cached_context:
# Update session context in cache
cached_context['session_context'] = {
"summary": session_summary,
"timestamp": created_at
}
self.session_cache[session_cache_key] = cached_context
logger.debug(f"Cache updated with new session context for session {session_id}")
else:
# If cache doesn't exist, create new entry
new_context = {
"session_id": session_id,
"session_context": {
"summary": session_summary,
"timestamp": created_at
},
"interaction_contexts": [],
"preferences": {},
"active_tasks": [],
"user_context_loaded": False
}
self.session_cache[session_cache_key] = new_context
logger.debug(f"Created new cache entry with session context for session {session_id}")
def _update_context(self, context: dict, user_input: str, response: str = None, user_id: str = "Test_Any") -> dict:
"""
Update context with deduplication and idempotency checks
Prevents duplicate context updates using interaction hashes
"""
try:
# Generate unique interaction hash to prevent duplicates
interaction_hash = self._generate_interaction_hash(user_input, context["session_id"], user_id)
# Check if this interaction was already processed
if self._is_duplicate_interaction(interaction_hash):
logger.info(f"Duplicate interaction detected, skipping update: {interaction_hash[:8]}")
return context
# Use transaction for atomic updates
current_time = datetime.now().isoformat()
with self.transaction_manager.transaction(context["session_id"]) as cursor:
# Update session activity (only if last_activity is older to prevent unnecessary updates)
cursor.execute("""
UPDATE sessions
SET last_activity = ?, user_id = ?
WHERE session_id = ? AND (last_activity IS NULL OR last_activity < ?)
""", (current_time, user_id, context["session_id"], current_time))
# Store interaction with duplicate prevention using INSERT OR IGNORE
session_context = {
"preferences": context.get("preferences", {}),
"active_tasks": context.get("active_tasks", [])
}
cursor.execute("""
INSERT OR IGNORE INTO interactions (
interaction_hash,
session_id,
user_input,
context_snapshot,
created_at
) VALUES (?, ?, ?, ?, ?)
""", (
interaction_hash,
context["session_id"],
user_input,
json.dumps(session_context),
current_time
))
# Mark interaction as processed (outside transaction)
self._mark_interaction_processed(interaction_hash)
# Update in-memory context
context["last_interaction"] = user_input
context["last_update"] = current_time
logger.info(f"Context updated for session {context['session_id']} with hash {interaction_hash[:8]}")
return context
except Exception as e:
logger.error(f"Error updating context: {e}", exc_info=True)
return context
def _generate_interaction_hash(self, user_input: str, session_id: str, user_id: str) -> str:
"""Generate unique hash for interaction to prevent duplicates"""
# Use session_id, user_id, and user_input for exact duplicate detection
# Normalize user input by stripping whitespace
normalized_input = user_input.strip()
content = f"{session_id}:{user_id}:{normalized_input}"
return hashlib.sha256(content.encode()).hexdigest()
def _is_duplicate_interaction(self, interaction_hash: str) -> bool:
"""Check if interaction was already processed"""
# Keep a rolling window of recent interaction hashes in memory
if not hasattr(self, '_processed_interactions'):
self._processed_interactions = set()
# Check in-memory cache first
if interaction_hash in self._processed_interactions:
return True
# Also check database for persistent duplicates
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Check if interaction_hash column exists and query for duplicates
cursor.execute("PRAGMA table_info(interactions)")
columns = [row[1] for row in cursor.fetchall()]
if 'interaction_hash' in columns:
cursor.execute("""
SELECT COUNT(*) FROM interactions
WHERE interaction_hash IS NOT NULL AND interaction_hash = ?
""", (interaction_hash,))
count = cursor.fetchone()[0]
conn.close()
return count > 0
else:
conn.close()
return False
except sqlite3.OperationalError:
# Column might not exist yet, only check in-memory
return interaction_hash in self._processed_interactions
def _mark_interaction_processed(self, interaction_hash: str):
"""Mark interaction as processed"""
if not hasattr(self, '_processed_interactions'):
self._processed_interactions = set()
self._processed_interactions.add(interaction_hash)
# Limit memory usage by keeping only last 1000 hashes
if len(self._processed_interactions) > 1000:
# Keep most recent 500 entries (simple truncation)
self._processed_interactions = set(list(self._processed_interactions)[-500:])
async def manage_context_optimized(self, session_id: str, user_input: str, user_id: str = "Test_Any") -> dict:
"""
Efficient context management with transaction optimization
"""
# Use session-only cache key
session_cache_key = f"session_{session_id}"
# Try to get from cache first (no DB access)
cached_context = self._get_from_memory_cache(session_cache_key)
if cached_context and self._is_cache_valid(cached_context):
logger.debug(f"Using cached context for session {session_id}")
return cached_context
# Use transaction for all DB operations
with self.transaction_manager.transaction(session_id) as cursor:
# Atomic session retrieval and update
cursor.execute("""
SELECT s.context_data, s.user_metadata, s.last_activity, s.user_id,
COUNT(ic.interaction_id) as interaction_count
FROM sessions s
LEFT JOIN interaction_contexts ic ON s.session_id = ic.session_id
WHERE s.session_id = ?
GROUP BY s.session_id
""", (session_id,))
row = cursor.fetchone()
if row:
# Parse existing session data
context_data = json.loads(row[0] or '{}')
user_metadata = json.loads(row[1] or '{}')
last_activity = row[2]
stored_user_id = row[3] or user_id
interaction_count = row[4] or 0
# Handle user change atomically
if stored_user_id != user_id:
self._handle_user_change_atomic(cursor, session_id, stored_user_id, user_id)
# Get interaction contexts efficiently
interaction_contexts = self._get_interaction_contexts_atomic(cursor, session_id)
else:
# Create new session atomically
cursor.execute("""
INSERT INTO sessions (session_id, user_id, created_at, last_activity, context_data, user_metadata)
VALUES (?, ?, datetime('now'), datetime('now'), '{}', '{}')
""", (session_id, user_id))
context_data = {}
user_metadata = {}
interaction_contexts = []
interaction_count = 0
# Load user context asynchronously (outside transaction)
user_context = await self._load_user_context_async(user_id)
# Build final context
final_context = {
"session_id": session_id,
"user_id": user_id,
"interaction_contexts": interaction_contexts,
"user_context": user_context,
"preferences": user_metadata.get("preferences", {}),
"active_tasks": user_metadata.get("active_tasks", []),
"interaction_count": interaction_count,
"cache_timestamp": datetime.now().isoformat()
}
# Update cache
self._warm_memory_cache(session_cache_key, final_context)
return self._optimize_context(final_context)
def _handle_user_change_atomic(self, cursor, session_id: str, old_user_id: str, new_user_id: str):
"""Handle user change within transaction"""
logger.info(f"Handling user change in transaction: {old_user_id} -> {new_user_id}")
# Update session
cursor.execute("""
UPDATE sessions
SET user_id = ?, last_activity = datetime('now')
WHERE session_id = ?
""", (new_user_id, session_id))
# Log the change
try:
cursor.execute("""
INSERT INTO user_change_log (session_id, old_user_id, new_user_id, timestamp)
VALUES (?, ?, ?, datetime('now'))
""", (session_id, old_user_id, new_user_id))
except sqlite3.OperationalError:
# Table might not exist yet
pass
# Invalidate related caches
try:
cursor.execute("""
UPDATE interaction_contexts
SET needs_refresh = 1
WHERE session_id = ?
""", (session_id,))
except sqlite3.OperationalError:
# Column might not exist yet
pass
def _get_interaction_contexts_atomic(self, cursor, session_id: str, limit: int = 20):
"""Get interaction contexts within transaction"""
try:
cursor.execute("""
SELECT interaction_summary, created_at, interaction_id
FROM interaction_contexts
WHERE session_id = ? AND (needs_refresh IS NULL OR needs_refresh = 0)
ORDER BY created_at DESC
LIMIT ?
""", (session_id, limit))
except sqlite3.OperationalError:
# Fallback if needs_refresh column doesn't exist
cursor.execute("""
SELECT interaction_summary, created_at, interaction_id
FROM interaction_contexts
WHERE session_id = ?
ORDER BY created_at DESC
LIMIT ?
""", (session_id, limit))
contexts = []
for row in cursor.fetchall():
if row[0]:
contexts.append({
"summary": row[0],
"timestamp": row[1],
"id": row[2] if len(row) > 2 else None
})
return contexts
async def _load_user_context_async(self, user_id: str):
"""Load user context asynchronously to avoid blocking"""
try:
# Check memory cache first
user_cache_key = f"user_{user_id}"
cached = self._get_from_memory_cache(user_cache_key)
if cached:
return cached.get("user_context", "")
# Load from database
return await self.get_user_context(user_id)
except Exception as e:
logger.error(f"Error loading user context: {e}")
return ""
def _is_cache_valid(self, cached_context: dict, max_age_seconds: int = 60) -> bool:
"""Check if cached context is still valid"""
if not cached_context:
return False
cache_timestamp = cached_context.get("cache_timestamp")
if not cache_timestamp:
return False
try:
cache_time = datetime.fromisoformat(cache_timestamp)
age = (datetime.now() - cache_time).total_seconds()
return age < max_age_seconds
except:
return False
def invalidate_session_cache(self, session_id: str):
"""
Invalidate cached context for a session to force fresh retrieval
Only affects cache management - does not change application functionality
"""
session_cache_key = f"session_{session_id}"
if session_cache_key in self.session_cache:
del self.session_cache[session_cache_key]
logger.info(f"Cache invalidated for session {session_id} to ensure fresh context retrieval")
def optimize_database_indexes(self):
"""Create database indexes for better query performance"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create indexes for frequently queried columns
indexes = [
"CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id)",
"CREATE INDEX IF NOT EXISTS idx_sessions_last_activity ON sessions(last_activity)",
"CREATE INDEX IF NOT EXISTS idx_interactions_session_id ON interactions(session_id)",
"CREATE INDEX IF NOT EXISTS idx_interaction_contexts_session_id ON interaction_contexts(session_id)",
"CREATE INDEX IF NOT EXISTS idx_interaction_contexts_created_at ON interaction_contexts(created_at)",
"CREATE INDEX IF NOT EXISTS idx_user_change_log_session_id ON user_change_log(session_id)",
"CREATE INDEX IF NOT EXISTS idx_user_contexts_updated_at ON user_contexts(updated_at)"
]
for index in indexes:
try:
cursor.execute(index)
except sqlite3.OperationalError as e:
# Table might not exist yet, skip this index
logger.debug(f"Skipping index creation (table may not exist): {e}")
# Analyze database for query optimization
try:
cursor.execute("ANALYZE")
except sqlite3.OperationalError:
# ANALYZE might not be available in all SQLite versions
pass
conn.commit()
conn.close()
logger.info("✓ Database indexes optimized successfully")
except Exception as e:
logger.error(f"Error optimizing database indexes: {e}", exc_info=True)
def set_context_mode(self, session_id: str, mode: str, user_id: str = "Test_Any"):
"""
Set context mode for session (fresh or relevant)
Args:
session_id: Session identifier
mode: 'fresh' (no user context) or 'relevant' (only relevant context)
user_id: User identifier
Returns:
bool: True if successful, False otherwise
"""
try:
import time
# VALIDATION: Ensure mode is valid
if mode not in ['fresh', 'relevant']:
logger.warning(f"Invalid context mode '{mode}', defaulting to 'fresh'")
mode = 'fresh'
# Get or create cache entry
cache_key = f"session_{session_id}"
cached_context = self._get_from_memory_cache(cache_key)
if not cached_context:
cached_context = {
'session_id': session_id,
'user_id': user_id,
'preferences': {},
'context_mode': mode,
'context_mode_timestamp': time.time()
}
else:
# Update existing context (preserve other data)
cached_context['context_mode'] = mode
cached_context['context_mode_timestamp'] = time.time()
cached_context['user_id'] = user_id # Update user_id if changed
# Update cache with TTL
self.add_context_cache(cache_key, cached_context, ttl=3600)
logger.info(f"Context mode set to '{mode}' for session {session_id} (user: {user_id})")
return True
except Exception as e:
logger.error(f"Error setting context mode: {e}", exc_info=True)
return False # Failure doesn't break existing flow
def get_context_mode(self, session_id: str) -> str:
"""
Get current context mode for session
Args:
session_id: Session identifier
Returns:
str: 'fresh' or 'relevant' (default: 'fresh')
"""
try:
cache_key = f"session_{session_id}"
cached_context = self._get_from_memory_cache(cache_key)
if cached_context:
mode = cached_context.get('context_mode', 'fresh')
# VALIDATION: Ensure mode is still valid
if mode in ['fresh', 'relevant']:
return mode
else:
logger.warning(f"Invalid cached mode '{mode}', resetting to 'fresh'")
cached_context['context_mode'] = 'fresh'
import time
cached_context['context_mode_timestamp'] = time.time()
self.add_context_cache(cache_key, cached_context, ttl=3600)
return 'fresh'
# Default for new sessions
return 'fresh'
except Exception as e:
logger.error(f"Error getting context mode: {e}", exc_info=True)
return 'fresh' # Safe default - no degradation
async def get_all_user_sessions(self, user_id: str) -> List[Dict]:
"""
Fetch all session contexts for a user (for relevance classification)
Performance: Single database query with JOIN
Args:
user_id: User identifier
Returns:
List of session context dictionaries with summaries and interactions
"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Fetch all session contexts for user with interaction summaries
cursor.execute("""
SELECT DISTINCT
sc.session_id,
sc.session_summary,
sc.created_at,
(SELECT GROUP_CONCAT(ic.interaction_summary, ' ||| ')
FROM interaction_contexts ic
WHERE ic.session_id = sc.session_id
ORDER BY ic.created_at DESC
LIMIT 10) as recent_interactions
FROM session_contexts sc
JOIN sessions s ON sc.session_id = s.session_id
WHERE s.user_id = ?
ORDER BY sc.created_at DESC
LIMIT 50
""", (user_id,))
sessions = []
for row in cursor.fetchall():
session_id, session_summary, created_at, interactions_str = row
# Parse interaction summaries
interaction_list = []
if interactions_str:
for summary in interactions_str.split(' ||| '):
if summary.strip():
interaction_list.append({
'summary': summary.strip(),
'timestamp': created_at
})
sessions.append({
'session_id': session_id,
'summary': session_summary or '',
'created_at': created_at,
'interaction_contexts': interaction_list
})
conn.close()
logger.info(f"Fetched {len(sessions)} sessions for user {user_id}")
return sessions
except Exception as e:
logger.error(f"Error fetching user sessions: {e}", exc_info=True)
return [] # Safe fallback - no degradation
def _extract_entities(self, context: dict) -> list:
"""
Extract essential entities from context
"""
# TODO: Implement entity extraction
return []
def _generate_summary(self, context: dict) -> str:
"""
Generate conversation summary
"""
# TODO: Implement summary generation
return ""
def get_or_create_session_context(self, session_id: str, user_id: Optional[str] = None) -> Dict:
"""Enhanced context retrieval with caching"""
import time
# In-memory cache check first
if session_id in self._session_cache:
cache_entry = self._session_cache[session_id]
if time.time() - cache_entry['timestamp'] < 300: # 5 min cache
logger.debug(f"Cache hit for session {session_id}")
return cache_entry['context']
# Batch database queries
conn = None
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Single query for all context data
query = """
SELECT
s.context_data,
s.user_metadata,
s.last_activity,
u.persona_summary,
ic.interaction_summary
FROM sessions s
LEFT JOIN user_contexts u ON s.user_id = u.user_id
LEFT JOIN interaction_contexts ic ON s.session_id = ic.session_id
WHERE s.session_id = ?
ORDER BY ic.created_at DESC
LIMIT 10
"""
cursor.execute(query, (session_id,))
results = cursor.fetchall()
# Process results efficiently
context = self._build_context_from_results(results, session_id, user_id)
# Update cache
self._session_cache[session_id] = {
'context': context,
'timestamp': time.time()
}
return context
except Exception as e:
logger.error(f"Error in get_or_create_session_context: {e}", exc_info=True)
# Return safe fallback
return {
"session_id": session_id,
"user_id": user_id or "Test_Any",
"interaction_contexts": [],
"session_context": None,
"preferences": {},
"active_tasks": [],
"user_context_loaded": False
}
finally:
if conn:
conn.close()
def _build_context_from_results(self, results: list, session_id: str, user_id: Optional[str]) -> Dict:
"""Build context dictionary from batch query results"""
context = {
"session_id": session_id,
"user_id": user_id or "Test_Any",
"interaction_contexts": [],
"session_context": None,
"user_context": "",
"preferences": {},
"active_tasks": [],
"user_context_loaded": False
}
if not results:
return context
# Process first row for session data
first_row = results[0]
if first_row[0]: # context_data
try:
session_data = json.loads(first_row[0])
context["preferences"] = session_data.get("preferences", {})
context["active_tasks"] = session_data.get("active_tasks", [])
except:
pass
if first_row[1]: # user_metadata
try:
user_metadata = json.loads(first_row[1])
context["preferences"].update(user_metadata.get("preferences", {}))
except:
pass
context["last_activity"] = first_row[2] # last_activity
if first_row[3]: # persona_summary
context["user_context"] = first_row[3]
context["user_context_loaded"] = True
# Process interaction contexts
seen_interactions = set()
for row in results:
if row[4]: # interaction_summary
# Deduplicate interactions
if row[4] not in seen_interactions:
seen_interactions.add(row[4])
context["interaction_contexts"].append({
"summary": row[4],
"timestamp": None # Could extract from row if available
})
return context