|
|
|
|
|
import sqlite3 |
|
|
import json |
|
|
import logging |
|
|
import uuid |
|
|
import hashlib |
|
|
import threading |
|
|
import time |
|
|
from contextlib import contextmanager |
|
|
from datetime import datetime, timedelta |
|
|
from typing import Dict, Optional, List |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class TransactionManager: |
|
|
"""Manage database transactions with proper locking""" |
|
|
|
|
|
def __init__(self, db_path): |
|
|
self.db_path = db_path |
|
|
self._lock = threading.RLock() |
|
|
self._connections = {} |
|
|
|
|
|
@contextmanager |
|
|
def transaction(self, session_id=None): |
|
|
"""Context manager for database transactions with automatic rollback""" |
|
|
conn = None |
|
|
cursor = None |
|
|
|
|
|
try: |
|
|
with self._lock: |
|
|
conn = sqlite3.connect(self.db_path, isolation_level='IMMEDIATE') |
|
|
conn.execute('PRAGMA journal_mode=WAL') |
|
|
conn.execute('PRAGMA busy_timeout=5000') |
|
|
cursor = conn.cursor() |
|
|
|
|
|
yield cursor |
|
|
|
|
|
conn.commit() |
|
|
logger.debug(f"Transaction committed for session {session_id}") |
|
|
|
|
|
except Exception as e: |
|
|
if conn: |
|
|
conn.rollback() |
|
|
logger.error(f"Transaction rolled back for session {session_id}: {e}") |
|
|
raise |
|
|
finally: |
|
|
if conn: |
|
|
conn.close() |
|
|
|
|
|
class EfficientContextManager: |
|
|
def __init__(self, llm_router=None, db_path=None): |
|
|
self.session_cache = {} |
|
|
self._session_cache = {} |
|
|
self.cache_config = { |
|
|
"max_session_size": 10, |
|
|
"ttl": 3600, |
|
|
"compression": "gzip", |
|
|
"eviction_policy": "LRU" |
|
|
} |
|
|
|
|
|
if db_path is None: |
|
|
try: |
|
|
from src.config import settings |
|
|
db_path = settings.db_path |
|
|
except (ImportError, AttributeError): |
|
|
|
|
|
import os |
|
|
db_path = os.getenv("DB_PATH", "/tmp/sessions.db") |
|
|
self.db_path = db_path |
|
|
self.llm_router = llm_router |
|
|
logger.info(f"Initializing ContextManager with DB path: {self.db_path}") |
|
|
self.transaction_manager = TransactionManager(self.db_path) |
|
|
self._init_database() |
|
|
self.optimize_database_indexes() |
|
|
|
|
|
def _init_database(self): |
|
|
"""Initialize database and create tables""" |
|
|
try: |
|
|
logger.info("Initializing database...") |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS sessions ( |
|
|
session_id TEXT PRIMARY KEY, |
|
|
user_id TEXT DEFAULT 'Test_Any', |
|
|
created_at TIMESTAMP, |
|
|
last_activity TIMESTAMP, |
|
|
context_data TEXT, |
|
|
user_metadata TEXT |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
try: |
|
|
cursor.execute("ALTER TABLE sessions ADD COLUMN user_id TEXT DEFAULT 'Test_Any'") |
|
|
logger.info("✓ Added user_id column to sessions table") |
|
|
except sqlite3.OperationalError: |
|
|
|
|
|
pass |
|
|
|
|
|
logger.info("✓ Sessions table ready") |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS interactions ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
session_id TEXT REFERENCES sessions(session_id), |
|
|
user_input TEXT, |
|
|
context_snapshot TEXT, |
|
|
created_at TIMESTAMP, |
|
|
FOREIGN KEY(session_id) REFERENCES sessions(session_id) |
|
|
) |
|
|
""") |
|
|
logger.info("✓ Interactions table ready") |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS user_contexts ( |
|
|
user_id TEXT PRIMARY KEY, |
|
|
persona_summary TEXT, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
logger.info("✓ User contexts table ready") |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS session_contexts ( |
|
|
session_id TEXT PRIMARY KEY, |
|
|
user_id TEXT, |
|
|
session_summary TEXT, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
FOREIGN KEY(session_id) REFERENCES sessions(session_id), |
|
|
FOREIGN KEY(user_id) REFERENCES user_contexts(user_id) |
|
|
) |
|
|
""") |
|
|
logger.info("✓ Session contexts table ready") |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS interaction_contexts ( |
|
|
interaction_id TEXT PRIMARY KEY, |
|
|
session_id TEXT, |
|
|
user_input TEXT, |
|
|
system_response TEXT, |
|
|
interaction_summary TEXT, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
FOREIGN KEY(session_id) REFERENCES sessions(session_id) |
|
|
) |
|
|
""") |
|
|
logger.info("✓ Interaction contexts table ready") |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
self._update_database_schema() |
|
|
|
|
|
logger.info("Database initialization complete") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Database initialization error: {e}", exc_info=True) |
|
|
|
|
|
def _update_database_schema(self): |
|
|
"""Add missing columns and tables for user change tracking""" |
|
|
try: |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
try: |
|
|
cursor.execute(""" |
|
|
ALTER TABLE interaction_contexts |
|
|
ADD COLUMN needs_refresh INTEGER DEFAULT 0 |
|
|
""") |
|
|
logger.info("✓ Added needs_refresh column to interaction_contexts") |
|
|
except sqlite3.OperationalError: |
|
|
pass |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS user_change_log ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
session_id TEXT, |
|
|
old_user_id TEXT, |
|
|
new_user_id TEXT, |
|
|
timestamp TIMESTAMP, |
|
|
FOREIGN KEY(session_id) REFERENCES sessions(session_id) |
|
|
) |
|
|
""") |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
logger.info("✓ Database schema updated successfully for user change tracking") |
|
|
|
|
|
|
|
|
self._update_interactions_table() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Schema update error: {e}", exc_info=True) |
|
|
|
|
|
def _update_interactions_table(self): |
|
|
"""Add interaction_hash column for deduplication""" |
|
|
try: |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute("PRAGMA table_info(interactions)") |
|
|
columns = [row[1] for row in cursor.fetchall()] |
|
|
|
|
|
|
|
|
if 'interaction_hash' not in columns: |
|
|
try: |
|
|
cursor.execute(""" |
|
|
ALTER TABLE interactions |
|
|
ADD COLUMN interaction_hash TEXT |
|
|
""") |
|
|
logger.info("✓ Added interaction_hash column to interactions table") |
|
|
except sqlite3.OperationalError: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
cursor.execute(""" |
|
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_interaction_hash_unique |
|
|
ON interactions(interaction_hash) |
|
|
""") |
|
|
logger.info("✓ Created unique index on interaction_hash") |
|
|
except sqlite3.OperationalError: |
|
|
|
|
|
cursor.execute(""" |
|
|
CREATE INDEX IF NOT EXISTS idx_interaction_hash |
|
|
ON interactions(interaction_hash) |
|
|
""") |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
logger.info("✓ Interactions table updated for deduplication") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error updating interactions table: {e}", exc_info=True) |
|
|
|
|
|
async def manage_context(self, session_id: str, user_input: str, user_id: str = "Test_Any") -> dict: |
|
|
""" |
|
|
Efficient context management with separated session/user caching |
|
|
STEP 1: Fetch User Context (if available) |
|
|
STEP 2: Get Previous Interaction Contexts |
|
|
STEP 3: Combine for workflow use |
|
|
""" |
|
|
|
|
|
session_cache_key = f"session_{session_id}" |
|
|
user_cache_key = f"user_{user_id}" |
|
|
|
|
|
|
|
|
session_context = self._get_from_memory_cache(session_cache_key) |
|
|
|
|
|
|
|
|
|
|
|
cached_entry = self.session_cache.get(session_cache_key) |
|
|
if cached_entry: |
|
|
|
|
|
if isinstance(cached_entry, dict) and 'value' in cached_entry: |
|
|
actual_context = cached_entry.get('value', {}) |
|
|
else: |
|
|
actual_context = cached_entry |
|
|
|
|
|
if actual_context and actual_context.get("user_id") != user_id: |
|
|
|
|
|
logger.info(f"User mismatch in cache for session {session_id}, invalidating cache") |
|
|
session_context = None |
|
|
if session_cache_key in self.session_cache: |
|
|
del self.session_cache[session_cache_key] |
|
|
else: |
|
|
session_context = actual_context |
|
|
|
|
|
|
|
|
user_context = self._get_from_memory_cache(user_cache_key) |
|
|
|
|
|
if not session_context: |
|
|
|
|
|
session_context = await self._retrieve_from_db(session_id, user_input, user_id) |
|
|
|
|
|
|
|
|
self.add_context_cache(session_cache_key, session_context, ttl=self.cache_config.get("ttl", 3600)) |
|
|
|
|
|
|
|
|
|
|
|
if not user_context or not user_context.get("user_context_loaded"): |
|
|
user_context_data = await self.get_user_context(user_id) |
|
|
user_context = { |
|
|
"user_context": user_context_data, |
|
|
"user_context_loaded": True, |
|
|
"user_id": user_id |
|
|
} |
|
|
|
|
|
self._warm_memory_cache(user_cache_key, user_context) |
|
|
logger.debug(f"User context loaded once for {user_id} and cached") |
|
|
else: |
|
|
|
|
|
logger.debug(f"Using cached user context for {user_id}") |
|
|
|
|
|
|
|
|
merged_context = { |
|
|
**session_context, |
|
|
"user_context": user_context.get("user_context", ""), |
|
|
"user_context_loaded": True, |
|
|
"user_id": user_id |
|
|
} |
|
|
|
|
|
|
|
|
updated_context = self._update_context(merged_context, user_input, user_id=user_id) |
|
|
|
|
|
return self._optimize_context(updated_context) |
|
|
|
|
|
async def get_user_context(self, user_id: str) -> str: |
|
|
""" |
|
|
STEP 1: Fetch or generate User Context (500-token persona summary) |
|
|
Available for all interactions except first time per user |
|
|
""" |
|
|
try: |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT persona_summary FROM user_contexts WHERE user_id = ? |
|
|
""", (user_id,)) |
|
|
|
|
|
row = cursor.fetchone() |
|
|
if row and row[0]: |
|
|
|
|
|
conn.close() |
|
|
logger.info(f"✓ User context loaded for {user_id}") |
|
|
return row[0] |
|
|
|
|
|
|
|
|
logger.info(f"Generating new user context for {user_id}") |
|
|
|
|
|
|
|
|
all_session_summaries = [] |
|
|
all_interaction_summaries = [] |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT session_summary FROM session_contexts WHERE user_id = ? |
|
|
ORDER BY created_at DESC LIMIT 50 |
|
|
""", (user_id,)) |
|
|
for row in cursor.fetchall(): |
|
|
if row[0]: |
|
|
all_session_summaries.append(row[0]) |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT ic.interaction_summary |
|
|
FROM interaction_contexts ic |
|
|
JOIN sessions s ON ic.session_id = s.session_id |
|
|
WHERE s.user_id = ? |
|
|
ORDER BY ic.created_at DESC LIMIT 100 |
|
|
""", (user_id,)) |
|
|
for row in cursor.fetchall(): |
|
|
if row[0]: |
|
|
all_interaction_summaries.append(row[0]) |
|
|
|
|
|
conn.close() |
|
|
|
|
|
if not all_session_summaries and not all_interaction_summaries: |
|
|
|
|
|
logger.info(f"No historical data for {user_id} - first time user") |
|
|
return "" |
|
|
|
|
|
|
|
|
historical_data = "\n\n".join(all_session_summaries + all_interaction_summaries[:20]) |
|
|
|
|
|
if self.llm_router: |
|
|
prompt = f"""Generate a concise 500-token persona summary for user {user_id} based on their interaction history: |
|
|
|
|
|
Historical Context: |
|
|
{historical_data} |
|
|
|
|
|
Create a persona summary that captures: |
|
|
- Communication style and preferences |
|
|
- Common topics and interests |
|
|
- Interaction patterns |
|
|
- Key information shared across sessions |
|
|
|
|
|
Keep the summary concise and focused (approximately 500 tokens).""" |
|
|
|
|
|
try: |
|
|
persona_summary = await self.llm_router.route_inference( |
|
|
task_type="general_reasoning", |
|
|
prompt=prompt, |
|
|
max_tokens=500, |
|
|
temperature=0.7 |
|
|
) |
|
|
|
|
|
if persona_summary and isinstance(persona_summary, str) and persona_summary.strip(): |
|
|
|
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
INSERT OR REPLACE INTO user_contexts (user_id, persona_summary, updated_at) |
|
|
VALUES (?, ?, ?) |
|
|
""", (user_id, persona_summary.strip(), datetime.now().isoformat())) |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
logger.info(f"✓ Generated and stored user context for {user_id}") |
|
|
return persona_summary.strip() |
|
|
except Exception as e: |
|
|
logger.error(f"Error generating user context: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
logger.warning(f"Could not generate user context for {user_id} - using empty") |
|
|
return "" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error getting user context: {e}", exc_info=True) |
|
|
return "" |
|
|
|
|
|
async def generate_interaction_context(self, interaction_id: str, session_id: str, |
|
|
user_input: str, system_response: str, |
|
|
user_id: str = "Test_Any") -> str: |
|
|
""" |
|
|
STEP 2: Generate Interaction Context (50-token summary) |
|
|
Called after each response |
|
|
""" |
|
|
try: |
|
|
if not self.llm_router: |
|
|
return "" |
|
|
|
|
|
prompt = f"""Summarize this interaction in approximately 50 tokens: |
|
|
|
|
|
User Input: {user_input[:200]} |
|
|
System Response: {system_response[:300]} |
|
|
|
|
|
Provide a brief summary capturing the key exchange.""" |
|
|
|
|
|
try: |
|
|
summary = await self.llm_router.route_inference( |
|
|
task_type="general_reasoning", |
|
|
prompt=prompt, |
|
|
max_tokens=50, |
|
|
temperature=0.7 |
|
|
) |
|
|
|
|
|
if summary and isinstance(summary, str) and summary.strip(): |
|
|
|
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
created_at = datetime.now().isoformat() |
|
|
cursor.execute(""" |
|
|
INSERT OR REPLACE INTO interaction_contexts |
|
|
(interaction_id, session_id, user_input, system_response, interaction_summary, created_at) |
|
|
VALUES (?, ?, ?, ?, ?, ?) |
|
|
""", ( |
|
|
interaction_id, |
|
|
session_id, |
|
|
user_input[:500], |
|
|
system_response[:1000], |
|
|
summary.strip(), |
|
|
created_at |
|
|
)) |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
|
|
|
self._update_cache_with_interaction_context(session_id, summary.strip(), created_at) |
|
|
|
|
|
logger.info(f"✓ Generated interaction context for {interaction_id} and updated cache") |
|
|
return summary.strip() |
|
|
except Exception as e: |
|
|
logger.error(f"Error generating interaction context: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
return "" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in generate_interaction_context: {e}", exc_info=True) |
|
|
return "" |
|
|
|
|
|
async def generate_session_context(self, session_id: str, user_id: str = "Test_Any") -> str: |
|
|
""" |
|
|
Generate Session Context (100-token summary) at every turn |
|
|
Uses cached interaction contexts instead of querying database |
|
|
Updates both database and cache immediately |
|
|
""" |
|
|
try: |
|
|
|
|
|
session_cache_key = f"session_{session_id}" |
|
|
cached_context = self.session_cache.get(session_cache_key) |
|
|
|
|
|
if not cached_context: |
|
|
logger.warning(f"No cached context found for session {session_id}, cannot generate session context") |
|
|
return "" |
|
|
|
|
|
interaction_contexts = cached_context.get('interaction_contexts', []) |
|
|
|
|
|
if not interaction_contexts: |
|
|
logger.info(f"No interaction contexts available for session {session_id} to summarize") |
|
|
return "" |
|
|
|
|
|
|
|
|
interaction_summaries = [ic.get('summary', '') for ic in interaction_contexts if ic.get('summary')] |
|
|
|
|
|
if not interaction_summaries: |
|
|
logger.info(f"No interaction summaries available for session {session_id}") |
|
|
return "" |
|
|
|
|
|
|
|
|
if self.llm_router: |
|
|
combined_context = "\n".join(interaction_summaries) |
|
|
|
|
|
prompt = f"""Summarize this session's interactions in approximately 100 tokens: |
|
|
|
|
|
Interaction Summaries: |
|
|
{combined_context} |
|
|
|
|
|
Create a concise session summary capturing: |
|
|
- Main topics discussed |
|
|
- Key outcomes or information shared |
|
|
- User's focus areas |
|
|
|
|
|
Keep the summary concise (approximately 100 tokens).""" |
|
|
|
|
|
try: |
|
|
session_summary = await self.llm_router.route_inference( |
|
|
task_type="general_reasoning", |
|
|
prompt=prompt, |
|
|
max_tokens=100, |
|
|
temperature=0.7 |
|
|
) |
|
|
|
|
|
if session_summary and isinstance(session_summary, str) and session_summary.strip(): |
|
|
|
|
|
created_at = datetime.now().isoformat() |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
INSERT OR REPLACE INTO session_contexts |
|
|
(session_id, user_id, session_summary, created_at) |
|
|
VALUES (?, ?, ?, ?) |
|
|
""", (session_id, user_id, session_summary.strip(), created_at)) |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
|
|
|
self._update_cache_with_session_context(session_id, session_summary.strip(), created_at) |
|
|
|
|
|
logger.info(f"✓ Generated session context for {session_id} and updated cache") |
|
|
return session_summary.strip() |
|
|
except Exception as e: |
|
|
logger.error(f"Error generating session context: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
return "" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in generate_session_context: {e}", exc_info=True) |
|
|
return "" |
|
|
|
|
|
async def end_session(self, session_id: str, user_id: str = "Test_Any"): |
|
|
""" |
|
|
End session and clear cache |
|
|
Note: Session context is already generated at every turn, so this just clears cache |
|
|
""" |
|
|
try: |
|
|
|
|
|
|
|
|
session_cache_key = f"session_{session_id}" |
|
|
if session_cache_key in self.session_cache: |
|
|
del self.session_cache[session_cache_key] |
|
|
logger.info(f"✓ Cleared cache for session {session_id}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error ending session: {e}", exc_info=True) |
|
|
|
|
|
def _clear_user_cache_on_change(self, session_id: str, new_user_id: str, old_user_id: str): |
|
|
"""Clear cache entries when user changes""" |
|
|
if new_user_id != old_user_id: |
|
|
|
|
|
old_cache_key = f"{session_id}_{old_user_id}" |
|
|
if old_cache_key in self.session_cache: |
|
|
del self.session_cache[old_cache_key] |
|
|
logger.info(f"Cleared old cache for user {old_user_id} on session {session_id}") |
|
|
|
|
|
def _optimize_context(self, context: dict, relevance_classification: Optional[Dict] = None) -> dict: |
|
|
""" |
|
|
Optimize context for LLM consumption with relevance filtering support |
|
|
Format: [Session Context] + [User Context (conditional)] + [Interaction Context #N, #N-1, ...] |
|
|
|
|
|
Args: |
|
|
context: Base context dictionary |
|
|
relevance_classification: Optional relevance classification results with dynamic user context |
|
|
|
|
|
Applies smart pruning before formatting. |
|
|
""" |
|
|
|
|
|
pruned_context = self.prune_context(context, max_tokens=2000) |
|
|
|
|
|
|
|
|
session_id = pruned_context.get("session_id") |
|
|
context_mode = self.get_context_mode(session_id) |
|
|
|
|
|
interaction_contexts = pruned_context.get("interaction_contexts", []) |
|
|
session_context = pruned_context.get("session_context", {}) |
|
|
session_summary = session_context.get("summary", "") if isinstance(session_context, dict) else "" |
|
|
|
|
|
|
|
|
user_context = "" |
|
|
if context_mode == 'relevant' and relevance_classification: |
|
|
|
|
|
user_context = relevance_classification.get('combined_user_context', '') |
|
|
|
|
|
if user_context: |
|
|
logger.info( |
|
|
f"Using dynamic relevant context: {len(relevance_classification.get('relevant_summaries', []))} " |
|
|
f"sessions summarized for session {session_id}" |
|
|
) |
|
|
elif context_mode == 'relevant' and not relevance_classification: |
|
|
|
|
|
user_context = pruned_context.get("user_context", "") |
|
|
logger.debug(f"Relevant mode but no classification, using traditional user context") |
|
|
|
|
|
|
|
|
|
|
|
formatted_interactions = [] |
|
|
for idx, ic in enumerate(interaction_contexts[:10]): |
|
|
formatted_interactions.append(f"[Interaction Context #{len(interaction_contexts) - idx}]\n{ic.get('summary', '')}") |
|
|
|
|
|
|
|
|
combined_context = "" |
|
|
if session_summary: |
|
|
combined_context += f"[Session Context]\n{session_summary}\n\n" |
|
|
|
|
|
|
|
|
if user_context: |
|
|
context_label = "[Relevant User Context]" if context_mode == 'relevant' else "[User Context]" |
|
|
combined_context += f"{context_label}\n{user_context}\n\n" |
|
|
|
|
|
if formatted_interactions: |
|
|
combined_context += "\n\n".join(formatted_interactions) |
|
|
|
|
|
return { |
|
|
"session_id": pruned_context.get("session_id"), |
|
|
"user_id": pruned_context.get("user_id", "Test_Any"), |
|
|
"user_context": user_context, |
|
|
"session_context": session_context, |
|
|
"interaction_contexts": interaction_contexts, |
|
|
"combined_context": combined_context, |
|
|
"context_mode": context_mode, |
|
|
"relevance_metadata": relevance_classification.get('relevance_scores', {}) if relevance_classification else {}, |
|
|
"preferences": pruned_context.get("preferences", {}), |
|
|
"active_tasks": pruned_context.get("active_tasks", []), |
|
|
"last_activity": pruned_context.get("last_activity") |
|
|
} |
|
|
|
|
|
def _get_from_memory_cache(self, cache_key: str) -> dict: |
|
|
""" |
|
|
Retrieve context from in-memory session cache with expiration check |
|
|
""" |
|
|
cached = self.session_cache.get(cache_key) |
|
|
if not cached: |
|
|
return None |
|
|
|
|
|
|
|
|
if isinstance(cached, dict) and 'value' in cached: |
|
|
|
|
|
if self._is_cache_expired(cached): |
|
|
|
|
|
del self.session_cache[cache_key] |
|
|
logger.debug(f"Cache expired for key: {cache_key}") |
|
|
return None |
|
|
return cached.get('value') |
|
|
else: |
|
|
|
|
|
return cached |
|
|
|
|
|
def _is_cache_expired(self, cache_entry: dict) -> bool: |
|
|
""" |
|
|
Check if cache entry has expired based on TTL |
|
|
""" |
|
|
if not isinstance(cache_entry, dict): |
|
|
return True |
|
|
|
|
|
expires = cache_entry.get('expires') |
|
|
if not expires: |
|
|
return False |
|
|
|
|
|
return time.time() > expires |
|
|
|
|
|
def add_context_cache(self, key: str, value: dict, ttl: int = 3600): |
|
|
""" |
|
|
Step 2: Implement Context Caching with TTL expiration |
|
|
|
|
|
Add context to cache with expiration time. |
|
|
|
|
|
Args: |
|
|
key: Cache key |
|
|
value: Value to cache (dict) |
|
|
ttl: Time to live in seconds (default 3600 = 1 hour) |
|
|
""" |
|
|
import time |
|
|
self.session_cache[key] = { |
|
|
'value': value, |
|
|
'expires': time.time() + ttl, |
|
|
'timestamp': time.time() |
|
|
} |
|
|
logger.debug(f"Cached context for key: {key} with TTL: {ttl}s") |
|
|
|
|
|
def get_token_count(self, text: str) -> int: |
|
|
""" |
|
|
Approximate token count for text (4 characters ≈ 1 token) |
|
|
|
|
|
Args: |
|
|
text: Text to count tokens for |
|
|
|
|
|
Returns: |
|
|
Approximate token count |
|
|
""" |
|
|
if not text: |
|
|
return 0 |
|
|
|
|
|
return len(text) // 4 |
|
|
|
|
|
def prune_context(self, context: dict, max_tokens: int = 2000) -> dict: |
|
|
""" |
|
|
Step 4: Implement Smart Context Pruning |
|
|
|
|
|
Prune context to stay within token limit while keeping most recent and relevant content. |
|
|
|
|
|
Args: |
|
|
context: Context dictionary to prune |
|
|
max_tokens: Maximum token count (default 2000) |
|
|
|
|
|
Returns: |
|
|
Pruned context dictionary |
|
|
""" |
|
|
try: |
|
|
|
|
|
current_tokens = self._calculate_context_tokens(context) |
|
|
|
|
|
if current_tokens <= max_tokens: |
|
|
return context |
|
|
|
|
|
logger.info(f"Context token count ({current_tokens}) exceeds limit ({max_tokens}), pruning...") |
|
|
|
|
|
|
|
|
pruned_context = context.copy() |
|
|
|
|
|
|
|
|
interaction_contexts = pruned_context.get('interaction_contexts', []) |
|
|
session_context = pruned_context.get('session_context', {}) |
|
|
user_context = pruned_context.get('user_context', '') |
|
|
|
|
|
|
|
|
essential_tokens = ( |
|
|
self.get_token_count(user_context) + |
|
|
self.get_token_count(str(session_context)) |
|
|
) |
|
|
|
|
|
|
|
|
available_tokens = max_tokens - essential_tokens |
|
|
if available_tokens < 0: |
|
|
|
|
|
if self.get_token_count(user_context) > max_tokens // 2: |
|
|
pruned_context['user_context'] = user_context[:max_tokens * 2] |
|
|
logger.warning(f"User context too large, truncated") |
|
|
return pruned_context |
|
|
|
|
|
|
|
|
kept_interactions = [] |
|
|
current_size = 0 |
|
|
|
|
|
for interaction in interaction_contexts: |
|
|
summary = interaction.get('summary', '') |
|
|
interaction_tokens = self.get_token_count(summary) |
|
|
|
|
|
if current_size + interaction_tokens <= available_tokens: |
|
|
kept_interactions.append(interaction) |
|
|
current_size += interaction_tokens |
|
|
else: |
|
|
break |
|
|
|
|
|
pruned_context['interaction_contexts'] = kept_interactions |
|
|
|
|
|
logger.info(f"Pruned context: kept {len(kept_interactions)}/{len(interaction_contexts)} interactions, " |
|
|
f"reduced from {current_tokens} to {self._calculate_context_tokens(pruned_context)} tokens") |
|
|
|
|
|
return pruned_context |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error pruning context: {e}", exc_info=True) |
|
|
return context |
|
|
|
|
|
def _calculate_context_tokens(self, context: dict) -> int: |
|
|
"""Calculate total token count for context""" |
|
|
total = 0 |
|
|
|
|
|
|
|
|
user_context = context.get('user_context', '') |
|
|
total += self.get_token_count(str(user_context)) |
|
|
|
|
|
session_context = context.get('session_context', {}) |
|
|
if isinstance(session_context, dict): |
|
|
total += self.get_token_count(str(session_context.get('summary', ''))) |
|
|
else: |
|
|
total += self.get_token_count(str(session_context)) |
|
|
|
|
|
interaction_contexts = context.get('interaction_contexts', []) |
|
|
for interaction in interaction_contexts: |
|
|
summary = interaction.get('summary', '') |
|
|
total += self.get_token_count(str(summary)) |
|
|
|
|
|
return total |
|
|
|
|
|
async def _retrieve_from_db(self, session_id: str, user_input: str, user_id: str = "Test_Any") -> dict: |
|
|
""" |
|
|
Retrieve session context with proper user_id synchronization |
|
|
Uses transactions to ensure atomic updates of database and cache |
|
|
""" |
|
|
conn = None |
|
|
try: |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute("BEGIN TRANSACTION") |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT context_data, user_metadata, last_activity, user_id |
|
|
FROM sessions |
|
|
WHERE session_id = ? |
|
|
""", (session_id,)) |
|
|
|
|
|
row = cursor.fetchone() |
|
|
|
|
|
if row: |
|
|
context_data = json.loads(row[0]) if row[0] else {} |
|
|
user_metadata = json.loads(row[1]) if row[1] else {} |
|
|
last_activity = row[2] |
|
|
session_user_id = row[3] if len(row) > 3 else user_id |
|
|
|
|
|
|
|
|
user_changed = False |
|
|
if session_user_id != user_id: |
|
|
logger.info(f"User change detected: {session_user_id} -> {user_id} for session {session_id}") |
|
|
user_changed = True |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
UPDATE sessions |
|
|
SET user_id = ?, last_activity = ? |
|
|
WHERE session_id = ? |
|
|
""", (user_id, datetime.now().isoformat(), session_id)) |
|
|
|
|
|
|
|
|
try: |
|
|
cursor.execute(""" |
|
|
UPDATE interaction_contexts |
|
|
SET needs_refresh = 1 |
|
|
WHERE session_id = ? |
|
|
""", (session_id,)) |
|
|
except sqlite3.OperationalError: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
cursor.execute(""" |
|
|
INSERT INTO user_change_log (session_id, old_user_id, new_user_id, timestamp) |
|
|
VALUES (?, ?, ?, ?) |
|
|
""", (session_id, session_user_id, user_id, datetime.now().isoformat())) |
|
|
except sqlite3.OperationalError: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
self._clear_user_cache_on_change(session_id, user_id, session_user_id) |
|
|
|
|
|
cursor.execute("COMMIT") |
|
|
|
|
|
|
|
|
try: |
|
|
cursor.execute(""" |
|
|
SELECT interaction_summary, created_at, needs_refresh |
|
|
FROM interaction_contexts |
|
|
WHERE session_id = ? AND (needs_refresh IS NULL OR needs_refresh = 0) |
|
|
ORDER BY created_at DESC |
|
|
LIMIT 20 |
|
|
""", (session_id,)) |
|
|
except sqlite3.OperationalError: |
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT interaction_summary, created_at |
|
|
FROM interaction_contexts |
|
|
WHERE session_id = ? |
|
|
ORDER BY created_at DESC |
|
|
LIMIT 20 |
|
|
""", (session_id,)) |
|
|
|
|
|
interaction_contexts = [] |
|
|
for ic_row in cursor.fetchall(): |
|
|
|
|
|
if len(ic_row) >= 2: |
|
|
summary = ic_row[0] |
|
|
timestamp = ic_row[1] |
|
|
needs_refresh = ic_row[2] if len(ic_row) > 2 else 0 |
|
|
|
|
|
if summary and not needs_refresh: |
|
|
interaction_contexts.append({ |
|
|
"summary": summary, |
|
|
"timestamp": timestamp |
|
|
}) |
|
|
|
|
|
|
|
|
session_context_data = None |
|
|
try: |
|
|
cursor.execute(""" |
|
|
SELECT session_summary, created_at |
|
|
FROM session_contexts |
|
|
WHERE session_id = ? |
|
|
ORDER BY created_at DESC |
|
|
LIMIT 1 |
|
|
""", (session_id,)) |
|
|
sc_row = cursor.fetchone() |
|
|
if sc_row and sc_row[0]: |
|
|
session_context_data = { |
|
|
"summary": sc_row[0], |
|
|
"timestamp": sc_row[1] |
|
|
} |
|
|
except sqlite3.OperationalError: |
|
|
|
|
|
pass |
|
|
|
|
|
context = { |
|
|
"session_id": session_id, |
|
|
"user_id": user_id, |
|
|
"interaction_contexts": interaction_contexts, |
|
|
"session_context": session_context_data, |
|
|
"preferences": user_metadata.get("preferences", {}), |
|
|
"active_tasks": user_metadata.get("active_tasks", []), |
|
|
"last_activity": last_activity, |
|
|
"user_context_loaded": False, |
|
|
"user_changed": user_changed |
|
|
} |
|
|
|
|
|
conn.close() |
|
|
return context |
|
|
else: |
|
|
|
|
|
cursor.execute(""" |
|
|
INSERT INTO sessions (session_id, user_id, created_at, last_activity, context_data, user_metadata) |
|
|
VALUES (?, ?, ?, ?, ?, ?) |
|
|
""", (session_id, user_id, datetime.now().isoformat(), datetime.now().isoformat(), "{}", "{}")) |
|
|
|
|
|
cursor.execute("COMMIT") |
|
|
conn.close() |
|
|
|
|
|
return { |
|
|
"session_id": session_id, |
|
|
"user_id": user_id, |
|
|
"interaction_contexts": [], |
|
|
"session_context": None, |
|
|
"preferences": {}, |
|
|
"active_tasks": [], |
|
|
"user_context_loaded": False, |
|
|
"user_changed": False |
|
|
} |
|
|
|
|
|
except sqlite3.Error as e: |
|
|
logger.error(f"Database transaction error: {e}", exc_info=True) |
|
|
if conn: |
|
|
try: |
|
|
conn.rollback() |
|
|
except: |
|
|
pass |
|
|
conn.close() |
|
|
|
|
|
return { |
|
|
"session_id": session_id, |
|
|
"user_id": user_id, |
|
|
"interaction_contexts": [], |
|
|
"session_context": None, |
|
|
"preferences": {}, |
|
|
"active_tasks": [], |
|
|
"user_context_loaded": False, |
|
|
"error": str(e), |
|
|
"user_changed": False |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Database retrieval error: {e}", exc_info=True) |
|
|
if conn: |
|
|
try: |
|
|
conn.rollback() |
|
|
except: |
|
|
pass |
|
|
conn.close() |
|
|
|
|
|
return { |
|
|
"session_id": session_id, |
|
|
"user_id": user_id, |
|
|
"interaction_contexts": [], |
|
|
"session_context": None, |
|
|
"preferences": {}, |
|
|
"active_tasks": [], |
|
|
"user_context_loaded": False, |
|
|
"error": str(e), |
|
|
"user_changed": False |
|
|
} |
|
|
|
|
|
def _warm_memory_cache(self, cache_key: str, context: dict): |
|
|
""" |
|
|
Warm the in-memory cache with retrieved context |
|
|
Note: Use add_context_cache() instead for TTL support |
|
|
""" |
|
|
|
|
|
self.add_context_cache(cache_key, context, ttl=self.cache_config.get("ttl", 3600)) |
|
|
|
|
|
def _update_cache_with_interaction_context(self, session_id: str, interaction_summary: str, created_at: str): |
|
|
""" |
|
|
Update cache with new interaction context immediately after database update |
|
|
This keeps cache synchronized with database without requiring database queries |
|
|
""" |
|
|
session_cache_key = f"session_{session_id}" |
|
|
|
|
|
|
|
|
cached_context = self.session_cache.get(session_cache_key) |
|
|
|
|
|
if cached_context: |
|
|
|
|
|
interaction_contexts = cached_context.get('interaction_contexts', []) |
|
|
new_interaction = { |
|
|
"summary": interaction_summary, |
|
|
"timestamp": created_at |
|
|
} |
|
|
|
|
|
interaction_contexts.insert(0, new_interaction) |
|
|
interaction_contexts = interaction_contexts[:20] |
|
|
|
|
|
|
|
|
cached_context['interaction_contexts'] = interaction_contexts |
|
|
self.session_cache[session_cache_key] = cached_context |
|
|
|
|
|
logger.debug(f"Cache updated with new interaction context for session {session_id} (total: {len(interaction_contexts)})") |
|
|
else: |
|
|
|
|
|
new_context = { |
|
|
"session_id": session_id, |
|
|
"interaction_contexts": [{ |
|
|
"summary": interaction_summary, |
|
|
"timestamp": created_at |
|
|
}], |
|
|
"preferences": {}, |
|
|
"active_tasks": [], |
|
|
"user_context_loaded": False |
|
|
} |
|
|
self.session_cache[session_cache_key] = new_context |
|
|
logger.debug(f"Created new cache entry with interaction context for session {session_id}") |
|
|
|
|
|
def _update_cache_with_session_context(self, session_id: str, session_summary: str, created_at: str): |
|
|
""" |
|
|
Update cache with new session context immediately after database update |
|
|
This keeps cache synchronized with database without requiring database queries |
|
|
""" |
|
|
session_cache_key = f"session_{session_id}" |
|
|
|
|
|
|
|
|
cached_context = self.session_cache.get(session_cache_key) |
|
|
|
|
|
if cached_context: |
|
|
|
|
|
cached_context['session_context'] = { |
|
|
"summary": session_summary, |
|
|
"timestamp": created_at |
|
|
} |
|
|
self.session_cache[session_cache_key] = cached_context |
|
|
|
|
|
logger.debug(f"Cache updated with new session context for session {session_id}") |
|
|
else: |
|
|
|
|
|
new_context = { |
|
|
"session_id": session_id, |
|
|
"session_context": { |
|
|
"summary": session_summary, |
|
|
"timestamp": created_at |
|
|
}, |
|
|
"interaction_contexts": [], |
|
|
"preferences": {}, |
|
|
"active_tasks": [], |
|
|
"user_context_loaded": False |
|
|
} |
|
|
self.session_cache[session_cache_key] = new_context |
|
|
logger.debug(f"Created new cache entry with session context for session {session_id}") |
|
|
|
|
|
def _update_context(self, context: dict, user_input: str, response: str = None, user_id: str = "Test_Any") -> dict: |
|
|
""" |
|
|
Update context with deduplication and idempotency checks |
|
|
Prevents duplicate context updates using interaction hashes |
|
|
""" |
|
|
try: |
|
|
|
|
|
interaction_hash = self._generate_interaction_hash(user_input, context["session_id"], user_id) |
|
|
|
|
|
|
|
|
if self._is_duplicate_interaction(interaction_hash): |
|
|
logger.info(f"Duplicate interaction detected, skipping update: {interaction_hash[:8]}") |
|
|
return context |
|
|
|
|
|
|
|
|
current_time = datetime.now().isoformat() |
|
|
with self.transaction_manager.transaction(context["session_id"]) as cursor: |
|
|
|
|
|
cursor.execute(""" |
|
|
UPDATE sessions |
|
|
SET last_activity = ?, user_id = ? |
|
|
WHERE session_id = ? AND (last_activity IS NULL OR last_activity < ?) |
|
|
""", (current_time, user_id, context["session_id"], current_time)) |
|
|
|
|
|
|
|
|
session_context = { |
|
|
"preferences": context.get("preferences", {}), |
|
|
"active_tasks": context.get("active_tasks", []) |
|
|
} |
|
|
|
|
|
cursor.execute(""" |
|
|
INSERT OR IGNORE INTO interactions ( |
|
|
interaction_hash, |
|
|
session_id, |
|
|
user_input, |
|
|
context_snapshot, |
|
|
created_at |
|
|
) VALUES (?, ?, ?, ?, ?) |
|
|
""", ( |
|
|
interaction_hash, |
|
|
context["session_id"], |
|
|
user_input, |
|
|
json.dumps(session_context), |
|
|
current_time |
|
|
)) |
|
|
|
|
|
|
|
|
self._mark_interaction_processed(interaction_hash) |
|
|
|
|
|
|
|
|
context["last_interaction"] = user_input |
|
|
context["last_update"] = current_time |
|
|
|
|
|
logger.info(f"Context updated for session {context['session_id']} with hash {interaction_hash[:8]}") |
|
|
|
|
|
return context |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error updating context: {e}", exc_info=True) |
|
|
return context |
|
|
|
|
|
def _generate_interaction_hash(self, user_input: str, session_id: str, user_id: str) -> str: |
|
|
"""Generate unique hash for interaction to prevent duplicates""" |
|
|
|
|
|
|
|
|
normalized_input = user_input.strip() |
|
|
content = f"{session_id}:{user_id}:{normalized_input}" |
|
|
return hashlib.sha256(content.encode()).hexdigest() |
|
|
|
|
|
def _is_duplicate_interaction(self, interaction_hash: str) -> bool: |
|
|
"""Check if interaction was already processed""" |
|
|
|
|
|
if not hasattr(self, '_processed_interactions'): |
|
|
self._processed_interactions = set() |
|
|
|
|
|
|
|
|
if interaction_hash in self._processed_interactions: |
|
|
return True |
|
|
|
|
|
|
|
|
try: |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.execute("PRAGMA table_info(interactions)") |
|
|
columns = [row[1] for row in cursor.fetchall()] |
|
|
if 'interaction_hash' in columns: |
|
|
cursor.execute(""" |
|
|
SELECT COUNT(*) FROM interactions |
|
|
WHERE interaction_hash IS NOT NULL AND interaction_hash = ? |
|
|
""", (interaction_hash,)) |
|
|
count = cursor.fetchone()[0] |
|
|
conn.close() |
|
|
return count > 0 |
|
|
else: |
|
|
conn.close() |
|
|
return False |
|
|
except sqlite3.OperationalError: |
|
|
|
|
|
return interaction_hash in self._processed_interactions |
|
|
|
|
|
def _mark_interaction_processed(self, interaction_hash: str): |
|
|
"""Mark interaction as processed""" |
|
|
if not hasattr(self, '_processed_interactions'): |
|
|
self._processed_interactions = set() |
|
|
self._processed_interactions.add(interaction_hash) |
|
|
|
|
|
|
|
|
if len(self._processed_interactions) > 1000: |
|
|
|
|
|
self._processed_interactions = set(list(self._processed_interactions)[-500:]) |
|
|
|
|
|
async def manage_context_optimized(self, session_id: str, user_input: str, user_id: str = "Test_Any") -> dict: |
|
|
""" |
|
|
Efficient context management with transaction optimization |
|
|
""" |
|
|
|
|
|
session_cache_key = f"session_{session_id}" |
|
|
|
|
|
|
|
|
cached_context = self._get_from_memory_cache(session_cache_key) |
|
|
if cached_context and self._is_cache_valid(cached_context): |
|
|
logger.debug(f"Using cached context for session {session_id}") |
|
|
return cached_context |
|
|
|
|
|
|
|
|
with self.transaction_manager.transaction(session_id) as cursor: |
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT s.context_data, s.user_metadata, s.last_activity, s.user_id, |
|
|
COUNT(ic.interaction_id) as interaction_count |
|
|
FROM sessions s |
|
|
LEFT JOIN interaction_contexts ic ON s.session_id = ic.session_id |
|
|
WHERE s.session_id = ? |
|
|
GROUP BY s.session_id |
|
|
""", (session_id,)) |
|
|
|
|
|
row = cursor.fetchone() |
|
|
|
|
|
if row: |
|
|
|
|
|
context_data = json.loads(row[0] or '{}') |
|
|
user_metadata = json.loads(row[1] or '{}') |
|
|
last_activity = row[2] |
|
|
stored_user_id = row[3] or user_id |
|
|
interaction_count = row[4] or 0 |
|
|
|
|
|
|
|
|
if stored_user_id != user_id: |
|
|
self._handle_user_change_atomic(cursor, session_id, stored_user_id, user_id) |
|
|
|
|
|
|
|
|
interaction_contexts = self._get_interaction_contexts_atomic(cursor, session_id) |
|
|
|
|
|
else: |
|
|
|
|
|
cursor.execute(""" |
|
|
INSERT INTO sessions (session_id, user_id, created_at, last_activity, context_data, user_metadata) |
|
|
VALUES (?, ?, datetime('now'), datetime('now'), '{}', '{}') |
|
|
""", (session_id, user_id)) |
|
|
|
|
|
context_data = {} |
|
|
user_metadata = {} |
|
|
interaction_contexts = [] |
|
|
interaction_count = 0 |
|
|
|
|
|
|
|
|
user_context = await self._load_user_context_async(user_id) |
|
|
|
|
|
|
|
|
final_context = { |
|
|
"session_id": session_id, |
|
|
"user_id": user_id, |
|
|
"interaction_contexts": interaction_contexts, |
|
|
"user_context": user_context, |
|
|
"preferences": user_metadata.get("preferences", {}), |
|
|
"active_tasks": user_metadata.get("active_tasks", []), |
|
|
"interaction_count": interaction_count, |
|
|
"cache_timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
|
|
|
self._warm_memory_cache(session_cache_key, final_context) |
|
|
|
|
|
return self._optimize_context(final_context) |
|
|
|
|
|
def _handle_user_change_atomic(self, cursor, session_id: str, old_user_id: str, new_user_id: str): |
|
|
"""Handle user change within transaction""" |
|
|
logger.info(f"Handling user change in transaction: {old_user_id} -> {new_user_id}") |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
UPDATE sessions |
|
|
SET user_id = ?, last_activity = datetime('now') |
|
|
WHERE session_id = ? |
|
|
""", (new_user_id, session_id)) |
|
|
|
|
|
|
|
|
try: |
|
|
cursor.execute(""" |
|
|
INSERT INTO user_change_log (session_id, old_user_id, new_user_id, timestamp) |
|
|
VALUES (?, ?, ?, datetime('now')) |
|
|
""", (session_id, old_user_id, new_user_id)) |
|
|
except sqlite3.OperationalError: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
cursor.execute(""" |
|
|
UPDATE interaction_contexts |
|
|
SET needs_refresh = 1 |
|
|
WHERE session_id = ? |
|
|
""", (session_id,)) |
|
|
except sqlite3.OperationalError: |
|
|
|
|
|
pass |
|
|
|
|
|
def _get_interaction_contexts_atomic(self, cursor, session_id: str, limit: int = 20): |
|
|
"""Get interaction contexts within transaction""" |
|
|
try: |
|
|
cursor.execute(""" |
|
|
SELECT interaction_summary, created_at, interaction_id |
|
|
FROM interaction_contexts |
|
|
WHERE session_id = ? AND (needs_refresh IS NULL OR needs_refresh = 0) |
|
|
ORDER BY created_at DESC |
|
|
LIMIT ? |
|
|
""", (session_id, limit)) |
|
|
except sqlite3.OperationalError: |
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT interaction_summary, created_at, interaction_id |
|
|
FROM interaction_contexts |
|
|
WHERE session_id = ? |
|
|
ORDER BY created_at DESC |
|
|
LIMIT ? |
|
|
""", (session_id, limit)) |
|
|
|
|
|
contexts = [] |
|
|
for row in cursor.fetchall(): |
|
|
if row[0]: |
|
|
contexts.append({ |
|
|
"summary": row[0], |
|
|
"timestamp": row[1], |
|
|
"id": row[2] if len(row) > 2 else None |
|
|
}) |
|
|
|
|
|
return contexts |
|
|
|
|
|
async def _load_user_context_async(self, user_id: str): |
|
|
"""Load user context asynchronously to avoid blocking""" |
|
|
try: |
|
|
|
|
|
user_cache_key = f"user_{user_id}" |
|
|
cached = self._get_from_memory_cache(user_cache_key) |
|
|
if cached: |
|
|
return cached.get("user_context", "") |
|
|
|
|
|
|
|
|
return await self.get_user_context(user_id) |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading user context: {e}") |
|
|
return "" |
|
|
|
|
|
def _is_cache_valid(self, cached_context: dict, max_age_seconds: int = 60) -> bool: |
|
|
"""Check if cached context is still valid""" |
|
|
if not cached_context: |
|
|
return False |
|
|
|
|
|
cache_timestamp = cached_context.get("cache_timestamp") |
|
|
if not cache_timestamp: |
|
|
return False |
|
|
|
|
|
try: |
|
|
cache_time = datetime.fromisoformat(cache_timestamp) |
|
|
age = (datetime.now() - cache_time).total_seconds() |
|
|
return age < max_age_seconds |
|
|
except: |
|
|
return False |
|
|
|
|
|
def invalidate_session_cache(self, session_id: str): |
|
|
""" |
|
|
Invalidate cached context for a session to force fresh retrieval |
|
|
Only affects cache management - does not change application functionality |
|
|
""" |
|
|
session_cache_key = f"session_{session_id}" |
|
|
if session_cache_key in self.session_cache: |
|
|
del self.session_cache[session_cache_key] |
|
|
logger.info(f"Cache invalidated for session {session_id} to ensure fresh context retrieval") |
|
|
|
|
|
def optimize_database_indexes(self): |
|
|
"""Create database indexes for better query performance""" |
|
|
try: |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
indexes = [ |
|
|
"CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id)", |
|
|
"CREATE INDEX IF NOT EXISTS idx_sessions_last_activity ON sessions(last_activity)", |
|
|
"CREATE INDEX IF NOT EXISTS idx_interactions_session_id ON interactions(session_id)", |
|
|
"CREATE INDEX IF NOT EXISTS idx_interaction_contexts_session_id ON interaction_contexts(session_id)", |
|
|
"CREATE INDEX IF NOT EXISTS idx_interaction_contexts_created_at ON interaction_contexts(created_at)", |
|
|
"CREATE INDEX IF NOT EXISTS idx_user_change_log_session_id ON user_change_log(session_id)", |
|
|
"CREATE INDEX IF NOT EXISTS idx_user_contexts_updated_at ON user_contexts(updated_at)" |
|
|
] |
|
|
|
|
|
for index in indexes: |
|
|
try: |
|
|
cursor.execute(index) |
|
|
except sqlite3.OperationalError as e: |
|
|
|
|
|
logger.debug(f"Skipping index creation (table may not exist): {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
cursor.execute("ANALYZE") |
|
|
except sqlite3.OperationalError: |
|
|
|
|
|
pass |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
logger.info("✓ Database indexes optimized successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error optimizing database indexes: {e}", exc_info=True) |
|
|
|
|
|
def set_context_mode(self, session_id: str, mode: str, user_id: str = "Test_Any"): |
|
|
""" |
|
|
Set context mode for session (fresh or relevant) |
|
|
|
|
|
Args: |
|
|
session_id: Session identifier |
|
|
mode: 'fresh' (no user context) or 'relevant' (only relevant context) |
|
|
user_id: User identifier |
|
|
|
|
|
Returns: |
|
|
bool: True if successful, False otherwise |
|
|
""" |
|
|
try: |
|
|
import time |
|
|
|
|
|
|
|
|
if mode not in ['fresh', 'relevant']: |
|
|
logger.warning(f"Invalid context mode '{mode}', defaulting to 'fresh'") |
|
|
mode = 'fresh' |
|
|
|
|
|
|
|
|
cache_key = f"session_{session_id}" |
|
|
cached_context = self._get_from_memory_cache(cache_key) |
|
|
|
|
|
if not cached_context: |
|
|
cached_context = { |
|
|
'session_id': session_id, |
|
|
'user_id': user_id, |
|
|
'preferences': {}, |
|
|
'context_mode': mode, |
|
|
'context_mode_timestamp': time.time() |
|
|
} |
|
|
else: |
|
|
|
|
|
cached_context['context_mode'] = mode |
|
|
cached_context['context_mode_timestamp'] = time.time() |
|
|
cached_context['user_id'] = user_id |
|
|
|
|
|
|
|
|
self.add_context_cache(cache_key, cached_context, ttl=3600) |
|
|
|
|
|
logger.info(f"Context mode set to '{mode}' for session {session_id} (user: {user_id})") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error setting context mode: {e}", exc_info=True) |
|
|
return False |
|
|
|
|
|
def get_context_mode(self, session_id: str) -> str: |
|
|
""" |
|
|
Get current context mode for session |
|
|
|
|
|
Args: |
|
|
session_id: Session identifier |
|
|
|
|
|
Returns: |
|
|
str: 'fresh' or 'relevant' (default: 'fresh') |
|
|
""" |
|
|
try: |
|
|
cache_key = f"session_{session_id}" |
|
|
cached_context = self._get_from_memory_cache(cache_key) |
|
|
|
|
|
if cached_context: |
|
|
mode = cached_context.get('context_mode', 'fresh') |
|
|
|
|
|
if mode in ['fresh', 'relevant']: |
|
|
return mode |
|
|
else: |
|
|
logger.warning(f"Invalid cached mode '{mode}', resetting to 'fresh'") |
|
|
cached_context['context_mode'] = 'fresh' |
|
|
import time |
|
|
cached_context['context_mode_timestamp'] = time.time() |
|
|
self.add_context_cache(cache_key, cached_context, ttl=3600) |
|
|
return 'fresh' |
|
|
|
|
|
|
|
|
return 'fresh' |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error getting context mode: {e}", exc_info=True) |
|
|
return 'fresh' |
|
|
|
|
|
async def get_all_user_sessions(self, user_id: str) -> List[Dict]: |
|
|
""" |
|
|
Fetch all session contexts for a user (for relevance classification) |
|
|
|
|
|
Performance: Single database query with JOIN |
|
|
|
|
|
Args: |
|
|
user_id: User identifier |
|
|
|
|
|
Returns: |
|
|
List of session context dictionaries with summaries and interactions |
|
|
""" |
|
|
try: |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute(""" |
|
|
SELECT DISTINCT |
|
|
sc.session_id, |
|
|
sc.session_summary, |
|
|
sc.created_at, |
|
|
(SELECT GROUP_CONCAT(ic.interaction_summary, ' ||| ') |
|
|
FROM interaction_contexts ic |
|
|
WHERE ic.session_id = sc.session_id |
|
|
ORDER BY ic.created_at DESC |
|
|
LIMIT 10) as recent_interactions |
|
|
FROM session_contexts sc |
|
|
JOIN sessions s ON sc.session_id = s.session_id |
|
|
WHERE s.user_id = ? |
|
|
ORDER BY sc.created_at DESC |
|
|
LIMIT 50 |
|
|
""", (user_id,)) |
|
|
|
|
|
sessions = [] |
|
|
for row in cursor.fetchall(): |
|
|
session_id, session_summary, created_at, interactions_str = row |
|
|
|
|
|
|
|
|
interaction_list = [] |
|
|
if interactions_str: |
|
|
for summary in interactions_str.split(' ||| '): |
|
|
if summary.strip(): |
|
|
interaction_list.append({ |
|
|
'summary': summary.strip(), |
|
|
'timestamp': created_at |
|
|
}) |
|
|
|
|
|
sessions.append({ |
|
|
'session_id': session_id, |
|
|
'summary': session_summary or '', |
|
|
'created_at': created_at, |
|
|
'interaction_contexts': interaction_list |
|
|
}) |
|
|
|
|
|
conn.close() |
|
|
logger.info(f"Fetched {len(sessions)} sessions for user {user_id}") |
|
|
return sessions |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error fetching user sessions: {e}", exc_info=True) |
|
|
return [] |
|
|
|
|
|
def _extract_entities(self, context: dict) -> list: |
|
|
""" |
|
|
Extract essential entities from context |
|
|
""" |
|
|
|
|
|
return [] |
|
|
|
|
|
def _generate_summary(self, context: dict) -> str: |
|
|
""" |
|
|
Generate conversation summary |
|
|
""" |
|
|
|
|
|
return "" |
|
|
|
|
|
def get_or_create_session_context(self, session_id: str, user_id: Optional[str] = None) -> Dict: |
|
|
"""Enhanced context retrieval with caching""" |
|
|
import time |
|
|
|
|
|
|
|
|
if session_id in self._session_cache: |
|
|
cache_entry = self._session_cache[session_id] |
|
|
if time.time() - cache_entry['timestamp'] < 300: |
|
|
logger.debug(f"Cache hit for session {session_id}") |
|
|
return cache_entry['context'] |
|
|
|
|
|
|
|
|
conn = None |
|
|
try: |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
query = """ |
|
|
SELECT |
|
|
s.context_data, |
|
|
s.user_metadata, |
|
|
s.last_activity, |
|
|
u.persona_summary, |
|
|
ic.interaction_summary |
|
|
FROM sessions s |
|
|
LEFT JOIN user_contexts u ON s.user_id = u.user_id |
|
|
LEFT JOIN interaction_contexts ic ON s.session_id = ic.session_id |
|
|
WHERE s.session_id = ? |
|
|
ORDER BY ic.created_at DESC |
|
|
LIMIT 10 |
|
|
""" |
|
|
|
|
|
cursor.execute(query, (session_id,)) |
|
|
results = cursor.fetchall() |
|
|
|
|
|
|
|
|
context = self._build_context_from_results(results, session_id, user_id) |
|
|
|
|
|
|
|
|
self._session_cache[session_id] = { |
|
|
'context': context, |
|
|
'timestamp': time.time() |
|
|
} |
|
|
|
|
|
return context |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in get_or_create_session_context: {e}", exc_info=True) |
|
|
|
|
|
return { |
|
|
"session_id": session_id, |
|
|
"user_id": user_id or "Test_Any", |
|
|
"interaction_contexts": [], |
|
|
"session_context": None, |
|
|
"preferences": {}, |
|
|
"active_tasks": [], |
|
|
"user_context_loaded": False |
|
|
} |
|
|
finally: |
|
|
if conn: |
|
|
conn.close() |
|
|
|
|
|
def _build_context_from_results(self, results: list, session_id: str, user_id: Optional[str]) -> Dict: |
|
|
"""Build context dictionary from batch query results""" |
|
|
context = { |
|
|
"session_id": session_id, |
|
|
"user_id": user_id or "Test_Any", |
|
|
"interaction_contexts": [], |
|
|
"session_context": None, |
|
|
"user_context": "", |
|
|
"preferences": {}, |
|
|
"active_tasks": [], |
|
|
"user_context_loaded": False |
|
|
} |
|
|
|
|
|
if not results: |
|
|
return context |
|
|
|
|
|
|
|
|
first_row = results[0] |
|
|
if first_row[0]: |
|
|
try: |
|
|
session_data = json.loads(first_row[0]) |
|
|
context["preferences"] = session_data.get("preferences", {}) |
|
|
context["active_tasks"] = session_data.get("active_tasks", []) |
|
|
except: |
|
|
pass |
|
|
|
|
|
if first_row[1]: |
|
|
try: |
|
|
user_metadata = json.loads(first_row[1]) |
|
|
context["preferences"].update(user_metadata.get("preferences", {})) |
|
|
except: |
|
|
pass |
|
|
|
|
|
context["last_activity"] = first_row[2] |
|
|
|
|
|
if first_row[3]: |
|
|
context["user_context"] = first_row[3] |
|
|
context["user_context_loaded"] = True |
|
|
|
|
|
|
|
|
seen_interactions = set() |
|
|
for row in results: |
|
|
if row[4]: |
|
|
|
|
|
if row[4] not in seen_interactions: |
|
|
seen_interactions.add(row[4]) |
|
|
context["interaction_contexts"].append({ |
|
|
"summary": row[4], |
|
|
"timestamp": None |
|
|
}) |
|
|
|
|
|
return context |
|
|
|