Commit
·
8f4d405
0
Parent(s):
Initial commit: Research AI Assistant API
Browse files- .gitignore +92 -0
- Dockerfile +42 -0
- Dockerfile.flask +39 -0
- README.md +395 -0
- config.py +63 -0
- flask_api_standalone.py +257 -0
- requirements.txt +89 -0
- src/__init__.py +15 -0
- src/agents/__init__.py +21 -0
- src/agents/intent_agent.py +301 -0
- src/agents/safety_agent.py +453 -0
- src/agents/skills_identification_agent.py +547 -0
- src/agents/synthesis_agent.py +735 -0
- src/config.py +42 -0
- src/context_manager.py +1695 -0
- src/context_relevance_classifier.py +491 -0
- src/database.py +97 -0
- src/event_handlers.py +125 -0
- src/llm_router.py +471 -0
- src/local_model_loader.py +322 -0
- src/mobile_handlers.py +169 -0
- src/models_config.py +43 -0
- src/orchestrator_engine.py +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
venv/
|
| 25 |
+
env/
|
| 26 |
+
ENV/
|
| 27 |
+
.venv/
|
| 28 |
+
|
| 29 |
+
# IDEs
|
| 30 |
+
.vscode/
|
| 31 |
+
.idea/
|
| 32 |
+
*.swp
|
| 33 |
+
*.swo
|
| 34 |
+
*~
|
| 35 |
+
.DS_Store
|
| 36 |
+
|
| 37 |
+
# Environment variables
|
| 38 |
+
.env
|
| 39 |
+
.env.local
|
| 40 |
+
|
| 41 |
+
# Database files
|
| 42 |
+
*.db
|
| 43 |
+
*.sqlite
|
| 44 |
+
*.sqlite3
|
| 45 |
+
sessions.db
|
| 46 |
+
embeddings.faiss
|
| 47 |
+
embeddings.faiss.index
|
| 48 |
+
|
| 49 |
+
# Logs
|
| 50 |
+
*.log
|
| 51 |
+
logs/
|
| 52 |
+
*.log.*
|
| 53 |
+
|
| 54 |
+
# Cache
|
| 55 |
+
.cache/
|
| 56 |
+
__pycache__/
|
| 57 |
+
*.pyc
|
| 58 |
+
.pytest_cache/
|
| 59 |
+
.mypy_cache/
|
| 60 |
+
|
| 61 |
+
# Model cache (optional - uncomment if you don't want to commit model cache)
|
| 62 |
+
# models/
|
| 63 |
+
# .huggingface/
|
| 64 |
+
|
| 65 |
+
# Temporary files
|
| 66 |
+
tmp/
|
| 67 |
+
temp/
|
| 68 |
+
*.tmp
|
| 69 |
+
|
| 70 |
+
# OS files
|
| 71 |
+
Thumbs.db
|
| 72 |
+
desktop.ini
|
| 73 |
+
|
| 74 |
+
# Jupyter Notebooks
|
| 75 |
+
.ipynb_checkpoints/
|
| 76 |
+
|
| 77 |
+
# Distribution / packaging
|
| 78 |
+
*.zip
|
| 79 |
+
*.tar.gz
|
| 80 |
+
*.rar
|
| 81 |
+
|
| 82 |
+
# Testing
|
| 83 |
+
.coverage
|
| 84 |
+
htmlcov/
|
| 85 |
+
.tox/
|
| 86 |
+
.pytest_cache/
|
| 87 |
+
|
| 88 |
+
# Type checking
|
| 89 |
+
.mypy_cache/
|
| 90 |
+
.dmypy.json
|
| 91 |
+
dmypy.json
|
| 92 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for Hugging Face Spaces
|
| 2 |
+
# Based on HF Spaces Docker SDK documentation: https://huggingface.co/docs/hub/spaces-sdks-docker
|
| 3 |
+
|
| 4 |
+
FROM python:3.10-slim
|
| 5 |
+
|
| 6 |
+
# Set working directory
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
# Install system dependencies
|
| 10 |
+
RUN apt-get update && apt-get install -y \
|
| 11 |
+
gcc \
|
| 12 |
+
g++ \
|
| 13 |
+
cmake \
|
| 14 |
+
libopenblas-dev \
|
| 15 |
+
libomp-dev \
|
| 16 |
+
curl \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
# Copy requirements file first (for better caching)
|
| 20 |
+
COPY requirements.txt .
|
| 21 |
+
|
| 22 |
+
# Install Python dependencies
|
| 23 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 24 |
+
pip install --no-cache-dir -r requirements.txt
|
| 25 |
+
|
| 26 |
+
# Copy application code
|
| 27 |
+
COPY . .
|
| 28 |
+
|
| 29 |
+
# Expose port 7860 (HF Spaces standard)
|
| 30 |
+
EXPOSE 7860
|
| 31 |
+
|
| 32 |
+
# Set environment variables
|
| 33 |
+
ENV PYTHONUNBUFFERED=1
|
| 34 |
+
ENV PORT=7860
|
| 35 |
+
|
| 36 |
+
# Health check
|
| 37 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=120s --retries=3 \
|
| 38 |
+
CMD curl -f http://localhost:7860/api/health || exit 1
|
| 39 |
+
|
| 40 |
+
# Run Flask application on port 7860
|
| 41 |
+
CMD ["python", "flask_api_standalone.py"]
|
| 42 |
+
|
Dockerfile.flask
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
# Set working directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install system dependencies
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
gcc \
|
| 9 |
+
g++ \
|
| 10 |
+
cmake \
|
| 11 |
+
libopenblas-dev \
|
| 12 |
+
libomp-dev \
|
| 13 |
+
curl \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
# Copy requirements file
|
| 17 |
+
COPY requirements.txt .
|
| 18 |
+
|
| 19 |
+
# Install Python dependencies
|
| 20 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 21 |
+
|
| 22 |
+
# Copy application code
|
| 23 |
+
COPY . .
|
| 24 |
+
|
| 25 |
+
# Expose port 7860 (HF Spaces standard)
|
| 26 |
+
EXPOSE 7860
|
| 27 |
+
|
| 28 |
+
# Set environment variables
|
| 29 |
+
ENV PYTHONUNBUFFERED=1
|
| 30 |
+
ENV PORT=7860
|
| 31 |
+
|
| 32 |
+
# Health check
|
| 33 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=120s --retries=3 \
|
| 34 |
+
CMD curl -f http://localhost:7860/api/health || exit 1
|
| 35 |
+
|
| 36 |
+
# Run Flask application
|
| 37 |
+
# Note: For Flask-only deployment, use this Dockerfile with README_FLASK_API.md
|
| 38 |
+
CMD ["python", "flask_api_standalone.py"]
|
| 39 |
+
|
README.md
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: AI Research Assistant MVP
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
license: apache-2.0
|
| 10 |
+
tags:
|
| 11 |
+
- ai
|
| 12 |
+
- chatbot
|
| 13 |
+
- research
|
| 14 |
+
- education
|
| 15 |
+
- transformers
|
| 16 |
+
models:
|
| 17 |
+
- mistralai/Mistral-7B-Instruct-v0.2
|
| 18 |
+
- sentence-transformers/all-MiniLM-L6-v2
|
| 19 |
+
- cardiffnlp/twitter-roberta-base-emotion
|
| 20 |
+
- unitary/unbiased-toxic-roberta
|
| 21 |
+
datasets:
|
| 22 |
+
- wikipedia
|
| 23 |
+
- commoncrawl
|
| 24 |
+
base_path: research-assistant
|
| 25 |
+
hf_oauth: true
|
| 26 |
+
hf_token: true
|
| 27 |
+
disable_embedding: false
|
| 28 |
+
duplicated_from: null
|
| 29 |
+
extra_gated_prompt: null
|
| 30 |
+
extra_gated_fields: {}
|
| 31 |
+
gated: false
|
| 32 |
+
public: true
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
# AI Research Assistant - MVP
|
| 36 |
+
|
| 37 |
+
<div align="center">
|
| 38 |
+
|
| 39 |
+

|
| 40 |
+

|
| 41 |
+

|
| 42 |
+

|
| 43 |
+
|
| 44 |
+
**Academic-grade AI assistant with transparent reasoning and mobile-optimized interface**
|
| 45 |
+
|
| 46 |
+
[](https://huggingface.co/spaces/your-username/research-assistant)
|
| 47 |
+
[](https://github.com/your-org/research-assistant/wiki)
|
| 48 |
+
|
| 49 |
+
</div>
|
| 50 |
+
|
| 51 |
+
## 🎯 Overview
|
| 52 |
+
|
| 53 |
+
This MVP demonstrates an intelligent research assistant framework featuring **transparent reasoning chains**, **specialized agent architecture**, and **mobile-first design**. Built for Hugging Face Spaces with NVIDIA T4 GPU acceleration for local model inference.
|
| 54 |
+
|
| 55 |
+
### Key Differentiators
|
| 56 |
+
- **🔍 Transparent Reasoning**: Watch the AI think step-by-step with Chain of Thought
|
| 57 |
+
- **🧠 Specialized Agents**: Multiple AI models working together for optimal performance
|
| 58 |
+
- **📱 Mobile-First**: Optimized for seamless mobile web experience
|
| 59 |
+
- **🎓 Academic Focus**: Designed for research and educational use cases
|
| 60 |
+
|
| 61 |
+
## 🚀 Quick Start
|
| 62 |
+
|
| 63 |
+
### Option 1: Use Our Demo
|
| 64 |
+
Visit our live demo on Hugging Face Spaces:
|
| 65 |
+
```bash
|
| 66 |
+
https://huggingface.co/spaces/your-username/research-assistant
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### Option 2: Deploy Your Own Instance
|
| 70 |
+
|
| 71 |
+
#### Prerequisites
|
| 72 |
+
- Hugging Face account with [write token](https://huggingface.co/settings/tokens)
|
| 73 |
+
- Basic understanding of Hugging Face Spaces
|
| 74 |
+
|
| 75 |
+
#### Deployment Steps
|
| 76 |
+
|
| 77 |
+
1. **Fork this space** using the Hugging Face UI
|
| 78 |
+
2. **Add your HF token** in Space Settings:
|
| 79 |
+
- Go to your Space → Settings → Repository secrets
|
| 80 |
+
- Add `HF_TOKEN` with your Hugging Face token
|
| 81 |
+
3. **The space will auto-build** (takes 5-10 minutes)
|
| 82 |
+
|
| 83 |
+
#### Manual Build (Advanced)
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
# Clone the repository
|
| 87 |
+
git clone https://huggingface.co/spaces/your-username/research-assistant
|
| 88 |
+
cd research-assistant
|
| 89 |
+
|
| 90 |
+
# Install dependencies
|
| 91 |
+
pip install -r requirements.txt
|
| 92 |
+
|
| 93 |
+
# Set up environment
|
| 94 |
+
export HF_TOKEN="your_hugging_face_token_here"
|
| 95 |
+
|
| 96 |
+
# Launch the application (multiple options)
|
| 97 |
+
python main.py # Full integration with error handling
|
| 98 |
+
python launch.py # Simple launcher
|
| 99 |
+
python app.py # UI-only mode
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
## 📁 Integration Structure
|
| 103 |
+
|
| 104 |
+
The MVP now includes complete integration files for deployment:
|
| 105 |
+
|
| 106 |
+
```
|
| 107 |
+
├── main.py # 🎯 Main integration entry point
|
| 108 |
+
├── launch.py # 🚀 Simple launcher for HF Spaces
|
| 109 |
+
├── app.py # 📱 Mobile-optimized UI
|
| 110 |
+
├── requirements.txt # 📦 Dependencies
|
| 111 |
+
└── src/
|
| 112 |
+
├── __init__.py # 📦 Package initialization
|
| 113 |
+
├── database.py # 🗄️ SQLite database management
|
| 114 |
+
├── event_handlers.py # 🔗 UI event integration
|
| 115 |
+
├── config.py # ⚙️ Configuration
|
| 116 |
+
├── llm_router.py # 🤖 LLM routing
|
| 117 |
+
├── orchestrator_engine.py # 🎭 Request orchestration
|
| 118 |
+
├── context_manager.py # 🧠 Context management
|
| 119 |
+
├── mobile_handlers.py # 📱 Mobile UX handlers
|
| 120 |
+
└── agents/
|
| 121 |
+
├── __init__.py # 🤖 Agents package
|
| 122 |
+
├── intent_agent.py # 🎯 Intent recognition
|
| 123 |
+
├── synthesis_agent.py # ✨ Response synthesis
|
| 124 |
+
└── safety_agent.py # 🛡️ Safety checking
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### Key Features:
|
| 128 |
+
- **🔄 Graceful Degradation**: Falls back to mock mode if components fail
|
| 129 |
+
- **📱 Mobile-First**: Optimized for mobile devices and small screens
|
| 130 |
+
- **🗄️ Database Ready**: SQLite integration with session management
|
| 131 |
+
- **🔗 Event Handling**: Complete UI-to-backend integration
|
| 132 |
+
- **⚡ Error Recovery**: Robust error handling throughout
|
| 133 |
+
|
| 134 |
+
## 🏗️ Architecture
|
| 135 |
+
|
| 136 |
+
```
|
| 137 |
+
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
| 138 |
+
│ Mobile Web │ ── │ ORCHESTRATOR │ ── │ AGENT SWARM │
|
| 139 |
+
│ Interface │ │ (Core Engine) │ │ (5 Specialists)│
|
| 140 |
+
└─────────────────┘ └���─────────────────┘ └─────────────────┘
|
| 141 |
+
│ │ │
|
| 142 |
+
└─────────────────────────┼────────────────────────┘
|
| 143 |
+
│
|
| 144 |
+
┌─────────────────────────────┐
|
| 145 |
+
│ PERSISTENCE LAYER │
|
| 146 |
+
│ (SQLite + FAISS Lite) │
|
| 147 |
+
└─────────────────────────────┘
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### Core Components
|
| 151 |
+
|
| 152 |
+
| Component | Purpose | Technology |
|
| 153 |
+
|-----------|---------|------------|
|
| 154 |
+
| **Orchestrator** | Main coordination engine | Python + Async |
|
| 155 |
+
| **Intent Recognition** | Understand user goals | RoBERTa-base + CoT |
|
| 156 |
+
| **Context Manager** | Session memory & recall | FAISS + SQLite |
|
| 157 |
+
| **Response Synthesis** | Generate final answers | Mistral-7B |
|
| 158 |
+
| **Safety Checker** | Content moderation | Unbiased-Toxic-RoBERTa |
|
| 159 |
+
| **Research Agent** | Information gathering | Web search + analysis |
|
| 160 |
+
|
| 161 |
+
## 💡 Usage Examples
|
| 162 |
+
|
| 163 |
+
### Basic Research Query
|
| 164 |
+
```
|
| 165 |
+
User: "Explain quantum entanglement in simple terms"
|
| 166 |
+
|
| 167 |
+
Assistant:
|
| 168 |
+
1. 🤔 [Reasoning] Breaking down quantum physics concepts...
|
| 169 |
+
2. 🔍 [Research] Gathering latest explanations...
|
| 170 |
+
3. ✍️ [Synthesis] Creating simplified explanation...
|
| 171 |
+
|
| 172 |
+
[Final Response]: Quantum entanglement is when two particles become linked...
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
### Technical Analysis
|
| 176 |
+
```
|
| 177 |
+
User: "Compare transformer models for text classification"
|
| 178 |
+
|
| 179 |
+
Assistant:
|
| 180 |
+
1. 🏷️ [Intent] Identifying technical comparison request
|
| 181 |
+
2. 📊 [Analysis] Evaluating BERT vs RoBERTa vs DistilBERT
|
| 182 |
+
3. 📈 [Synthesis] Creating comparison table with metrics...
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
## ⚙️ Configuration
|
| 186 |
+
|
| 187 |
+
### Environment Variables
|
| 188 |
+
|
| 189 |
+
```python
|
| 190 |
+
# Required
|
| 191 |
+
HF_TOKEN="your_hugging_face_token"
|
| 192 |
+
|
| 193 |
+
# Optional
|
| 194 |
+
MAX_WORKERS=2
|
| 195 |
+
CACHE_TTL=3600
|
| 196 |
+
DEFAULT_MODEL="mistralai/Mistral-7B-Instruct-v0.2"
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
### Model Configuration
|
| 200 |
+
|
| 201 |
+
The system uses multiple specialized models:
|
| 202 |
+
|
| 203 |
+
| Task | Model | Purpose |
|
| 204 |
+
|------|-------|---------|
|
| 205 |
+
| Primary Reasoning | `mistralai/Mistral-7B-Instruct-v0.2` | General responses |
|
| 206 |
+
| Embeddings | `sentence-transformers/all-MiniLM-L6-v2` | Semantic search |
|
| 207 |
+
| Intent Classification | `cardiffnlp/twitter-roberta-base-emotion` | User goal detection |
|
| 208 |
+
| Safety Checking | `unitary/unbiased-toxic-roberta` | Content moderation |
|
| 209 |
+
|
| 210 |
+
## 📱 Mobile Optimization
|
| 211 |
+
|
| 212 |
+
### Key Mobile Features
|
| 213 |
+
- **Touch-friendly** interface (44px+ touch targets)
|
| 214 |
+
- **Progressive Web App** capabilities
|
| 215 |
+
- **Offline functionality** for cached sessions
|
| 216 |
+
- **Reduced data usage** with optimized responses
|
| 217 |
+
- **Keyboard-aware** layout adjustments
|
| 218 |
+
|
| 219 |
+
### Supported Devices
|
| 220 |
+
- ✅ Smartphones (iOS/Android)
|
| 221 |
+
- ✅ Tablets
|
| 222 |
+
- ✅ Desktop browsers
|
| 223 |
+
- ✅ Screen readers (accessibility)
|
| 224 |
+
|
| 225 |
+
## 🛠️ Development
|
| 226 |
+
|
| 227 |
+
### Project Structure
|
| 228 |
+
```
|
| 229 |
+
research-assistant/
|
| 230 |
+
├── app.py # Main Gradio application
|
| 231 |
+
├── requirements.txt # Dependencies
|
| 232 |
+
├── Dockerfile # Container configuration
|
| 233 |
+
├── src/
|
| 234 |
+
│ ├── orchestrator.py # Core orchestration engine
|
| 235 |
+
│ ├── agents/ # Specialized agent modules
|
| 236 |
+
│ ├── llm_router.py # Multi-model routing
|
| 237 |
+
│ └── mobile_ux.py # Mobile optimizations
|
| 238 |
+
├── tests/ # Test suites
|
| 239 |
+
└── docs/ # Documentation
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
### Adding New Agents
|
| 243 |
+
|
| 244 |
+
1. Create agent module in `src/agents/`
|
| 245 |
+
2. Implement agent protocol:
|
| 246 |
+
```python
|
| 247 |
+
class YourNewAgent:
|
| 248 |
+
async def execute(self, user_input: str, context: dict) -> dict:
|
| 249 |
+
# Your agent logic here
|
| 250 |
+
return {
|
| 251 |
+
"result": processed_output,
|
| 252 |
+
"confidence": 0.95,
|
| 253 |
+
"metadata": {}
|
| 254 |
+
}
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
3. Register agent in orchestrator configuration
|
| 258 |
+
|
| 259 |
+
## 🧪 Testing
|
| 260 |
+
|
| 261 |
+
### Run Test Suite
|
| 262 |
+
```bash
|
| 263 |
+
# Install test dependencies
|
| 264 |
+
pip install -r requirements.txt
|
| 265 |
+
|
| 266 |
+
# Run all tests
|
| 267 |
+
pytest tests/ -v
|
| 268 |
+
|
| 269 |
+
# Run specific test categories
|
| 270 |
+
pytest tests/test_agents.py -v
|
| 271 |
+
pytest tests/test_mobile_ux.py -v
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
### Test Coverage
|
| 275 |
+
- ✅ Agent functionality
|
| 276 |
+
- ✅ Mobile UX components
|
| 277 |
+
- ✅ LLM routing logic
|
| 278 |
+
- ✅ Error handling
|
| 279 |
+
- ✅ Performance benchmarks
|
| 280 |
+
|
| 281 |
+
## 🚨 Troubleshooting
|
| 282 |
+
|
| 283 |
+
### Common Build Issues
|
| 284 |
+
|
| 285 |
+
| Issue | Solution |
|
| 286 |
+
|-------|----------|
|
| 287 |
+
| **HF_TOKEN not found** | Add token in Space Settings → Secrets |
|
| 288 |
+
| **Build timeout** | Reduce model sizes in requirements |
|
| 289 |
+
| **Memory errors** | Check GPU memory usage, optimize model loading |
|
| 290 |
+
| **Import errors** | Check Python version (3.9+) |
|
| 291 |
+
|
| 292 |
+
### Performance Optimization
|
| 293 |
+
|
| 294 |
+
1. **Enable caching** in context manager
|
| 295 |
+
2. **Use smaller models** for initial deployment
|
| 296 |
+
3. **Implement lazy loading** for mobile users
|
| 297 |
+
4. **Monitor memory usage** with built-in tools
|
| 298 |
+
|
| 299 |
+
### Debug Mode
|
| 300 |
+
|
| 301 |
+
Enable detailed logging:
|
| 302 |
+
```python
|
| 303 |
+
import logging
|
| 304 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 305 |
+
```
|
| 306 |
+
|
| 307 |
+
## 📊 Performance Metrics
|
| 308 |
+
|
| 309 |
+
| Metric | Target | Current |
|
| 310 |
+
|--------|---------|---------|
|
| 311 |
+
| Response Time | <10s | ~7s |
|
| 312 |
+
| Cache Hit Rate | >60% | ~65% |
|
| 313 |
+
| Mobile UX Score | >80/100 | 85/100 |
|
| 314 |
+
| Error Rate | <5% | ~3% |
|
| 315 |
+
|
| 316 |
+
## 🔮 Roadmap
|
| 317 |
+
|
| 318 |
+
### Phase 1 (Current - MVP)
|
| 319 |
+
- ✅ Basic agent orchestration
|
| 320 |
+
- ✅ Mobile-optimized interface
|
| 321 |
+
- ✅ Multi-model routing
|
| 322 |
+
- ✅ Transparent reasoning display
|
| 323 |
+
|
| 324 |
+
### Phase 2 (Next 3 months)
|
| 325 |
+
- 🚧 Advanced research capabilities
|
| 326 |
+
- 🚧 Plugin system for tools
|
| 327 |
+
- 🚧 Enhanced mobile PWA features
|
| 328 |
+
- 🚧 Multi-language support
|
| 329 |
+
|
| 330 |
+
### Phase 3 (Future)
|
| 331 |
+
- 🔮 Autonomous agent swarms
|
| 332 |
+
- 🔮 Voice interface integration
|
| 333 |
+
- 🔮 Enterprise features
|
| 334 |
+
- 🔮 Advanced analytics
|
| 335 |
+
|
| 336 |
+
## 👥 Contributing
|
| 337 |
+
|
| 338 |
+
We welcome contributions! Please see:
|
| 339 |
+
|
| 340 |
+
1. [Contributing Guidelines](docs/CONTRIBUTING.md)
|
| 341 |
+
2. [Code of Conduct](docs/CODE_OF_CONDUCT.md)
|
| 342 |
+
3. [Development Setup](docs/DEVELOPMENT.md)
|
| 343 |
+
|
| 344 |
+
### Quick Contribution Steps
|
| 345 |
+
```bash
|
| 346 |
+
# 1. Fork the repository
|
| 347 |
+
# 2. Create feature branch
|
| 348 |
+
git checkout -b feature/amazing-feature
|
| 349 |
+
|
| 350 |
+
# 3. Commit changes
|
| 351 |
+
git commit -m "Add amazing feature"
|
| 352 |
+
|
| 353 |
+
# 4. Push to branch
|
| 354 |
+
git push origin feature/amazing-feature
|
| 355 |
+
|
| 356 |
+
# 5. Open Pull Request
|
| 357 |
+
```
|
| 358 |
+
|
| 359 |
+
## 📄 Citation
|
| 360 |
+
|
| 361 |
+
If you use this framework in your research, please cite:
|
| 362 |
+
|
| 363 |
+
```bibtex
|
| 364 |
+
@software{research_assistant_mvp,
|
| 365 |
+
title = {AI Research Assistant - MVP},
|
| 366 |
+
author = {Your Name},
|
| 367 |
+
year = {2024},
|
| 368 |
+
url = {https://huggingface.co/spaces/your-username/research-assistant}
|
| 369 |
+
}
|
| 370 |
+
```
|
| 371 |
+
|
| 372 |
+
## 📜 License
|
| 373 |
+
|
| 374 |
+
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
| 375 |
+
|
| 376 |
+
## 🙏 Acknowledgments
|
| 377 |
+
|
| 378 |
+
- [Hugging Face](https://huggingface.co) for the infrastructure
|
| 379 |
+
- [Gradio](https://gradio.app) for the web framework
|
| 380 |
+
- Model contributors from the HF community
|
| 381 |
+
- Early testers and feedback providers
|
| 382 |
+
|
| 383 |
+
---
|
| 384 |
+
|
| 385 |
+
<div align="center">
|
| 386 |
+
|
| 387 |
+
**Need help?**
|
| 388 |
+
- [Open an Issue](https://github.com/your-org/research-assistant/issues)
|
| 389 |
+
- [Join our Discord](https://discord.gg/your-discord)
|
| 390 |
+
- [Email Support](mailto:support@your-domain.com)
|
| 391 |
+
|
| 392 |
+
*Built with ❤️ for the research community*
|
| 393 |
+
|
| 394 |
+
</div>
|
| 395 |
+
|
config.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
import os
|
| 3 |
+
from pydantic_settings import BaseSettings
|
| 4 |
+
|
| 5 |
+
class Settings(BaseSettings):
|
| 6 |
+
# HF Spaces specific settings
|
| 7 |
+
hf_token: str = os.getenv("HF_TOKEN", "")
|
| 8 |
+
hf_cache_dir: str = os.getenv("HF_HOME", "/tmp/huggingface")
|
| 9 |
+
|
| 10 |
+
# Model settings
|
| 11 |
+
default_model: str = "mistralai/Mistral-7B-Instruct-v0.2"
|
| 12 |
+
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 13 |
+
classification_model: str = "cardiffnlp/twitter-roberta-base-emotion"
|
| 14 |
+
|
| 15 |
+
# Performance settings
|
| 16 |
+
max_workers: int = int(os.getenv("MAX_WORKERS", "4"))
|
| 17 |
+
cache_ttl: int = int(os.getenv("CACHE_TTL", "3600"))
|
| 18 |
+
|
| 19 |
+
# Database settings
|
| 20 |
+
db_path: str = os.getenv("DB_PATH", "sessions.db")
|
| 21 |
+
faiss_index_path: str = os.getenv("FAISS_INDEX_PATH", "embeddings.faiss")
|
| 22 |
+
|
| 23 |
+
# Session settings
|
| 24 |
+
session_timeout: int = int(os.getenv("SESSION_TIMEOUT", "3600"))
|
| 25 |
+
max_session_size_mb: int = int(os.getenv("MAX_SESSION_SIZE_MB", "10"))
|
| 26 |
+
|
| 27 |
+
# Mobile optimization settings
|
| 28 |
+
mobile_max_tokens: int = int(os.getenv("MOBILE_MAX_TOKENS", "800"))
|
| 29 |
+
mobile_timeout: int = int(os.getenv("MOBILE_TIMEOUT", "15000"))
|
| 30 |
+
|
| 31 |
+
# Gradio settings
|
| 32 |
+
gradio_port: int = int(os.getenv("GRADIO_PORT", "7860"))
|
| 33 |
+
gradio_host: str = os.getenv("GRADIO_HOST", "0.0.0.0")
|
| 34 |
+
|
| 35 |
+
# Logging settings
|
| 36 |
+
log_level: str = os.getenv("LOG_LEVEL", "INFO")
|
| 37 |
+
log_format: str = os.getenv("LOG_FORMAT", "json")
|
| 38 |
+
|
| 39 |
+
class Config:
|
| 40 |
+
env_file = ".env"
|
| 41 |
+
|
| 42 |
+
settings = Settings()
|
| 43 |
+
|
| 44 |
+
# Context configuration
|
| 45 |
+
CONTEXT_CONFIG = {
|
| 46 |
+
'max_context_tokens': int(os.getenv("MAX_CONTEXT_TOKENS", "4000")),
|
| 47 |
+
'cache_ttl_seconds': int(os.getenv("CACHE_TTL_SECONDS", "300")),
|
| 48 |
+
'max_cache_size': int(os.getenv("MAX_CACHE_SIZE", "100")),
|
| 49 |
+
'parallel_processing': os.getenv("PARALLEL_PROCESSING", "True").lower() == "true",
|
| 50 |
+
'context_decay_factor': float(os.getenv("CONTEXT_DECAY_FACTOR", "0.8")),
|
| 51 |
+
'max_interactions_to_keep': int(os.getenv("MAX_INTERACTIONS_TO_KEEP", "10")),
|
| 52 |
+
'enable_metrics': os.getenv("ENABLE_METRICS", "True").lower() == "true",
|
| 53 |
+
'compression_enabled': os.getenv("COMPRESSION_ENABLED", "True").lower() == "true",
|
| 54 |
+
'summarization_threshold': int(os.getenv("SUMMARIZATION_THRESHOLD", "2000")) # tokens
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# Model selection for context operations
|
| 58 |
+
CONTEXT_MODELS = {
|
| 59 |
+
'summarization': os.getenv("CONTEXT_SUMMARIZATION_MODEL", "Qwen/Qwen2.5-7B-Instruct"),
|
| 60 |
+
'intent': os.getenv("CONTEXT_INTENT_MODEL", "Qwen/Qwen2.5-7B-Instruct"),
|
| 61 |
+
'synthesis': os.getenv("CONTEXT_SYNTHESIS_MODEL", "Qwen/Qwen2.5-72B-Instruct")
|
| 62 |
+
}
|
| 63 |
+
|
flask_api_standalone.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Pure Flask API for Hugging Face Spaces
|
| 4 |
+
No Gradio - Just Flask REST API
|
| 5 |
+
Uses local GPU models for inference
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from flask import Flask, request, jsonify
|
| 9 |
+
from flask_cors import CORS
|
| 10 |
+
import logging
|
| 11 |
+
import sys
|
| 12 |
+
import os
|
| 13 |
+
import asyncio
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
# Setup logging
|
| 17 |
+
logging.basicConfig(
|
| 18 |
+
level=logging.INFO,
|
| 19 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 20 |
+
)
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# Add project root to path
|
| 24 |
+
project_root = Path(__file__).parent
|
| 25 |
+
sys.path.insert(0, str(project_root))
|
| 26 |
+
|
| 27 |
+
# Create Flask app
|
| 28 |
+
app = Flask(__name__)
|
| 29 |
+
CORS(app) # Enable CORS for all origins
|
| 30 |
+
|
| 31 |
+
# Global orchestrator
|
| 32 |
+
orchestrator = None
|
| 33 |
+
orchestrator_available = False
|
| 34 |
+
|
| 35 |
+
def initialize_orchestrator():
|
| 36 |
+
"""Initialize the AI orchestrator with local GPU models"""
|
| 37 |
+
global orchestrator, orchestrator_available
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
logger.info("=" * 60)
|
| 41 |
+
logger.info("INITIALIZING AI ORCHESTRATOR (Local GPU Models)")
|
| 42 |
+
logger.info("=" * 60)
|
| 43 |
+
|
| 44 |
+
from src.agents.intent_agent import create_intent_agent
|
| 45 |
+
from src.agents.synthesis_agent import create_synthesis_agent
|
| 46 |
+
from src.agents.safety_agent import create_safety_agent
|
| 47 |
+
from src.agents.skills_identification_agent import create_skills_identification_agent
|
| 48 |
+
from src.llm_router import LLMRouter
|
| 49 |
+
from src.orchestrator_engine import MVPOrchestrator
|
| 50 |
+
from src.context_manager import EfficientContextManager
|
| 51 |
+
|
| 52 |
+
logger.info("✓ Imports successful")
|
| 53 |
+
|
| 54 |
+
hf_token = os.getenv('HF_TOKEN', '')
|
| 55 |
+
if not hf_token:
|
| 56 |
+
logger.warning("HF_TOKEN not set - API fallback will be used if local models fail")
|
| 57 |
+
|
| 58 |
+
# Initialize LLM Router with local model loading enabled
|
| 59 |
+
logger.info("Initializing LLM Router with local GPU model loading...")
|
| 60 |
+
llm_router = LLMRouter(hf_token, use_local_models=True)
|
| 61 |
+
|
| 62 |
+
logger.info("Initializing Agents...")
|
| 63 |
+
agents = {
|
| 64 |
+
'intent_recognition': create_intent_agent(llm_router),
|
| 65 |
+
'response_synthesis': create_synthesis_agent(llm_router),
|
| 66 |
+
'safety_check': create_safety_agent(llm_router),
|
| 67 |
+
'skills_identification': create_skills_identification_agent(llm_router)
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
logger.info("Initializing Context Manager...")
|
| 71 |
+
context_manager = EfficientContextManager(llm_router=llm_router)
|
| 72 |
+
|
| 73 |
+
logger.info("Initializing Orchestrator...")
|
| 74 |
+
orchestrator = MVPOrchestrator(llm_router, context_manager, agents)
|
| 75 |
+
|
| 76 |
+
orchestrator_available = True
|
| 77 |
+
logger.info("=" * 60)
|
| 78 |
+
logger.info("✓ AI ORCHESTRATOR READY")
|
| 79 |
+
logger.info(" - Local GPU models enabled")
|
| 80 |
+
logger.info(" - MAX_WORKERS: 4")
|
| 81 |
+
logger.info("=" * 60)
|
| 82 |
+
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Failed to initialize: {e}", exc_info=True)
|
| 87 |
+
orchestrator_available = False
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
# Root endpoint
|
| 91 |
+
@app.route('/', methods=['GET'])
|
| 92 |
+
def root():
|
| 93 |
+
"""API information"""
|
| 94 |
+
return jsonify({
|
| 95 |
+
'name': 'AI Assistant Flask API',
|
| 96 |
+
'version': '1.0',
|
| 97 |
+
'status': 'running',
|
| 98 |
+
'orchestrator_ready': orchestrator_available,
|
| 99 |
+
'features': {
|
| 100 |
+
'local_gpu_models': True,
|
| 101 |
+
'max_workers': 4,
|
| 102 |
+
'hardware': 'NVIDIA T4 Medium'
|
| 103 |
+
},
|
| 104 |
+
'endpoints': {
|
| 105 |
+
'health': 'GET /api/health',
|
| 106 |
+
'chat': 'POST /api/chat',
|
| 107 |
+
'initialize': 'POST /api/initialize'
|
| 108 |
+
}
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
# Health check
|
| 112 |
+
@app.route('/api/health', methods=['GET'])
|
| 113 |
+
def health_check():
|
| 114 |
+
"""Health check endpoint"""
|
| 115 |
+
return jsonify({
|
| 116 |
+
'status': 'healthy' if orchestrator_available else 'initializing',
|
| 117 |
+
'orchestrator_ready': orchestrator_available
|
| 118 |
+
})
|
| 119 |
+
|
| 120 |
+
# Chat endpoint
|
| 121 |
+
@app.route('/api/chat', methods=['POST'])
|
| 122 |
+
def chat():
|
| 123 |
+
"""
|
| 124 |
+
Process chat message
|
| 125 |
+
|
| 126 |
+
POST /api/chat
|
| 127 |
+
{
|
| 128 |
+
"message": "user message",
|
| 129 |
+
"history": [[user, assistant], ...],
|
| 130 |
+
"session_id": "session-123",
|
| 131 |
+
"user_id": "user-456"
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
{
|
| 136 |
+
"success": true,
|
| 137 |
+
"message": "AI response",
|
| 138 |
+
"history": [...],
|
| 139 |
+
"reasoning": {...},
|
| 140 |
+
"performance": {...}
|
| 141 |
+
}
|
| 142 |
+
"""
|
| 143 |
+
try:
|
| 144 |
+
data = request.get_json()
|
| 145 |
+
|
| 146 |
+
if not data or 'message' not in data:
|
| 147 |
+
return jsonify({
|
| 148 |
+
'success': False,
|
| 149 |
+
'error': 'Message is required'
|
| 150 |
+
}), 400
|
| 151 |
+
|
| 152 |
+
message = data['message']
|
| 153 |
+
history = data.get('history', [])
|
| 154 |
+
session_id = data.get('session_id')
|
| 155 |
+
user_id = data.get('user_id', 'anonymous')
|
| 156 |
+
|
| 157 |
+
logger.info(f"Chat request - User: {user_id}, Session: {session_id}")
|
| 158 |
+
logger.info(f"Message: {message[:100]}...")
|
| 159 |
+
|
| 160 |
+
if not orchestrator_available or orchestrator is None:
|
| 161 |
+
return jsonify({
|
| 162 |
+
'success': False,
|
| 163 |
+
'error': 'Orchestrator not ready',
|
| 164 |
+
'message': 'AI system is initializing. Please try again in a moment.'
|
| 165 |
+
}), 503
|
| 166 |
+
|
| 167 |
+
# Process with orchestrator (async method)
|
| 168 |
+
# Set user_id for session tracking
|
| 169 |
+
if session_id:
|
| 170 |
+
orchestrator.set_user_id(session_id, user_id)
|
| 171 |
+
|
| 172 |
+
# Run async process_request in event loop
|
| 173 |
+
loop = asyncio.new_event_loop()
|
| 174 |
+
asyncio.set_event_loop(loop)
|
| 175 |
+
try:
|
| 176 |
+
result = loop.run_until_complete(
|
| 177 |
+
orchestrator.process_request(
|
| 178 |
+
session_id=session_id or f"session-{user_id}",
|
| 179 |
+
user_input=message
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
finally:
|
| 183 |
+
loop.close()
|
| 184 |
+
|
| 185 |
+
# Extract response
|
| 186 |
+
if isinstance(result, dict):
|
| 187 |
+
response_text = result.get('response', '')
|
| 188 |
+
reasoning = result.get('reasoning', {})
|
| 189 |
+
performance = result.get('performance', {})
|
| 190 |
+
else:
|
| 191 |
+
response_text = str(result)
|
| 192 |
+
reasoning = {}
|
| 193 |
+
performance = {}
|
| 194 |
+
|
| 195 |
+
updated_history = history + [[message, response_text]]
|
| 196 |
+
|
| 197 |
+
logger.info(f"✓ Response generated (length: {len(response_text)})")
|
| 198 |
+
|
| 199 |
+
return jsonify({
|
| 200 |
+
'success': True,
|
| 201 |
+
'message': response_text,
|
| 202 |
+
'history': updated_history,
|
| 203 |
+
'reasoning': reasoning,
|
| 204 |
+
'performance': performance
|
| 205 |
+
})
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logger.error(f"Chat error: {e}", exc_info=True)
|
| 209 |
+
return jsonify({
|
| 210 |
+
'success': False,
|
| 211 |
+
'error': str(e),
|
| 212 |
+
'message': 'Error processing your request. Please try again.'
|
| 213 |
+
}), 500
|
| 214 |
+
|
| 215 |
+
# Manual initialization endpoint
|
| 216 |
+
@app.route('/api/initialize', methods=['POST'])
|
| 217 |
+
def initialize():
|
| 218 |
+
"""Manually trigger initialization"""
|
| 219 |
+
success = initialize_orchestrator()
|
| 220 |
+
|
| 221 |
+
if success:
|
| 222 |
+
return jsonify({
|
| 223 |
+
'success': True,
|
| 224 |
+
'message': 'Orchestrator initialized successfully'
|
| 225 |
+
})
|
| 226 |
+
else:
|
| 227 |
+
return jsonify({
|
| 228 |
+
'success': False,
|
| 229 |
+
'message': 'Initialization failed. Check logs for details.'
|
| 230 |
+
}), 500
|
| 231 |
+
|
| 232 |
+
# Initialize on startup
|
| 233 |
+
if __name__ == '__main__':
|
| 234 |
+
logger.info("=" * 60)
|
| 235 |
+
logger.info("STARTING PURE FLASK API")
|
| 236 |
+
logger.info("=" * 60)
|
| 237 |
+
|
| 238 |
+
# Initialize orchestrator
|
| 239 |
+
initialize_orchestrator()
|
| 240 |
+
|
| 241 |
+
port = int(os.getenv('PORT', 7860))
|
| 242 |
+
|
| 243 |
+
logger.info(f"Starting Flask on port {port}")
|
| 244 |
+
logger.info("Endpoints available:")
|
| 245 |
+
logger.info(" GET /")
|
| 246 |
+
logger.info(" GET /api/health")
|
| 247 |
+
logger.info(" POST /api/chat")
|
| 248 |
+
logger.info(" POST /api/initialize")
|
| 249 |
+
logger.info("=" * 60)
|
| 250 |
+
|
| 251 |
+
app.run(
|
| 252 |
+
host='0.0.0.0',
|
| 253 |
+
port=port,
|
| 254 |
+
debug=False,
|
| 255 |
+
threaded=True # Enable threading for concurrent requests
|
| 256 |
+
)
|
| 257 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# requirements.txt for Hugging Face Spaces with NVIDIA T4 GPU
|
| 2 |
+
# Core Framework Dependencies
|
| 3 |
+
|
| 4 |
+
# Note: gradio, fastapi, uvicorn, datasets, huggingface-hub,
|
| 5 |
+
# pydantic==2.10.6, and protobuf<4 are installed by HF Spaces SDK
|
| 6 |
+
|
| 7 |
+
# PyTorch with CUDA support (for GPU inference)
|
| 8 |
+
# Note: HF Spaces provides torch, but we ensure GPU support
|
| 9 |
+
torch>=2.0.0
|
| 10 |
+
|
| 11 |
+
# Web Framework & Interface
|
| 12 |
+
aiohttp>=3.9.0
|
| 13 |
+
httpx>=0.25.0
|
| 14 |
+
|
| 15 |
+
# Hugging Face Ecosystem
|
| 16 |
+
transformers>=4.35.0
|
| 17 |
+
accelerate>=0.24.0
|
| 18 |
+
tokenizers>=0.15.0
|
| 19 |
+
sentence-transformers>=2.2.0
|
| 20 |
+
|
| 21 |
+
# Vector Database & Search
|
| 22 |
+
faiss-cpu>=1.7.4
|
| 23 |
+
numpy>=1.24.0
|
| 24 |
+
scipy>=1.11.0
|
| 25 |
+
|
| 26 |
+
# Data Processing & Utilities
|
| 27 |
+
pandas>=2.1.0
|
| 28 |
+
scikit-learn>=1.3.0
|
| 29 |
+
|
| 30 |
+
# Database & Persistence
|
| 31 |
+
sqlalchemy>=2.0.0
|
| 32 |
+
alembic>=1.12.0
|
| 33 |
+
|
| 34 |
+
# Caching & Performance
|
| 35 |
+
cachetools>=5.3.0
|
| 36 |
+
redis>=5.0.0
|
| 37 |
+
python-multipart>=0.0.6
|
| 38 |
+
|
| 39 |
+
# Security & Validation
|
| 40 |
+
pydantic-settings>=2.1.0
|
| 41 |
+
python-jose[cryptography]>=3.3.0
|
| 42 |
+
bcrypt>=4.0.0
|
| 43 |
+
|
| 44 |
+
# Mobile Optimization & UI
|
| 45 |
+
cssutils>=2.7.0
|
| 46 |
+
pillow>=10.1.0
|
| 47 |
+
requests>=2.31.0
|
| 48 |
+
|
| 49 |
+
# Async & Concurrency
|
| 50 |
+
aiofiles>=23.2.0
|
| 51 |
+
concurrent-log-handler>=0.9.0
|
| 52 |
+
|
| 53 |
+
# Logging & Monitoring
|
| 54 |
+
structlog>=23.2.0
|
| 55 |
+
prometheus-client>=0.19.0
|
| 56 |
+
psutil>=5.9.0
|
| 57 |
+
|
| 58 |
+
# Development & Testing
|
| 59 |
+
pytest>=7.4.0
|
| 60 |
+
pytest-asyncio>=0.21.0
|
| 61 |
+
pytest-cov>=4.1.0
|
| 62 |
+
black>=23.11.0
|
| 63 |
+
flake8>=6.1.0
|
| 64 |
+
mypy>=1.7.0
|
| 65 |
+
|
| 66 |
+
# Utility Libraries
|
| 67 |
+
python-dateutil>=2.8.0
|
| 68 |
+
pytz>=2023.3
|
| 69 |
+
tzdata>=2023.3
|
| 70 |
+
ujson>=5.8.0
|
| 71 |
+
orjson>=3.9.0
|
| 72 |
+
|
| 73 |
+
# Flask API for external integrations
|
| 74 |
+
flask>=3.0.0
|
| 75 |
+
flask-cors>=4.0.0
|
| 76 |
+
|
| 77 |
+
# HF Spaces Specific Dependencies
|
| 78 |
+
# Note: huggingface-cli is part of huggingface-hub (installed by SDK)
|
| 79 |
+
gradio-client>=0.8.0
|
| 80 |
+
gradio-pdf>=0.0.6
|
| 81 |
+
|
| 82 |
+
# Model-specific dependencies
|
| 83 |
+
safetensors>=0.4.0
|
| 84 |
+
|
| 85 |
+
# Development/debugging
|
| 86 |
+
ipython>=8.17.0
|
| 87 |
+
ipdb>=0.13.0
|
| 88 |
+
debugpy>=1.7.0
|
| 89 |
+
|
src/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Research Assistant MVP Package
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
__version__ = "1.0.0"
|
| 6 |
+
__author__ = "Research Assistant Team"
|
| 7 |
+
__description__ = "Academic AI assistant with transparent reasoning"
|
| 8 |
+
|
| 9 |
+
# Import key components for easy access
|
| 10 |
+
try:
|
| 11 |
+
from .config import settings
|
| 12 |
+
__all__ = ['settings']
|
| 13 |
+
except ImportError:
|
| 14 |
+
# Fallback if config is not available
|
| 15 |
+
__all__ = []
|
src/agents/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AI Research Assistant Agents
|
| 3 |
+
Specialized agents for different tasks
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .intent_agent import IntentRecognitionAgent, create_intent_agent
|
| 7 |
+
from .synthesis_agent import ResponseSynthesisAgent, create_synthesis_agent
|
| 8 |
+
from .safety_agent import SafetyCheckAgent, create_safety_agent
|
| 9 |
+
from .skills_identification_agent import SkillsIdentificationAgent, create_skills_identification_agent
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
'IntentRecognitionAgent',
|
| 13 |
+
'create_intent_agent',
|
| 14 |
+
'ResponseSynthesisAgent',
|
| 15 |
+
'create_synthesis_agent',
|
| 16 |
+
'SafetyCheckAgent',
|
| 17 |
+
'create_safety_agent',
|
| 18 |
+
'SkillsIdentificationAgent',
|
| 19 |
+
'create_skills_identification_agent'
|
| 20 |
+
]
|
| 21 |
+
|
src/agents/intent_agent.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Intent Recognition Agent
|
| 3 |
+
Specialized in understanding user goals using Chain of Thought reasoning
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Dict, Any, List
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class IntentRecognitionAgent:
|
| 13 |
+
def __init__(self, llm_router=None):
|
| 14 |
+
self.llm_router = llm_router
|
| 15 |
+
self.agent_id = "INTENT_REC_001"
|
| 16 |
+
self.specialization = "Multi-class intent classification with context awareness"
|
| 17 |
+
|
| 18 |
+
# Intent categories for classification
|
| 19 |
+
self.intent_categories = [
|
| 20 |
+
"information_request", # Asking for facts, explanations
|
| 21 |
+
"task_execution", # Requesting actions, automation
|
| 22 |
+
"creative_generation", # Content creation, writing
|
| 23 |
+
"analysis_research", # Data analysis, research
|
| 24 |
+
"casual_conversation", # Chat, social interaction
|
| 25 |
+
"troubleshooting", # Problem solving, debugging
|
| 26 |
+
"education_learning", # Learning, tutorials
|
| 27 |
+
"technical_support" # Technical help, guidance
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
async def execute(self, user_input: str, context: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
|
| 31 |
+
"""
|
| 32 |
+
Execute intent recognition with Chain of Thought reasoning
|
| 33 |
+
"""
|
| 34 |
+
try:
|
| 35 |
+
logger.info(f"{self.agent_id} processing user input: {user_input[:100]}...")
|
| 36 |
+
|
| 37 |
+
# Use LLM for sophisticated intent recognition if available
|
| 38 |
+
if self.llm_router:
|
| 39 |
+
intent_result = await self._llm_based_intent_recognition(user_input, context)
|
| 40 |
+
else:
|
| 41 |
+
# Fallback to rule-based classification
|
| 42 |
+
intent_result = await self._rule_based_intent_recognition(user_input, context)
|
| 43 |
+
|
| 44 |
+
# Add agent metadata
|
| 45 |
+
intent_result.update({
|
| 46 |
+
"agent_id": self.agent_id,
|
| 47 |
+
"processing_time": intent_result.get("processing_time", 0),
|
| 48 |
+
"confidence_calibration": self._calibrate_confidence(intent_result)
|
| 49 |
+
})
|
| 50 |
+
|
| 51 |
+
logger.info(f"{self.agent_id} completed with intent: {intent_result['primary_intent']}")
|
| 52 |
+
return intent_result
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
logger.error(f"{self.agent_id} error: {str(e)}")
|
| 56 |
+
return self._get_fallback_intent(user_input, context)
|
| 57 |
+
|
| 58 |
+
async def _llm_based_intent_recognition(self, user_input: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 59 |
+
"""Use LLM for sophisticated intent classification with Chain of Thought"""
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
cot_prompt = self._build_chain_of_thought_prompt(user_input, context)
|
| 63 |
+
|
| 64 |
+
logger.info(f"{self.agent_id} calling LLM for intent recognition")
|
| 65 |
+
llm_response = await self.llm_router.route_inference(
|
| 66 |
+
task_type="intent_classification",
|
| 67 |
+
prompt=cot_prompt,
|
| 68 |
+
max_tokens=1000,
|
| 69 |
+
temperature=0.3
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if llm_response and isinstance(llm_response, str) and len(llm_response.strip()) > 0:
|
| 73 |
+
# Parse LLM response
|
| 74 |
+
parsed_result = self._parse_llm_intent_response(llm_response)
|
| 75 |
+
parsed_result["processing_time"] = 0.8
|
| 76 |
+
parsed_result["method"] = "llm_enhanced"
|
| 77 |
+
return parsed_result
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logger.error(f"{self.agent_id} LLM intent recognition failed: {e}")
|
| 81 |
+
|
| 82 |
+
# Fallback to rule-based classification if LLM fails
|
| 83 |
+
logger.info(f"{self.agent_id} falling back to rule-based classification")
|
| 84 |
+
return await self._rule_based_intent_recognition(user_input, context)
|
| 85 |
+
|
| 86 |
+
async def _rule_based_intent_recognition(self, user_input: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 87 |
+
"""Rule-based fallback intent classification"""
|
| 88 |
+
|
| 89 |
+
primary_intent, confidence = self._analyze_intent_patterns(user_input)
|
| 90 |
+
secondary_intents = self._get_secondary_intents(user_input, primary_intent)
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
"primary_intent": primary_intent,
|
| 94 |
+
"secondary_intents": secondary_intents,
|
| 95 |
+
"confidence_scores": {primary_intent: confidence},
|
| 96 |
+
"reasoning_chain": ["Rule-based pattern matching applied"],
|
| 97 |
+
"context_tags": [],
|
| 98 |
+
"processing_time": 0.02
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
def _build_chain_of_thought_prompt(self, user_input: str, context: Dict[str, Any]) -> str:
|
| 102 |
+
"""Build Chain of Thought prompt for intent recognition"""
|
| 103 |
+
|
| 104 |
+
# Extract context information from Context Manager structure
|
| 105 |
+
# Session context, user context, and interaction contexts are all from cache
|
| 106 |
+
context_info = ""
|
| 107 |
+
if context:
|
| 108 |
+
# Use combined_context if available (pre-formatted by Context Manager, includes session context)
|
| 109 |
+
combined_context = context.get('combined_context', '')
|
| 110 |
+
if combined_context:
|
| 111 |
+
# Use the pre-formatted context from Context Manager (includes session context)
|
| 112 |
+
context_info = f"\n\nAvailable Context:\n{combined_context[:1000]}..." # Truncate if too long
|
| 113 |
+
else:
|
| 114 |
+
# Fallback: Build from session_context, user_context, and interaction_contexts (all from cache)
|
| 115 |
+
session_context = context.get('session_context', {})
|
| 116 |
+
session_summary = session_context.get('summary', '') if isinstance(session_context, dict) else ""
|
| 117 |
+
interaction_contexts = context.get('interaction_contexts', [])
|
| 118 |
+
user_context = context.get('user_context', '')
|
| 119 |
+
|
| 120 |
+
context_parts = []
|
| 121 |
+
if session_summary:
|
| 122 |
+
context_parts.append(f"Session Context: {session_summary[:300]}...")
|
| 123 |
+
if user_context:
|
| 124 |
+
context_parts.append(f"User Context: {user_context[:300]}...")
|
| 125 |
+
|
| 126 |
+
if interaction_contexts:
|
| 127 |
+
# Show last 2 interaction summaries for context
|
| 128 |
+
recent_contexts = interaction_contexts[-2:]
|
| 129 |
+
context_parts.append("Recent Interactions:")
|
| 130 |
+
for idx, ic in enumerate(recent_contexts, 1):
|
| 131 |
+
summary = ic.get('summary', '')
|
| 132 |
+
if summary:
|
| 133 |
+
context_parts.append(f" {idx}. {summary}")
|
| 134 |
+
|
| 135 |
+
if context_parts:
|
| 136 |
+
context_info = "\n\nAvailable Context:\n" + "\n".join(context_parts)
|
| 137 |
+
|
| 138 |
+
if not context_info:
|
| 139 |
+
context_info = "\n\nAvailable Context: No previous context available (first interaction in session)."
|
| 140 |
+
|
| 141 |
+
return f"""
|
| 142 |
+
Analyze the user's intent step by step:
|
| 143 |
+
|
| 144 |
+
User Input: "{user_input}"
|
| 145 |
+
{context_info}
|
| 146 |
+
|
| 147 |
+
Step 1: Identify key entities, actions, and questions in the input
|
| 148 |
+
Step 2: Map to intent categories: {', '.join(self.intent_categories)}
|
| 149 |
+
Step 3: Consider the conversation flow and user's likely goals (if context available)
|
| 150 |
+
Step 4: Assign confidence scores (0.0-1.0) for each relevant intent
|
| 151 |
+
Step 5: Provide reasoning for the classification
|
| 152 |
+
|
| 153 |
+
Respond with JSON format containing primary_intent, secondary_intents, confidence_scores, and reasoning_chain.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def _analyze_intent_patterns(self, user_input: str) -> tuple:
|
| 157 |
+
"""Analyze user input patterns to determine intent"""
|
| 158 |
+
user_input_lower = user_input.lower()
|
| 159 |
+
|
| 160 |
+
# Pattern matching for different intents
|
| 161 |
+
patterns = {
|
| 162 |
+
"information_request": [
|
| 163 |
+
"what is", "how to", "explain", "tell me about", "what are",
|
| 164 |
+
"define", "meaning of", "information about"
|
| 165 |
+
],
|
| 166 |
+
"task_execution": [
|
| 167 |
+
"do this", "make a", "create", "build", "generate", "automate",
|
| 168 |
+
"set up", "configure", "execute", "run"
|
| 169 |
+
],
|
| 170 |
+
"creative_generation": [
|
| 171 |
+
"write a", "compose", "create content", "make a story",
|
| 172 |
+
"generate poem", "creative", "artistic"
|
| 173 |
+
],
|
| 174 |
+
"analysis_research": [
|
| 175 |
+
"analyze", "research", "compare", "study", "investigate",
|
| 176 |
+
"data analysis", "find patterns", "statistics"
|
| 177 |
+
],
|
| 178 |
+
"troubleshooting": [
|
| 179 |
+
"error", "problem", "fix", "debug", "not working",
|
| 180 |
+
"help with", "issue", "broken"
|
| 181 |
+
],
|
| 182 |
+
"technical_support": [
|
| 183 |
+
"how do i", "help me", "guide me", "tutorial", "step by step"
|
| 184 |
+
]
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
# Find matching patterns
|
| 188 |
+
for intent, pattern_list in patterns.items():
|
| 189 |
+
for pattern in pattern_list:
|
| 190 |
+
if pattern in user_input_lower:
|
| 191 |
+
confidence = min(0.9, 0.6 + (len(pattern) * 0.1)) # Basic confidence calculation
|
| 192 |
+
return intent, confidence
|
| 193 |
+
|
| 194 |
+
# Default to casual conversation
|
| 195 |
+
return "casual_conversation", 0.7
|
| 196 |
+
|
| 197 |
+
def _get_secondary_intents(self, user_input: str, primary_intent: str) -> List[str]:
|
| 198 |
+
"""Get secondary intents based on input complexity"""
|
| 199 |
+
user_input_lower = user_input.lower()
|
| 200 |
+
secondary = []
|
| 201 |
+
|
| 202 |
+
# Add secondary intents based on content
|
| 203 |
+
if "research" in user_input_lower and primary_intent != "analysis_research":
|
| 204 |
+
secondary.append("analysis_research")
|
| 205 |
+
if "help" in user_input_lower and primary_intent != "technical_support":
|
| 206 |
+
secondary.append("technical_support")
|
| 207 |
+
|
| 208 |
+
return secondary[:2] # Limit to 2 secondary intents
|
| 209 |
+
|
| 210 |
+
def _extract_context_tags(self, user_input: str, context: Dict[str, Any]) -> List[str]:
|
| 211 |
+
"""Extract relevant context tags from user input"""
|
| 212 |
+
tags = []
|
| 213 |
+
user_input_lower = user_input.lower()
|
| 214 |
+
|
| 215 |
+
# Simple tag extraction
|
| 216 |
+
if "research" in user_input_lower:
|
| 217 |
+
tags.append("research")
|
| 218 |
+
if "technical" in user_input_lower or "code" in user_input_lower:
|
| 219 |
+
tags.append("technical")
|
| 220 |
+
if "academic" in user_input_lower or "study" in user_input_lower:
|
| 221 |
+
tags.append("academic")
|
| 222 |
+
if "quick" in user_input_lower or "simple" in user_input_lower:
|
| 223 |
+
tags.append("quick_request")
|
| 224 |
+
|
| 225 |
+
return tags
|
| 226 |
+
|
| 227 |
+
def _calibrate_confidence(self, intent_result: Dict[str, Any]) -> Dict[str, Any]:
|
| 228 |
+
"""Calibrate confidence scores based on various factors"""
|
| 229 |
+
primary_intent = intent_result["primary_intent"]
|
| 230 |
+
confidence = intent_result["confidence_scores"][primary_intent]
|
| 231 |
+
|
| 232 |
+
calibration_factors = {
|
| 233 |
+
"input_length_impact": min(1.0, len(intent_result.get('user_input', '')) / 100),
|
| 234 |
+
"context_enhancement": 0.1 if intent_result.get('context_tags') else 0.0,
|
| 235 |
+
"reasoning_depth_bonus": 0.05 if len(intent_result.get('reasoning_chain', [])) > 2 else 0.0
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
calibrated_confidence = min(0.95, confidence + sum(calibration_factors.values()))
|
| 239 |
+
|
| 240 |
+
return {
|
| 241 |
+
"original_confidence": confidence,
|
| 242 |
+
"calibrated_confidence": calibrated_confidence,
|
| 243 |
+
"calibration_factors": calibration_factors
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
def _parse_llm_intent_response(self, response: str) -> Dict[str, Any]:
|
| 247 |
+
"""Parse LLM response for intent classification"""
|
| 248 |
+
try:
|
| 249 |
+
import json
|
| 250 |
+
import re
|
| 251 |
+
|
| 252 |
+
# Try to extract JSON from response
|
| 253 |
+
json_match = re.search(r'\{.*\}', response, re.DOTALL)
|
| 254 |
+
if json_match:
|
| 255 |
+
parsed = json.loads(json_match.group())
|
| 256 |
+
return parsed
|
| 257 |
+
except json.JSONDecodeError:
|
| 258 |
+
logger.warning(f"{self.agent_id} Failed to parse LLM intent JSON")
|
| 259 |
+
|
| 260 |
+
# Fallback parsing - extract intent from text
|
| 261 |
+
response_lower = response.lower()
|
| 262 |
+
primary_intent = "casual_conversation"
|
| 263 |
+
confidence = 0.7
|
| 264 |
+
|
| 265 |
+
# Simple pattern matching for intent extraction
|
| 266 |
+
if any(word in response_lower for word in ['question', 'ask', 'what', 'how', 'why']):
|
| 267 |
+
primary_intent = "information_request"
|
| 268 |
+
confidence = 0.8
|
| 269 |
+
elif any(word in response_lower for word in ['task', 'action', 'do', 'help', 'assist']):
|
| 270 |
+
primary_intent = "task_execution"
|
| 271 |
+
confidence = 0.8
|
| 272 |
+
elif any(word in response_lower for word in ['create', 'generate', 'write', 'make']):
|
| 273 |
+
primary_intent = "creative_generation"
|
| 274 |
+
confidence = 0.8
|
| 275 |
+
|
| 276 |
+
return {
|
| 277 |
+
"primary_intent": primary_intent,
|
| 278 |
+
"secondary_intents": [],
|
| 279 |
+
"confidence_scores": {primary_intent: confidence},
|
| 280 |
+
"reasoning_chain": [f"LLM response parsed: {response[:100]}..."],
|
| 281 |
+
"context_tags": ["llm_parsed"],
|
| 282 |
+
"method": "llm_parsed"
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
def _get_fallback_intent(self, user_input: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 286 |
+
"""Provide fallback intent when processing fails"""
|
| 287 |
+
return {
|
| 288 |
+
"primary_intent": "casual_conversation",
|
| 289 |
+
"secondary_intents": [],
|
| 290 |
+
"confidence_scores": {"casual_conversation": 0.5},
|
| 291 |
+
"reasoning_chain": ["Fallback: Default to casual conversation"],
|
| 292 |
+
"context_tags": ["fallback"],
|
| 293 |
+
"processing_time": 0.01,
|
| 294 |
+
"agent_id": self.agent_id,
|
| 295 |
+
"error_handled": True
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
# Factory function for easy instantiation
|
| 299 |
+
def create_intent_agent(llm_router=None):
|
| 300 |
+
return IntentRecognitionAgent(llm_router)
|
| 301 |
+
|
src/agents/safety_agent.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Safety & Bias Mitigation Agent
|
| 3 |
+
Specialized in content moderation and bias detection with non-blocking warnings
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import re
|
| 8 |
+
from typing import Dict, Any, List, Tuple
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class SafetyCheckAgent:
|
| 13 |
+
def __init__(self, llm_router=None):
|
| 14 |
+
self.llm_router = llm_router
|
| 15 |
+
self.agent_id = "SAFETY_BIAS_001"
|
| 16 |
+
self.specialization = "Content moderation and bias detection with warning-based approach"
|
| 17 |
+
|
| 18 |
+
# Safety thresholds (non-blocking, warning-only)
|
| 19 |
+
self.safety_thresholds = {
|
| 20 |
+
"toxicity": 0.8, # High threshold for warnings
|
| 21 |
+
"bias": 0.7, # Moderate threshold for bias detection
|
| 22 |
+
"safety": 0.6, # Lower threshold for general safety
|
| 23 |
+
"privacy": 0.9 # Very high threshold for privacy concerns
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
# Warning templates (non-blocking)
|
| 27 |
+
self.warning_templates = {
|
| 28 |
+
"toxicity": "⚠️ Note: Content may contain strong language",
|
| 29 |
+
"bias": "🔍 Note: Potential biases detected in response",
|
| 30 |
+
"safety": "📝 Note: Response should be verified for accuracy",
|
| 31 |
+
"privacy": "🔒 Note: Privacy-sensitive topics discussed",
|
| 32 |
+
"controversial": "💭 Note: This topic may have multiple perspectives"
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
# Pattern-based detection for quick analysis
|
| 36 |
+
self.sensitive_patterns = {
|
| 37 |
+
"toxicity": [
|
| 38 |
+
r'\b(hate|violence|harm|attack|destroy)\b',
|
| 39 |
+
r'\b(kill|hurt|harm|danger)\b',
|
| 40 |
+
r'racial slurs', # Placeholder for actual sensitive terms
|
| 41 |
+
],
|
| 42 |
+
"bias": [
|
| 43 |
+
r'\b(all|always|never|every)\b', # Overgeneralizations
|
| 44 |
+
r'\b(should|must|have to)\b', # Prescriptive language
|
| 45 |
+
r'stereotypes?', # Stereotype indicators
|
| 46 |
+
],
|
| 47 |
+
"privacy": [
|
| 48 |
+
r'\b(ssn|social security|password|credit card)\b',
|
| 49 |
+
r'\b(address|phone|email|personal)\b',
|
| 50 |
+
r'\b(confidential|secret|private)\b',
|
| 51 |
+
]
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
async def execute(self, response, context: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
|
| 55 |
+
"""
|
| 56 |
+
Execute safety check with non-blocking warnings
|
| 57 |
+
Returns original response with added warnings
|
| 58 |
+
"""
|
| 59 |
+
try:
|
| 60 |
+
# Handle both string and dict inputs
|
| 61 |
+
if isinstance(response, dict):
|
| 62 |
+
# Extract the actual response string from the dict
|
| 63 |
+
response_text = response.get('final_response', response.get('response', str(response)))
|
| 64 |
+
else:
|
| 65 |
+
response_text = str(response)
|
| 66 |
+
|
| 67 |
+
logger.info(f"{self.agent_id} analyzing response of length {len(response_text)}")
|
| 68 |
+
|
| 69 |
+
# Perform safety analysis
|
| 70 |
+
safety_analysis = await self._analyze_safety(response_text, context)
|
| 71 |
+
|
| 72 |
+
# Generate warnings without modifying response
|
| 73 |
+
warnings = self._generate_warnings(safety_analysis)
|
| 74 |
+
|
| 75 |
+
# Add safety metadata to response
|
| 76 |
+
result = {
|
| 77 |
+
"original_response": response_text,
|
| 78 |
+
"safety_checked_response": response_text, # Response never modified
|
| 79 |
+
"warnings": warnings,
|
| 80 |
+
"safety_analysis": safety_analysis,
|
| 81 |
+
"blocked": False, # Never blocks content
|
| 82 |
+
"confidence_scores": safety_analysis.get("confidence_scores", {}),
|
| 83 |
+
"agent_id": self.agent_id
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
logger.info(f"{self.agent_id} completed with {len(warnings)} warnings")
|
| 87 |
+
return result
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error(f"{self.agent_id} error: {str(e)}", exc_info=True)
|
| 91 |
+
# Fail-safe: return original response with error note
|
| 92 |
+
response_text = str(response) if not isinstance(response, dict) else response.get('final_response', str(response))
|
| 93 |
+
return self._get_fallback_result(response_text)
|
| 94 |
+
|
| 95 |
+
async def _analyze_safety(self, response: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 96 |
+
"""Analyze response for safety concerns using multiple methods"""
|
| 97 |
+
|
| 98 |
+
if self.llm_router:
|
| 99 |
+
return await self._llm_based_safety_analysis(response, context)
|
| 100 |
+
else:
|
| 101 |
+
return await self._pattern_based_safety_analysis(response)
|
| 102 |
+
|
| 103 |
+
async def _llm_based_safety_analysis(self, response: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 104 |
+
"""Use LLM for sophisticated safety analysis"""
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
safety_prompt = self._build_safety_prompt(response, context)
|
| 108 |
+
|
| 109 |
+
logger.info(f"{self.agent_id} calling LLM for safety analysis")
|
| 110 |
+
llm_response = await self.llm_router.route_inference(
|
| 111 |
+
task_type="safety_check",
|
| 112 |
+
prompt=safety_prompt,
|
| 113 |
+
max_tokens=800,
|
| 114 |
+
temperature=0.3
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if llm_response and isinstance(llm_response, str) and len(llm_response.strip()) > 0:
|
| 118 |
+
# Parse LLM response
|
| 119 |
+
parsed_analysis = self._parse_llm_safety_response(llm_response)
|
| 120 |
+
parsed_analysis["processing_time"] = 0.6
|
| 121 |
+
parsed_analysis["method"] = "llm_enhanced"
|
| 122 |
+
return parsed_analysis
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.error(f"{self.agent_id} LLM safety analysis failed: {e}")
|
| 126 |
+
|
| 127 |
+
# Fallback to pattern-based analysis if LLM fails
|
| 128 |
+
logger.info(f"{self.agent_id} falling back to pattern-based safety analysis")
|
| 129 |
+
return await self._pattern_based_safety_analysis(response)
|
| 130 |
+
|
| 131 |
+
async def _pattern_based_safety_analysis(self, response: str) -> Dict[str, Any]:
|
| 132 |
+
"""Pattern-based safety analysis as fallback"""
|
| 133 |
+
|
| 134 |
+
detected_issues = self._pattern_based_detection(response)
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
"toxicity_score": self._calculate_toxicity_score(response),
|
| 138 |
+
"bias_indicators": self._detect_bias_indicators(response),
|
| 139 |
+
"privacy_concerns": self._check_privacy_issues(response),
|
| 140 |
+
"overall_safety_score": 0.75, # Conservative estimate
|
| 141 |
+
"confidence_scores": {
|
| 142 |
+
"toxicity": 0.6,
|
| 143 |
+
"bias": 0.5,
|
| 144 |
+
"safety": 0.7,
|
| 145 |
+
"privacy": 0.8
|
| 146 |
+
},
|
| 147 |
+
"detected_issues": detected_issues,
|
| 148 |
+
"analysis_method": "pattern_based"
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
def _build_safety_prompt(self, response: str, context: Dict[str, Any]) -> str:
|
| 152 |
+
"""Build prompt for safety analysis with context awareness"""
|
| 153 |
+
|
| 154 |
+
# Extract relevant context information for safety analysis
|
| 155 |
+
context_info = ""
|
| 156 |
+
if context:
|
| 157 |
+
# Get session context, user context, and interaction contexts (all from cache)
|
| 158 |
+
session_context = context.get('session_context', {})
|
| 159 |
+
session_summary = session_context.get('summary', '') if isinstance(session_context, dict) else ""
|
| 160 |
+
user_context = context.get('user_context', '')
|
| 161 |
+
interaction_contexts = context.get('interaction_contexts', [])
|
| 162 |
+
|
| 163 |
+
if session_summary:
|
| 164 |
+
context_info = f"\n\nSession Context (for safety context): {session_summary[:200]}..."
|
| 165 |
+
if user_context:
|
| 166 |
+
context_info += f"\n\nUser Context (for safety context): {user_context[:200]}..."
|
| 167 |
+
|
| 168 |
+
# Optionally include recent interaction context to understand conversation flow
|
| 169 |
+
if interaction_contexts:
|
| 170 |
+
recent_context = interaction_contexts[-1].get('summary', '') if interaction_contexts else ''
|
| 171 |
+
if recent_context:
|
| 172 |
+
context_info += f"\nRecent conversation: {recent_context[:200]}..."
|
| 173 |
+
|
| 174 |
+
return f"""
|
| 175 |
+
Analyze the following text for safety concerns and provide warnings if needed:
|
| 176 |
+
|
| 177 |
+
Text to analyze: "{response[:500]}..." # Truncate for large responses
|
| 178 |
+
{context_info}
|
| 179 |
+
|
| 180 |
+
Analyze for:
|
| 181 |
+
1. Toxicity or harmful language
|
| 182 |
+
2. Potential biases or stereotypes
|
| 183 |
+
3. Privacy or security concerns
|
| 184 |
+
4. Controversial or sensitive topics
|
| 185 |
+
|
| 186 |
+
Consider the conversation context when assessing whether content is appropriate in this specific conversational context.
|
| 187 |
+
|
| 188 |
+
Provide analysis in JSON format with:
|
| 189 |
+
- safety_scores for each category (0-1 scale)
|
| 190 |
+
- detected_issues list
|
| 191 |
+
- confidence_level
|
| 192 |
+
- recommended_warnings (non-blocking, advisory only)
|
| 193 |
+
|
| 194 |
+
IMPORTANT: Never block or modify the content, only provide warnings.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def _pattern_based_detection(self, response: str) -> List[Dict[str, Any]]:
|
| 198 |
+
"""Detect safety issues using pattern matching"""
|
| 199 |
+
issues = []
|
| 200 |
+
response_lower = response.lower()
|
| 201 |
+
|
| 202 |
+
# Check each category
|
| 203 |
+
for category, patterns in self.sensitive_patterns.items():
|
| 204 |
+
for pattern in patterns:
|
| 205 |
+
if re.search(pattern, response_lower, re.IGNORECASE):
|
| 206 |
+
issues.append({
|
| 207 |
+
"category": category,
|
| 208 |
+
"pattern": pattern,
|
| 209 |
+
"severity": "low", # Always low for warning-only approach
|
| 210 |
+
"confidence": 0.7
|
| 211 |
+
})
|
| 212 |
+
break # Only report one pattern match per category
|
| 213 |
+
|
| 214 |
+
return issues
|
| 215 |
+
|
| 216 |
+
def _calculate_toxicity_score(self, response: str) -> float:
|
| 217 |
+
"""Calculate toxicity score (simplified version)"""
|
| 218 |
+
# Simple heuristic-based toxicity detection
|
| 219 |
+
toxic_indicators = [
|
| 220 |
+
'hate', 'violence', 'harm', 'attack', 'destroy', 'kill', 'hurt'
|
| 221 |
+
]
|
| 222 |
+
|
| 223 |
+
score = 0.0
|
| 224 |
+
words = response.lower().split()
|
| 225 |
+
for indicator in toxic_indicators:
|
| 226 |
+
if indicator in words:
|
| 227 |
+
score += 0.2
|
| 228 |
+
|
| 229 |
+
return min(1.0, score)
|
| 230 |
+
|
| 231 |
+
def _detect_bias_indicators(self, response: str) -> List[str]:
|
| 232 |
+
"""Detect potential bias indicators"""
|
| 233 |
+
biases = []
|
| 234 |
+
|
| 235 |
+
# Overgeneralization detection
|
| 236 |
+
if re.search(r'\b(all|always|never|every)\s+\w+s\b', response, re.IGNORECASE):
|
| 237 |
+
biases.append("overgeneralization")
|
| 238 |
+
|
| 239 |
+
# Prescriptive language
|
| 240 |
+
if re.search(r'\b(should|must|have to|ought to)\b', response, re.IGNORECASE):
|
| 241 |
+
biases.append("prescriptive_language")
|
| 242 |
+
|
| 243 |
+
# Stereotype indicators
|
| 244 |
+
stereotype_patterns = [
|
| 245 |
+
r'\b(all|most)\s+\w+\s+people\b',
|
| 246 |
+
r'\b(typical|usual|normal)\s+\w+\b',
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
for pattern in stereotype_patterns:
|
| 250 |
+
if re.search(pattern, response, re.IGNORECASE):
|
| 251 |
+
biases.append("potential_stereotype")
|
| 252 |
+
break
|
| 253 |
+
|
| 254 |
+
return biases
|
| 255 |
+
|
| 256 |
+
def _check_privacy_issues(self, response: str) -> List[str]:
|
| 257 |
+
"""Check for privacy-sensitive content"""
|
| 258 |
+
privacy_issues = []
|
| 259 |
+
|
| 260 |
+
# Personal information patterns
|
| 261 |
+
personal_info_patterns = [
|
| 262 |
+
r'\b\d{3}-\d{2}-\d{4}\b', # SSN-like pattern
|
| 263 |
+
r'\b\d{16}\b', # Credit card-like pattern
|
| 264 |
+
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', # Email
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
for pattern in personal_info_patterns:
|
| 268 |
+
if re.search(pattern, response):
|
| 269 |
+
privacy_issues.append("potential_personal_info")
|
| 270 |
+
break
|
| 271 |
+
|
| 272 |
+
return privacy_issues
|
| 273 |
+
|
| 274 |
+
def _generate_warnings(self, safety_analysis: Dict[str, Any]) -> List[str]:
|
| 275 |
+
"""Generate non-blocking warnings based on safety analysis"""
|
| 276 |
+
warnings = []
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
# Safely extract and validate confidence_scores
|
| 280 |
+
confidence_scores = safety_analysis.get("confidence_scores", {})
|
| 281 |
+
if not isinstance(confidence_scores, dict):
|
| 282 |
+
confidence_scores = {}
|
| 283 |
+
|
| 284 |
+
# Safely extract detected_issues
|
| 285 |
+
detected_issues = safety_analysis.get("detected_issues", [])
|
| 286 |
+
if not isinstance(detected_issues, list):
|
| 287 |
+
detected_issues = []
|
| 288 |
+
|
| 289 |
+
# Toxicity warnings
|
| 290 |
+
if isinstance(confidence_scores.get("toxicity"), (int, float)) and confidence_scores.get("toxicity", 0) > self.safety_thresholds["toxicity"]:
|
| 291 |
+
warnings.append(self.warning_templates["toxicity"])
|
| 292 |
+
|
| 293 |
+
# Bias warnings
|
| 294 |
+
has_bias_score = isinstance(confidence_scores.get("bias"), (int, float)) and confidence_scores.get("bias", 0) > self.safety_thresholds["bias"]
|
| 295 |
+
has_bias_indicators = safety_analysis.get("bias_indicators")
|
| 296 |
+
if has_bias_score or (has_bias_indicators and has_bias_indicators):
|
| 297 |
+
warnings.append(self.warning_templates["bias"])
|
| 298 |
+
|
| 299 |
+
# Privacy warnings
|
| 300 |
+
has_privacy_score = isinstance(confidence_scores.get("privacy"), (int, float)) and confidence_scores.get("privacy", 0) > self.safety_thresholds["privacy"]
|
| 301 |
+
has_privacy_concerns = safety_analysis.get("privacy_concerns")
|
| 302 |
+
if has_privacy_score or (has_privacy_concerns and has_privacy_concerns):
|
| 303 |
+
warnings.append(self.warning_templates["privacy"])
|
| 304 |
+
|
| 305 |
+
# General safety warning if overall score is low
|
| 306 |
+
overall_score = safety_analysis.get("overall_safety_score", 1.0)
|
| 307 |
+
if isinstance(overall_score, (int, float)) and overall_score < 0.7:
|
| 308 |
+
warnings.append(self.warning_templates["safety"])
|
| 309 |
+
|
| 310 |
+
# Add context-specific warnings for detected issues
|
| 311 |
+
for issue in detected_issues:
|
| 312 |
+
try:
|
| 313 |
+
if isinstance(issue, dict):
|
| 314 |
+
category = issue.get("category")
|
| 315 |
+
if category and isinstance(category, str) and category in self.warning_templates:
|
| 316 |
+
category_warning = self.warning_templates[category]
|
| 317 |
+
if category_warning not in warnings:
|
| 318 |
+
warnings.append(category_warning)
|
| 319 |
+
except Exception as e:
|
| 320 |
+
logger.debug(f"Error processing issue: {e}")
|
| 321 |
+
continue
|
| 322 |
+
|
| 323 |
+
# Deduplicate warnings and ensure all are strings
|
| 324 |
+
warnings = [w for w in warnings if isinstance(w, str)]
|
| 325 |
+
# Create set and convert back to list (safely handle any edge cases)
|
| 326 |
+
seen = set()
|
| 327 |
+
unique_warnings = []
|
| 328 |
+
for w in warnings:
|
| 329 |
+
if w not in seen:
|
| 330 |
+
seen.add(w)
|
| 331 |
+
unique_warnings.append(w)
|
| 332 |
+
return unique_warnings
|
| 333 |
+
|
| 334 |
+
except Exception as e:
|
| 335 |
+
logger.error(f"Error generating warnings: {e}", exc_info=True)
|
| 336 |
+
# Return empty list on error
|
| 337 |
+
return []
|
| 338 |
+
|
| 339 |
+
def _parse_llm_safety_response(self, response: str) -> Dict[str, Any]:
|
| 340 |
+
"""Parse LLM response for safety analysis"""
|
| 341 |
+
try:
|
| 342 |
+
import json
|
| 343 |
+
import re
|
| 344 |
+
|
| 345 |
+
# Try to extract JSON from response
|
| 346 |
+
json_match = re.search(r'\{.*\}', response, re.DOTALL)
|
| 347 |
+
if json_match:
|
| 348 |
+
parsed = json.loads(json_match.group())
|
| 349 |
+
return parsed
|
| 350 |
+
except json.JSONDecodeError:
|
| 351 |
+
logger.warning(f"{self.agent_id} Failed to parse LLM safety JSON")
|
| 352 |
+
|
| 353 |
+
# Fallback parsing - extract safety info from text
|
| 354 |
+
response_lower = response.lower()
|
| 355 |
+
|
| 356 |
+
# Simple safety analysis based on keywords
|
| 357 |
+
toxicity_score = 0.1
|
| 358 |
+
bias_score = 0.1
|
| 359 |
+
safety_score = 0.9
|
| 360 |
+
|
| 361 |
+
if any(word in response_lower for word in ['toxic', 'harmful', 'dangerous', 'inappropriate']):
|
| 362 |
+
toxicity_score = 0.8
|
| 363 |
+
safety_score = 0.3
|
| 364 |
+
elif any(word in response_lower for word in ['bias', 'discriminatory', 'unfair', 'prejudiced']):
|
| 365 |
+
bias_score = 0.7
|
| 366 |
+
safety_score = 0.5
|
| 367 |
+
|
| 368 |
+
return {
|
| 369 |
+
"toxicity_score": toxicity_score,
|
| 370 |
+
"bias_indicators": [],
|
| 371 |
+
"privacy_concerns": [],
|
| 372 |
+
"overall_safety_score": safety_score,
|
| 373 |
+
"confidence_scores": {
|
| 374 |
+
"toxicity": 0.7,
|
| 375 |
+
"bias": 0.6,
|
| 376 |
+
"safety": safety_score,
|
| 377 |
+
"privacy": 0.9
|
| 378 |
+
},
|
| 379 |
+
"detected_issues": [],
|
| 380 |
+
"analysis_method": "llm_parsed",
|
| 381 |
+
"llm_response": response[:200] + "..." if len(response) > 200 else response
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
def _get_fallback_result(self, response: str) -> Dict[str, Any]:
|
| 385 |
+
"""Fallback result when safety check fails"""
|
| 386 |
+
return {
|
| 387 |
+
"original_response": response,
|
| 388 |
+
"safety_checked_response": response,
|
| 389 |
+
"warnings": ["🔧 Note: Safety analysis temporarily unavailable"],
|
| 390 |
+
"safety_analysis": {
|
| 391 |
+
"overall_safety_score": 0.5,
|
| 392 |
+
"confidence_scores": {"safety": 0.5},
|
| 393 |
+
"detected_issues": [],
|
| 394 |
+
"analysis_method": "fallback"
|
| 395 |
+
},
|
| 396 |
+
"blocked": False,
|
| 397 |
+
"agent_id": self.agent_id,
|
| 398 |
+
"error_handled": True
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
def get_safety_summary(self, analysis_result: Dict[str, Any]) -> str:
|
| 402 |
+
"""Generate a user-friendly safety summary"""
|
| 403 |
+
warnings = analysis_result.get("warnings", [])
|
| 404 |
+
safety_score = analysis_result.get("safety_analysis", {}).get("overall_safety_score", 1.0)
|
| 405 |
+
|
| 406 |
+
if not warnings:
|
| 407 |
+
return "✅ Content appears safe based on automated analysis"
|
| 408 |
+
|
| 409 |
+
warning_count = len(warnings)
|
| 410 |
+
if safety_score > 0.8:
|
| 411 |
+
severity = "low"
|
| 412 |
+
elif safety_score > 0.6:
|
| 413 |
+
severity = "medium"
|
| 414 |
+
else:
|
| 415 |
+
severity = "high"
|
| 416 |
+
|
| 417 |
+
return f"⚠️ {warning_count} advisory note(s) - {severity} severity"
|
| 418 |
+
|
| 419 |
+
async def batch_analyze(self, responses: List[str]) -> List[Dict[str, Any]]:
|
| 420 |
+
"""Analyze multiple responses efficiently"""
|
| 421 |
+
results = []
|
| 422 |
+
for response in responses:
|
| 423 |
+
result = await self.execute(response)
|
| 424 |
+
results.append(result)
|
| 425 |
+
return results
|
| 426 |
+
|
| 427 |
+
# Factory function for easy instantiation
|
| 428 |
+
def create_safety_agent(llm_router=None):
|
| 429 |
+
return SafetyCheckAgent(llm_router)
|
| 430 |
+
|
| 431 |
+
# Example usage
|
| 432 |
+
if __name__ == "__main__":
|
| 433 |
+
# Test the safety agent
|
| 434 |
+
agent = SafetyCheckAgent()
|
| 435 |
+
|
| 436 |
+
test_responses = [
|
| 437 |
+
"This is a perfectly normal response with no issues.",
|
| 438 |
+
"Some content that might contain controversial topics.",
|
| 439 |
+
"Discussion about sensitive personal information."
|
| 440 |
+
]
|
| 441 |
+
|
| 442 |
+
import asyncio
|
| 443 |
+
|
| 444 |
+
async def test_agent():
|
| 445 |
+
for response in test_responses:
|
| 446 |
+
result = await agent.execute(response)
|
| 447 |
+
print(f"Response: {response[:50]}...")
|
| 448 |
+
print(f"Warnings: {result['warnings']}")
|
| 449 |
+
print(f"Safety Score: {result['safety_analysis']['overall_safety_score']}")
|
| 450 |
+
print("-" * 50)
|
| 451 |
+
|
| 452 |
+
asyncio.run(test_agent())
|
| 453 |
+
|
src/agents/skills_identification_agent.py
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Skills Identification Agent
|
| 3 |
+
Specialized in analyzing user prompts and identifying relevant expert skills based on market analysis
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Dict, Any, List, Tuple
|
| 8 |
+
import json
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class SkillsIdentificationAgent:
|
| 14 |
+
def __init__(self, llm_router=None):
|
| 15 |
+
self.llm_router = llm_router
|
| 16 |
+
self.agent_id = "SKILLS_ID_001"
|
| 17 |
+
self.specialization = "Expert skills identification and market analysis"
|
| 18 |
+
|
| 19 |
+
# Market analysis data from Expert_Skills_Market_Analysis_2024.md
|
| 20 |
+
self.market_categories = {
|
| 21 |
+
"IT and Software Development": {
|
| 22 |
+
"market_share": 25,
|
| 23 |
+
"growth_rate": 25.0,
|
| 24 |
+
"specialized_skills": [
|
| 25 |
+
"Cybersecurity", "Artificial Intelligence & Machine Learning",
|
| 26 |
+
"Cloud Computing", "Data Analytics & Big Data",
|
| 27 |
+
"Software Engineering", "Blockchain Technology", "Quantum Computing"
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
"Finance and Accounting": {
|
| 31 |
+
"market_share": 20,
|
| 32 |
+
"growth_rate": 6.8,
|
| 33 |
+
"specialized_skills": [
|
| 34 |
+
"Financial Analysis & Modeling", "Risk Management",
|
| 35 |
+
"Regulatory Compliance", "Fintech Solutions",
|
| 36 |
+
"ESG Reporting", "Tax Preparation", "Investment Analysis"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
"Healthcare and Medicine": {
|
| 40 |
+
"market_share": 15,
|
| 41 |
+
"growth_rate": 8.5,
|
| 42 |
+
"specialized_skills": [
|
| 43 |
+
"Telemedicine Training", "Advanced Nursing Certifications",
|
| 44 |
+
"Healthcare Informatics", "Clinical Research",
|
| 45 |
+
"Medical Device Technology", "Public Health", "Mental Health Services"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
"Education and Teaching": {
|
| 49 |
+
"market_share": 10,
|
| 50 |
+
"growth_rate": 3.2,
|
| 51 |
+
"specialized_skills": [
|
| 52 |
+
"Instructional Design", "Educational Technology Integration",
|
| 53 |
+
"Digital Literacy Training", "Special Education",
|
| 54 |
+
"Career Coaching", "E-learning Development", "STEM Education"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
"Engineering and Construction": {
|
| 58 |
+
"market_share": 10,
|
| 59 |
+
"growth_rate": 8.5,
|
| 60 |
+
"specialized_skills": [
|
| 61 |
+
"Automation Engineering", "Sustainable Design",
|
| 62 |
+
"Project Management", "Environmental Engineering",
|
| 63 |
+
"Advanced Manufacturing", "Infrastructure Development", "Quality Control"
|
| 64 |
+
]
|
| 65 |
+
},
|
| 66 |
+
"Marketing and Sales": {
|
| 67 |
+
"market_share": 10,
|
| 68 |
+
"growth_rate": 7.1,
|
| 69 |
+
"specialized_skills": [
|
| 70 |
+
"Digital Marketing", "Data Analytics",
|
| 71 |
+
"Customer Relationship Management", "Content Marketing",
|
| 72 |
+
"E-commerce Management", "Market Research", "Sales Strategy"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
"Consulting and Strategy": {
|
| 76 |
+
"market_share": 5,
|
| 77 |
+
"growth_rate": 6.0,
|
| 78 |
+
"specialized_skills": [
|
| 79 |
+
"Business Analysis", "Change Management",
|
| 80 |
+
"Strategic Planning", "Operations Research",
|
| 81 |
+
"Industry-Specific Knowledge", "Problem-Solving", "Leadership Development"
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
"Environmental and Sustainability": {
|
| 85 |
+
"market_share": 5,
|
| 86 |
+
"growth_rate": 15.0,
|
| 87 |
+
"specialized_skills": [
|
| 88 |
+
"Renewable Energy Technologies", "Environmental Policy",
|
| 89 |
+
"Sustainability Reporting", "Ecological Conservation",
|
| 90 |
+
"Carbon Management", "Green Technology", "Circular Economy"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
"Arts and Humanities": {
|
| 94 |
+
"market_share": 5,
|
| 95 |
+
"growth_rate": 2.5,
|
| 96 |
+
"specialized_skills": [
|
| 97 |
+
"Creative Thinking", "Cultural Analysis",
|
| 98 |
+
"Communication", "Digital Media",
|
| 99 |
+
"Language Services", "Historical Research", "Philosophical Analysis"
|
| 100 |
+
]
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
# Skill classification categories for the classification_specialist model
|
| 105 |
+
self.skill_categories = [
|
| 106 |
+
"technical_programming", "data_analysis", "cybersecurity", "cloud_computing",
|
| 107 |
+
"financial_analysis", "risk_management", "regulatory_compliance", "fintech",
|
| 108 |
+
"healthcare_technology", "medical_research", "telemedicine", "nursing",
|
| 109 |
+
"educational_technology", "curriculum_design", "online_learning", "teaching",
|
| 110 |
+
"project_management", "engineering_design", "sustainable_engineering", "manufacturing",
|
| 111 |
+
"digital_marketing", "sales_strategy", "customer_management", "market_research",
|
| 112 |
+
"business_consulting", "strategic_planning", "change_management", "leadership",
|
| 113 |
+
"environmental_science", "sustainability", "renewable_energy", "green_technology",
|
| 114 |
+
"creative_design", "content_creation", "communication", "cultural_analysis"
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
async def execute(self, user_input: str, context: Dict[str, Any] = None, **kwargs) -> Dict[str, Any]:
|
| 118 |
+
"""
|
| 119 |
+
Execute skills identification with two-step process:
|
| 120 |
+
1. Market analysis using reasoning_primary model
|
| 121 |
+
2. Skill classification using classification_specialist model
|
| 122 |
+
"""
|
| 123 |
+
try:
|
| 124 |
+
logger.info(f"{self.agent_id} processing user input: {user_input[:100]}...")
|
| 125 |
+
|
| 126 |
+
# Step 1: Market Analysis with reasoning_primary model
|
| 127 |
+
market_analysis = await self._analyze_market_relevance(user_input, context)
|
| 128 |
+
|
| 129 |
+
# Step 2: Skill Classification with classification_specialist model
|
| 130 |
+
skill_classification = await self._classify_skills(user_input, context)
|
| 131 |
+
|
| 132 |
+
# Combine results
|
| 133 |
+
combined_data = {
|
| 134 |
+
"market_analysis": market_analysis,
|
| 135 |
+
"skill_classification": skill_classification,
|
| 136 |
+
"user_input": user_input,
|
| 137 |
+
"context": context
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
result = {
|
| 141 |
+
"agent_id": self.agent_id,
|
| 142 |
+
"market_analysis": market_analysis,
|
| 143 |
+
"skill_classification": skill_classification,
|
| 144 |
+
"identified_skills": self._extract_high_probability_skills(combined_data),
|
| 145 |
+
"processing_time": market_analysis.get("processing_time", 0) + skill_classification.get("processing_time", 0),
|
| 146 |
+
"confidence_score": self._calculate_overall_confidence(market_analysis, skill_classification)
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
logger.info(f"{self.agent_id} completed with {len(result['identified_skills'])} skills identified")
|
| 150 |
+
return result
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
logger.error(f"{self.agent_id} error: {str(e)}")
|
| 154 |
+
return self._get_fallback_result(user_input, context)
|
| 155 |
+
|
| 156 |
+
async def _analyze_market_relevance(self, user_input: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 157 |
+
"""Use reasoning_primary model to analyze market relevance"""
|
| 158 |
+
|
| 159 |
+
if self.llm_router:
|
| 160 |
+
try:
|
| 161 |
+
# Build market analysis prompt with context
|
| 162 |
+
market_prompt = self._build_market_analysis_prompt(user_input, context)
|
| 163 |
+
|
| 164 |
+
logger.info(f"{self.agent_id} calling reasoning_primary for market analysis")
|
| 165 |
+
llm_response = await self.llm_router.route_inference(
|
| 166 |
+
task_type="general_reasoning",
|
| 167 |
+
prompt=market_prompt,
|
| 168 |
+
max_tokens=2000,
|
| 169 |
+
temperature=0.7
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if llm_response and isinstance(llm_response, str) and len(llm_response.strip()) > 0:
|
| 173 |
+
# Parse LLM response
|
| 174 |
+
parsed_analysis = self._parse_market_analysis_response(llm_response)
|
| 175 |
+
parsed_analysis["processing_time"] = 0.8
|
| 176 |
+
parsed_analysis["method"] = "llm_enhanced"
|
| 177 |
+
return parsed_analysis
|
| 178 |
+
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.error(f"{self.agent_id} LLM market analysis failed: {e}")
|
| 181 |
+
|
| 182 |
+
# Fallback to rule-based analysis
|
| 183 |
+
return self._rule_based_market_analysis(user_input)
|
| 184 |
+
|
| 185 |
+
async def _classify_skills(self, user_input: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 186 |
+
"""Use classification_specialist model to classify skills"""
|
| 187 |
+
|
| 188 |
+
if self.llm_router:
|
| 189 |
+
try:
|
| 190 |
+
# Build classification prompt
|
| 191 |
+
classification_prompt = self._build_classification_prompt(user_input)
|
| 192 |
+
|
| 193 |
+
logger.info(f"{self.agent_id} calling classification_specialist for skill classification")
|
| 194 |
+
llm_response = await self.llm_router.route_inference(
|
| 195 |
+
task_type="intent_classification",
|
| 196 |
+
prompt=classification_prompt,
|
| 197 |
+
max_tokens=512,
|
| 198 |
+
temperature=0.3
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if llm_response and isinstance(llm_response, str) and len(llm_response.strip()) > 0:
|
| 202 |
+
# Parse classification response
|
| 203 |
+
parsed_classification = self._parse_classification_response(llm_response)
|
| 204 |
+
parsed_classification["processing_time"] = 0.3
|
| 205 |
+
parsed_classification["method"] = "llm_enhanced"
|
| 206 |
+
return parsed_classification
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.error(f"{self.agent_id} LLM classification failed: {e}")
|
| 210 |
+
|
| 211 |
+
# Fallback to rule-based classification
|
| 212 |
+
return self._rule_based_skill_classification(user_input)
|
| 213 |
+
|
| 214 |
+
def _build_market_analysis_prompt(self, user_input: str, context: Dict[str, Any] = None) -> str:
|
| 215 |
+
"""Build prompt for market analysis using reasoning_primary model with optional context"""
|
| 216 |
+
|
| 217 |
+
market_data = "\n".join([
|
| 218 |
+
f"- {category}: {data['market_share']}% market share, {data['growth_rate']}% growth rate"
|
| 219 |
+
for category, data in self.market_categories.items()
|
| 220 |
+
])
|
| 221 |
+
|
| 222 |
+
specialized_skills = "\n".join([
|
| 223 |
+
f"- {category}: {', '.join(data['specialized_skills'][:3])}"
|
| 224 |
+
for category, data in self.market_categories.items()
|
| 225 |
+
])
|
| 226 |
+
|
| 227 |
+
# Add context information if available (all from cache)
|
| 228 |
+
context_info = ""
|
| 229 |
+
if context:
|
| 230 |
+
session_context = context.get('session_context', {})
|
| 231 |
+
session_summary = session_context.get('summary', '') if isinstance(session_context, dict) else ""
|
| 232 |
+
user_context = context.get('user_context', '')
|
| 233 |
+
interaction_contexts = context.get('interaction_contexts', [])
|
| 234 |
+
|
| 235 |
+
if session_summary:
|
| 236 |
+
context_info = f"\n\nSession Context (session summary): {session_summary[:300]}..."
|
| 237 |
+
if user_context:
|
| 238 |
+
context_info += f"\n\nUser Context (persona summary): {user_context[:300]}..."
|
| 239 |
+
|
| 240 |
+
if interaction_contexts:
|
| 241 |
+
# Include recent interaction context to understand topic continuity
|
| 242 |
+
recent_contexts = interaction_contexts[-2:] # Last 2 interactions
|
| 243 |
+
if recent_contexts:
|
| 244 |
+
context_info += "\n\nRecent conversation context:"
|
| 245 |
+
for idx, ic in enumerate(recent_contexts, 1):
|
| 246 |
+
summary = ic.get('summary', '')
|
| 247 |
+
if summary:
|
| 248 |
+
context_info += f"\n {idx}. {summary}"
|
| 249 |
+
|
| 250 |
+
return f"""Analyze the following user input and identify the most relevant industry categories and specialized skills based on current market data.
|
| 251 |
+
|
| 252 |
+
User Input: "{user_input}"
|
| 253 |
+
{context_info}
|
| 254 |
+
|
| 255 |
+
Current Market Distribution:
|
| 256 |
+
{market_data}
|
| 257 |
+
|
| 258 |
+
Specialized Skills by Category (top 3 per category):
|
| 259 |
+
{specialized_skills}
|
| 260 |
+
|
| 261 |
+
Task:
|
| 262 |
+
1. Identify which industry categories are most relevant to the user's input (consider conversation context if provided)
|
| 263 |
+
2. Select 1-3 specialized skills from each relevant category that best match the user's needs
|
| 264 |
+
3. Provide market share percentages and growth rates for identified categories
|
| 265 |
+
4. Explain your reasoning for each selection
|
| 266 |
+
5. If conversation context is available, consider how previous topics might inform the skill identification
|
| 267 |
+
|
| 268 |
+
Respond in JSON format:
|
| 269 |
+
{{
|
| 270 |
+
"relevant_categories": [
|
| 271 |
+
{{
|
| 272 |
+
"category": "category_name",
|
| 273 |
+
"market_share": percentage,
|
| 274 |
+
"growth_rate": percentage,
|
| 275 |
+
"relevance_score": 0.0-1.0,
|
| 276 |
+
"reasoning": "explanation"
|
| 277 |
+
}}
|
| 278 |
+
],
|
| 279 |
+
"selected_skills": [
|
| 280 |
+
{{
|
| 281 |
+
"skill": "skill_name",
|
| 282 |
+
"category": "category_name",
|
| 283 |
+
"relevance_score": 0.0-1.0,
|
| 284 |
+
"reasoning": "explanation"
|
| 285 |
+
}}
|
| 286 |
+
],
|
| 287 |
+
"overall_analysis": "summary of findings"
|
| 288 |
+
}}"""
|
| 289 |
+
|
| 290 |
+
def _build_classification_prompt(self, user_input: str) -> str:
|
| 291 |
+
"""Build prompt for skill classification using classification_specialist model"""
|
| 292 |
+
|
| 293 |
+
skill_categories_str = ", ".join(self.skill_categories)
|
| 294 |
+
|
| 295 |
+
return f"""Classify the following user input into relevant skill categories. For each category, provide a probability score (0.0-1.0) indicating how likely the input relates to that skill.
|
| 296 |
+
|
| 297 |
+
User Input: "{user_input}"
|
| 298 |
+
|
| 299 |
+
Available Skill Categories: {skill_categories_str}
|
| 300 |
+
|
| 301 |
+
Task: Provide probability scores for each skill category that passes a 20% threshold.
|
| 302 |
+
|
| 303 |
+
Respond in JSON format:
|
| 304 |
+
{{
|
| 305 |
+
"skill_probabilities": {{
|
| 306 |
+
"category_name": probability_score,
|
| 307 |
+
...
|
| 308 |
+
}},
|
| 309 |
+
"top_skills": [
|
| 310 |
+
{{
|
| 311 |
+
"skill": "category_name",
|
| 312 |
+
"probability": score,
|
| 313 |
+
"confidence": "high/medium/low"
|
| 314 |
+
}}
|
| 315 |
+
],
|
| 316 |
+
"classification_reasoning": "explanation of classification decisions"
|
| 317 |
+
}}"""
|
| 318 |
+
|
| 319 |
+
def _parse_market_analysis_response(self, response: str) -> Dict[str, Any]:
|
| 320 |
+
"""Parse LLM response for market analysis"""
|
| 321 |
+
try:
|
| 322 |
+
# Try to extract JSON from response
|
| 323 |
+
json_match = re.search(r'\{.*\}', response, re.DOTALL)
|
| 324 |
+
if json_match:
|
| 325 |
+
parsed = json.loads(json_match.group())
|
| 326 |
+
return parsed
|
| 327 |
+
except json.JSONDecodeError:
|
| 328 |
+
logger.warning(f"{self.agent_id} Failed to parse market analysis JSON")
|
| 329 |
+
|
| 330 |
+
# Fallback parsing
|
| 331 |
+
return {
|
| 332 |
+
"relevant_categories": [{"category": "General", "market_share": 10, "growth_rate": 5.0, "relevance_score": 0.7, "reasoning": "General analysis"}],
|
| 333 |
+
"selected_skills": [{"skill": "General Analysis", "category": "General", "relevance_score": 0.7, "reasoning": "Broad applicability"}],
|
| 334 |
+
"overall_analysis": "Market analysis completed with fallback parsing",
|
| 335 |
+
"method": "fallback_parsing"
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
def _parse_classification_response(self, response: str) -> Dict[str, Any]:
|
| 339 |
+
"""Parse LLM response for skill classification"""
|
| 340 |
+
try:
|
| 341 |
+
# Try to extract JSON from response
|
| 342 |
+
json_match = re.search(r'\{.*\}', response, re.DOTALL)
|
| 343 |
+
if json_match:
|
| 344 |
+
parsed = json.loads(json_match.group())
|
| 345 |
+
return parsed
|
| 346 |
+
except json.JSONDecodeError:
|
| 347 |
+
logger.warning(f"{self.agent_id} Failed to parse classification JSON")
|
| 348 |
+
|
| 349 |
+
# Fallback parsing
|
| 350 |
+
return {
|
| 351 |
+
"skill_probabilities": {"general_analysis": 0.7},
|
| 352 |
+
"top_skills": [{"skill": "general_analysis", "probability": 0.7, "confidence": "medium"}],
|
| 353 |
+
"classification_reasoning": "Classification completed with fallback parsing",
|
| 354 |
+
"method": "fallback_parsing"
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
def _rule_based_market_analysis(self, user_input: str) -> Dict[str, Any]:
|
| 358 |
+
"""Rule-based fallback for market analysis"""
|
| 359 |
+
user_input_lower = user_input.lower()
|
| 360 |
+
|
| 361 |
+
relevant_categories = []
|
| 362 |
+
selected_skills = []
|
| 363 |
+
|
| 364 |
+
# Pattern matching for different categories
|
| 365 |
+
patterns = {
|
| 366 |
+
"IT and Software Development": ["code", "programming", "software", "tech", "ai", "machine learning", "data", "cyber", "cloud"],
|
| 367 |
+
"Finance and Accounting": ["finance", "money", "investment", "banking", "accounting", "financial", "risk", "compliance"],
|
| 368 |
+
"Healthcare and Medicine": ["health", "medical", "doctor", "nurse", "patient", "clinical", "medicine", "healthcare"],
|
| 369 |
+
"Education and Teaching": ["teach", "education", "learn", "student", "school", "curriculum", "instruction"],
|
| 370 |
+
"Engineering and Construction": ["engineer", "construction", "build", "project", "manufacturing", "design"],
|
| 371 |
+
"Marketing and Sales": ["marketing", "sales", "customer", "advertising", "promotion", "brand"],
|
| 372 |
+
"Consulting and Strategy": ["consulting", "strategy", "business", "management", "planning"],
|
| 373 |
+
"Environmental and Sustainability": ["environment", "sustainable", "green", "renewable", "climate", "carbon"],
|
| 374 |
+
"Arts and Humanities": ["art", "creative", "culture", "humanities", "design", "communication"]
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
for category, keywords in patterns.items():
|
| 378 |
+
relevance_score = 0.0
|
| 379 |
+
for keyword in keywords:
|
| 380 |
+
if keyword in user_input_lower:
|
| 381 |
+
relevance_score += 0.2
|
| 382 |
+
|
| 383 |
+
if relevance_score > 0.0:
|
| 384 |
+
category_data = self.market_categories[category]
|
| 385 |
+
relevant_categories.append({
|
| 386 |
+
"category": category,
|
| 387 |
+
"market_share": category_data["market_share"],
|
| 388 |
+
"growth_rate": category_data["growth_rate"],
|
| 389 |
+
"relevance_score": min(1.0, relevance_score),
|
| 390 |
+
"reasoning": f"Matched keywords: {[k for k in keywords if k in user_input_lower]}"
|
| 391 |
+
})
|
| 392 |
+
|
| 393 |
+
# Add top skills from this category
|
| 394 |
+
for skill in category_data["specialized_skills"][:2]:
|
| 395 |
+
selected_skills.append({
|
| 396 |
+
"skill": skill,
|
| 397 |
+
"category": category,
|
| 398 |
+
"relevance_score": relevance_score * 0.8,
|
| 399 |
+
"reasoning": f"From {category} category"
|
| 400 |
+
})
|
| 401 |
+
|
| 402 |
+
return {
|
| 403 |
+
"relevant_categories": relevant_categories,
|
| 404 |
+
"selected_skills": selected_skills,
|
| 405 |
+
"overall_analysis": f"Rule-based analysis identified {len(relevant_categories)} relevant categories",
|
| 406 |
+
"processing_time": 0.1,
|
| 407 |
+
"method": "rule_based"
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
def _rule_based_skill_classification(self, user_input: str) -> Dict[str, Any]:
|
| 411 |
+
"""Rule-based fallback for skill classification"""
|
| 412 |
+
user_input_lower = user_input.lower()
|
| 413 |
+
|
| 414 |
+
skill_probabilities = {}
|
| 415 |
+
top_skills = []
|
| 416 |
+
|
| 417 |
+
# Simple keyword matching for skill categories
|
| 418 |
+
skill_keywords = {
|
| 419 |
+
"technical_programming": ["code", "programming", "software", "development", "python", "java"],
|
| 420 |
+
"data_analysis": ["data", "analysis", "statistics", "analytics", "research"],
|
| 421 |
+
"cybersecurity": ["security", "cyber", "hack", "protection", "vulnerability"],
|
| 422 |
+
"financial_analysis": ["finance", "money", "investment", "financial", "economic"],
|
| 423 |
+
"healthcare_technology": ["health", "medical", "healthcare", "clinical", "patient"],
|
| 424 |
+
"educational_technology": ["education", "teach", "learn", "student", "curriculum"],
|
| 425 |
+
"project_management": ["project", "manage", "planning", "coordination", "leadership"],
|
| 426 |
+
"digital_marketing": ["marketing", "advertising", "promotion", "social media", "brand"],
|
| 427 |
+
"environmental_science": ["environment", "sustainable", "green", "climate", "carbon"],
|
| 428 |
+
"creative_design": ["design", "creative", "art", "visual", "graphic"]
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
for skill, keywords in skill_keywords.items():
|
| 432 |
+
probability = 0.0
|
| 433 |
+
for keyword in keywords:
|
| 434 |
+
if keyword in user_input_lower:
|
| 435 |
+
probability += 0.3
|
| 436 |
+
|
| 437 |
+
if probability > 0.2: # 20% threshold
|
| 438 |
+
skill_probabilities[skill] = min(1.0, probability)
|
| 439 |
+
top_skills.append({
|
| 440 |
+
"skill": skill,
|
| 441 |
+
"probability": skill_probabilities[skill],
|
| 442 |
+
"confidence": "high" if probability > 0.6 else "medium" if probability > 0.4 else "low"
|
| 443 |
+
})
|
| 444 |
+
|
| 445 |
+
return {
|
| 446 |
+
"skill_probabilities": skill_probabilities,
|
| 447 |
+
"top_skills": top_skills,
|
| 448 |
+
"classification_reasoning": f"Rule-based classification identified {len(top_skills)} relevant skills",
|
| 449 |
+
"processing_time": 0.05,
|
| 450 |
+
"method": "rule_based"
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
def _extract_high_probability_skills(self, classification: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 454 |
+
"""Extract skills that pass the 20% probability threshold"""
|
| 455 |
+
high_prob_skills = []
|
| 456 |
+
|
| 457 |
+
# From market analysis
|
| 458 |
+
market_analysis = classification.get("market_analysis", {})
|
| 459 |
+
market_skills = market_analysis.get("selected_skills", [])
|
| 460 |
+
for skill in market_skills:
|
| 461 |
+
if skill.get("relevance_score", 0) > 0.2:
|
| 462 |
+
high_prob_skills.append({
|
| 463 |
+
"skill": skill["skill"],
|
| 464 |
+
"category": skill["category"],
|
| 465 |
+
"probability": skill["relevance_score"],
|
| 466 |
+
"source": "market_analysis"
|
| 467 |
+
})
|
| 468 |
+
|
| 469 |
+
# From skill classification
|
| 470 |
+
skill_classification = classification.get("skill_classification", {})
|
| 471 |
+
classification_skills = skill_classification.get("top_skills", [])
|
| 472 |
+
for skill in classification_skills:
|
| 473 |
+
if skill.get("probability", 0) > 0.2:
|
| 474 |
+
high_prob_skills.append({
|
| 475 |
+
"skill": skill["skill"],
|
| 476 |
+
"category": "classified",
|
| 477 |
+
"probability": skill["probability"],
|
| 478 |
+
"source": "skill_classification"
|
| 479 |
+
})
|
| 480 |
+
|
| 481 |
+
# If no skills found from LLM, use rule-based fallback
|
| 482 |
+
if not high_prob_skills:
|
| 483 |
+
logger.warning(f"{self.agent_id} No skills identified from LLM, using rule-based fallback")
|
| 484 |
+
# Extract user input from context if available
|
| 485 |
+
user_input = ""
|
| 486 |
+
if isinstance(classification, dict) and "user_input" in classification:
|
| 487 |
+
user_input = classification["user_input"]
|
| 488 |
+
elif isinstance(classification, dict) and "context" in classification:
|
| 489 |
+
context = classification["context"]
|
| 490 |
+
if isinstance(context, dict) and "user_input" in context:
|
| 491 |
+
user_input = context["user_input"]
|
| 492 |
+
|
| 493 |
+
if user_input:
|
| 494 |
+
rule_based_result = self._rule_based_skill_classification(user_input)
|
| 495 |
+
rule_skills = rule_based_result.get("top_skills", [])
|
| 496 |
+
for skill in rule_skills:
|
| 497 |
+
if skill.get("probability", 0) > 0.2:
|
| 498 |
+
high_prob_skills.append({
|
| 499 |
+
"skill": skill["skill"],
|
| 500 |
+
"category": "rule_based",
|
| 501 |
+
"probability": skill["probability"],
|
| 502 |
+
"source": "rule_based_fallback"
|
| 503 |
+
})
|
| 504 |
+
|
| 505 |
+
# Remove duplicates and sort by probability
|
| 506 |
+
unique_skills = {}
|
| 507 |
+
for skill in high_prob_skills:
|
| 508 |
+
skill_name = skill["skill"]
|
| 509 |
+
if skill_name not in unique_skills or skill["probability"] > unique_skills[skill_name]["probability"]:
|
| 510 |
+
unique_skills[skill_name] = skill
|
| 511 |
+
|
| 512 |
+
return sorted(unique_skills.values(), key=lambda x: x["probability"], reverse=True)
|
| 513 |
+
|
| 514 |
+
def _calculate_overall_confidence(self, market_analysis: Dict[str, Any], skill_classification: Dict[str, Any]) -> float:
|
| 515 |
+
"""Calculate overall confidence score"""
|
| 516 |
+
market_confidence = len(market_analysis.get("relevant_categories", [])) * 0.1
|
| 517 |
+
classification_confidence = len(skill_classification.get("top_skills", [])) * 0.1
|
| 518 |
+
|
| 519 |
+
return min(1.0, market_confidence + classification_confidence + 0.3)
|
| 520 |
+
|
| 521 |
+
def _get_fallback_result(self, user_input: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 522 |
+
"""Provide fallback result when processing fails"""
|
| 523 |
+
return {
|
| 524 |
+
"agent_id": self.agent_id,
|
| 525 |
+
"market_analysis": {
|
| 526 |
+
"relevant_categories": [{"category": "General", "market_share": 10, "growth_rate": 5.0, "relevance_score": 0.5, "reasoning": "Fallback analysis"}],
|
| 527 |
+
"selected_skills": [{"skill": "General Analysis", "category": "General", "relevance_score": 0.5, "reasoning": "Fallback skill"}],
|
| 528 |
+
"overall_analysis": "Fallback analysis due to processing error",
|
| 529 |
+
"processing_time": 0.01,
|
| 530 |
+
"method": "fallback"
|
| 531 |
+
},
|
| 532 |
+
"skill_classification": {
|
| 533 |
+
"skill_probabilities": {"general_analysis": 0.5},
|
| 534 |
+
"top_skills": [{"skill": "general_analysis", "probability": 0.5, "confidence": "low"}],
|
| 535 |
+
"classification_reasoning": "Fallback classification due to processing error",
|
| 536 |
+
"processing_time": 0.01,
|
| 537 |
+
"method": "fallback"
|
| 538 |
+
},
|
| 539 |
+
"identified_skills": [{"skill": "General Analysis", "category": "General", "probability": 0.5, "source": "fallback"}],
|
| 540 |
+
"processing_time": 0.02,
|
| 541 |
+
"confidence_score": 0.3,
|
| 542 |
+
"error_handled": True
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
# Factory function for easy instantiation
|
| 546 |
+
def create_skills_identification_agent(llm_router=None):
|
| 547 |
+
return SkillsIdentificationAgent(llm_router)
|
src/agents/synthesis_agent.py
ADDED
|
@@ -0,0 +1,735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced Synthesis Agent with Expert Consultant Assignment
|
| 3 |
+
Based on skill probability scores from Skills Identification Agent
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import json
|
| 8 |
+
from typing import Dict, List, Any, Optional, Tuple
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ExpertConsultantAssigner:
|
| 16 |
+
"""
|
| 17 |
+
Assigns expert consultant profiles based on skill probabilities
|
| 18 |
+
and generates weighted expertise for response synthesis
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# Expert consultant profiles with skill mappings
|
| 22 |
+
EXPERT_PROFILES = {
|
| 23 |
+
"data_analysis": {
|
| 24 |
+
"title": "Senior Data Analytics Consultant",
|
| 25 |
+
"expertise": ["Statistical Analysis", "Data Visualization", "Business Intelligence", "Predictive Modeling"],
|
| 26 |
+
"background": "15+ years in data science across finance, healthcare, and tech sectors",
|
| 27 |
+
"style": "methodical, evidence-based, quantitative reasoning"
|
| 28 |
+
},
|
| 29 |
+
"technical_programming": {
|
| 30 |
+
"title": "Principal Software Engineering Consultant",
|
| 31 |
+
"expertise": ["Full-Stack Development", "System Architecture", "DevOps", "Code Optimization"],
|
| 32 |
+
"background": "20+ years leading technical teams at Fortune 500 companies",
|
| 33 |
+
"style": "practical, solution-oriented, best practices focused"
|
| 34 |
+
},
|
| 35 |
+
"project_management": {
|
| 36 |
+
"title": "Strategic Project Management Consultant",
|
| 37 |
+
"expertise": ["Agile/Scrum", "Risk Management", "Stakeholder Communication", "Resource Optimization"],
|
| 38 |
+
"background": "12+ years managing complex enterprise projects across industries",
|
| 39 |
+
"style": "structured, process-driven, outcome-focused"
|
| 40 |
+
},
|
| 41 |
+
"financial_analysis": {
|
| 42 |
+
"title": "Executive Financial Strategy Consultant",
|
| 43 |
+
"expertise": ["Financial Modeling", "Investment Analysis", "Risk Assessment", "Corporate Finance"],
|
| 44 |
+
"background": "18+ years in investment banking and corporate finance advisory",
|
| 45 |
+
"style": "analytical, risk-aware, ROI-focused"
|
| 46 |
+
},
|
| 47 |
+
"digital_marketing": {
|
| 48 |
+
"title": "Chief Marketing Strategy Consultant",
|
| 49 |
+
"expertise": ["Digital Campaign Strategy", "Customer Analytics", "Brand Development", "Growth Hacking"],
|
| 50 |
+
"background": "14+ years scaling marketing for startups to enterprise clients",
|
| 51 |
+
"style": "creative, data-driven, customer-centric"
|
| 52 |
+
},
|
| 53 |
+
"business_consulting": {
|
| 54 |
+
"title": "Senior Management Consultant",
|
| 55 |
+
"expertise": ["Strategic Planning", "Organizational Development", "Process Improvement", "Change Management"],
|
| 56 |
+
"background": "16+ years at top-tier consulting firms (McKinsey, BCG equivalent)",
|
| 57 |
+
"style": "strategic, framework-driven, holistic thinking"
|
| 58 |
+
},
|
| 59 |
+
"cybersecurity": {
|
| 60 |
+
"title": "Chief Information Security Consultant",
|
| 61 |
+
"expertise": ["Threat Assessment", "Security Architecture", "Compliance", "Incident Response"],
|
| 62 |
+
"background": "12+ years protecting critical infrastructure across government and private sectors",
|
| 63 |
+
"style": "security-first, compliance-aware, risk mitigation focused"
|
| 64 |
+
},
|
| 65 |
+
"healthcare_technology": {
|
| 66 |
+
"title": "Healthcare Innovation Consultant",
|
| 67 |
+
"expertise": ["Health Informatics", "Telemedicine", "Medical Device Integration", "HIPAA Compliance"],
|
| 68 |
+
"background": "10+ years implementing healthcare technology solutions",
|
| 69 |
+
"style": "patient-centric, regulation-compliant, evidence-based"
|
| 70 |
+
},
|
| 71 |
+
"educational_technology": {
|
| 72 |
+
"title": "Learning Technology Strategy Consultant",
|
| 73 |
+
"expertise": ["Instructional Design", "EdTech Implementation", "Learning Analytics", "Curriculum Development"],
|
| 74 |
+
"background": "13+ years transforming educational experiences through technology",
|
| 75 |
+
"style": "learner-focused, pedagogy-driven, accessibility-minded"
|
| 76 |
+
},
|
| 77 |
+
"environmental_science": {
|
| 78 |
+
"title": "Sustainability Strategy Consultant",
|
| 79 |
+
"expertise": ["Environmental Impact Assessment", "Carbon Footprint Analysis", "Green Technology", "ESG Reporting"],
|
| 80 |
+
"background": "11+ years driving environmental initiatives for corporations",
|
| 81 |
+
"style": "sustainability-focused, data-driven, long-term thinking"
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
def assign_expert_consultant(self, skill_probabilities: Dict[str, float]) -> Dict[str, Any]:
|
| 86 |
+
"""
|
| 87 |
+
Create ultra-expert profile combining all relevant consultants
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
skill_probabilities: Dict mapping skill categories to probability scores (0.0-1.0)
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Dict containing ultra-expert profile with combined expertise
|
| 94 |
+
"""
|
| 95 |
+
if not skill_probabilities:
|
| 96 |
+
return self._get_default_consultant()
|
| 97 |
+
|
| 98 |
+
# Calculate weighted scores for available expert profiles
|
| 99 |
+
expert_scores = {}
|
| 100 |
+
total_weight = 0
|
| 101 |
+
|
| 102 |
+
for skill, probability in skill_probabilities.items():
|
| 103 |
+
if skill in self.EXPERT_PROFILES and probability >= 0.2: # 20% threshold
|
| 104 |
+
expert_scores[skill] = probability
|
| 105 |
+
total_weight += probability
|
| 106 |
+
|
| 107 |
+
if not expert_scores:
|
| 108 |
+
return self._get_default_consultant()
|
| 109 |
+
|
| 110 |
+
# Create ultra-expert combining all relevant consultants
|
| 111 |
+
ultra_expert = self._create_ultra_expert(expert_scores, total_weight)
|
| 112 |
+
|
| 113 |
+
return {
|
| 114 |
+
"assigned_consultant": ultra_expert,
|
| 115 |
+
"expertise_weights": expert_scores,
|
| 116 |
+
"total_weight": total_weight,
|
| 117 |
+
"assignment_rationale": self._generate_ultra_expert_rationale(expert_scores, total_weight)
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
def _get_default_consultant(self) -> Dict[str, Any]:
|
| 121 |
+
"""Default consultant for general inquiries"""
|
| 122 |
+
return {
|
| 123 |
+
"assigned_consultant": {
|
| 124 |
+
"primary_expertise": "business_consulting",
|
| 125 |
+
"title": "Senior Management Consultant",
|
| 126 |
+
"expertise": ["Strategic Planning", "Problem Solving", "Analysis", "Communication"],
|
| 127 |
+
"background": "Generalist consultant with broad industry experience",
|
| 128 |
+
"style": "balanced, analytical, comprehensive",
|
| 129 |
+
"secondary_expertise": [],
|
| 130 |
+
"confidence_score": 0.7
|
| 131 |
+
},
|
| 132 |
+
"expertise_weights": {"business_consulting": 0.7},
|
| 133 |
+
"total_weight": 0.7,
|
| 134 |
+
"assignment_rationale": "Default consultant assigned for general business inquiry"
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
def _create_ultra_expert(self, expert_scores: Dict[str, float], total_weight: float) -> Dict[str, Any]:
|
| 138 |
+
"""Create ultra-expert profile combining all relevant consultants"""
|
| 139 |
+
|
| 140 |
+
# Sort skills by probability (highest first)
|
| 141 |
+
sorted_skills = sorted(expert_scores.items(), key=lambda x: x[1], reverse=True)
|
| 142 |
+
|
| 143 |
+
# Combine all expertise areas with weights
|
| 144 |
+
combined_expertise = []
|
| 145 |
+
combined_background_elements = []
|
| 146 |
+
combined_style_elements = []
|
| 147 |
+
|
| 148 |
+
for skill, weight in sorted_skills:
|
| 149 |
+
if skill in self.EXPERT_PROFILES:
|
| 150 |
+
profile = self.EXPERT_PROFILES[skill]
|
| 151 |
+
|
| 152 |
+
# Weight-based contribution
|
| 153 |
+
contribution_ratio = weight / total_weight
|
| 154 |
+
|
| 155 |
+
# Add expertise areas with weight indicators
|
| 156 |
+
for expertise in profile["expertise"]:
|
| 157 |
+
weighted_expertise = f"{expertise} (Weight: {contribution_ratio:.1%})"
|
| 158 |
+
combined_expertise.append(weighted_expertise)
|
| 159 |
+
|
| 160 |
+
# Extract background years and combine
|
| 161 |
+
background = profile["background"]
|
| 162 |
+
combined_background_elements.append(f"{background} [{skill}]")
|
| 163 |
+
|
| 164 |
+
# Combine style elements
|
| 165 |
+
style_parts = [s.strip() for s in profile["style"].split(",")]
|
| 166 |
+
combined_style_elements.extend(style_parts)
|
| 167 |
+
|
| 168 |
+
# Create ultra-expert title combining top skills
|
| 169 |
+
top_skills = [skill.replace("_", " ").title() for skill, _ in sorted_skills[:3]]
|
| 170 |
+
ultra_title = f"Visionary Ultra-Expert: {' + '.join(top_skills)} Integration Specialist"
|
| 171 |
+
|
| 172 |
+
# Combine backgrounds into comprehensive experience
|
| 173 |
+
total_years = sum([self._extract_years_from_background(bg) for bg in combined_background_elements])
|
| 174 |
+
ultra_background = f"{total_years}+ years combined experience across {len(sorted_skills)} domains: " + \
|
| 175 |
+
"; ".join(combined_background_elements[:3]) # Limit for readability
|
| 176 |
+
|
| 177 |
+
# Create unified style combining all approaches
|
| 178 |
+
unique_styles = list(set(combined_style_elements))
|
| 179 |
+
ultra_style = ", ".join(unique_styles[:6]) # Top 6 style elements
|
| 180 |
+
|
| 181 |
+
return {
|
| 182 |
+
"primary_expertise": "ultra_expert_integration",
|
| 183 |
+
"title": ultra_title,
|
| 184 |
+
"expertise": combined_expertise,
|
| 185 |
+
"background": ultra_background,
|
| 186 |
+
"style": ultra_style,
|
| 187 |
+
"domain_integration": sorted_skills,
|
| 188 |
+
"confidence_score": total_weight / len(sorted_skills), # Average confidence
|
| 189 |
+
"ultra_expert": True,
|
| 190 |
+
"expertise_count": len(sorted_skills),
|
| 191 |
+
"total_experience_years": total_years
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
def _extract_years_from_background(self, background: str) -> int:
|
| 195 |
+
"""Extract years of experience from background string"""
|
| 196 |
+
years_match = re.search(r'(\d+)\+?\s*years?', background.lower())
|
| 197 |
+
return int(years_match.group(1)) if years_match else 10 # Default to 10 years
|
| 198 |
+
|
| 199 |
+
def _generate_ultra_expert_rationale(self, expert_scores: Dict[str, float], total_weight: float) -> str:
|
| 200 |
+
"""Generate explanation for ultra-expert assignment"""
|
| 201 |
+
sorted_skills = sorted(expert_scores.items(), key=lambda x: x[1], reverse=True)
|
| 202 |
+
|
| 203 |
+
rationale_parts = [
|
| 204 |
+
f"Ultra-Expert Profile combining {len(sorted_skills)} specialized domains",
|
| 205 |
+
f"Total expertise weight: {total_weight:.2f} across integrated skill areas"
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
# Add top 3 contributions
|
| 209 |
+
top_contributions = []
|
| 210 |
+
for skill, weight in sorted_skills[:3]:
|
| 211 |
+
contribution = (weight / total_weight) * 100
|
| 212 |
+
top_contributions.append(f"{skill} ({weight:.1%}, {contribution:.0f}% contribution)")
|
| 213 |
+
|
| 214 |
+
rationale_parts.append(f"Primary domains: {'; '.join(top_contributions)}")
|
| 215 |
+
|
| 216 |
+
if len(sorted_skills) > 3:
|
| 217 |
+
additional_count = len(sorted_skills) - 3
|
| 218 |
+
rationale_parts.append(f"Plus {additional_count} additional specialized areas")
|
| 219 |
+
|
| 220 |
+
return " | ".join(rationale_parts)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class EnhancedSynthesisAgent:
|
| 224 |
+
"""
|
| 225 |
+
Enhanced synthesis agent with expert consultant assignment
|
| 226 |
+
Compatible with existing ResponseSynthesisAgent interface
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
def __init__(self, llm_router, agent_id: str = "RESP_SYNTH_001"):
|
| 230 |
+
self.llm_router = llm_router
|
| 231 |
+
self.agent_id = agent_id
|
| 232 |
+
self.specialization = "Multi-source information integration and coherent response generation"
|
| 233 |
+
self.expert_assigner = ExpertConsultantAssigner()
|
| 234 |
+
self._current_user_input = None
|
| 235 |
+
|
| 236 |
+
async def execute(self, user_input: str = None, agent_outputs: List[Dict[str, Any]] = None,
|
| 237 |
+
context: Dict[str, Any] = None, skills_result: Dict[str, Any] = None,
|
| 238 |
+
**kwargs) -> Dict[str, Any]:
|
| 239 |
+
"""
|
| 240 |
+
Execute synthesis with expert consultant assignment
|
| 241 |
+
Compatible with both old interface (agent_outputs first) and new interface (user_input first)
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
user_input: Original user question
|
| 245 |
+
agent_outputs: Results from other agents (can be first positional arg for compatibility)
|
| 246 |
+
context: Conversation context
|
| 247 |
+
skills_result: Output from skills identification agent
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
Dict containing synthesized response and metadata
|
| 251 |
+
"""
|
| 252 |
+
# Handle backward compatibility and normalize arguments
|
| 253 |
+
# Case 1: First arg is agent_outputs (old interface)
|
| 254 |
+
if isinstance(user_input, list) and agent_outputs is None:
|
| 255 |
+
agent_outputs = user_input
|
| 256 |
+
user_input = kwargs.get('user_input', '')
|
| 257 |
+
context = kwargs.get('context', context)
|
| 258 |
+
skills_result = kwargs.get('skills_result', skills_result)
|
| 259 |
+
# Case 2: All args via kwargs
|
| 260 |
+
elif user_input is None:
|
| 261 |
+
user_input = kwargs.get('user_input', '')
|
| 262 |
+
agent_outputs = kwargs.get('agent_outputs', agent_outputs)
|
| 263 |
+
context = kwargs.get('context', context)
|
| 264 |
+
skills_result = kwargs.get('skills_result', skills_result)
|
| 265 |
+
|
| 266 |
+
# Ensure user_input is a string
|
| 267 |
+
if not isinstance(user_input, str):
|
| 268 |
+
user_input = str(user_input) if user_input else ''
|
| 269 |
+
|
| 270 |
+
# Default agent_outputs to empty list and normalize format
|
| 271 |
+
if agent_outputs is None:
|
| 272 |
+
agent_outputs = []
|
| 273 |
+
|
| 274 |
+
# Normalize agent_outputs: convert dict to list if needed
|
| 275 |
+
if isinstance(agent_outputs, dict):
|
| 276 |
+
# Convert dict {task_name: result} to list of dicts
|
| 277 |
+
normalized_outputs = []
|
| 278 |
+
for task_name, result in agent_outputs.items():
|
| 279 |
+
if isinstance(result, dict):
|
| 280 |
+
# Add task name to the result dict for context
|
| 281 |
+
result_with_task = result.copy()
|
| 282 |
+
result_with_task['task_name'] = task_name
|
| 283 |
+
normalized_outputs.append(result_with_task)
|
| 284 |
+
else:
|
| 285 |
+
# Wrap non-dict results
|
| 286 |
+
normalized_outputs.append({
|
| 287 |
+
'task_name': task_name,
|
| 288 |
+
'content': str(result),
|
| 289 |
+
'result': str(result)
|
| 290 |
+
})
|
| 291 |
+
agent_outputs = normalized_outputs
|
| 292 |
+
|
| 293 |
+
# Ensure it's a list
|
| 294 |
+
if not isinstance(agent_outputs, list):
|
| 295 |
+
agent_outputs = [agent_outputs] if agent_outputs else []
|
| 296 |
+
|
| 297 |
+
logger.info(f"{self.agent_id} synthesizing {len(agent_outputs)} agent outputs")
|
| 298 |
+
if context:
|
| 299 |
+
interaction_count = len(context.get('interaction_contexts', [])) if context else 0
|
| 300 |
+
logger.info(f"{self.agent_id} context has {interaction_count} interaction contexts")
|
| 301 |
+
|
| 302 |
+
# STEP 1: Extract skill probabilities from skills_result
|
| 303 |
+
skill_probabilities = self._extract_skill_probabilities(skills_result)
|
| 304 |
+
logger.info(f"Extracted skill probabilities: {skill_probabilities}")
|
| 305 |
+
|
| 306 |
+
# STEP 2: Assign expert consultant based on probabilities
|
| 307 |
+
consultant_assignment = self.expert_assigner.assign_expert_consultant(skill_probabilities)
|
| 308 |
+
assigned_consultant = consultant_assignment["assigned_consultant"]
|
| 309 |
+
logger.info(f"Assigned consultant: {assigned_consultant['title']} ({assigned_consultant.get('primary_expertise', 'N/A')})")
|
| 310 |
+
|
| 311 |
+
# STEP 3: Generate expert consultant preamble
|
| 312 |
+
expert_preamble = self._generate_expert_preamble(assigned_consultant, consultant_assignment)
|
| 313 |
+
|
| 314 |
+
# STEP 4: Build synthesis prompt with expert context
|
| 315 |
+
synthesis_prompt = self._build_synthesis_prompt_with_expert(
|
| 316 |
+
user_input=user_input,
|
| 317 |
+
context=context,
|
| 318 |
+
agent_outputs=agent_outputs,
|
| 319 |
+
expert_preamble=expert_preamble,
|
| 320 |
+
assigned_consultant=assigned_consultant
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
logger.info(f"{self.agent_id} calling LLM for response synthesis")
|
| 324 |
+
|
| 325 |
+
# Call LLM with enhanced prompt
|
| 326 |
+
try:
|
| 327 |
+
response = await self.llm_router.route_inference(
|
| 328 |
+
task_type="response_synthesis",
|
| 329 |
+
prompt=synthesis_prompt,
|
| 330 |
+
max_tokens=2000,
|
| 331 |
+
temperature=0.7
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Only use fallback if LLM actually fails (returns None, empty, or invalid)
|
| 335 |
+
if not response or not isinstance(response, str) or len(response.strip()) == 0:
|
| 336 |
+
logger.warning(f"{self.agent_id} LLM returned empty/invalid response, using fallback")
|
| 337 |
+
return self._get_fallback_response(user_input, agent_outputs, assigned_consultant)
|
| 338 |
+
|
| 339 |
+
clean_response = response.strip()
|
| 340 |
+
logger.info(f"{self.agent_id} received LLM response (length: {len(clean_response)})")
|
| 341 |
+
|
| 342 |
+
# Build comprehensive result compatible with existing interface
|
| 343 |
+
result = {
|
| 344 |
+
"synthesized_response": clean_response,
|
| 345 |
+
"draft_response": clean_response,
|
| 346 |
+
"final_response": clean_response, # Main response field - used by UI
|
| 347 |
+
"assigned_consultant": assigned_consultant,
|
| 348 |
+
"expertise_weights": consultant_assignment["expertise_weights"],
|
| 349 |
+
"assignment_rationale": consultant_assignment["assignment_rationale"],
|
| 350 |
+
"source_references": self._extract_source_references(agent_outputs),
|
| 351 |
+
"coherence_score": 0.90,
|
| 352 |
+
"improvement_opportunities": self._identify_improvements(clean_response),
|
| 353 |
+
"synthesis_method": "expert_enhanced_llm",
|
| 354 |
+
"agent_id": self.agent_id,
|
| 355 |
+
"synthesis_quality_metrics": self._calculate_quality_metrics({"final_response": clean_response}),
|
| 356 |
+
"synthesis_metadata": {
|
| 357 |
+
"agent_outputs_count": len(agent_outputs),
|
| 358 |
+
"context_interactions": len(context.get('interaction_contexts', [])) if context else 0,
|
| 359 |
+
"user_context_available": bool(context.get('user_context', '')) if context else False,
|
| 360 |
+
"expert_enhanced": True,
|
| 361 |
+
"processing_timestamp": datetime.now().isoformat()
|
| 362 |
+
}
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
# Add intent alignment if available
|
| 366 |
+
intent_info = self._extract_intent_info(agent_outputs)
|
| 367 |
+
if intent_info:
|
| 368 |
+
result["intent_alignment"] = self._check_intent_alignment(result, intent_info)
|
| 369 |
+
|
| 370 |
+
return result
|
| 371 |
+
|
| 372 |
+
except Exception as e:
|
| 373 |
+
logger.error(f"{self.agent_id} synthesis failed: {str(e)}", exc_info=True)
|
| 374 |
+
return self._get_fallback_response(user_input, agent_outputs, assigned_consultant)
|
| 375 |
+
|
| 376 |
+
def _extract_skill_probabilities(self, skills_result: Dict[str, Any]) -> Dict[str, float]:
|
| 377 |
+
"""Extract skill probabilities from skills identification result"""
|
| 378 |
+
if not skills_result:
|
| 379 |
+
return {}
|
| 380 |
+
|
| 381 |
+
# Check for skill_classification structure
|
| 382 |
+
skill_classification = skills_result.get('skill_classification', {})
|
| 383 |
+
if 'skill_probabilities' in skill_classification:
|
| 384 |
+
return skill_classification['skill_probabilities']
|
| 385 |
+
|
| 386 |
+
# Check for direct skill_probabilities
|
| 387 |
+
if 'skill_probabilities' in skills_result:
|
| 388 |
+
return skills_result['skill_probabilities']
|
| 389 |
+
|
| 390 |
+
# Extract from identified_skills if structured differently
|
| 391 |
+
identified_skills = skills_result.get('identified_skills', [])
|
| 392 |
+
if isinstance(identified_skills, list):
|
| 393 |
+
probabilities = {}
|
| 394 |
+
for skill in identified_skills:
|
| 395 |
+
if isinstance(skill, dict) and 'skill' in skill and 'probability' in skill:
|
| 396 |
+
# Map skill name to expert profile name if needed
|
| 397 |
+
skill_name = skill['skill']
|
| 398 |
+
probability = skill['probability']
|
| 399 |
+
probabilities[skill_name] = probability
|
| 400 |
+
elif isinstance(skill, dict) and 'category' in skill:
|
| 401 |
+
skill_name = skill['category']
|
| 402 |
+
probability = skill.get('probability', skill.get('confidence', 0.5))
|
| 403 |
+
probabilities[skill_name] = probability
|
| 404 |
+
return probabilities
|
| 405 |
+
|
| 406 |
+
return {}
|
| 407 |
+
|
| 408 |
+
def _generate_expert_preamble(self, assigned_consultant: Dict[str, Any],
|
| 409 |
+
consultant_assignment: Dict[str, Any]) -> str:
|
| 410 |
+
"""Generate expert consultant preamble for LLM prompt"""
|
| 411 |
+
|
| 412 |
+
if assigned_consultant.get('ultra_expert'):
|
| 413 |
+
# Ultra-expert preamble
|
| 414 |
+
preamble = f"""You are responding as a {assigned_consultant['title']} - an unprecedented combination of industry-leading experts.
|
| 415 |
+
|
| 416 |
+
ULTRA-EXPERT PROFILE:
|
| 417 |
+
- Integrated Expertise: {assigned_consultant['expertise_count']} specialized domains
|
| 418 |
+
- Combined Experience: {assigned_consultant['total_experience_years']}+ years across multiple industries
|
| 419 |
+
- Integration Approach: Cross-domain synthesis with deep specialization
|
| 420 |
+
- Response Style: {assigned_consultant['style']}
|
| 421 |
+
|
| 422 |
+
DOMAIN INTEGRATION: {', '.join([f"{skill} ({weight:.1%})" for skill, weight in assigned_consultant['domain_integration']])}
|
| 423 |
+
|
| 424 |
+
SPECIALIZED EXPERTISE AREAS:
|
| 425 |
+
{chr(10).join([f"• {expertise}" for expertise in assigned_consultant['expertise'][:8]])}
|
| 426 |
+
|
| 427 |
+
ASSIGNMENT RATIONALE: {consultant_assignment['assignment_rationale']}
|
| 428 |
+
|
| 429 |
+
KNOWLEDGE DEPTH REQUIREMENT:
|
| 430 |
+
- Provide insights equivalent to a visionary thought leader combining expertise from multiple domains
|
| 431 |
+
- Synthesize knowledge across {assigned_consultant['expertise_count']} specialization areas
|
| 432 |
+
- Apply interdisciplinary thinking and cross-domain innovation
|
| 433 |
+
- Leverage combined {assigned_consultant['total_experience_years']}+ years of integrated experience
|
| 434 |
+
|
| 435 |
+
ULTRA-EXPERT RESPONSE GUIDELINES:
|
| 436 |
+
- Draw from extensive cross-domain experience and pattern recognition
|
| 437 |
+
- Provide multi-perspective analysis combining different expert viewpoints
|
| 438 |
+
- Include interdisciplinary frameworks and innovative approaches
|
| 439 |
+
- Acknowledge complexity while providing actionable, synthesized recommendations
|
| 440 |
+
- Balance broad visionary thinking with deep domain-specific insights
|
| 441 |
+
- Use integrative problem-solving that spans multiple expertise areas
|
| 442 |
+
"""
|
| 443 |
+
else:
|
| 444 |
+
# Standard single expert preamble
|
| 445 |
+
preamble = f"""You are responding as a {assigned_consultant['title']} with the following profile:
|
| 446 |
+
|
| 447 |
+
EXPERTISE PROFILE:
|
| 448 |
+
- Primary Expertise: {assigned_consultant['primary_expertise']}
|
| 449 |
+
- Core Skills: {', '.join(assigned_consultant['expertise'])}
|
| 450 |
+
- Background: {assigned_consultant['background']}
|
| 451 |
+
- Response Style: {assigned_consultant['style']}
|
| 452 |
+
|
| 453 |
+
ASSIGNMENT RATIONALE: {consultant_assignment['assignment_rationale']}
|
| 454 |
+
|
| 455 |
+
EXPERTISE WEIGHTS: {', '.join([f"{skill}: {weight:.1%}" for skill, weight in consultant_assignment['expertise_weights'].items()])}
|
| 456 |
+
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
if assigned_consultant.get('secondary_expertise'):
|
| 460 |
+
preamble += f"SECONDARY EXPERTISE: {', '.join(assigned_consultant['secondary_expertise'])}\n"
|
| 461 |
+
|
| 462 |
+
preamble += f"""
|
| 463 |
+
KNOWLEDGE DEPTH REQUIREMENT: Provide insights equivalent to a highly experienced, industry-leading {assigned_consultant['title']} with deep domain expertise and practical experience.
|
| 464 |
+
|
| 465 |
+
RESPONSE GUIDELINES:
|
| 466 |
+
- Draw from extensive practical experience in your field
|
| 467 |
+
- Provide industry-specific insights and best practices
|
| 468 |
+
- Include relevant frameworks, methodologies, or tools
|
| 469 |
+
- Acknowledge complexity while remaining actionable
|
| 470 |
+
- Balance theoretical knowledge with real-world application
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
return preamble
|
| 474 |
+
|
| 475 |
+
def _build_synthesis_prompt_with_expert(self, user_input: str, context: Dict[str, Any],
|
| 476 |
+
agent_outputs: List[Dict[str, Any]],
|
| 477 |
+
expert_preamble: str,
|
| 478 |
+
assigned_consultant: Dict[str, Any]) -> str:
|
| 479 |
+
"""Build synthesis prompt with expert consultant context"""
|
| 480 |
+
|
| 481 |
+
# Build context section with summarization for long conversations
|
| 482 |
+
context_section = self._build_context_section(context)
|
| 483 |
+
|
| 484 |
+
# Build agent outputs section if any
|
| 485 |
+
agent_outputs_section = ""
|
| 486 |
+
if agent_outputs:
|
| 487 |
+
# Handle both dict and list formats
|
| 488 |
+
if isinstance(agent_outputs, dict):
|
| 489 |
+
# Convert dict to list format
|
| 490 |
+
outputs_list = []
|
| 491 |
+
for task_name, result in agent_outputs.items():
|
| 492 |
+
if isinstance(result, dict):
|
| 493 |
+
outputs_list.append(result)
|
| 494 |
+
else:
|
| 495 |
+
# Wrap string/non-dict results in dict format
|
| 496 |
+
outputs_list.append({
|
| 497 |
+
'task': task_name,
|
| 498 |
+
'content': str(result),
|
| 499 |
+
'result': str(result)
|
| 500 |
+
})
|
| 501 |
+
agent_outputs = outputs_list
|
| 502 |
+
|
| 503 |
+
# Ensure it's a list now
|
| 504 |
+
if isinstance(agent_outputs, list):
|
| 505 |
+
agent_outputs_section = f"\n\nAgent Analysis Results:\n"
|
| 506 |
+
for i, output in enumerate(agent_outputs, 1):
|
| 507 |
+
# Handle both dict and string outputs
|
| 508 |
+
if isinstance(output, dict):
|
| 509 |
+
output_text = output.get('content') or output.get('result') or output.get('final_response') or str(output)
|
| 510 |
+
else:
|
| 511 |
+
# If output is a string or other type
|
| 512 |
+
output_text = str(output)
|
| 513 |
+
agent_outputs_section += f"Agent {i}: {output_text}\n"
|
| 514 |
+
else:
|
| 515 |
+
# Fallback for unexpected types
|
| 516 |
+
agent_outputs_section = f"\n\nAgent Analysis Results:\n{str(agent_outputs)}\n"
|
| 517 |
+
|
| 518 |
+
# Construct full prompt
|
| 519 |
+
prompt = f"""{expert_preamble}
|
| 520 |
+
|
| 521 |
+
User Question: {user_input}
|
| 522 |
+
|
| 523 |
+
{context_section}{agent_outputs_section}
|
| 524 |
+
|
| 525 |
+
Instructions: Provide a comprehensive, helpful response that directly addresses the question from your expert perspective. If there's conversation context, use it to answer the current question appropriately. Be detailed, informative, and leverage your specialized expertise in {assigned_consultant.get('primary_expertise', 'general consulting')}.
|
| 526 |
+
|
| 527 |
+
Response:"""
|
| 528 |
+
|
| 529 |
+
return prompt
|
| 530 |
+
|
| 531 |
+
def _build_context_section(self, context: Dict[str, Any]) -> str:
|
| 532 |
+
"""Build context section with summarization for long conversations
|
| 533 |
+
|
| 534 |
+
Uses Context Manager structure:
|
| 535 |
+
- combined_context: Pre-formatted context string (preferred)
|
| 536 |
+
- interaction_contexts: List of interaction summaries with 'summary' and 'timestamp'
|
| 537 |
+
- user_context: User persona summary string
|
| 538 |
+
"""
|
| 539 |
+
if not context:
|
| 540 |
+
return ""
|
| 541 |
+
|
| 542 |
+
# Prefer combined_context if available (pre-formatted by Context Manager)
|
| 543 |
+
# combined_context includes Session Context, User Context, and Interaction Contexts
|
| 544 |
+
combined_context = context.get('combined_context', '')
|
| 545 |
+
if combined_context:
|
| 546 |
+
# Use the pre-formatted context from Context Manager
|
| 547 |
+
# It already includes Session Context, User Context, and Interaction Contexts formatted
|
| 548 |
+
return f"\n\nConversation Context:\n{combined_context}"
|
| 549 |
+
|
| 550 |
+
# Fallback: Build from individual components if combined_context not available
|
| 551 |
+
# All components are from cache
|
| 552 |
+
session_context = context.get('session_context', {})
|
| 553 |
+
session_summary = session_context.get('summary', '') if isinstance(session_context, dict) else ""
|
| 554 |
+
interaction_contexts = context.get('interaction_contexts', [])
|
| 555 |
+
user_context = context.get('user_context', '')
|
| 556 |
+
|
| 557 |
+
context_section = ""
|
| 558 |
+
|
| 559 |
+
# Add session context if available (from cache)
|
| 560 |
+
if session_summary:
|
| 561 |
+
context_section += f"\n\nSession Context (Session Summary):\n{session_summary[:500]}...\n"
|
| 562 |
+
# Add user context if available
|
| 563 |
+
if user_context:
|
| 564 |
+
context_section += f"\n\nUser Context (Persona Summary):\n{user_context[:500]}...\n"
|
| 565 |
+
|
| 566 |
+
# Add interaction contexts
|
| 567 |
+
if interaction_contexts:
|
| 568 |
+
if len(interaction_contexts) <= 8:
|
| 569 |
+
# Show all interaction summaries for short conversations
|
| 570 |
+
context_section += "\n\nPrevious Conversation Summary:\n"
|
| 571 |
+
for i, ic in enumerate(interaction_contexts, 1):
|
| 572 |
+
summary = ic.get('summary', '')
|
| 573 |
+
if summary:
|
| 574 |
+
context_section += f" {i}. {summary}\n"
|
| 575 |
+
else:
|
| 576 |
+
# Summarize older interactions, show recent ones
|
| 577 |
+
recent_contexts = interaction_contexts[-8:] # Last 8 interactions
|
| 578 |
+
older_contexts = interaction_contexts[:-8] # Everything before last 8
|
| 579 |
+
|
| 580 |
+
# Create summary of older interactions
|
| 581 |
+
summary = self._summarize_interaction_contexts(older_contexts)
|
| 582 |
+
|
| 583 |
+
context_section += f"\n\nConversation Summary (earlier context):\n{summary}\n\nRecent Conversation:\n"
|
| 584 |
+
|
| 585 |
+
for i, ic in enumerate(recent_contexts, 1):
|
| 586 |
+
summary_text = ic.get('summary', '')
|
| 587 |
+
if summary_text:
|
| 588 |
+
context_section += f" {i}. {summary_text}\n"
|
| 589 |
+
|
| 590 |
+
return context_section
|
| 591 |
+
|
| 592 |
+
def _summarize_interaction_contexts(self, interaction_contexts: List[Dict[str, Any]]) -> str:
|
| 593 |
+
"""Summarize older interaction contexts to preserve key context
|
| 594 |
+
|
| 595 |
+
Uses Context Manager structure where interaction_contexts contains:
|
| 596 |
+
- summary: 50-token interaction summary string
|
| 597 |
+
- timestamp: Interaction timestamp
|
| 598 |
+
"""
|
| 599 |
+
if not interaction_contexts:
|
| 600 |
+
return "No prior context."
|
| 601 |
+
|
| 602 |
+
# Extract key topics and themes from summaries
|
| 603 |
+
topics = []
|
| 604 |
+
key_points = []
|
| 605 |
+
|
| 606 |
+
for ic in interaction_contexts:
|
| 607 |
+
summary = ic.get('summary', '')
|
| 608 |
+
|
| 609 |
+
if summary:
|
| 610 |
+
# Extract topics from summary (simple keyword extraction)
|
| 611 |
+
# Summaries are already condensed, so extract meaningful terms
|
| 612 |
+
words = summary.lower().split()
|
| 613 |
+
key_terms = [word for word in words if len(word) > 4][:3]
|
| 614 |
+
topics.extend(key_terms)
|
| 615 |
+
|
| 616 |
+
# Use summary as key point (already a summary)
|
| 617 |
+
key_points.append(summary[:150])
|
| 618 |
+
|
| 619 |
+
# Build summary
|
| 620 |
+
unique_topics = list(set(topics))[:5] # Top 5 unique topics
|
| 621 |
+
recent_points = key_points[-5:] # Last 5 key points
|
| 622 |
+
|
| 623 |
+
summary_text = f"Topics discussed: {', '.join(unique_topics) if unique_topics else 'General discussion'}\n"
|
| 624 |
+
summary_text += f"Key points: {' | '.join(recent_points) if recent_points else 'No specific points'}"
|
| 625 |
+
|
| 626 |
+
return summary_text
|
| 627 |
+
|
| 628 |
+
def _summarize_interactions(self, interactions: List[Dict[str, Any]]) -> str:
|
| 629 |
+
"""Legacy method for backward compatibility - delegates to _summarize_interaction_contexts"""
|
| 630 |
+
# Convert old format to new format if needed
|
| 631 |
+
if interactions and 'summary' in interactions[0]:
|
| 632 |
+
# Already in new format
|
| 633 |
+
return self._summarize_interaction_contexts(interactions)
|
| 634 |
+
else:
|
| 635 |
+
# Old format - convert
|
| 636 |
+
interaction_contexts = []
|
| 637 |
+
for interaction in interactions:
|
| 638 |
+
user_input = interaction.get('user_input', '')
|
| 639 |
+
assistant_response = interaction.get('assistant_response') or interaction.get('response', '')
|
| 640 |
+
# Create a simple summary
|
| 641 |
+
summary = f"User asked: {user_input[:100]}..." if user_input else ""
|
| 642 |
+
if summary:
|
| 643 |
+
interaction_contexts.append({'summary': summary})
|
| 644 |
+
return self._summarize_interaction_contexts(interaction_contexts)
|
| 645 |
+
|
| 646 |
+
def _extract_intent_info(self, agent_outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 647 |
+
"""Extract intent information from agent outputs"""
|
| 648 |
+
for output in agent_outputs:
|
| 649 |
+
if 'primary_intent' in output:
|
| 650 |
+
return {
|
| 651 |
+
'primary_intent': output['primary_intent'],
|
| 652 |
+
'confidence': output.get('confidence_scores', {}).get(output['primary_intent'], 0.5),
|
| 653 |
+
'source_agent': output.get('agent_id', 'unknown')
|
| 654 |
+
}
|
| 655 |
+
return None
|
| 656 |
+
|
| 657 |
+
def _extract_source_references(self, agent_outputs: List[Dict[str, Any]]) -> List[str]:
|
| 658 |
+
"""Extract source references from agent outputs"""
|
| 659 |
+
sources = []
|
| 660 |
+
for output in agent_outputs:
|
| 661 |
+
agent_id = output.get('agent_id', 'unknown')
|
| 662 |
+
sources.append(agent_id)
|
| 663 |
+
return list(set(sources)) # Remove duplicates
|
| 664 |
+
|
| 665 |
+
def _calculate_quality_metrics(self, synthesis_result: Dict[str, Any]) -> Dict[str, Any]:
|
| 666 |
+
"""Calculate quality metrics for synthesis"""
|
| 667 |
+
response = synthesis_result.get('final_response', '')
|
| 668 |
+
|
| 669 |
+
return {
|
| 670 |
+
"length": len(response),
|
| 671 |
+
"word_count": len(response.split()) if response else 0,
|
| 672 |
+
"coherence_score": synthesis_result.get('coherence_score', 0.7),
|
| 673 |
+
"source_count": len(synthesis_result.get('source_references', [])),
|
| 674 |
+
"has_structured_elements": bool(re.search(r'[•\d+\.]', response)) if response else False
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
def _check_intent_alignment(self, synthesis_result: Dict[str, Any], intent_info: Dict[str, Any]) -> Dict[str, Any]:
|
| 678 |
+
"""Check if synthesis aligns with detected intent"""
|
| 679 |
+
# Calculate alignment based on intent confidence and response quality
|
| 680 |
+
intent_confidence = intent_info.get('confidence', 0.5)
|
| 681 |
+
coherence_score = synthesis_result.get('coherence_score', 0.7)
|
| 682 |
+
# Alignment is average of intent confidence and coherence
|
| 683 |
+
alignment_score = (intent_confidence + coherence_score) / 2.0
|
| 684 |
+
|
| 685 |
+
return {
|
| 686 |
+
"intent_detected": intent_info.get('primary_intent'),
|
| 687 |
+
"alignment_score": alignment_score,
|
| 688 |
+
"alignment_verified": alignment_score > 0.7
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
def _identify_improvements(self, response: str) -> List[str]:
|
| 692 |
+
"""Identify opportunities to improve the response"""
|
| 693 |
+
improvements = []
|
| 694 |
+
|
| 695 |
+
if len(response) < 50:
|
| 696 |
+
improvements.append("Could be more detailed")
|
| 697 |
+
|
| 698 |
+
if "?" not in response and len(response.split()) < 100:
|
| 699 |
+
improvements.append("Consider adding examples")
|
| 700 |
+
|
| 701 |
+
return improvements
|
| 702 |
+
|
| 703 |
+
def _get_fallback_response(self, user_input: str, agent_outputs: List[Dict[str, Any]],
|
| 704 |
+
assigned_consultant: Dict[str, Any]) -> Dict[str, Any]:
|
| 705 |
+
"""Provide fallback response when synthesis fails (LLM API failure only)"""
|
| 706 |
+
# Only use fallback when LLM API actually fails - not as default
|
| 707 |
+
if user_input:
|
| 708 |
+
fallback_text = f"Thank you for your question: '{user_input}'. I'm processing your request and will provide a detailed response shortly."
|
| 709 |
+
else:
|
| 710 |
+
fallback_text = "I apologize, but I encountered an issue processing your request. Please try again."
|
| 711 |
+
|
| 712 |
+
return {
|
| 713 |
+
"synthesized_response": fallback_text,
|
| 714 |
+
"draft_response": fallback_text,
|
| 715 |
+
"final_response": fallback_text,
|
| 716 |
+
"assigned_consultant": assigned_consultant,
|
| 717 |
+
"source_references": self._extract_source_references(agent_outputs),
|
| 718 |
+
"coherence_score": 0.5,
|
| 719 |
+
"improvement_opportunities": ["LLM API error - fallback activated"],
|
| 720 |
+
"synthesis_method": "expert_enhanced_fallback",
|
| 721 |
+
"agent_id": self.agent_id,
|
| 722 |
+
"synthesis_quality_metrics": self._calculate_quality_metrics({"final_response": fallback_text}),
|
| 723 |
+
"error": True,
|
| 724 |
+
"synthesis_metadata": {"expert_enhanced": True, "error": True, "llm_api_failed": True}
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
# Backward compatibility: ResponseSynthesisAgent is now EnhancedSynthesisAgent
|
| 729 |
+
ResponseSynthesisAgent = EnhancedSynthesisAgent
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
# Factory function for compatibility
|
| 733 |
+
def create_synthesis_agent(llm_router) -> EnhancedSynthesisAgent:
|
| 734 |
+
"""Factory function to create enhanced synthesis agent"""
|
| 735 |
+
return EnhancedSynthesisAgent(llm_router)
|
src/config.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
import os
|
| 3 |
+
from pydantic_settings import BaseSettings
|
| 4 |
+
|
| 5 |
+
class Settings(BaseSettings):
|
| 6 |
+
# HF Spaces specific settings
|
| 7 |
+
hf_token: str = os.getenv("HF_TOKEN", "")
|
| 8 |
+
hf_cache_dir: str = os.getenv("HF_HOME", "/tmp/huggingface")
|
| 9 |
+
|
| 10 |
+
# Model settings
|
| 11 |
+
default_model: str = "mistralai/Mistral-7B-Instruct-v0.2"
|
| 12 |
+
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 13 |
+
classification_model: str = "cardiffnlp/twitter-roberta-base-emotion"
|
| 14 |
+
|
| 15 |
+
# Performance settings
|
| 16 |
+
max_workers: int = int(os.getenv("MAX_WORKERS", "4"))
|
| 17 |
+
cache_ttl: int = int(os.getenv("CACHE_TTL", "3600"))
|
| 18 |
+
|
| 19 |
+
# Database settings
|
| 20 |
+
db_path: str = os.getenv("DB_PATH", "sessions.db")
|
| 21 |
+
faiss_index_path: str = os.getenv("FAISS_INDEX_PATH", "embeddings.faiss")
|
| 22 |
+
|
| 23 |
+
# Session settings
|
| 24 |
+
session_timeout: int = int(os.getenv("SESSION_TIMEOUT", "3600"))
|
| 25 |
+
max_session_size_mb: int = int(os.getenv("MAX_SESSION_SIZE_MB", "10"))
|
| 26 |
+
|
| 27 |
+
# Mobile optimization settings
|
| 28 |
+
mobile_max_tokens: int = int(os.getenv("MOBILE_MAX_TOKENS", "800"))
|
| 29 |
+
mobile_timeout: int = int(os.getenv("MOBILE_TIMEOUT", "15000"))
|
| 30 |
+
|
| 31 |
+
# Gradio settings
|
| 32 |
+
gradio_port: int = int(os.getenv("GRADIO_PORT", "7860"))
|
| 33 |
+
gradio_host: str = os.getenv("GRADIO_HOST", "0.0.0.0")
|
| 34 |
+
|
| 35 |
+
# Logging settings
|
| 36 |
+
log_level: str = os.getenv("LOG_LEVEL", "INFO")
|
| 37 |
+
log_format: str = os.getenv("LOG_FORMAT", "json")
|
| 38 |
+
|
| 39 |
+
class Config:
|
| 40 |
+
env_file = ".env"
|
| 41 |
+
|
| 42 |
+
settings = Settings()
|
src/context_manager.py
ADDED
|
@@ -0,0 +1,1695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# context_manager.py
|
| 2 |
+
import sqlite3
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import uuid
|
| 6 |
+
import hashlib
|
| 7 |
+
import threading
|
| 8 |
+
import time
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from datetime import datetime, timedelta
|
| 11 |
+
from typing import Dict, Optional, List
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TransactionManager:
|
| 17 |
+
"""Manage database transactions with proper locking"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, db_path):
|
| 20 |
+
self.db_path = db_path
|
| 21 |
+
self._lock = threading.RLock()
|
| 22 |
+
self._connections = {}
|
| 23 |
+
|
| 24 |
+
@contextmanager
|
| 25 |
+
def transaction(self, session_id=None):
|
| 26 |
+
"""Context manager for database transactions with automatic rollback"""
|
| 27 |
+
conn = None
|
| 28 |
+
cursor = None
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
with self._lock:
|
| 32 |
+
conn = sqlite3.connect(self.db_path, isolation_level='IMMEDIATE')
|
| 33 |
+
conn.execute('PRAGMA journal_mode=WAL') # Write-Ahead Logging for better concurrency
|
| 34 |
+
conn.execute('PRAGMA busy_timeout=5000') # 5 second timeout for locks
|
| 35 |
+
cursor = conn.cursor()
|
| 36 |
+
|
| 37 |
+
yield cursor
|
| 38 |
+
|
| 39 |
+
conn.commit()
|
| 40 |
+
logger.debug(f"Transaction committed for session {session_id}")
|
| 41 |
+
|
| 42 |
+
except Exception as e:
|
| 43 |
+
if conn:
|
| 44 |
+
conn.rollback()
|
| 45 |
+
logger.error(f"Transaction rolled back for session {session_id}: {e}")
|
| 46 |
+
raise
|
| 47 |
+
finally:
|
| 48 |
+
if conn:
|
| 49 |
+
conn.close()
|
| 50 |
+
|
| 51 |
+
class EfficientContextManager:
|
| 52 |
+
def __init__(self, llm_router=None):
|
| 53 |
+
self.session_cache = {} # In-memory for active sessions
|
| 54 |
+
self._session_cache = {} # Enhanced in-memory cache with timestamps
|
| 55 |
+
self.cache_config = {
|
| 56 |
+
"max_session_size": 10, # MB per session
|
| 57 |
+
"ttl": 3600, # 1 hour
|
| 58 |
+
"compression": "gzip",
|
| 59 |
+
"eviction_policy": "LRU"
|
| 60 |
+
}
|
| 61 |
+
self.db_path = "sessions.db"
|
| 62 |
+
self.llm_router = llm_router # For generating context summaries
|
| 63 |
+
logger.info(f"Initializing ContextManager with DB path: {self.db_path}")
|
| 64 |
+
self.transaction_manager = TransactionManager(self.db_path)
|
| 65 |
+
self._init_database()
|
| 66 |
+
self.optimize_database_indexes()
|
| 67 |
+
|
| 68 |
+
def _init_database(self):
|
| 69 |
+
"""Initialize database and create tables"""
|
| 70 |
+
try:
|
| 71 |
+
logger.info("Initializing database...")
|
| 72 |
+
conn = sqlite3.connect(self.db_path)
|
| 73 |
+
cursor = conn.cursor()
|
| 74 |
+
|
| 75 |
+
# Create sessions table if not exists
|
| 76 |
+
cursor.execute("""
|
| 77 |
+
CREATE TABLE IF NOT EXISTS sessions (
|
| 78 |
+
session_id TEXT PRIMARY KEY,
|
| 79 |
+
user_id TEXT DEFAULT 'Test_Any',
|
| 80 |
+
created_at TIMESTAMP,
|
| 81 |
+
last_activity TIMESTAMP,
|
| 82 |
+
context_data TEXT,
|
| 83 |
+
user_metadata TEXT
|
| 84 |
+
)
|
| 85 |
+
""")
|
| 86 |
+
|
| 87 |
+
# Add user_id column to existing sessions table if it doesn't exist
|
| 88 |
+
try:
|
| 89 |
+
cursor.execute("ALTER TABLE sessions ADD COLUMN user_id TEXT DEFAULT 'Test_Any'")
|
| 90 |
+
logger.info("✓ Added user_id column to sessions table")
|
| 91 |
+
except sqlite3.OperationalError:
|
| 92 |
+
# Column already exists
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
logger.info("✓ Sessions table ready")
|
| 96 |
+
|
| 97 |
+
# Create interactions table
|
| 98 |
+
cursor.execute("""
|
| 99 |
+
CREATE TABLE IF NOT EXISTS interactions (
|
| 100 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 101 |
+
session_id TEXT REFERENCES sessions(session_id),
|
| 102 |
+
user_input TEXT,
|
| 103 |
+
context_snapshot TEXT,
|
| 104 |
+
created_at TIMESTAMP,
|
| 105 |
+
FOREIGN KEY(session_id) REFERENCES sessions(session_id)
|
| 106 |
+
)
|
| 107 |
+
""")
|
| 108 |
+
logger.info("✓ Interactions table ready")
|
| 109 |
+
|
| 110 |
+
# Create user_contexts table (persistent user persona summaries)
|
| 111 |
+
cursor.execute("""
|
| 112 |
+
CREATE TABLE IF NOT EXISTS user_contexts (
|
| 113 |
+
user_id TEXT PRIMARY KEY,
|
| 114 |
+
persona_summary TEXT,
|
| 115 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 116 |
+
)
|
| 117 |
+
""")
|
| 118 |
+
logger.info("✓ User contexts table ready")
|
| 119 |
+
|
| 120 |
+
# Create session_contexts table (session summaries)
|
| 121 |
+
cursor.execute("""
|
| 122 |
+
CREATE TABLE IF NOT EXISTS session_contexts (
|
| 123 |
+
session_id TEXT PRIMARY KEY,
|
| 124 |
+
user_id TEXT,
|
| 125 |
+
session_summary TEXT,
|
| 126 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 127 |
+
FOREIGN KEY(session_id) REFERENCES sessions(session_id),
|
| 128 |
+
FOREIGN KEY(user_id) REFERENCES user_contexts(user_id)
|
| 129 |
+
)
|
| 130 |
+
""")
|
| 131 |
+
logger.info("✓ Session contexts table ready")
|
| 132 |
+
|
| 133 |
+
# Create interaction_contexts table (individual interaction summaries)
|
| 134 |
+
cursor.execute("""
|
| 135 |
+
CREATE TABLE IF NOT EXISTS interaction_contexts (
|
| 136 |
+
interaction_id TEXT PRIMARY KEY,
|
| 137 |
+
session_id TEXT,
|
| 138 |
+
user_input TEXT,
|
| 139 |
+
system_response TEXT,
|
| 140 |
+
interaction_summary TEXT,
|
| 141 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 142 |
+
FOREIGN KEY(session_id) REFERENCES sessions(session_id)
|
| 143 |
+
)
|
| 144 |
+
""")
|
| 145 |
+
logger.info("✓ Interaction contexts table ready")
|
| 146 |
+
|
| 147 |
+
conn.commit()
|
| 148 |
+
conn.close()
|
| 149 |
+
|
| 150 |
+
# Update schema with new columns and tables for user change tracking
|
| 151 |
+
self._update_database_schema()
|
| 152 |
+
|
| 153 |
+
logger.info("Database initialization complete")
|
| 154 |
+
|
| 155 |
+
except Exception as e:
|
| 156 |
+
logger.error(f"Database initialization error: {e}", exc_info=True)
|
| 157 |
+
|
| 158 |
+
def _update_database_schema(self):
|
| 159 |
+
"""Add missing columns and tables for user change tracking"""
|
| 160 |
+
try:
|
| 161 |
+
conn = sqlite3.connect(self.db_path)
|
| 162 |
+
cursor = conn.cursor()
|
| 163 |
+
|
| 164 |
+
# Add needs_refresh column to interaction_contexts
|
| 165 |
+
try:
|
| 166 |
+
cursor.execute("""
|
| 167 |
+
ALTER TABLE interaction_contexts
|
| 168 |
+
ADD COLUMN needs_refresh INTEGER DEFAULT 0
|
| 169 |
+
""")
|
| 170 |
+
logger.info("✓ Added needs_refresh column to interaction_contexts")
|
| 171 |
+
except sqlite3.OperationalError:
|
| 172 |
+
pass # Column already exists
|
| 173 |
+
|
| 174 |
+
# Create user change log table
|
| 175 |
+
cursor.execute("""
|
| 176 |
+
CREATE TABLE IF NOT EXISTS user_change_log (
|
| 177 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 178 |
+
session_id TEXT,
|
| 179 |
+
old_user_id TEXT,
|
| 180 |
+
new_user_id TEXT,
|
| 181 |
+
timestamp TIMESTAMP,
|
| 182 |
+
FOREIGN KEY(session_id) REFERENCES sessions(session_id)
|
| 183 |
+
)
|
| 184 |
+
""")
|
| 185 |
+
|
| 186 |
+
conn.commit()
|
| 187 |
+
conn.close()
|
| 188 |
+
logger.info("✓ Database schema updated successfully for user change tracking")
|
| 189 |
+
|
| 190 |
+
# Update interactions table for deduplication
|
| 191 |
+
self._update_interactions_table()
|
| 192 |
+
|
| 193 |
+
except Exception as e:
|
| 194 |
+
logger.error(f"Schema update error: {e}", exc_info=True)
|
| 195 |
+
|
| 196 |
+
def _update_interactions_table(self):
|
| 197 |
+
"""Add interaction_hash column for deduplication"""
|
| 198 |
+
try:
|
| 199 |
+
conn = sqlite3.connect(self.db_path)
|
| 200 |
+
cursor = conn.cursor()
|
| 201 |
+
|
| 202 |
+
# Check if column already exists
|
| 203 |
+
cursor.execute("PRAGMA table_info(interactions)")
|
| 204 |
+
columns = [row[1] for row in cursor.fetchall()]
|
| 205 |
+
|
| 206 |
+
# Add interaction_hash column if it doesn't exist
|
| 207 |
+
if 'interaction_hash' not in columns:
|
| 208 |
+
try:
|
| 209 |
+
cursor.execute("""
|
| 210 |
+
ALTER TABLE interactions
|
| 211 |
+
ADD COLUMN interaction_hash TEXT
|
| 212 |
+
""")
|
| 213 |
+
logger.info("✓ Added interaction_hash column to interactions table")
|
| 214 |
+
except sqlite3.OperationalError:
|
| 215 |
+
pass # Column already exists
|
| 216 |
+
|
| 217 |
+
# Create unique index for deduplication (this enforces uniqueness)
|
| 218 |
+
try:
|
| 219 |
+
cursor.execute("""
|
| 220 |
+
CREATE UNIQUE INDEX IF NOT EXISTS idx_interaction_hash_unique
|
| 221 |
+
ON interactions(interaction_hash)
|
| 222 |
+
""")
|
| 223 |
+
logger.info("✓ Created unique index on interaction_hash")
|
| 224 |
+
except sqlite3.OperationalError:
|
| 225 |
+
# Index might already exist, try non-unique index as fallback
|
| 226 |
+
cursor.execute("""
|
| 227 |
+
CREATE INDEX IF NOT EXISTS idx_interaction_hash
|
| 228 |
+
ON interactions(interaction_hash)
|
| 229 |
+
""")
|
| 230 |
+
|
| 231 |
+
conn.commit()
|
| 232 |
+
conn.close()
|
| 233 |
+
logger.info("✓ Interactions table updated for deduplication")
|
| 234 |
+
|
| 235 |
+
except Exception as e:
|
| 236 |
+
logger.error(f"Error updating interactions table: {e}", exc_info=True)
|
| 237 |
+
|
| 238 |
+
async def manage_context(self, session_id: str, user_input: str, user_id: str = "Test_Any") -> dict:
|
| 239 |
+
"""
|
| 240 |
+
Efficient context management with separated session/user caching
|
| 241 |
+
STEP 1: Fetch User Context (if available)
|
| 242 |
+
STEP 2: Get Previous Interaction Contexts
|
| 243 |
+
STEP 3: Combine for workflow use
|
| 244 |
+
"""
|
| 245 |
+
# Use session-only cache key to prevent user_id conflicts
|
| 246 |
+
session_cache_key = f"session_{session_id}"
|
| 247 |
+
user_cache_key = f"user_{user_id}"
|
| 248 |
+
|
| 249 |
+
# Get session context from cache
|
| 250 |
+
session_context = self._get_from_memory_cache(session_cache_key)
|
| 251 |
+
|
| 252 |
+
# Check if cached session context matches current user_id
|
| 253 |
+
# Handle both old and new cache formats
|
| 254 |
+
cached_entry = self.session_cache.get(session_cache_key)
|
| 255 |
+
if cached_entry:
|
| 256 |
+
# Extract actual context from cache entry
|
| 257 |
+
if isinstance(cached_entry, dict) and 'value' in cached_entry:
|
| 258 |
+
actual_context = cached_entry.get('value', {})
|
| 259 |
+
else:
|
| 260 |
+
actual_context = cached_entry
|
| 261 |
+
|
| 262 |
+
if actual_context and actual_context.get("user_id") != user_id:
|
| 263 |
+
# User changed, invalidate session cache
|
| 264 |
+
logger.info(f"User mismatch in cache for session {session_id}, invalidating cache")
|
| 265 |
+
session_context = None
|
| 266 |
+
if session_cache_key in self.session_cache:
|
| 267 |
+
del self.session_cache[session_cache_key]
|
| 268 |
+
else:
|
| 269 |
+
session_context = actual_context
|
| 270 |
+
|
| 271 |
+
# Get user context separately
|
| 272 |
+
user_context = self._get_from_memory_cache(user_cache_key)
|
| 273 |
+
|
| 274 |
+
if not session_context:
|
| 275 |
+
# Retrieve from database with user context
|
| 276 |
+
session_context = await self._retrieve_from_db(session_id, user_input, user_id)
|
| 277 |
+
|
| 278 |
+
# Step 2: Cache session context with TTL
|
| 279 |
+
self.add_context_cache(session_cache_key, session_context, ttl=self.cache_config.get("ttl", 3600))
|
| 280 |
+
|
| 281 |
+
# Handle user context separately - load only once and cache thereafter
|
| 282 |
+
# Cache does not refer to database after initial load
|
| 283 |
+
if not user_context or not user_context.get("user_context_loaded"):
|
| 284 |
+
user_context_data = await self.get_user_context(user_id)
|
| 285 |
+
user_context = {
|
| 286 |
+
"user_context": user_context_data,
|
| 287 |
+
"user_context_loaded": True,
|
| 288 |
+
"user_id": user_id
|
| 289 |
+
}
|
| 290 |
+
# Cache user context separately - this is the only database query for user context
|
| 291 |
+
self._warm_memory_cache(user_cache_key, user_context)
|
| 292 |
+
logger.debug(f"User context loaded once for {user_id} and cached")
|
| 293 |
+
else:
|
| 294 |
+
# User context already cached, use it without database query
|
| 295 |
+
logger.debug(f"Using cached user context for {user_id}")
|
| 296 |
+
|
| 297 |
+
# Merge contexts without duplication
|
| 298 |
+
merged_context = {
|
| 299 |
+
**session_context,
|
| 300 |
+
"user_context": user_context.get("user_context", ""),
|
| 301 |
+
"user_context_loaded": True,
|
| 302 |
+
"user_id": user_id # Ensure current user_id is used
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
# Update context with new interaction
|
| 306 |
+
updated_context = self._update_context(merged_context, user_input, user_id=user_id)
|
| 307 |
+
|
| 308 |
+
return self._optimize_context(updated_context)
|
| 309 |
+
|
| 310 |
+
async def get_user_context(self, user_id: str) -> str:
|
| 311 |
+
"""
|
| 312 |
+
STEP 1: Fetch or generate User Context (500-token persona summary)
|
| 313 |
+
Available for all interactions except first time per user
|
| 314 |
+
"""
|
| 315 |
+
try:
|
| 316 |
+
conn = sqlite3.connect(self.db_path)
|
| 317 |
+
cursor = conn.cursor()
|
| 318 |
+
|
| 319 |
+
# Check if user context exists
|
| 320 |
+
cursor.execute("""
|
| 321 |
+
SELECT persona_summary FROM user_contexts WHERE user_id = ?
|
| 322 |
+
""", (user_id,))
|
| 323 |
+
|
| 324 |
+
row = cursor.fetchone()
|
| 325 |
+
if row and row[0]:
|
| 326 |
+
# Existing user context found
|
| 327 |
+
conn.close()
|
| 328 |
+
logger.info(f"✓ User context loaded for {user_id}")
|
| 329 |
+
return row[0]
|
| 330 |
+
|
| 331 |
+
# Generate new user context from all historical data
|
| 332 |
+
logger.info(f"Generating new user context for {user_id}")
|
| 333 |
+
|
| 334 |
+
# Fetch all historical Session and Interaction contexts for this user
|
| 335 |
+
all_session_summaries = []
|
| 336 |
+
all_interaction_summaries = []
|
| 337 |
+
|
| 338 |
+
# Get all session contexts
|
| 339 |
+
cursor.execute("""
|
| 340 |
+
SELECT session_summary FROM session_contexts WHERE user_id = ?
|
| 341 |
+
ORDER BY created_at DESC LIMIT 50
|
| 342 |
+
""", (user_id,))
|
| 343 |
+
for row in cursor.fetchall():
|
| 344 |
+
if row[0]:
|
| 345 |
+
all_session_summaries.append(row[0])
|
| 346 |
+
|
| 347 |
+
# Get all interaction contexts
|
| 348 |
+
cursor.execute("""
|
| 349 |
+
SELECT ic.interaction_summary
|
| 350 |
+
FROM interaction_contexts ic
|
| 351 |
+
JOIN sessions s ON ic.session_id = s.session_id
|
| 352 |
+
WHERE s.user_id = ?
|
| 353 |
+
ORDER BY ic.created_at DESC LIMIT 100
|
| 354 |
+
""", (user_id,))
|
| 355 |
+
for row in cursor.fetchall():
|
| 356 |
+
if row[0]:
|
| 357 |
+
all_interaction_summaries.append(row[0])
|
| 358 |
+
|
| 359 |
+
conn.close()
|
| 360 |
+
|
| 361 |
+
if not all_session_summaries and not all_interaction_summaries:
|
| 362 |
+
# First time user - no context to generate
|
| 363 |
+
logger.info(f"No historical data for {user_id} - first time user")
|
| 364 |
+
return ""
|
| 365 |
+
|
| 366 |
+
# Generate persona summary using LLM (500 tokens)
|
| 367 |
+
historical_data = "\n\n".join(all_session_summaries + all_interaction_summaries[:20])
|
| 368 |
+
|
| 369 |
+
if self.llm_router:
|
| 370 |
+
prompt = f"""Generate a concise 500-token persona summary for user {user_id} based on their interaction history:
|
| 371 |
+
|
| 372 |
+
Historical Context:
|
| 373 |
+
{historical_data}
|
| 374 |
+
|
| 375 |
+
Create a persona summary that captures:
|
| 376 |
+
- Communication style and preferences
|
| 377 |
+
- Common topics and interests
|
| 378 |
+
- Interaction patterns
|
| 379 |
+
- Key information shared across sessions
|
| 380 |
+
|
| 381 |
+
Keep the summary concise and focused (approximately 500 tokens)."""
|
| 382 |
+
|
| 383 |
+
try:
|
| 384 |
+
persona_summary = await self.llm_router.route_inference(
|
| 385 |
+
task_type="general_reasoning",
|
| 386 |
+
prompt=prompt,
|
| 387 |
+
max_tokens=500,
|
| 388 |
+
temperature=0.7
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
if persona_summary and isinstance(persona_summary, str) and persona_summary.strip():
|
| 392 |
+
# Store in database
|
| 393 |
+
conn = sqlite3.connect(self.db_path)
|
| 394 |
+
cursor = conn.cursor()
|
| 395 |
+
cursor.execute("""
|
| 396 |
+
INSERT OR REPLACE INTO user_contexts (user_id, persona_summary, updated_at)
|
| 397 |
+
VALUES (?, ?, ?)
|
| 398 |
+
""", (user_id, persona_summary.strip(), datetime.now().isoformat()))
|
| 399 |
+
conn.commit()
|
| 400 |
+
conn.close()
|
| 401 |
+
|
| 402 |
+
logger.info(f"✓ Generated and stored user context for {user_id}")
|
| 403 |
+
return persona_summary.strip()
|
| 404 |
+
except Exception as e:
|
| 405 |
+
logger.error(f"Error generating user context: {e}", exc_info=True)
|
| 406 |
+
|
| 407 |
+
# Fallback: Return empty if LLM fails
|
| 408 |
+
logger.warning(f"Could not generate user context for {user_id} - using empty")
|
| 409 |
+
return ""
|
| 410 |
+
|
| 411 |
+
except Exception as e:
|
| 412 |
+
logger.error(f"Error getting user context: {e}", exc_info=True)
|
| 413 |
+
return ""
|
| 414 |
+
|
| 415 |
+
async def generate_interaction_context(self, interaction_id: str, session_id: str,
|
| 416 |
+
user_input: str, system_response: str,
|
| 417 |
+
user_id: str = "Test_Any") -> str:
|
| 418 |
+
"""
|
| 419 |
+
STEP 2: Generate Interaction Context (50-token summary)
|
| 420 |
+
Called after each response
|
| 421 |
+
"""
|
| 422 |
+
try:
|
| 423 |
+
if not self.llm_router:
|
| 424 |
+
return ""
|
| 425 |
+
|
| 426 |
+
prompt = f"""Summarize this interaction in approximately 50 tokens:
|
| 427 |
+
|
| 428 |
+
User Input: {user_input[:200]}
|
| 429 |
+
System Response: {system_response[:300]}
|
| 430 |
+
|
| 431 |
+
Provide a brief summary capturing the key exchange."""
|
| 432 |
+
|
| 433 |
+
try:
|
| 434 |
+
summary = await self.llm_router.route_inference(
|
| 435 |
+
task_type="general_reasoning",
|
| 436 |
+
prompt=prompt,
|
| 437 |
+
max_tokens=50,
|
| 438 |
+
temperature=0.7
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
if summary and isinstance(summary, str) and summary.strip():
|
| 442 |
+
# Store in database
|
| 443 |
+
conn = sqlite3.connect(self.db_path)
|
| 444 |
+
cursor = conn.cursor()
|
| 445 |
+
created_at = datetime.now().isoformat()
|
| 446 |
+
cursor.execute("""
|
| 447 |
+
INSERT OR REPLACE INTO interaction_contexts
|
| 448 |
+
(interaction_id, session_id, user_input, system_response, interaction_summary, created_at)
|
| 449 |
+
VALUES (?, ?, ?, ?, ?, ?)
|
| 450 |
+
""", (
|
| 451 |
+
interaction_id,
|
| 452 |
+
session_id,
|
| 453 |
+
user_input[:500],
|
| 454 |
+
system_response[:1000],
|
| 455 |
+
summary.strip(),
|
| 456 |
+
created_at
|
| 457 |
+
))
|
| 458 |
+
conn.commit()
|
| 459 |
+
conn.close()
|
| 460 |
+
|
| 461 |
+
# Update cache immediately with new interaction context
|
| 462 |
+
# This ensures cache is synchronized with database at the same time
|
| 463 |
+
self._update_cache_with_interaction_context(session_id, summary.strip(), created_at)
|
| 464 |
+
|
| 465 |
+
logger.info(f"✓ Generated interaction context for {interaction_id} and updated cache")
|
| 466 |
+
return summary.strip()
|
| 467 |
+
except Exception as e:
|
| 468 |
+
logger.error(f"Error generating interaction context: {e}", exc_info=True)
|
| 469 |
+
|
| 470 |
+
# Fallback on LLM failure
|
| 471 |
+
return ""
|
| 472 |
+
|
| 473 |
+
except Exception as e:
|
| 474 |
+
logger.error(f"Error in generate_interaction_context: {e}", exc_info=True)
|
| 475 |
+
return ""
|
| 476 |
+
|
| 477 |
+
async def generate_session_context(self, session_id: str, user_id: str = "Test_Any") -> str:
|
| 478 |
+
"""
|
| 479 |
+
Generate Session Context (100-token summary) at every turn
|
| 480 |
+
Uses cached interaction contexts instead of querying database
|
| 481 |
+
Updates both database and cache immediately
|
| 482 |
+
"""
|
| 483 |
+
try:
|
| 484 |
+
# Get interaction contexts from cache (no database query)
|
| 485 |
+
session_cache_key = f"session_{session_id}"
|
| 486 |
+
cached_context = self.session_cache.get(session_cache_key)
|
| 487 |
+
|
| 488 |
+
if not cached_context:
|
| 489 |
+
logger.warning(f"No cached context found for session {session_id}, cannot generate session context")
|
| 490 |
+
return ""
|
| 491 |
+
|
| 492 |
+
interaction_contexts = cached_context.get('interaction_contexts', [])
|
| 493 |
+
|
| 494 |
+
if not interaction_contexts:
|
| 495 |
+
logger.info(f"No interaction contexts available for session {session_id} to summarize")
|
| 496 |
+
return ""
|
| 497 |
+
|
| 498 |
+
# Use cached interaction contexts (from cache, not database)
|
| 499 |
+
interaction_summaries = [ic.get('summary', '') for ic in interaction_contexts if ic.get('summary')]
|
| 500 |
+
|
| 501 |
+
if not interaction_summaries:
|
| 502 |
+
logger.info(f"No interaction summaries available for session {session_id}")
|
| 503 |
+
return ""
|
| 504 |
+
|
| 505 |
+
# Generate session summary using LLM (100 tokens)
|
| 506 |
+
if self.llm_router:
|
| 507 |
+
combined_context = "\n".join(interaction_summaries)
|
| 508 |
+
|
| 509 |
+
prompt = f"""Summarize this session's interactions in approximately 100 tokens:
|
| 510 |
+
|
| 511 |
+
Interaction Summaries:
|
| 512 |
+
{combined_context}
|
| 513 |
+
|
| 514 |
+
Create a concise session summary capturing:
|
| 515 |
+
- Main topics discussed
|
| 516 |
+
- Key outcomes or information shared
|
| 517 |
+
- User's focus areas
|
| 518 |
+
|
| 519 |
+
Keep the summary concise (approximately 100 tokens)."""
|
| 520 |
+
|
| 521 |
+
try:
|
| 522 |
+
session_summary = await self.llm_router.route_inference(
|
| 523 |
+
task_type="general_reasoning",
|
| 524 |
+
prompt=prompt,
|
| 525 |
+
max_tokens=100,
|
| 526 |
+
temperature=0.7
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
if session_summary and isinstance(session_summary, str) and session_summary.strip():
|
| 530 |
+
# Store in database
|
| 531 |
+
created_at = datetime.now().isoformat()
|
| 532 |
+
conn = sqlite3.connect(self.db_path)
|
| 533 |
+
cursor = conn.cursor()
|
| 534 |
+
cursor.execute("""
|
| 535 |
+
INSERT OR REPLACE INTO session_contexts
|
| 536 |
+
(session_id, user_id, session_summary, created_at)
|
| 537 |
+
VALUES (?, ?, ?, ?)
|
| 538 |
+
""", (session_id, user_id, session_summary.strip(), created_at))
|
| 539 |
+
conn.commit()
|
| 540 |
+
conn.close()
|
| 541 |
+
|
| 542 |
+
# Update cache immediately with new session context
|
| 543 |
+
# This ensures cache is synchronized with database at the same time
|
| 544 |
+
self._update_cache_with_session_context(session_id, session_summary.strip(), created_at)
|
| 545 |
+
|
| 546 |
+
logger.info(f"✓ Generated session context for {session_id} and updated cache")
|
| 547 |
+
return session_summary.strip()
|
| 548 |
+
except Exception as e:
|
| 549 |
+
logger.error(f"Error generating session context: {e}", exc_info=True)
|
| 550 |
+
|
| 551 |
+
# Fallback on LLM failure
|
| 552 |
+
return ""
|
| 553 |
+
|
| 554 |
+
except Exception as e:
|
| 555 |
+
logger.error(f"Error in generate_session_context: {e}", exc_info=True)
|
| 556 |
+
return ""
|
| 557 |
+
|
| 558 |
+
async def end_session(self, session_id: str, user_id: str = "Test_Any"):
|
| 559 |
+
"""
|
| 560 |
+
End session and clear cache
|
| 561 |
+
Note: Session context is already generated at every turn, so this just clears cache
|
| 562 |
+
"""
|
| 563 |
+
try:
|
| 564 |
+
# Session context is already generated at every turn (no need to regenerate)
|
| 565 |
+
# Clear in-memory cache for this session (session-only key)
|
| 566 |
+
session_cache_key = f"session_{session_id}"
|
| 567 |
+
if session_cache_key in self.session_cache:
|
| 568 |
+
del self.session_cache[session_cache_key]
|
| 569 |
+
logger.info(f"✓ Cleared cache for session {session_id}")
|
| 570 |
+
|
| 571 |
+
except Exception as e:
|
| 572 |
+
logger.error(f"Error ending session: {e}", exc_info=True)
|
| 573 |
+
|
| 574 |
+
def _clear_user_cache_on_change(self, session_id: str, new_user_id: str, old_user_id: str):
|
| 575 |
+
"""Clear cache entries when user changes"""
|
| 576 |
+
if new_user_id != old_user_id:
|
| 577 |
+
# Clear old composite cache keys
|
| 578 |
+
old_cache_key = f"{session_id}_{old_user_id}"
|
| 579 |
+
if old_cache_key in self.session_cache:
|
| 580 |
+
del self.session_cache[old_cache_key]
|
| 581 |
+
logger.info(f"Cleared old cache for user {old_user_id} on session {session_id}")
|
| 582 |
+
|
| 583 |
+
def _optimize_context(self, context: dict, relevance_classification: Optional[Dict] = None) -> dict:
|
| 584 |
+
"""
|
| 585 |
+
Optimize context for LLM consumption with relevance filtering support
|
| 586 |
+
Format: [Session Context] + [User Context (conditional)] + [Interaction Context #N, #N-1, ...]
|
| 587 |
+
|
| 588 |
+
Args:
|
| 589 |
+
context: Base context dictionary
|
| 590 |
+
relevance_classification: Optional relevance classification results with dynamic user context
|
| 591 |
+
|
| 592 |
+
Applies smart pruning before formatting.
|
| 593 |
+
"""
|
| 594 |
+
# Step 4: Prune context if it exceeds token limits
|
| 595 |
+
pruned_context = self.prune_context(context, max_tokens=2000)
|
| 596 |
+
|
| 597 |
+
# Get context mode (fresh or relevant)
|
| 598 |
+
session_id = pruned_context.get("session_id")
|
| 599 |
+
context_mode = self.get_context_mode(session_id)
|
| 600 |
+
|
| 601 |
+
interaction_contexts = pruned_context.get("interaction_contexts", [])
|
| 602 |
+
session_context = pruned_context.get("session_context", {})
|
| 603 |
+
session_summary = session_context.get("summary", "") if isinstance(session_context, dict) else ""
|
| 604 |
+
|
| 605 |
+
# MODIFIED: Conditional user context inclusion based on mode and relevance
|
| 606 |
+
user_context = ""
|
| 607 |
+
if context_mode == 'relevant' and relevance_classification:
|
| 608 |
+
# Use dynamic relevant summaries from relevance classification
|
| 609 |
+
user_context = relevance_classification.get('combined_user_context', '')
|
| 610 |
+
|
| 611 |
+
if user_context:
|
| 612 |
+
logger.info(
|
| 613 |
+
f"Using dynamic relevant context: {len(relevance_classification.get('relevant_summaries', []))} "
|
| 614 |
+
f"sessions summarized for session {session_id}"
|
| 615 |
+
)
|
| 616 |
+
elif context_mode == 'relevant' and not relevance_classification:
|
| 617 |
+
# Fallback: Use traditional user context if relevance classification unavailable
|
| 618 |
+
user_context = pruned_context.get("user_context", "")
|
| 619 |
+
logger.debug(f"Relevant mode but no classification, using traditional user context")
|
| 620 |
+
# If context_mode == 'fresh', user_context remains empty (no user context)
|
| 621 |
+
|
| 622 |
+
# Format interaction contexts as requested
|
| 623 |
+
formatted_interactions = []
|
| 624 |
+
for idx, ic in enumerate(interaction_contexts[:10]): # Last 10 interactions
|
| 625 |
+
formatted_interactions.append(f"[Interaction Context #{len(interaction_contexts) - idx}]\n{ic.get('summary', '')}")
|
| 626 |
+
|
| 627 |
+
# Combine Session Context + (Conditional) User Context + Interaction Contexts
|
| 628 |
+
combined_context = ""
|
| 629 |
+
if session_summary:
|
| 630 |
+
combined_context += f"[Session Context]\n{session_summary}\n\n"
|
| 631 |
+
|
| 632 |
+
# Include user context only if available and in relevant mode
|
| 633 |
+
if user_context:
|
| 634 |
+
context_label = "[Relevant User Context]" if context_mode == 'relevant' else "[User Context]"
|
| 635 |
+
combined_context += f"{context_label}\n{user_context}\n\n"
|
| 636 |
+
|
| 637 |
+
if formatted_interactions:
|
| 638 |
+
combined_context += "\n\n".join(formatted_interactions)
|
| 639 |
+
|
| 640 |
+
return {
|
| 641 |
+
"session_id": pruned_context.get("session_id"),
|
| 642 |
+
"user_id": pruned_context.get("user_id", "Test_Any"),
|
| 643 |
+
"user_context": user_context, # Dynamic summaries OR empty
|
| 644 |
+
"session_context": session_context,
|
| 645 |
+
"interaction_contexts": interaction_contexts,
|
| 646 |
+
"combined_context": combined_context,
|
| 647 |
+
"context_mode": context_mode, # Include mode for debugging
|
| 648 |
+
"relevance_metadata": relevance_classification.get('relevance_scores', {}) if relevance_classification else {},
|
| 649 |
+
"preferences": pruned_context.get("preferences", {}),
|
| 650 |
+
"active_tasks": pruned_context.get("active_tasks", []),
|
| 651 |
+
"last_activity": pruned_context.get("last_activity")
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
def _get_from_memory_cache(self, cache_key: str) -> dict:
|
| 655 |
+
"""
|
| 656 |
+
Retrieve context from in-memory session cache with expiration check
|
| 657 |
+
"""
|
| 658 |
+
cached = self.session_cache.get(cache_key)
|
| 659 |
+
if not cached:
|
| 660 |
+
return None
|
| 661 |
+
|
| 662 |
+
# Check if it's the new format with expiration
|
| 663 |
+
if isinstance(cached, dict) and 'value' in cached:
|
| 664 |
+
# New format with TTL
|
| 665 |
+
if self._is_cache_expired(cached):
|
| 666 |
+
# Remove expired cache entry
|
| 667 |
+
del self.session_cache[cache_key]
|
| 668 |
+
logger.debug(f"Cache expired for key: {cache_key}")
|
| 669 |
+
return None
|
| 670 |
+
return cached.get('value')
|
| 671 |
+
else:
|
| 672 |
+
# Old format (direct value) - return as-is for backward compatibility
|
| 673 |
+
return cached
|
| 674 |
+
|
| 675 |
+
def _is_cache_expired(self, cache_entry: dict) -> bool:
|
| 676 |
+
"""
|
| 677 |
+
Check if cache entry has expired based on TTL
|
| 678 |
+
"""
|
| 679 |
+
if not isinstance(cache_entry, dict):
|
| 680 |
+
return True
|
| 681 |
+
|
| 682 |
+
expires = cache_entry.get('expires')
|
| 683 |
+
if not expires:
|
| 684 |
+
return False # No expiration set, consider valid
|
| 685 |
+
|
| 686 |
+
return time.time() > expires
|
| 687 |
+
|
| 688 |
+
def add_context_cache(self, key: str, value: dict, ttl: int = 3600):
|
| 689 |
+
"""
|
| 690 |
+
Step 2: Implement Context Caching with TTL expiration
|
| 691 |
+
|
| 692 |
+
Add context to cache with expiration time.
|
| 693 |
+
|
| 694 |
+
Args:
|
| 695 |
+
key: Cache key
|
| 696 |
+
value: Value to cache (dict)
|
| 697 |
+
ttl: Time to live in seconds (default 3600 = 1 hour)
|
| 698 |
+
"""
|
| 699 |
+
import time
|
| 700 |
+
self.session_cache[key] = {
|
| 701 |
+
'value': value,
|
| 702 |
+
'expires': time.time() + ttl,
|
| 703 |
+
'timestamp': time.time()
|
| 704 |
+
}
|
| 705 |
+
logger.debug(f"Cached context for key: {key} with TTL: {ttl}s")
|
| 706 |
+
|
| 707 |
+
def get_token_count(self, text: str) -> int:
|
| 708 |
+
"""
|
| 709 |
+
Approximate token count for text (4 characters ≈ 1 token)
|
| 710 |
+
|
| 711 |
+
Args:
|
| 712 |
+
text: Text to count tokens for
|
| 713 |
+
|
| 714 |
+
Returns:
|
| 715 |
+
Approximate token count
|
| 716 |
+
"""
|
| 717 |
+
if not text:
|
| 718 |
+
return 0
|
| 719 |
+
# Simple approximation: 4 characters per token
|
| 720 |
+
return len(text) // 4
|
| 721 |
+
|
| 722 |
+
def prune_context(self, context: dict, max_tokens: int = 2000) -> dict:
|
| 723 |
+
"""
|
| 724 |
+
Step 4: Implement Smart Context Pruning
|
| 725 |
+
|
| 726 |
+
Prune context to stay within token limit while keeping most recent and relevant content.
|
| 727 |
+
|
| 728 |
+
Args:
|
| 729 |
+
context: Context dictionary to prune
|
| 730 |
+
max_tokens: Maximum token count (default 2000)
|
| 731 |
+
|
| 732 |
+
Returns:
|
| 733 |
+
Pruned context dictionary
|
| 734 |
+
"""
|
| 735 |
+
try:
|
| 736 |
+
# Calculate current token count
|
| 737 |
+
current_tokens = self._calculate_context_tokens(context)
|
| 738 |
+
|
| 739 |
+
if current_tokens <= max_tokens:
|
| 740 |
+
return context # No pruning needed
|
| 741 |
+
|
| 742 |
+
logger.info(f"Context token count ({current_tokens}) exceeds limit ({max_tokens}), pruning...")
|
| 743 |
+
|
| 744 |
+
# Create a copy to avoid modifying original
|
| 745 |
+
pruned_context = context.copy()
|
| 746 |
+
|
| 747 |
+
# Priority: Keep most recent interactions + session context + user context
|
| 748 |
+
interaction_contexts = pruned_context.get('interaction_contexts', [])
|
| 749 |
+
session_context = pruned_context.get('session_context', {})
|
| 750 |
+
user_context = pruned_context.get('user_context', '')
|
| 751 |
+
|
| 752 |
+
# Keep user context and session context (essential)
|
| 753 |
+
essential_tokens = (
|
| 754 |
+
self.get_token_count(user_context) +
|
| 755 |
+
self.get_token_count(str(session_context))
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
# Calculate how many interaction contexts we can keep
|
| 759 |
+
available_tokens = max_tokens - essential_tokens
|
| 760 |
+
if available_tokens < 0:
|
| 761 |
+
# Essential context itself is too large - summarize user context
|
| 762 |
+
if self.get_token_count(user_context) > max_tokens // 2:
|
| 763 |
+
pruned_context['user_context'] = user_context[:max_tokens * 2] # Rough cut
|
| 764 |
+
logger.warning(f"User context too large, truncated")
|
| 765 |
+
return pruned_context
|
| 766 |
+
|
| 767 |
+
# Keep most recent interactions that fit in token budget
|
| 768 |
+
kept_interactions = []
|
| 769 |
+
current_size = 0
|
| 770 |
+
|
| 771 |
+
for interaction in interaction_contexts:
|
| 772 |
+
summary = interaction.get('summary', '')
|
| 773 |
+
interaction_tokens = self.get_token_count(summary)
|
| 774 |
+
|
| 775 |
+
if current_size + interaction_tokens <= available_tokens:
|
| 776 |
+
kept_interactions.append(interaction)
|
| 777 |
+
current_size += interaction_tokens
|
| 778 |
+
else:
|
| 779 |
+
break # Can't fit any more
|
| 780 |
+
|
| 781 |
+
pruned_context['interaction_contexts'] = kept_interactions
|
| 782 |
+
|
| 783 |
+
logger.info(f"Pruned context: kept {len(kept_interactions)}/{len(interaction_contexts)} interactions, "
|
| 784 |
+
f"reduced from {current_tokens} to {self._calculate_context_tokens(pruned_context)} tokens")
|
| 785 |
+
|
| 786 |
+
return pruned_context
|
| 787 |
+
|
| 788 |
+
except Exception as e:
|
| 789 |
+
logger.error(f"Error pruning context: {e}", exc_info=True)
|
| 790 |
+
return context # Return original on error
|
| 791 |
+
|
| 792 |
+
def _calculate_context_tokens(self, context: dict) -> int:
|
| 793 |
+
"""Calculate total token count for context"""
|
| 794 |
+
total = 0
|
| 795 |
+
|
| 796 |
+
# Count tokens in each component
|
| 797 |
+
user_context = context.get('user_context', '')
|
| 798 |
+
total += self.get_token_count(str(user_context))
|
| 799 |
+
|
| 800 |
+
session_context = context.get('session_context', {})
|
| 801 |
+
if isinstance(session_context, dict):
|
| 802 |
+
total += self.get_token_count(str(session_context.get('summary', '')))
|
| 803 |
+
else:
|
| 804 |
+
total += self.get_token_count(str(session_context))
|
| 805 |
+
|
| 806 |
+
interaction_contexts = context.get('interaction_contexts', [])
|
| 807 |
+
for interaction in interaction_contexts:
|
| 808 |
+
summary = interaction.get('summary', '')
|
| 809 |
+
total += self.get_token_count(str(summary))
|
| 810 |
+
|
| 811 |
+
return total
|
| 812 |
+
|
| 813 |
+
async def _retrieve_from_db(self, session_id: str, user_input: str, user_id: str = "Test_Any") -> dict:
|
| 814 |
+
"""
|
| 815 |
+
Retrieve session context with proper user_id synchronization
|
| 816 |
+
Uses transactions to ensure atomic updates of database and cache
|
| 817 |
+
"""
|
| 818 |
+
conn = None
|
| 819 |
+
try:
|
| 820 |
+
conn = sqlite3.connect(self.db_path)
|
| 821 |
+
cursor = conn.cursor()
|
| 822 |
+
|
| 823 |
+
# Use transaction to ensure atomic updates
|
| 824 |
+
cursor.execute("BEGIN TRANSACTION")
|
| 825 |
+
|
| 826 |
+
# Get session data (SQLite doesn't support FOR UPDATE, but transaction ensures consistency)
|
| 827 |
+
cursor.execute("""
|
| 828 |
+
SELECT context_data, user_metadata, last_activity, user_id
|
| 829 |
+
FROM sessions
|
| 830 |
+
WHERE session_id = ?
|
| 831 |
+
""", (session_id,))
|
| 832 |
+
|
| 833 |
+
row = cursor.fetchone()
|
| 834 |
+
|
| 835 |
+
if row:
|
| 836 |
+
context_data = json.loads(row[0]) if row[0] else {}
|
| 837 |
+
user_metadata = json.loads(row[1]) if row[1] else {}
|
| 838 |
+
last_activity = row[2]
|
| 839 |
+
session_user_id = row[3] if len(row) > 3 else user_id
|
| 840 |
+
|
| 841 |
+
# Check for user_id change and update atomically
|
| 842 |
+
user_changed = False
|
| 843 |
+
if session_user_id != user_id:
|
| 844 |
+
logger.info(f"User change detected: {session_user_id} -> {user_id} for session {session_id}")
|
| 845 |
+
user_changed = True
|
| 846 |
+
|
| 847 |
+
# Update session with new user_id
|
| 848 |
+
cursor.execute("""
|
| 849 |
+
UPDATE sessions
|
| 850 |
+
SET user_id = ?, last_activity = ?
|
| 851 |
+
WHERE session_id = ?
|
| 852 |
+
""", (user_id, datetime.now().isoformat(), session_id))
|
| 853 |
+
|
| 854 |
+
# Clear any cached interaction contexts for old user by marking for refresh
|
| 855 |
+
try:
|
| 856 |
+
cursor.execute("""
|
| 857 |
+
UPDATE interaction_contexts
|
| 858 |
+
SET needs_refresh = 1
|
| 859 |
+
WHERE session_id = ?
|
| 860 |
+
""", (session_id,))
|
| 861 |
+
except sqlite3.OperationalError:
|
| 862 |
+
# Column might not exist yet, will be created by schema update
|
| 863 |
+
pass
|
| 864 |
+
|
| 865 |
+
# Log user change event
|
| 866 |
+
try:
|
| 867 |
+
cursor.execute("""
|
| 868 |
+
INSERT INTO user_change_log (session_id, old_user_id, new_user_id, timestamp)
|
| 869 |
+
VALUES (?, ?, ?, ?)
|
| 870 |
+
""", (session_id, session_user_id, user_id, datetime.now().isoformat()))
|
| 871 |
+
except sqlite3.OperationalError:
|
| 872 |
+
# Table might not exist yet, will be created by schema update
|
| 873 |
+
pass
|
| 874 |
+
|
| 875 |
+
# Clear old cache entries when user changes
|
| 876 |
+
self._clear_user_cache_on_change(session_id, user_id, session_user_id)
|
| 877 |
+
|
| 878 |
+
cursor.execute("COMMIT")
|
| 879 |
+
|
| 880 |
+
# Get interaction contexts with refresh flag check
|
| 881 |
+
try:
|
| 882 |
+
cursor.execute("""
|
| 883 |
+
SELECT interaction_summary, created_at, needs_refresh
|
| 884 |
+
FROM interaction_contexts
|
| 885 |
+
WHERE session_id = ? AND (needs_refresh IS NULL OR needs_refresh = 0)
|
| 886 |
+
ORDER BY created_at DESC
|
| 887 |
+
LIMIT 20
|
| 888 |
+
""", (session_id,))
|
| 889 |
+
except sqlite3.OperationalError:
|
| 890 |
+
# Column might not exist yet, fall back to query without needs_refresh
|
| 891 |
+
cursor.execute("""
|
| 892 |
+
SELECT interaction_summary, created_at
|
| 893 |
+
FROM interaction_contexts
|
| 894 |
+
WHERE session_id = ?
|
| 895 |
+
ORDER BY created_at DESC
|
| 896 |
+
LIMIT 20
|
| 897 |
+
""", (session_id,))
|
| 898 |
+
|
| 899 |
+
interaction_contexts = []
|
| 900 |
+
for ic_row in cursor.fetchall():
|
| 901 |
+
# Handle both query formats (with and without needs_refresh)
|
| 902 |
+
if len(ic_row) >= 2:
|
| 903 |
+
summary = ic_row[0]
|
| 904 |
+
timestamp = ic_row[1]
|
| 905 |
+
needs_refresh = ic_row[2] if len(ic_row) > 2 else 0
|
| 906 |
+
|
| 907 |
+
if summary and not needs_refresh:
|
| 908 |
+
interaction_contexts.append({
|
| 909 |
+
"summary": summary,
|
| 910 |
+
"timestamp": timestamp
|
| 911 |
+
})
|
| 912 |
+
|
| 913 |
+
# Get session context from database
|
| 914 |
+
session_context_data = None
|
| 915 |
+
try:
|
| 916 |
+
cursor.execute("""
|
| 917 |
+
SELECT session_summary, created_at
|
| 918 |
+
FROM session_contexts
|
| 919 |
+
WHERE session_id = ?
|
| 920 |
+
ORDER BY created_at DESC
|
| 921 |
+
LIMIT 1
|
| 922 |
+
""", (session_id,))
|
| 923 |
+
sc_row = cursor.fetchone()
|
| 924 |
+
if sc_row and sc_row[0]:
|
| 925 |
+
session_context_data = {
|
| 926 |
+
"summary": sc_row[0],
|
| 927 |
+
"timestamp": sc_row[1]
|
| 928 |
+
}
|
| 929 |
+
except sqlite3.OperationalError:
|
| 930 |
+
# Table might not exist yet
|
| 931 |
+
pass
|
| 932 |
+
|
| 933 |
+
context = {
|
| 934 |
+
"session_id": session_id,
|
| 935 |
+
"user_id": user_id,
|
| 936 |
+
"interaction_contexts": interaction_contexts,
|
| 937 |
+
"session_context": session_context_data,
|
| 938 |
+
"preferences": user_metadata.get("preferences", {}),
|
| 939 |
+
"active_tasks": user_metadata.get("active_tasks", []),
|
| 940 |
+
"last_activity": last_activity,
|
| 941 |
+
"user_context_loaded": False,
|
| 942 |
+
"user_changed": user_changed
|
| 943 |
+
}
|
| 944 |
+
|
| 945 |
+
conn.close()
|
| 946 |
+
return context
|
| 947 |
+
else:
|
| 948 |
+
# Create new session with transaction
|
| 949 |
+
cursor.execute("""
|
| 950 |
+
INSERT INTO sessions (session_id, user_id, created_at, last_activity, context_data, user_metadata)
|
| 951 |
+
VALUES (?, ?, ?, ?, ?, ?)
|
| 952 |
+
""", (session_id, user_id, datetime.now().isoformat(), datetime.now().isoformat(), "{}", "{}"))
|
| 953 |
+
|
| 954 |
+
cursor.execute("COMMIT")
|
| 955 |
+
conn.close()
|
| 956 |
+
|
| 957 |
+
return {
|
| 958 |
+
"session_id": session_id,
|
| 959 |
+
"user_id": user_id,
|
| 960 |
+
"interaction_contexts": [],
|
| 961 |
+
"session_context": None,
|
| 962 |
+
"preferences": {},
|
| 963 |
+
"active_tasks": [],
|
| 964 |
+
"user_context_loaded": False,
|
| 965 |
+
"user_changed": False
|
| 966 |
+
}
|
| 967 |
+
|
| 968 |
+
except sqlite3.Error as e:
|
| 969 |
+
logger.error(f"Database transaction error: {e}", exc_info=True)
|
| 970 |
+
if conn:
|
| 971 |
+
try:
|
| 972 |
+
conn.rollback()
|
| 973 |
+
except:
|
| 974 |
+
pass
|
| 975 |
+
conn.close()
|
| 976 |
+
# Return safe fallback
|
| 977 |
+
return {
|
| 978 |
+
"session_id": session_id,
|
| 979 |
+
"user_id": user_id,
|
| 980 |
+
"interaction_contexts": [],
|
| 981 |
+
"session_context": None,
|
| 982 |
+
"preferences": {},
|
| 983 |
+
"active_tasks": [],
|
| 984 |
+
"user_context_loaded": False,
|
| 985 |
+
"error": str(e),
|
| 986 |
+
"user_changed": False
|
| 987 |
+
}
|
| 988 |
+
except Exception as e:
|
| 989 |
+
logger.error(f"Database retrieval error: {e}", exc_info=True)
|
| 990 |
+
if conn:
|
| 991 |
+
try:
|
| 992 |
+
conn.rollback()
|
| 993 |
+
except:
|
| 994 |
+
pass
|
| 995 |
+
conn.close()
|
| 996 |
+
# Return safe fallback
|
| 997 |
+
return {
|
| 998 |
+
"session_id": session_id,
|
| 999 |
+
"user_id": user_id,
|
| 1000 |
+
"interaction_contexts": [],
|
| 1001 |
+
"session_context": None,
|
| 1002 |
+
"preferences": {},
|
| 1003 |
+
"active_tasks": [],
|
| 1004 |
+
"user_context_loaded": False,
|
| 1005 |
+
"error": str(e),
|
| 1006 |
+
"user_changed": False
|
| 1007 |
+
}
|
| 1008 |
+
|
| 1009 |
+
def _warm_memory_cache(self, cache_key: str, context: dict):
|
| 1010 |
+
"""
|
| 1011 |
+
Warm the in-memory cache with retrieved context
|
| 1012 |
+
Note: Use add_context_cache() instead for TTL support
|
| 1013 |
+
"""
|
| 1014 |
+
# Use add_context_cache for consistency with TTL
|
| 1015 |
+
self.add_context_cache(cache_key, context, ttl=self.cache_config.get("ttl", 3600))
|
| 1016 |
+
|
| 1017 |
+
def _update_cache_with_interaction_context(self, session_id: str, interaction_summary: str, created_at: str):
|
| 1018 |
+
"""
|
| 1019 |
+
Update cache with new interaction context immediately after database update
|
| 1020 |
+
This keeps cache synchronized with database without requiring database queries
|
| 1021 |
+
"""
|
| 1022 |
+
session_cache_key = f"session_{session_id}"
|
| 1023 |
+
|
| 1024 |
+
# Get current cached context if it exists
|
| 1025 |
+
cached_context = self.session_cache.get(session_cache_key)
|
| 1026 |
+
|
| 1027 |
+
if cached_context:
|
| 1028 |
+
# Add new interaction context to the beginning of the list (most recent first)
|
| 1029 |
+
interaction_contexts = cached_context.get('interaction_contexts', [])
|
| 1030 |
+
new_interaction = {
|
| 1031 |
+
"summary": interaction_summary,
|
| 1032 |
+
"timestamp": created_at
|
| 1033 |
+
}
|
| 1034 |
+
# Insert at beginning and keep only last 20 (matches DB query limit)
|
| 1035 |
+
interaction_contexts.insert(0, new_interaction)
|
| 1036 |
+
interaction_contexts = interaction_contexts[:20]
|
| 1037 |
+
|
| 1038 |
+
# Update cached context with new interaction contexts
|
| 1039 |
+
cached_context['interaction_contexts'] = interaction_contexts
|
| 1040 |
+
self.session_cache[session_cache_key] = cached_context
|
| 1041 |
+
|
| 1042 |
+
logger.debug(f"Cache updated with new interaction context for session {session_id} (total: {len(interaction_contexts)})")
|
| 1043 |
+
else:
|
| 1044 |
+
# If cache doesn't exist, create new entry
|
| 1045 |
+
new_context = {
|
| 1046 |
+
"session_id": session_id,
|
| 1047 |
+
"interaction_contexts": [{
|
| 1048 |
+
"summary": interaction_summary,
|
| 1049 |
+
"timestamp": created_at
|
| 1050 |
+
}],
|
| 1051 |
+
"preferences": {},
|
| 1052 |
+
"active_tasks": [],
|
| 1053 |
+
"user_context_loaded": False
|
| 1054 |
+
}
|
| 1055 |
+
self.session_cache[session_cache_key] = new_context
|
| 1056 |
+
logger.debug(f"Created new cache entry with interaction context for session {session_id}")
|
| 1057 |
+
|
| 1058 |
+
def _update_cache_with_session_context(self, session_id: str, session_summary: str, created_at: str):
|
| 1059 |
+
"""
|
| 1060 |
+
Update cache with new session context immediately after database update
|
| 1061 |
+
This keeps cache synchronized with database without requiring database queries
|
| 1062 |
+
"""
|
| 1063 |
+
session_cache_key = f"session_{session_id}"
|
| 1064 |
+
|
| 1065 |
+
# Get current cached context if it exists
|
| 1066 |
+
cached_context = self.session_cache.get(session_cache_key)
|
| 1067 |
+
|
| 1068 |
+
if cached_context:
|
| 1069 |
+
# Update session context in cache
|
| 1070 |
+
cached_context['session_context'] = {
|
| 1071 |
+
"summary": session_summary,
|
| 1072 |
+
"timestamp": created_at
|
| 1073 |
+
}
|
| 1074 |
+
self.session_cache[session_cache_key] = cached_context
|
| 1075 |
+
|
| 1076 |
+
logger.debug(f"Cache updated with new session context for session {session_id}")
|
| 1077 |
+
else:
|
| 1078 |
+
# If cache doesn't exist, create new entry
|
| 1079 |
+
new_context = {
|
| 1080 |
+
"session_id": session_id,
|
| 1081 |
+
"session_context": {
|
| 1082 |
+
"summary": session_summary,
|
| 1083 |
+
"timestamp": created_at
|
| 1084 |
+
},
|
| 1085 |
+
"interaction_contexts": [],
|
| 1086 |
+
"preferences": {},
|
| 1087 |
+
"active_tasks": [],
|
| 1088 |
+
"user_context_loaded": False
|
| 1089 |
+
}
|
| 1090 |
+
self.session_cache[session_cache_key] = new_context
|
| 1091 |
+
logger.debug(f"Created new cache entry with session context for session {session_id}")
|
| 1092 |
+
|
| 1093 |
+
def _update_context(self, context: dict, user_input: str, response: str = None, user_id: str = "Test_Any") -> dict:
|
| 1094 |
+
"""
|
| 1095 |
+
Update context with deduplication and idempotency checks
|
| 1096 |
+
Prevents duplicate context updates using interaction hashes
|
| 1097 |
+
"""
|
| 1098 |
+
try:
|
| 1099 |
+
# Generate unique interaction hash to prevent duplicates
|
| 1100 |
+
interaction_hash = self._generate_interaction_hash(user_input, context["session_id"], user_id)
|
| 1101 |
+
|
| 1102 |
+
# Check if this interaction was already processed
|
| 1103 |
+
if self._is_duplicate_interaction(interaction_hash):
|
| 1104 |
+
logger.info(f"Duplicate interaction detected, skipping update: {interaction_hash[:8]}")
|
| 1105 |
+
return context
|
| 1106 |
+
|
| 1107 |
+
# Use transaction for atomic updates
|
| 1108 |
+
current_time = datetime.now().isoformat()
|
| 1109 |
+
with self.transaction_manager.transaction(context["session_id"]) as cursor:
|
| 1110 |
+
# Update session activity (only if last_activity is older to prevent unnecessary updates)
|
| 1111 |
+
cursor.execute("""
|
| 1112 |
+
UPDATE sessions
|
| 1113 |
+
SET last_activity = ?, user_id = ?
|
| 1114 |
+
WHERE session_id = ? AND (last_activity IS NULL OR last_activity < ?)
|
| 1115 |
+
""", (current_time, user_id, context["session_id"], current_time))
|
| 1116 |
+
|
| 1117 |
+
# Store interaction with duplicate prevention using INSERT OR IGNORE
|
| 1118 |
+
session_context = {
|
| 1119 |
+
"preferences": context.get("preferences", {}),
|
| 1120 |
+
"active_tasks": context.get("active_tasks", [])
|
| 1121 |
+
}
|
| 1122 |
+
|
| 1123 |
+
cursor.execute("""
|
| 1124 |
+
INSERT OR IGNORE INTO interactions (
|
| 1125 |
+
interaction_hash,
|
| 1126 |
+
session_id,
|
| 1127 |
+
user_input,
|
| 1128 |
+
context_snapshot,
|
| 1129 |
+
created_at
|
| 1130 |
+
) VALUES (?, ?, ?, ?, ?)
|
| 1131 |
+
""", (
|
| 1132 |
+
interaction_hash,
|
| 1133 |
+
context["session_id"],
|
| 1134 |
+
user_input,
|
| 1135 |
+
json.dumps(session_context),
|
| 1136 |
+
current_time
|
| 1137 |
+
))
|
| 1138 |
+
|
| 1139 |
+
# Mark interaction as processed (outside transaction)
|
| 1140 |
+
self._mark_interaction_processed(interaction_hash)
|
| 1141 |
+
|
| 1142 |
+
# Update in-memory context
|
| 1143 |
+
context["last_interaction"] = user_input
|
| 1144 |
+
context["last_update"] = current_time
|
| 1145 |
+
|
| 1146 |
+
logger.info(f"Context updated for session {context['session_id']} with hash {interaction_hash[:8]}")
|
| 1147 |
+
|
| 1148 |
+
return context
|
| 1149 |
+
|
| 1150 |
+
except Exception as e:
|
| 1151 |
+
logger.error(f"Error updating context: {e}", exc_info=True)
|
| 1152 |
+
return context
|
| 1153 |
+
|
| 1154 |
+
def _generate_interaction_hash(self, user_input: str, session_id: str, user_id: str) -> str:
|
| 1155 |
+
"""Generate unique hash for interaction to prevent duplicates"""
|
| 1156 |
+
# Use session_id, user_id, and user_input for exact duplicate detection
|
| 1157 |
+
# Normalize user input by stripping whitespace
|
| 1158 |
+
normalized_input = user_input.strip()
|
| 1159 |
+
content = f"{session_id}:{user_id}:{normalized_input}"
|
| 1160 |
+
return hashlib.sha256(content.encode()).hexdigest()
|
| 1161 |
+
|
| 1162 |
+
def _is_duplicate_interaction(self, interaction_hash: str) -> bool:
|
| 1163 |
+
"""Check if interaction was already processed"""
|
| 1164 |
+
# Keep a rolling window of recent interaction hashes in memory
|
| 1165 |
+
if not hasattr(self, '_processed_interactions'):
|
| 1166 |
+
self._processed_interactions = set()
|
| 1167 |
+
|
| 1168 |
+
# Check in-memory cache first
|
| 1169 |
+
if interaction_hash in self._processed_interactions:
|
| 1170 |
+
return True
|
| 1171 |
+
|
| 1172 |
+
# Also check database for persistent duplicates
|
| 1173 |
+
try:
|
| 1174 |
+
conn = sqlite3.connect(self.db_path)
|
| 1175 |
+
cursor = conn.cursor()
|
| 1176 |
+
# Check if interaction_hash column exists and query for duplicates
|
| 1177 |
+
cursor.execute("PRAGMA table_info(interactions)")
|
| 1178 |
+
columns = [row[1] for row in cursor.fetchall()]
|
| 1179 |
+
if 'interaction_hash' in columns:
|
| 1180 |
+
cursor.execute("""
|
| 1181 |
+
SELECT COUNT(*) FROM interactions
|
| 1182 |
+
WHERE interaction_hash IS NOT NULL AND interaction_hash = ?
|
| 1183 |
+
""", (interaction_hash,))
|
| 1184 |
+
count = cursor.fetchone()[0]
|
| 1185 |
+
conn.close()
|
| 1186 |
+
return count > 0
|
| 1187 |
+
else:
|
| 1188 |
+
conn.close()
|
| 1189 |
+
return False
|
| 1190 |
+
except sqlite3.OperationalError:
|
| 1191 |
+
# Column might not exist yet, only check in-memory
|
| 1192 |
+
return interaction_hash in self._processed_interactions
|
| 1193 |
+
|
| 1194 |
+
def _mark_interaction_processed(self, interaction_hash: str):
|
| 1195 |
+
"""Mark interaction as processed"""
|
| 1196 |
+
if not hasattr(self, '_processed_interactions'):
|
| 1197 |
+
self._processed_interactions = set()
|
| 1198 |
+
self._processed_interactions.add(interaction_hash)
|
| 1199 |
+
|
| 1200 |
+
# Limit memory usage by keeping only last 1000 hashes
|
| 1201 |
+
if len(self._processed_interactions) > 1000:
|
| 1202 |
+
# Keep most recent 500 entries (simple truncation)
|
| 1203 |
+
self._processed_interactions = set(list(self._processed_interactions)[-500:])
|
| 1204 |
+
|
| 1205 |
+
async def manage_context_optimized(self, session_id: str, user_input: str, user_id: str = "Test_Any") -> dict:
|
| 1206 |
+
"""
|
| 1207 |
+
Efficient context management with transaction optimization
|
| 1208 |
+
"""
|
| 1209 |
+
# Use session-only cache key
|
| 1210 |
+
session_cache_key = f"session_{session_id}"
|
| 1211 |
+
|
| 1212 |
+
# Try to get from cache first (no DB access)
|
| 1213 |
+
cached_context = self._get_from_memory_cache(session_cache_key)
|
| 1214 |
+
if cached_context and self._is_cache_valid(cached_context):
|
| 1215 |
+
logger.debug(f"Using cached context for session {session_id}")
|
| 1216 |
+
return cached_context
|
| 1217 |
+
|
| 1218 |
+
# Use transaction for all DB operations
|
| 1219 |
+
with self.transaction_manager.transaction(session_id) as cursor:
|
| 1220 |
+
# Atomic session retrieval and update
|
| 1221 |
+
cursor.execute("""
|
| 1222 |
+
SELECT s.context_data, s.user_metadata, s.last_activity, s.user_id,
|
| 1223 |
+
COUNT(ic.interaction_id) as interaction_count
|
| 1224 |
+
FROM sessions s
|
| 1225 |
+
LEFT JOIN interaction_contexts ic ON s.session_id = ic.session_id
|
| 1226 |
+
WHERE s.session_id = ?
|
| 1227 |
+
GROUP BY s.session_id
|
| 1228 |
+
""", (session_id,))
|
| 1229 |
+
|
| 1230 |
+
row = cursor.fetchone()
|
| 1231 |
+
|
| 1232 |
+
if row:
|
| 1233 |
+
# Parse existing session data
|
| 1234 |
+
context_data = json.loads(row[0] or '{}')
|
| 1235 |
+
user_metadata = json.loads(row[1] or '{}')
|
| 1236 |
+
last_activity = row[2]
|
| 1237 |
+
stored_user_id = row[3] or user_id
|
| 1238 |
+
interaction_count = row[4] or 0
|
| 1239 |
+
|
| 1240 |
+
# Handle user change atomically
|
| 1241 |
+
if stored_user_id != user_id:
|
| 1242 |
+
self._handle_user_change_atomic(cursor, session_id, stored_user_id, user_id)
|
| 1243 |
+
|
| 1244 |
+
# Get interaction contexts efficiently
|
| 1245 |
+
interaction_contexts = self._get_interaction_contexts_atomic(cursor, session_id)
|
| 1246 |
+
|
| 1247 |
+
else:
|
| 1248 |
+
# Create new session atomically
|
| 1249 |
+
cursor.execute("""
|
| 1250 |
+
INSERT INTO sessions (session_id, user_id, created_at, last_activity, context_data, user_metadata)
|
| 1251 |
+
VALUES (?, ?, datetime('now'), datetime('now'), '{}', '{}')
|
| 1252 |
+
""", (session_id, user_id))
|
| 1253 |
+
|
| 1254 |
+
context_data = {}
|
| 1255 |
+
user_metadata = {}
|
| 1256 |
+
interaction_contexts = []
|
| 1257 |
+
interaction_count = 0
|
| 1258 |
+
|
| 1259 |
+
# Load user context asynchronously (outside transaction)
|
| 1260 |
+
user_context = await self._load_user_context_async(user_id)
|
| 1261 |
+
|
| 1262 |
+
# Build final context
|
| 1263 |
+
final_context = {
|
| 1264 |
+
"session_id": session_id,
|
| 1265 |
+
"user_id": user_id,
|
| 1266 |
+
"interaction_contexts": interaction_contexts,
|
| 1267 |
+
"user_context": user_context,
|
| 1268 |
+
"preferences": user_metadata.get("preferences", {}),
|
| 1269 |
+
"active_tasks": user_metadata.get("active_tasks", []),
|
| 1270 |
+
"interaction_count": interaction_count,
|
| 1271 |
+
"cache_timestamp": datetime.now().isoformat()
|
| 1272 |
+
}
|
| 1273 |
+
|
| 1274 |
+
# Update cache
|
| 1275 |
+
self._warm_memory_cache(session_cache_key, final_context)
|
| 1276 |
+
|
| 1277 |
+
return self._optimize_context(final_context)
|
| 1278 |
+
|
| 1279 |
+
def _handle_user_change_atomic(self, cursor, session_id: str, old_user_id: str, new_user_id: str):
|
| 1280 |
+
"""Handle user change within transaction"""
|
| 1281 |
+
logger.info(f"Handling user change in transaction: {old_user_id} -> {new_user_id}")
|
| 1282 |
+
|
| 1283 |
+
# Update session
|
| 1284 |
+
cursor.execute("""
|
| 1285 |
+
UPDATE sessions
|
| 1286 |
+
SET user_id = ?, last_activity = datetime('now')
|
| 1287 |
+
WHERE session_id = ?
|
| 1288 |
+
""", (new_user_id, session_id))
|
| 1289 |
+
|
| 1290 |
+
# Log the change
|
| 1291 |
+
try:
|
| 1292 |
+
cursor.execute("""
|
| 1293 |
+
INSERT INTO user_change_log (session_id, old_user_id, new_user_id, timestamp)
|
| 1294 |
+
VALUES (?, ?, ?, datetime('now'))
|
| 1295 |
+
""", (session_id, old_user_id, new_user_id))
|
| 1296 |
+
except sqlite3.OperationalError:
|
| 1297 |
+
# Table might not exist yet
|
| 1298 |
+
pass
|
| 1299 |
+
|
| 1300 |
+
# Invalidate related caches
|
| 1301 |
+
try:
|
| 1302 |
+
cursor.execute("""
|
| 1303 |
+
UPDATE interaction_contexts
|
| 1304 |
+
SET needs_refresh = 1
|
| 1305 |
+
WHERE session_id = ?
|
| 1306 |
+
""", (session_id,))
|
| 1307 |
+
except sqlite3.OperationalError:
|
| 1308 |
+
# Column might not exist yet
|
| 1309 |
+
pass
|
| 1310 |
+
|
| 1311 |
+
def _get_interaction_contexts_atomic(self, cursor, session_id: str, limit: int = 20):
|
| 1312 |
+
"""Get interaction contexts within transaction"""
|
| 1313 |
+
try:
|
| 1314 |
+
cursor.execute("""
|
| 1315 |
+
SELECT interaction_summary, created_at, interaction_id
|
| 1316 |
+
FROM interaction_contexts
|
| 1317 |
+
WHERE session_id = ? AND (needs_refresh IS NULL OR needs_refresh = 0)
|
| 1318 |
+
ORDER BY created_at DESC
|
| 1319 |
+
LIMIT ?
|
| 1320 |
+
""", (session_id, limit))
|
| 1321 |
+
except sqlite3.OperationalError:
|
| 1322 |
+
# Fallback if needs_refresh column doesn't exist
|
| 1323 |
+
cursor.execute("""
|
| 1324 |
+
SELECT interaction_summary, created_at, interaction_id
|
| 1325 |
+
FROM interaction_contexts
|
| 1326 |
+
WHERE session_id = ?
|
| 1327 |
+
ORDER BY created_at DESC
|
| 1328 |
+
LIMIT ?
|
| 1329 |
+
""", (session_id, limit))
|
| 1330 |
+
|
| 1331 |
+
contexts = []
|
| 1332 |
+
for row in cursor.fetchall():
|
| 1333 |
+
if row[0]:
|
| 1334 |
+
contexts.append({
|
| 1335 |
+
"summary": row[0],
|
| 1336 |
+
"timestamp": row[1],
|
| 1337 |
+
"id": row[2] if len(row) > 2 else None
|
| 1338 |
+
})
|
| 1339 |
+
|
| 1340 |
+
return contexts
|
| 1341 |
+
|
| 1342 |
+
async def _load_user_context_async(self, user_id: str):
|
| 1343 |
+
"""Load user context asynchronously to avoid blocking"""
|
| 1344 |
+
try:
|
| 1345 |
+
# Check memory cache first
|
| 1346 |
+
user_cache_key = f"user_{user_id}"
|
| 1347 |
+
cached = self._get_from_memory_cache(user_cache_key)
|
| 1348 |
+
if cached:
|
| 1349 |
+
return cached.get("user_context", "")
|
| 1350 |
+
|
| 1351 |
+
# Load from database
|
| 1352 |
+
return await self.get_user_context(user_id)
|
| 1353 |
+
except Exception as e:
|
| 1354 |
+
logger.error(f"Error loading user context: {e}")
|
| 1355 |
+
return ""
|
| 1356 |
+
|
| 1357 |
+
def _is_cache_valid(self, cached_context: dict, max_age_seconds: int = 60) -> bool:
|
| 1358 |
+
"""Check if cached context is still valid"""
|
| 1359 |
+
if not cached_context:
|
| 1360 |
+
return False
|
| 1361 |
+
|
| 1362 |
+
cache_timestamp = cached_context.get("cache_timestamp")
|
| 1363 |
+
if not cache_timestamp:
|
| 1364 |
+
return False
|
| 1365 |
+
|
| 1366 |
+
try:
|
| 1367 |
+
cache_time = datetime.fromisoformat(cache_timestamp)
|
| 1368 |
+
age = (datetime.now() - cache_time).total_seconds()
|
| 1369 |
+
return age < max_age_seconds
|
| 1370 |
+
except:
|
| 1371 |
+
return False
|
| 1372 |
+
|
| 1373 |
+
def invalidate_session_cache(self, session_id: str):
|
| 1374 |
+
"""
|
| 1375 |
+
Invalidate cached context for a session to force fresh retrieval
|
| 1376 |
+
Only affects cache management - does not change application functionality
|
| 1377 |
+
"""
|
| 1378 |
+
session_cache_key = f"session_{session_id}"
|
| 1379 |
+
if session_cache_key in self.session_cache:
|
| 1380 |
+
del self.session_cache[session_cache_key]
|
| 1381 |
+
logger.info(f"Cache invalidated for session {session_id} to ensure fresh context retrieval")
|
| 1382 |
+
|
| 1383 |
+
def optimize_database_indexes(self):
|
| 1384 |
+
"""Create database indexes for better query performance"""
|
| 1385 |
+
try:
|
| 1386 |
+
conn = sqlite3.connect(self.db_path)
|
| 1387 |
+
cursor = conn.cursor()
|
| 1388 |
+
|
| 1389 |
+
# Create indexes for frequently queried columns
|
| 1390 |
+
indexes = [
|
| 1391 |
+
"CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id)",
|
| 1392 |
+
"CREATE INDEX IF NOT EXISTS idx_sessions_last_activity ON sessions(last_activity)",
|
| 1393 |
+
"CREATE INDEX IF NOT EXISTS idx_interactions_session_id ON interactions(session_id)",
|
| 1394 |
+
"CREATE INDEX IF NOT EXISTS idx_interaction_contexts_session_id ON interaction_contexts(session_id)",
|
| 1395 |
+
"CREATE INDEX IF NOT EXISTS idx_interaction_contexts_created_at ON interaction_contexts(created_at)",
|
| 1396 |
+
"CREATE INDEX IF NOT EXISTS idx_user_change_log_session_id ON user_change_log(session_id)",
|
| 1397 |
+
"CREATE INDEX IF NOT EXISTS idx_user_contexts_updated_at ON user_contexts(updated_at)"
|
| 1398 |
+
]
|
| 1399 |
+
|
| 1400 |
+
for index in indexes:
|
| 1401 |
+
try:
|
| 1402 |
+
cursor.execute(index)
|
| 1403 |
+
except sqlite3.OperationalError as e:
|
| 1404 |
+
# Table might not exist yet, skip this index
|
| 1405 |
+
logger.debug(f"Skipping index creation (table may not exist): {e}")
|
| 1406 |
+
|
| 1407 |
+
# Analyze database for query optimization
|
| 1408 |
+
try:
|
| 1409 |
+
cursor.execute("ANALYZE")
|
| 1410 |
+
except sqlite3.OperationalError:
|
| 1411 |
+
# ANALYZE might not be available in all SQLite versions
|
| 1412 |
+
pass
|
| 1413 |
+
|
| 1414 |
+
conn.commit()
|
| 1415 |
+
conn.close()
|
| 1416 |
+
|
| 1417 |
+
logger.info("✓ Database indexes optimized successfully")
|
| 1418 |
+
|
| 1419 |
+
except Exception as e:
|
| 1420 |
+
logger.error(f"Error optimizing database indexes: {e}", exc_info=True)
|
| 1421 |
+
|
| 1422 |
+
def set_context_mode(self, session_id: str, mode: str, user_id: str = "Test_Any"):
|
| 1423 |
+
"""
|
| 1424 |
+
Set context mode for session (fresh or relevant)
|
| 1425 |
+
|
| 1426 |
+
Args:
|
| 1427 |
+
session_id: Session identifier
|
| 1428 |
+
mode: 'fresh' (no user context) or 'relevant' (only relevant context)
|
| 1429 |
+
user_id: User identifier
|
| 1430 |
+
|
| 1431 |
+
Returns:
|
| 1432 |
+
bool: True if successful, False otherwise
|
| 1433 |
+
"""
|
| 1434 |
+
try:
|
| 1435 |
+
import time
|
| 1436 |
+
|
| 1437 |
+
# VALIDATION: Ensure mode is valid
|
| 1438 |
+
if mode not in ['fresh', 'relevant']:
|
| 1439 |
+
logger.warning(f"Invalid context mode '{mode}', defaulting to 'fresh'")
|
| 1440 |
+
mode = 'fresh'
|
| 1441 |
+
|
| 1442 |
+
# Get or create cache entry
|
| 1443 |
+
cache_key = f"session_{session_id}"
|
| 1444 |
+
cached_context = self._get_from_memory_cache(cache_key)
|
| 1445 |
+
|
| 1446 |
+
if not cached_context:
|
| 1447 |
+
cached_context = {
|
| 1448 |
+
'session_id': session_id,
|
| 1449 |
+
'user_id': user_id,
|
| 1450 |
+
'preferences': {},
|
| 1451 |
+
'context_mode': mode,
|
| 1452 |
+
'context_mode_timestamp': time.time()
|
| 1453 |
+
}
|
| 1454 |
+
else:
|
| 1455 |
+
# Update existing context (preserve other data)
|
| 1456 |
+
cached_context['context_mode'] = mode
|
| 1457 |
+
cached_context['context_mode_timestamp'] = time.time()
|
| 1458 |
+
cached_context['user_id'] = user_id # Update user_id if changed
|
| 1459 |
+
|
| 1460 |
+
# Update cache with TTL
|
| 1461 |
+
self.add_context_cache(cache_key, cached_context, ttl=3600)
|
| 1462 |
+
|
| 1463 |
+
logger.info(f"Context mode set to '{mode}' for session {session_id} (user: {user_id})")
|
| 1464 |
+
return True
|
| 1465 |
+
|
| 1466 |
+
except Exception as e:
|
| 1467 |
+
logger.error(f"Error setting context mode: {e}", exc_info=True)
|
| 1468 |
+
return False # Failure doesn't break existing flow
|
| 1469 |
+
|
| 1470 |
+
def get_context_mode(self, session_id: str) -> str:
|
| 1471 |
+
"""
|
| 1472 |
+
Get current context mode for session
|
| 1473 |
+
|
| 1474 |
+
Args:
|
| 1475 |
+
session_id: Session identifier
|
| 1476 |
+
|
| 1477 |
+
Returns:
|
| 1478 |
+
str: 'fresh' or 'relevant' (default: 'fresh')
|
| 1479 |
+
"""
|
| 1480 |
+
try:
|
| 1481 |
+
cache_key = f"session_{session_id}"
|
| 1482 |
+
cached_context = self._get_from_memory_cache(cache_key)
|
| 1483 |
+
|
| 1484 |
+
if cached_context:
|
| 1485 |
+
mode = cached_context.get('context_mode', 'fresh')
|
| 1486 |
+
# VALIDATION: Ensure mode is still valid
|
| 1487 |
+
if mode in ['fresh', 'relevant']:
|
| 1488 |
+
return mode
|
| 1489 |
+
else:
|
| 1490 |
+
logger.warning(f"Invalid cached mode '{mode}', resetting to 'fresh'")
|
| 1491 |
+
cached_context['context_mode'] = 'fresh'
|
| 1492 |
+
import time
|
| 1493 |
+
cached_context['context_mode_timestamp'] = time.time()
|
| 1494 |
+
self.add_context_cache(cache_key, cached_context, ttl=3600)
|
| 1495 |
+
return 'fresh'
|
| 1496 |
+
|
| 1497 |
+
# Default for new sessions
|
| 1498 |
+
return 'fresh'
|
| 1499 |
+
|
| 1500 |
+
except Exception as e:
|
| 1501 |
+
logger.error(f"Error getting context mode: {e}", exc_info=True)
|
| 1502 |
+
return 'fresh' # Safe default - no degradation
|
| 1503 |
+
|
| 1504 |
+
async def get_all_user_sessions(self, user_id: str) -> List[Dict]:
|
| 1505 |
+
"""
|
| 1506 |
+
Fetch all session contexts for a user (for relevance classification)
|
| 1507 |
+
|
| 1508 |
+
Performance: Single database query with JOIN
|
| 1509 |
+
|
| 1510 |
+
Args:
|
| 1511 |
+
user_id: User identifier
|
| 1512 |
+
|
| 1513 |
+
Returns:
|
| 1514 |
+
List of session context dictionaries with summaries and interactions
|
| 1515 |
+
"""
|
| 1516 |
+
try:
|
| 1517 |
+
conn = sqlite3.connect(self.db_path)
|
| 1518 |
+
cursor = conn.cursor()
|
| 1519 |
+
|
| 1520 |
+
# Fetch all session contexts for user with interaction summaries
|
| 1521 |
+
cursor.execute("""
|
| 1522 |
+
SELECT DISTINCT
|
| 1523 |
+
sc.session_id,
|
| 1524 |
+
sc.session_summary,
|
| 1525 |
+
sc.created_at,
|
| 1526 |
+
(SELECT GROUP_CONCAT(ic.interaction_summary, ' ||| ')
|
| 1527 |
+
FROM interaction_contexts ic
|
| 1528 |
+
WHERE ic.session_id = sc.session_id
|
| 1529 |
+
ORDER BY ic.created_at DESC
|
| 1530 |
+
LIMIT 10) as recent_interactions
|
| 1531 |
+
FROM session_contexts sc
|
| 1532 |
+
JOIN sessions s ON sc.session_id = s.session_id
|
| 1533 |
+
WHERE s.user_id = ?
|
| 1534 |
+
ORDER BY sc.created_at DESC
|
| 1535 |
+
LIMIT 50
|
| 1536 |
+
""", (user_id,))
|
| 1537 |
+
|
| 1538 |
+
sessions = []
|
| 1539 |
+
for row in cursor.fetchall():
|
| 1540 |
+
session_id, session_summary, created_at, interactions_str = row
|
| 1541 |
+
|
| 1542 |
+
# Parse interaction summaries
|
| 1543 |
+
interaction_list = []
|
| 1544 |
+
if interactions_str:
|
| 1545 |
+
for summary in interactions_str.split(' ||| '):
|
| 1546 |
+
if summary.strip():
|
| 1547 |
+
interaction_list.append({
|
| 1548 |
+
'summary': summary.strip(),
|
| 1549 |
+
'timestamp': created_at
|
| 1550 |
+
})
|
| 1551 |
+
|
| 1552 |
+
sessions.append({
|
| 1553 |
+
'session_id': session_id,
|
| 1554 |
+
'summary': session_summary or '',
|
| 1555 |
+
'created_at': created_at,
|
| 1556 |
+
'interaction_contexts': interaction_list
|
| 1557 |
+
})
|
| 1558 |
+
|
| 1559 |
+
conn.close()
|
| 1560 |
+
logger.info(f"Fetched {len(sessions)} sessions for user {user_id}")
|
| 1561 |
+
return sessions
|
| 1562 |
+
|
| 1563 |
+
except Exception as e:
|
| 1564 |
+
logger.error(f"Error fetching user sessions: {e}", exc_info=True)
|
| 1565 |
+
return [] # Safe fallback - no degradation
|
| 1566 |
+
|
| 1567 |
+
def _extract_entities(self, context: dict) -> list:
|
| 1568 |
+
"""
|
| 1569 |
+
Extract essential entities from context
|
| 1570 |
+
"""
|
| 1571 |
+
# TODO: Implement entity extraction
|
| 1572 |
+
return []
|
| 1573 |
+
|
| 1574 |
+
def _generate_summary(self, context: dict) -> str:
|
| 1575 |
+
"""
|
| 1576 |
+
Generate conversation summary
|
| 1577 |
+
"""
|
| 1578 |
+
# TODO: Implement summary generation
|
| 1579 |
+
return ""
|
| 1580 |
+
|
| 1581 |
+
def get_or_create_session_context(self, session_id: str, user_id: Optional[str] = None) -> Dict:
|
| 1582 |
+
"""Enhanced context retrieval with caching"""
|
| 1583 |
+
import time
|
| 1584 |
+
|
| 1585 |
+
# In-memory cache check first
|
| 1586 |
+
if session_id in self._session_cache:
|
| 1587 |
+
cache_entry = self._session_cache[session_id]
|
| 1588 |
+
if time.time() - cache_entry['timestamp'] < 300: # 5 min cache
|
| 1589 |
+
logger.debug(f"Cache hit for session {session_id}")
|
| 1590 |
+
return cache_entry['context']
|
| 1591 |
+
|
| 1592 |
+
# Batch database queries
|
| 1593 |
+
conn = None
|
| 1594 |
+
try:
|
| 1595 |
+
conn = sqlite3.connect(self.db_path)
|
| 1596 |
+
cursor = conn.cursor()
|
| 1597 |
+
|
| 1598 |
+
# Single query for all context data
|
| 1599 |
+
query = """
|
| 1600 |
+
SELECT
|
| 1601 |
+
s.context_data,
|
| 1602 |
+
s.user_metadata,
|
| 1603 |
+
s.last_activity,
|
| 1604 |
+
u.persona_summary,
|
| 1605 |
+
ic.interaction_summary
|
| 1606 |
+
FROM sessions s
|
| 1607 |
+
LEFT JOIN user_contexts u ON s.user_id = u.user_id
|
| 1608 |
+
LEFT JOIN interaction_contexts ic ON s.session_id = ic.session_id
|
| 1609 |
+
WHERE s.session_id = ?
|
| 1610 |
+
ORDER BY ic.created_at DESC
|
| 1611 |
+
LIMIT 10
|
| 1612 |
+
"""
|
| 1613 |
+
|
| 1614 |
+
cursor.execute(query, (session_id,))
|
| 1615 |
+
results = cursor.fetchall()
|
| 1616 |
+
|
| 1617 |
+
# Process results efficiently
|
| 1618 |
+
context = self._build_context_from_results(results, session_id, user_id)
|
| 1619 |
+
|
| 1620 |
+
# Update cache
|
| 1621 |
+
self._session_cache[session_id] = {
|
| 1622 |
+
'context': context,
|
| 1623 |
+
'timestamp': time.time()
|
| 1624 |
+
}
|
| 1625 |
+
|
| 1626 |
+
return context
|
| 1627 |
+
|
| 1628 |
+
except Exception as e:
|
| 1629 |
+
logger.error(f"Error in get_or_create_session_context: {e}", exc_info=True)
|
| 1630 |
+
# Return safe fallback
|
| 1631 |
+
return {
|
| 1632 |
+
"session_id": session_id,
|
| 1633 |
+
"user_id": user_id or "Test_Any",
|
| 1634 |
+
"interaction_contexts": [],
|
| 1635 |
+
"session_context": None,
|
| 1636 |
+
"preferences": {},
|
| 1637 |
+
"active_tasks": [],
|
| 1638 |
+
"user_context_loaded": False
|
| 1639 |
+
}
|
| 1640 |
+
finally:
|
| 1641 |
+
if conn:
|
| 1642 |
+
conn.close()
|
| 1643 |
+
|
| 1644 |
+
def _build_context_from_results(self, results: list, session_id: str, user_id: Optional[str]) -> Dict:
|
| 1645 |
+
"""Build context dictionary from batch query results"""
|
| 1646 |
+
context = {
|
| 1647 |
+
"session_id": session_id,
|
| 1648 |
+
"user_id": user_id or "Test_Any",
|
| 1649 |
+
"interaction_contexts": [],
|
| 1650 |
+
"session_context": None,
|
| 1651 |
+
"user_context": "",
|
| 1652 |
+
"preferences": {},
|
| 1653 |
+
"active_tasks": [],
|
| 1654 |
+
"user_context_loaded": False
|
| 1655 |
+
}
|
| 1656 |
+
|
| 1657 |
+
if not results:
|
| 1658 |
+
return context
|
| 1659 |
+
|
| 1660 |
+
# Process first row for session data
|
| 1661 |
+
first_row = results[0]
|
| 1662 |
+
if first_row[0]: # context_data
|
| 1663 |
+
try:
|
| 1664 |
+
session_data = json.loads(first_row[0])
|
| 1665 |
+
context["preferences"] = session_data.get("preferences", {})
|
| 1666 |
+
context["active_tasks"] = session_data.get("active_tasks", [])
|
| 1667 |
+
except:
|
| 1668 |
+
pass
|
| 1669 |
+
|
| 1670 |
+
if first_row[1]: # user_metadata
|
| 1671 |
+
try:
|
| 1672 |
+
user_metadata = json.loads(first_row[1])
|
| 1673 |
+
context["preferences"].update(user_metadata.get("preferences", {}))
|
| 1674 |
+
except:
|
| 1675 |
+
pass
|
| 1676 |
+
|
| 1677 |
+
context["last_activity"] = first_row[2] # last_activity
|
| 1678 |
+
|
| 1679 |
+
if first_row[3]: # persona_summary
|
| 1680 |
+
context["user_context"] = first_row[3]
|
| 1681 |
+
context["user_context_loaded"] = True
|
| 1682 |
+
|
| 1683 |
+
# Process interaction contexts
|
| 1684 |
+
seen_interactions = set()
|
| 1685 |
+
for row in results:
|
| 1686 |
+
if row[4]: # interaction_summary
|
| 1687 |
+
# Deduplicate interactions
|
| 1688 |
+
if row[4] not in seen_interactions:
|
| 1689 |
+
seen_interactions.add(row[4])
|
| 1690 |
+
context["interaction_contexts"].append({
|
| 1691 |
+
"summary": row[4],
|
| 1692 |
+
"timestamp": None # Could extract from row if available
|
| 1693 |
+
})
|
| 1694 |
+
|
| 1695 |
+
return context
|
src/context_relevance_classifier.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# context_relevance_classifier.py
|
| 2 |
+
"""
|
| 3 |
+
Context Relevance Classification Module
|
| 4 |
+
Uses LLM inference to identify relevant session contexts and generate dynamic summaries
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import asyncio
|
| 9 |
+
from typing import Dict, List, Optional
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ContextRelevanceClassifier:
|
| 16 |
+
"""
|
| 17 |
+
Classify which session contexts are relevant to current conversation
|
| 18 |
+
and generate 2-line summaries for each relevant session
|
| 19 |
+
|
| 20 |
+
Performance Priority:
|
| 21 |
+
- LLM inference first (accuracy over speed)
|
| 22 |
+
- Parallel processing for multiple sessions
|
| 23 |
+
- Caching for repeated queries
|
| 24 |
+
- Graceful degradation on failures
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, llm_router):
|
| 28 |
+
"""
|
| 29 |
+
Initialize classifier with LLM router
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
llm_router: LLMRouter instance for inference calls
|
| 33 |
+
"""
|
| 34 |
+
self.llm_router = llm_router
|
| 35 |
+
self._relevance_cache = {} # Cache relevance scores to reduce LLM calls
|
| 36 |
+
self._summary_cache = {} # Cache summaries to avoid regenerating
|
| 37 |
+
self._cache_ttl = 3600 # 1 hour cache TTL
|
| 38 |
+
|
| 39 |
+
async def classify_and_summarize_relevant_contexts(self,
|
| 40 |
+
current_input: str,
|
| 41 |
+
session_contexts: List[Dict],
|
| 42 |
+
user_id: str = "Test_Any") -> Dict:
|
| 43 |
+
"""
|
| 44 |
+
Main method: Classify relevant contexts AND generate 2-line summaries
|
| 45 |
+
|
| 46 |
+
Performance Strategy:
|
| 47 |
+
1. Extract current topic (LLM inference - single call)
|
| 48 |
+
2. Calculate relevance in parallel (multiple LLM calls in parallel)
|
| 49 |
+
3. Generate summaries in parallel (only for relevant sessions)
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
current_input: Current user query
|
| 53 |
+
session_contexts: List of session context dictionaries
|
| 54 |
+
user_id: User identifier for logging
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
{
|
| 58 |
+
'relevant_summaries': List[str], # 2-line summaries
|
| 59 |
+
'combined_user_context': str, # Combined summaries
|
| 60 |
+
'relevance_scores': Dict, # Scores for each session
|
| 61 |
+
'classification_confidence': float,
|
| 62 |
+
'topic': str,
|
| 63 |
+
'processing_time': float
|
| 64 |
+
}
|
| 65 |
+
"""
|
| 66 |
+
start_time = datetime.now()
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
# Early exit: No contexts to process
|
| 70 |
+
if not session_contexts:
|
| 71 |
+
logger.info("No session contexts provided for classification")
|
| 72 |
+
return {
|
| 73 |
+
'relevant_summaries': [],
|
| 74 |
+
'combined_user_context': '',
|
| 75 |
+
'relevance_scores': {},
|
| 76 |
+
'classification_confidence': 1.0,
|
| 77 |
+
'topic': '',
|
| 78 |
+
'processing_time': 0.0
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# Step 1: Extract current topic (LLM inference - OPTION A: Single call)
|
| 82 |
+
current_topic = await self._extract_current_topic(current_input)
|
| 83 |
+
logger.info(f"Extracted current topic: '{current_topic}'")
|
| 84 |
+
|
| 85 |
+
# Step 2: Calculate relevance scores (parallel processing for performance)
|
| 86 |
+
relevance_tasks = []
|
| 87 |
+
for session_ctx in session_contexts:
|
| 88 |
+
task = self._calculate_relevance_with_cache(
|
| 89 |
+
current_topic,
|
| 90 |
+
current_input,
|
| 91 |
+
session_ctx
|
| 92 |
+
)
|
| 93 |
+
relevance_tasks.append((session_ctx, task))
|
| 94 |
+
|
| 95 |
+
# Execute all relevance calculations in parallel
|
| 96 |
+
relevance_results = await asyncio.gather(
|
| 97 |
+
*[task for _, task in relevance_tasks],
|
| 98 |
+
return_exceptions=True
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Filter relevant sessions (score >= 0.6)
|
| 102 |
+
relevant_sessions = []
|
| 103 |
+
relevance_scores = {}
|
| 104 |
+
|
| 105 |
+
for (session_ctx, _), result in zip(relevance_tasks, relevance_results):
|
| 106 |
+
if isinstance(result, Exception):
|
| 107 |
+
logger.error(f"Error calculating relevance: {result}")
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
session_id = session_ctx.get('session_id', 'unknown')
|
| 111 |
+
score = result.get('score', 0.0)
|
| 112 |
+
relevance_scores[session_id] = score
|
| 113 |
+
|
| 114 |
+
if score >= 0.6: # Relevance threshold
|
| 115 |
+
relevant_sessions.append({
|
| 116 |
+
'session_id': session_id,
|
| 117 |
+
'summary': session_ctx.get('summary', ''),
|
| 118 |
+
'relevance_score': score,
|
| 119 |
+
'interaction_contexts': session_ctx.get('interaction_contexts', []),
|
| 120 |
+
'created_at': session_ctx.get('created_at', '')
|
| 121 |
+
})
|
| 122 |
+
|
| 123 |
+
logger.info(f"Found {len(relevant_sessions)} relevant sessions out of {len(session_contexts)}")
|
| 124 |
+
|
| 125 |
+
# Step 3: Generate 2-line summaries for relevant sessions (parallel)
|
| 126 |
+
summary_tasks = []
|
| 127 |
+
for relevant_session in relevant_sessions:
|
| 128 |
+
task = self._generate_session_summary(
|
| 129 |
+
relevant_session,
|
| 130 |
+
current_input,
|
| 131 |
+
current_topic
|
| 132 |
+
)
|
| 133 |
+
summary_tasks.append(task)
|
| 134 |
+
|
| 135 |
+
# Execute all summaries in parallel
|
| 136 |
+
summary_results = await asyncio.gather(*summary_tasks, return_exceptions=True)
|
| 137 |
+
|
| 138 |
+
# Filter valid summaries
|
| 139 |
+
valid_summaries = []
|
| 140 |
+
for summary in summary_results:
|
| 141 |
+
if isinstance(summary, str) and summary.strip():
|
| 142 |
+
valid_summaries.append(summary.strip())
|
| 143 |
+
elif isinstance(summary, Exception):
|
| 144 |
+
logger.error(f"Error generating summary: {summary}")
|
| 145 |
+
|
| 146 |
+
# Step 4: Combine summaries into dynamic user context
|
| 147 |
+
combined_user_context = self._combine_summaries(valid_summaries, current_topic)
|
| 148 |
+
|
| 149 |
+
processing_time = (datetime.now() - start_time).total_seconds()
|
| 150 |
+
|
| 151 |
+
logger.info(
|
| 152 |
+
f"Relevance classification complete: {len(valid_summaries)} summaries, "
|
| 153 |
+
f"topic '{current_topic}', time: {processing_time:.2f}s"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return {
|
| 157 |
+
'relevant_summaries': valid_summaries,
|
| 158 |
+
'combined_user_context': combined_user_context,
|
| 159 |
+
'relevance_scores': relevance_scores,
|
| 160 |
+
'classification_confidence': 0.8,
|
| 161 |
+
'topic': current_topic,
|
| 162 |
+
'processing_time': processing_time
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
logger.error(f"Error in relevance classification: {e}", exc_info=True)
|
| 167 |
+
processing_time = (datetime.now() - start_time).total_seconds()
|
| 168 |
+
|
| 169 |
+
# SAFE FALLBACK: Return empty result (no degradation)
|
| 170 |
+
return {
|
| 171 |
+
'relevant_summaries': [],
|
| 172 |
+
'combined_user_context': '',
|
| 173 |
+
'relevance_scores': {},
|
| 174 |
+
'classification_confidence': 0.0,
|
| 175 |
+
'topic': '',
|
| 176 |
+
'processing_time': processing_time,
|
| 177 |
+
'error': str(e)
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
async def _extract_current_topic(self, user_input: str) -> str:
|
| 181 |
+
"""
|
| 182 |
+
Extract main topic from current input using LLM inference
|
| 183 |
+
|
| 184 |
+
Performance: Single LLM call with caching
|
| 185 |
+
"""
|
| 186 |
+
try:
|
| 187 |
+
# Check cache first
|
| 188 |
+
cache_key = f"topic_{hash(user_input[:200])}"
|
| 189 |
+
if cache_key in self._relevance_cache:
|
| 190 |
+
cached = self._relevance_cache[cache_key]
|
| 191 |
+
if cached.get('timestamp', 0) + self._cache_ttl > datetime.now().timestamp():
|
| 192 |
+
return cached['value']
|
| 193 |
+
|
| 194 |
+
if not self.llm_router:
|
| 195 |
+
# Fallback: Simple extraction
|
| 196 |
+
words = user_input.split()[:5]
|
| 197 |
+
return ' '.join(words) if words else 'general query'
|
| 198 |
+
|
| 199 |
+
prompt = f"""Extract the main topic (2-5 words) from this query:
|
| 200 |
+
|
| 201 |
+
Query: "{user_input}"
|
| 202 |
+
|
| 203 |
+
Respond with ONLY the topic name. Maximum 5 words."""
|
| 204 |
+
|
| 205 |
+
result = await self.llm_router.route_inference(
|
| 206 |
+
task_type="classification",
|
| 207 |
+
prompt=prompt,
|
| 208 |
+
max_tokens=20,
|
| 209 |
+
temperature=0.2 # Low temperature for consistency
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
topic = result.strip() if result else user_input[:100]
|
| 213 |
+
|
| 214 |
+
# Cache result
|
| 215 |
+
self._relevance_cache[cache_key] = {
|
| 216 |
+
'value': topic,
|
| 217 |
+
'timestamp': datetime.now().timestamp()
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
return topic
|
| 221 |
+
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.error(f"Error extracting topic: {e}", exc_info=True)
|
| 224 |
+
# Fallback
|
| 225 |
+
return user_input[:100]
|
| 226 |
+
|
| 227 |
+
async def _calculate_relevance_with_cache(self,
|
| 228 |
+
current_topic: str,
|
| 229 |
+
current_input: str,
|
| 230 |
+
session_ctx: Dict) -> Dict:
|
| 231 |
+
"""
|
| 232 |
+
Calculate relevance score with caching to reduce LLM calls
|
| 233 |
+
|
| 234 |
+
Returns: {'score': float, 'cached': bool}
|
| 235 |
+
"""
|
| 236 |
+
try:
|
| 237 |
+
session_id = session_ctx.get('session_id', 'unknown')
|
| 238 |
+
session_summary = session_ctx.get('summary', '')
|
| 239 |
+
|
| 240 |
+
# Check cache
|
| 241 |
+
cache_key = f"rel_{session_id}_{hash(current_input[:100] + current_topic)}"
|
| 242 |
+
if cache_key in self._relevance_cache:
|
| 243 |
+
cached = self._relevance_cache[cache_key]
|
| 244 |
+
if cached.get('timestamp', 0) + self._cache_ttl > datetime.now().timestamp():
|
| 245 |
+
return {'score': cached['value'], 'cached': True}
|
| 246 |
+
|
| 247 |
+
# Calculate relevance
|
| 248 |
+
score = await self._calculate_relevance(
|
| 249 |
+
current_topic,
|
| 250 |
+
current_input,
|
| 251 |
+
session_summary
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Cache result
|
| 255 |
+
self._relevance_cache[cache_key] = {
|
| 256 |
+
'value': score,
|
| 257 |
+
'timestamp': datetime.now().timestamp()
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
return {'score': score, 'cached': False}
|
| 261 |
+
|
| 262 |
+
except Exception as e:
|
| 263 |
+
logger.error(f"Error in cached relevance calculation: {e}", exc_info=True)
|
| 264 |
+
return {'score': 0.5, 'cached': False} # Neutral score on error
|
| 265 |
+
|
| 266 |
+
async def _calculate_relevance(self,
|
| 267 |
+
current_topic: str,
|
| 268 |
+
current_input: str,
|
| 269 |
+
context_text: str) -> float:
|
| 270 |
+
"""
|
| 271 |
+
Calculate relevance score (0.0 to 1.0) using LLM inference
|
| 272 |
+
|
| 273 |
+
Performance: Single LLM call per session context
|
| 274 |
+
"""
|
| 275 |
+
try:
|
| 276 |
+
if not context_text:
|
| 277 |
+
return 0.0
|
| 278 |
+
|
| 279 |
+
if not self.llm_router:
|
| 280 |
+
# Fallback: Keyword matching
|
| 281 |
+
return self._simple_keyword_relevance(current_input, context_text)
|
| 282 |
+
|
| 283 |
+
# OPTION A: Direct relevance scoring (faster, single call)
|
| 284 |
+
# OPTION B: Detailed analysis (more accurate, more tokens)
|
| 285 |
+
# Choosing OPTION A for performance, but with quality prompt
|
| 286 |
+
|
| 287 |
+
prompt = f"""Rate the relevance (0.0 to 1.0) of this session context to the current conversation.
|
| 288 |
+
|
| 289 |
+
Current Topic: {current_topic}
|
| 290 |
+
Current Query: "{current_input[:200]}"
|
| 291 |
+
|
| 292 |
+
Session Context:
|
| 293 |
+
"{context_text[:500]}"
|
| 294 |
+
|
| 295 |
+
Consider:
|
| 296 |
+
- Topic similarity (0.0-1.0)
|
| 297 |
+
- Discussion depth alignment
|
| 298 |
+
- Information continuity
|
| 299 |
+
|
| 300 |
+
Respond with ONLY a number between 0.0 and 1.0 (e.g., 0.75)."""
|
| 301 |
+
|
| 302 |
+
result = await self.llm_router.route_inference(
|
| 303 |
+
task_type="general_reasoning",
|
| 304 |
+
prompt=prompt,
|
| 305 |
+
max_tokens=10,
|
| 306 |
+
temperature=0.1 # Very low for consistency
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if result:
|
| 310 |
+
try:
|
| 311 |
+
score = float(result.strip())
|
| 312 |
+
return max(0.0, min(1.0, score)) # Clamp to [0, 1]
|
| 313 |
+
except ValueError:
|
| 314 |
+
logger.warning(f"Could not parse relevance score: {result}")
|
| 315 |
+
|
| 316 |
+
# Fallback to keyword matching
|
| 317 |
+
return self._simple_keyword_relevance(current_input, context_text)
|
| 318 |
+
|
| 319 |
+
except Exception as e:
|
| 320 |
+
logger.error(f"Error calculating relevance: {e}", exc_info=True)
|
| 321 |
+
return 0.5 # Neutral score on error
|
| 322 |
+
|
| 323 |
+
def _simple_keyword_relevance(self, current_input: str, context_text: str) -> float:
|
| 324 |
+
"""Fallback keyword-based relevance calculation"""
|
| 325 |
+
try:
|
| 326 |
+
current_lower = current_input.lower()
|
| 327 |
+
context_lower = context_text.lower()
|
| 328 |
+
|
| 329 |
+
current_words = set(current_lower.split())
|
| 330 |
+
context_words = set(context_lower.split())
|
| 331 |
+
|
| 332 |
+
# Remove common stop words for better matching
|
| 333 |
+
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
|
| 334 |
+
current_words = current_words - stop_words
|
| 335 |
+
context_words = context_words - stop_words
|
| 336 |
+
|
| 337 |
+
if not current_words:
|
| 338 |
+
return 0.5
|
| 339 |
+
|
| 340 |
+
# Jaccard similarity
|
| 341 |
+
intersection = len(current_words & context_words)
|
| 342 |
+
union = len(current_words | context_words)
|
| 343 |
+
|
| 344 |
+
return (intersection / union) if union > 0 else 0.0
|
| 345 |
+
|
| 346 |
+
except Exception:
|
| 347 |
+
return 0.5
|
| 348 |
+
|
| 349 |
+
async def _generate_session_summary(self,
|
| 350 |
+
session_data: Dict,
|
| 351 |
+
current_input: str,
|
| 352 |
+
current_topic: str) -> str:
|
| 353 |
+
"""
|
| 354 |
+
Generate 2-line summary for a relevant session context
|
| 355 |
+
|
| 356 |
+
Performance: LLM inference with caching and timeout protection
|
| 357 |
+
Builds depth and width of topic discussion
|
| 358 |
+
"""
|
| 359 |
+
try:
|
| 360 |
+
session_id = session_data.get('session_id', 'unknown')
|
| 361 |
+
session_summary = session_data.get('summary', '')
|
| 362 |
+
interaction_contexts = session_data.get('interaction_contexts', [])
|
| 363 |
+
|
| 364 |
+
# Check cache
|
| 365 |
+
cache_key = f"summary_{session_id}_{hash(current_topic)}"
|
| 366 |
+
if cache_key in self._summary_cache:
|
| 367 |
+
cached = self._summary_cache[cache_key]
|
| 368 |
+
if cached.get('timestamp', 0) + self._cache_ttl > datetime.now().timestamp():
|
| 369 |
+
return cached['value']
|
| 370 |
+
|
| 371 |
+
# Validation: Ensure content available
|
| 372 |
+
if not session_summary and not interaction_contexts:
|
| 373 |
+
logger.warning(f"No content for summarization: session {session_id}")
|
| 374 |
+
return f"Previous discussion on {current_topic}.\nContext details unavailable."
|
| 375 |
+
|
| 376 |
+
# Build context text with limits
|
| 377 |
+
session_context_text = session_summary[:500] if session_summary else ""
|
| 378 |
+
|
| 379 |
+
if interaction_contexts:
|
| 380 |
+
recent_interactions = "\n".join([
|
| 381 |
+
ic.get('summary', '')[:100]
|
| 382 |
+
for ic in interaction_contexts[-5:]
|
| 383 |
+
if ic.get('summary')
|
| 384 |
+
])
|
| 385 |
+
if recent_interactions:
|
| 386 |
+
session_context_text = f"{session_context_text}\n\nRecent interactions:\n{recent_interactions[:400]}"
|
| 387 |
+
|
| 388 |
+
# Limit total context
|
| 389 |
+
if len(session_context_text) > 1000:
|
| 390 |
+
session_context_text = session_context_text[:1000] + "..."
|
| 391 |
+
|
| 392 |
+
if not self.llm_router:
|
| 393 |
+
# Fallback
|
| 394 |
+
return f"Previous {current_topic} discussion.\nCovered: {session_summary[:80]}..."
|
| 395 |
+
|
| 396 |
+
# LLM-based summarization with timeout
|
| 397 |
+
prompt = f"""Generate a precise 2-line summary (maximum 2 sentences, ~100 tokens total) that captures the depth and breadth of the topic discussion:
|
| 398 |
+
|
| 399 |
+
Current Topic: {current_topic}
|
| 400 |
+
Current Query: "{current_input[:150]}"
|
| 401 |
+
|
| 402 |
+
Previous Session Context:
|
| 403 |
+
{session_context_text}
|
| 404 |
+
|
| 405 |
+
Requirements:
|
| 406 |
+
- Line 1: Summarize the MAIN TOPICS/SUBJECTS discussed (breadth/width)
|
| 407 |
+
- Line 2: Summarize the DEPTH/LEVEL of discussion (technical depth, detail level, approach)
|
| 408 |
+
- Focus on relevance to: "{current_topic}"
|
| 409 |
+
- Keep total under 100 tokens
|
| 410 |
+
- Be specific about what was covered
|
| 411 |
+
|
| 412 |
+
Respond with ONLY the 2-line summary, no explanations."""
|
| 413 |
+
|
| 414 |
+
try:
|
| 415 |
+
result = await asyncio.wait_for(
|
| 416 |
+
self.llm_router.route_inference(
|
| 417 |
+
task_type="general_reasoning",
|
| 418 |
+
prompt=prompt,
|
| 419 |
+
max_tokens=100,
|
| 420 |
+
temperature=0.4
|
| 421 |
+
),
|
| 422 |
+
timeout=10.0 # 10 second timeout
|
| 423 |
+
)
|
| 424 |
+
except asyncio.TimeoutError:
|
| 425 |
+
logger.warning(f"Summary generation timeout for session {session_id}")
|
| 426 |
+
return f"Previous {current_topic} discussion.\nDepth and approach covered in prior session."
|
| 427 |
+
|
| 428 |
+
# Validate and format result
|
| 429 |
+
if result and isinstance(result, str) and result.strip():
|
| 430 |
+
summary = result.strip()
|
| 431 |
+
lines = [line.strip() for line in summary.split('\n') if line.strip()]
|
| 432 |
+
|
| 433 |
+
if len(lines) >= 1:
|
| 434 |
+
if len(lines) > 2:
|
| 435 |
+
combined = f"{lines[0]}\n{'. '.join(lines[1:])}"
|
| 436 |
+
formatted_summary = combined[:200]
|
| 437 |
+
else:
|
| 438 |
+
formatted_summary = '\n'.join(lines[:2])[:200]
|
| 439 |
+
|
| 440 |
+
# Ensure minimum quality
|
| 441 |
+
if len(formatted_summary) < 20:
|
| 442 |
+
formatted_summary = f"Previous {current_topic} discussion.\nDetails from previous session."
|
| 443 |
+
|
| 444 |
+
# Cache result
|
| 445 |
+
self._summary_cache[cache_key] = {
|
| 446 |
+
'value': formatted_summary,
|
| 447 |
+
'timestamp': datetime.now().timestamp()
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
return formatted_summary
|
| 451 |
+
else:
|
| 452 |
+
return f"Previous {current_topic} discussion.\nContext from previous session."
|
| 453 |
+
|
| 454 |
+
# Invalid result fallback
|
| 455 |
+
logger.warning(f"Invalid summary result for session {session_id}")
|
| 456 |
+
return f"Previous {current_topic} discussion.\nDepth and approach covered previously."
|
| 457 |
+
|
| 458 |
+
except Exception as e:
|
| 459 |
+
logger.error(f"Error generating session summary: {e}", exc_info=True)
|
| 460 |
+
session_summary = session_data.get('summary', '')[:100] if session_data.get('summary') else 'topic discussion'
|
| 461 |
+
return f"{session_summary}...\n{current_topic} discussion from previous session."
|
| 462 |
+
|
| 463 |
+
def _combine_summaries(self, summaries: List[str], current_topic: str) -> str:
|
| 464 |
+
"""
|
| 465 |
+
Combine multiple 2-line summaries into coherent user context
|
| 466 |
+
|
| 467 |
+
Builds width (multiple topics) and depth (summarized discussions)
|
| 468 |
+
"""
|
| 469 |
+
try:
|
| 470 |
+
if not summaries:
|
| 471 |
+
return ''
|
| 472 |
+
|
| 473 |
+
if len(summaries) == 1:
|
| 474 |
+
return summaries[0]
|
| 475 |
+
|
| 476 |
+
# Format combined summaries with topic focus
|
| 477 |
+
combined = f"Relevant Previous Discussions (Topic: {current_topic}):\n\n"
|
| 478 |
+
|
| 479 |
+
for idx, summary in enumerate(summaries, 1):
|
| 480 |
+
combined += f"[Session {idx}]\n{summary}\n\n"
|
| 481 |
+
|
| 482 |
+
# Add summary statement
|
| 483 |
+
combined += f"These sessions provide context for {current_topic} discussions, covering multiple aspects and depth levels."
|
| 484 |
+
|
| 485 |
+
return combined
|
| 486 |
+
|
| 487 |
+
except Exception as e:
|
| 488 |
+
logger.error(f"Error combining summaries: {e}", exc_info=True)
|
| 489 |
+
# Simple fallback
|
| 490 |
+
return '\n\n'.join(summaries[:5])
|
| 491 |
+
|
src/database.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database initialization and management
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sqlite3
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class DatabaseManager:
|
| 13 |
+
def __init__(self, db_path: str = "sessions.db"):
|
| 14 |
+
self.db_path = db_path
|
| 15 |
+
self.connection = None
|
| 16 |
+
self._init_db()
|
| 17 |
+
|
| 18 |
+
def _init_db(self):
|
| 19 |
+
"""Initialize database with required tables"""
|
| 20 |
+
try:
|
| 21 |
+
# Create database directory if needed
|
| 22 |
+
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
| 23 |
+
|
| 24 |
+
self.connection = sqlite3.connect(self.db_path, check_same_thread=False)
|
| 25 |
+
self.connection.row_factory = sqlite3.Row
|
| 26 |
+
|
| 27 |
+
# Create tables
|
| 28 |
+
self._create_tables()
|
| 29 |
+
logger.info(f"Database initialized at {self.db_path}")
|
| 30 |
+
|
| 31 |
+
except Exception as e:
|
| 32 |
+
logger.error(f"Database initialization failed: {e}")
|
| 33 |
+
# Fallback to in-memory database
|
| 34 |
+
self.connection = sqlite3.connect(":memory:", check_same_thread=False)
|
| 35 |
+
self._create_tables()
|
| 36 |
+
logger.info("Using in-memory database as fallback")
|
| 37 |
+
|
| 38 |
+
def _create_tables(self):
|
| 39 |
+
"""Create required database tables"""
|
| 40 |
+
cursor = self.connection.cursor()
|
| 41 |
+
|
| 42 |
+
# Sessions table
|
| 43 |
+
cursor.execute("""
|
| 44 |
+
CREATE TABLE IF NOT EXISTS sessions (
|
| 45 |
+
session_id TEXT PRIMARY KEY,
|
| 46 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 47 |
+
last_activity TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 48 |
+
context_data TEXT,
|
| 49 |
+
user_metadata TEXT
|
| 50 |
+
)
|
| 51 |
+
""")
|
| 52 |
+
|
| 53 |
+
# Interactions table
|
| 54 |
+
cursor.execute("""
|
| 55 |
+
CREATE TABLE IF NOT EXISTS interactions (
|
| 56 |
+
interaction_id TEXT PRIMARY KEY,
|
| 57 |
+
session_id TEXT REFERENCES sessions(session_id),
|
| 58 |
+
user_input TEXT NOT NULL,
|
| 59 |
+
agent_trace TEXT,
|
| 60 |
+
final_response TEXT,
|
| 61 |
+
processing_time INTEGER,
|
| 62 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 63 |
+
)
|
| 64 |
+
""")
|
| 65 |
+
|
| 66 |
+
self.connection.commit()
|
| 67 |
+
logger.info("Database tables created successfully")
|
| 68 |
+
|
| 69 |
+
def get_connection(self):
|
| 70 |
+
"""Get database connection"""
|
| 71 |
+
return self.connection
|
| 72 |
+
|
| 73 |
+
def close(self):
|
| 74 |
+
"""Close database connection"""
|
| 75 |
+
if self.connection:
|
| 76 |
+
self.connection.close()
|
| 77 |
+
logger.info("Database connection closed")
|
| 78 |
+
|
| 79 |
+
# Global database instance
|
| 80 |
+
db_manager = None
|
| 81 |
+
|
| 82 |
+
def init_database(db_path: str = "sessions.db"):
|
| 83 |
+
"""Initialize global database instance"""
|
| 84 |
+
global db_manager
|
| 85 |
+
if db_manager is None:
|
| 86 |
+
db_manager = DatabaseManager(db_path)
|
| 87 |
+
return db_manager
|
| 88 |
+
|
| 89 |
+
def get_db():
|
| 90 |
+
"""Get database connection"""
|
| 91 |
+
global db_manager
|
| 92 |
+
if db_manager is None:
|
| 93 |
+
init_database()
|
| 94 |
+
return db_manager.get_connection()
|
| 95 |
+
|
| 96 |
+
# Initialize database on import
|
| 97 |
+
init_database()
|
src/event_handlers.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Event handlers for connecting UI to backend
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import uuid
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class EventHandlers:
|
| 12 |
+
def __init__(self, components: Dict[str, Any]):
|
| 13 |
+
self.components = components
|
| 14 |
+
self.sessions = {} # In-memory session storage
|
| 15 |
+
|
| 16 |
+
async def handle_message_submit(self, message: str, chat_history: list,
|
| 17 |
+
session_id: str, show_reasoning: bool,
|
| 18 |
+
show_agent_trace: bool, request):
|
| 19 |
+
"""Handle user message submission"""
|
| 20 |
+
try:
|
| 21 |
+
# Ensure session exists
|
| 22 |
+
if session_id not in self.sessions:
|
| 23 |
+
self.sessions[session_id] = {
|
| 24 |
+
'history': [],
|
| 25 |
+
'context': {},
|
| 26 |
+
'created_at': uuid.uuid4().hex
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
# Add user message to history
|
| 30 |
+
chat_history.append((message, None)) # None for pending response
|
| 31 |
+
|
| 32 |
+
# Generate response based on available components
|
| 33 |
+
if self.components.get('mock_mode'):
|
| 34 |
+
response = self._generate_mock_response(message)
|
| 35 |
+
else:
|
| 36 |
+
response = await self._generate_ai_response(message, session_id)
|
| 37 |
+
|
| 38 |
+
# Update chat history with response
|
| 39 |
+
chat_history[-1] = (message, response)
|
| 40 |
+
|
| 41 |
+
# Prepare additional data for UI
|
| 42 |
+
reasoning_data = {}
|
| 43 |
+
performance_data = {}
|
| 44 |
+
|
| 45 |
+
if show_reasoning:
|
| 46 |
+
reasoning_data = {
|
| 47 |
+
"chain_of_thought": {
|
| 48 |
+
"step_1": {
|
| 49 |
+
"hypothesis": "Mock reasoning for demonstration",
|
| 50 |
+
"evidence": ["Mock mode active", f"User input: {message[:50]}..."],
|
| 51 |
+
"confidence": 0.5,
|
| 52 |
+
"reasoning": "Demonstration mode - enhanced reasoning chain not available"
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"alternative_paths": [],
|
| 56 |
+
"uncertainty_areas": [
|
| 57 |
+
{
|
| 58 |
+
"aspect": "System mode",
|
| 59 |
+
"confidence": 0.5,
|
| 60 |
+
"mitigation": "Mock mode - full reasoning chain not available"
|
| 61 |
+
}
|
| 62 |
+
],
|
| 63 |
+
"evidence_sources": [],
|
| 64 |
+
"confidence_calibration": {"overall_confidence": 0.5, "mock_mode": True}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
if show_agent_trace:
|
| 68 |
+
performance_data = {"agents_used": ["intent", "synthesis", "safety"]}
|
| 69 |
+
|
| 70 |
+
return "", chat_history, reasoning_data, performance_data
|
| 71 |
+
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.error(f"Error handling message: {e}")
|
| 74 |
+
error_response = "I apologize, but I'm experiencing technical difficulties. Please try again."
|
| 75 |
+
chat_history.append((message, error_response))
|
| 76 |
+
return "", chat_history, {"error": str(e)}, {"status": "error"}
|
| 77 |
+
|
| 78 |
+
def _generate_mock_response(self, message: str) -> str:
|
| 79 |
+
"""Generate mock response for demonstration"""
|
| 80 |
+
mock_responses = [
|
| 81 |
+
f"I understand you're asking about: {message}. This is a mock response while the AI system initializes.",
|
| 82 |
+
f"Thank you for your question: '{message}'. The research assistant is currently in demonstration mode.",
|
| 83 |
+
f"Interesting question about {message}. In a full implementation, I would analyze this using multiple AI agents.",
|
| 84 |
+
f"I've received your query: '{message}'. The system is working properly in mock mode."
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
import random
|
| 88 |
+
return random.choice(mock_responses)
|
| 89 |
+
|
| 90 |
+
async def _generate_ai_response(self, message: str, session_id: str) -> str:
|
| 91 |
+
"""Generate AI response using orchestrator"""
|
| 92 |
+
try:
|
| 93 |
+
if 'orchestrator' in self.components:
|
| 94 |
+
result = await self.components['orchestrator'].process_request(
|
| 95 |
+
session_id=session_id,
|
| 96 |
+
user_input=message
|
| 97 |
+
)
|
| 98 |
+
return result.get('final_response', 'No response generated')
|
| 99 |
+
else:
|
| 100 |
+
return "Orchestrator not available. Using mock response."
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"AI response generation failed: {e}")
|
| 103 |
+
return f"AI processing error: {str(e)}"
|
| 104 |
+
|
| 105 |
+
def handle_new_session(self):
|
| 106 |
+
"""Handle new session creation"""
|
| 107 |
+
new_session_id = uuid.uuid4().hex[:8] # Short session ID for display
|
| 108 |
+
self.sessions[new_session_id] = {
|
| 109 |
+
'history': [],
|
| 110 |
+
'context': {},
|
| 111 |
+
'created_at': uuid.uuid4().hex
|
| 112 |
+
}
|
| 113 |
+
return new_session_id, [] # New session ID and empty history
|
| 114 |
+
|
| 115 |
+
def handle_settings_toggle(self, current_visibility: bool):
|
| 116 |
+
"""Toggle settings panel visibility"""
|
| 117 |
+
return not current_visibility
|
| 118 |
+
|
| 119 |
+
def handle_tab_change(self, tab_name: str):
|
| 120 |
+
"""Handle tab changes in mobile interface"""
|
| 121 |
+
return tab_name, False # Return tab name and hide mobile nav
|
| 122 |
+
|
| 123 |
+
# Factory function
|
| 124 |
+
def create_event_handlers(components: Dict[str, Any]):
|
| 125 |
+
return EventHandlers(components)
|
src/llm_router.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# llm_router.py - UPDATED FOR LOCAL GPU MODEL LOADING
|
| 2 |
+
import logging
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import Dict, Optional
|
| 5 |
+
from .models_config import LLM_CONFIG
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
class LLMRouter:
|
| 10 |
+
def __init__(self, hf_token, use_local_models: bool = True):
|
| 11 |
+
self.hf_token = hf_token
|
| 12 |
+
self.health_status = {}
|
| 13 |
+
self.use_local_models = use_local_models
|
| 14 |
+
self.local_loader = None
|
| 15 |
+
|
| 16 |
+
logger.info("LLMRouter initialized")
|
| 17 |
+
if hf_token:
|
| 18 |
+
logger.info("HF token available")
|
| 19 |
+
else:
|
| 20 |
+
logger.warning("No HF token provided")
|
| 21 |
+
|
| 22 |
+
# Initialize local model loader if enabled
|
| 23 |
+
if self.use_local_models:
|
| 24 |
+
try:
|
| 25 |
+
from .local_model_loader import LocalModelLoader
|
| 26 |
+
self.local_loader = LocalModelLoader()
|
| 27 |
+
logger.info("✓ Local model loader initialized (GPU-based inference)")
|
| 28 |
+
|
| 29 |
+
# Note: Pre-loading will happen on first request (lazy loading)
|
| 30 |
+
# Models will be loaded on-demand to avoid blocking startup
|
| 31 |
+
logger.info("Models will be loaded on-demand for faster startup")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
logger.warning(f"Could not initialize local model loader: {e}. Falling back to API.")
|
| 34 |
+
logger.warning("This is normal if transformers/torch not available")
|
| 35 |
+
self.use_local_models = False
|
| 36 |
+
self.local_loader = None
|
| 37 |
+
|
| 38 |
+
async def route_inference(self, task_type: str, prompt: str, **kwargs):
|
| 39 |
+
"""
|
| 40 |
+
Smart routing based on task specialization
|
| 41 |
+
Tries local models first, falls back to HF Inference API if needed
|
| 42 |
+
"""
|
| 43 |
+
logger.info(f"Routing inference for task: {task_type}")
|
| 44 |
+
model_config = self._select_model(task_type)
|
| 45 |
+
logger.info(f"Selected model: {model_config['model_id']}")
|
| 46 |
+
|
| 47 |
+
# Try local model first if available
|
| 48 |
+
if self.use_local_models and self.local_loader:
|
| 49 |
+
try:
|
| 50 |
+
# Handle embedding generation separately
|
| 51 |
+
if task_type == "embedding_generation":
|
| 52 |
+
result = await self._call_local_embedding(model_config, prompt, **kwargs)
|
| 53 |
+
else:
|
| 54 |
+
result = await self._call_local_model(model_config, prompt, task_type, **kwargs)
|
| 55 |
+
|
| 56 |
+
if result is not None:
|
| 57 |
+
logger.info(f"Inference complete for {task_type} (local model)")
|
| 58 |
+
return result
|
| 59 |
+
else:
|
| 60 |
+
logger.warning("Local model returned None, falling back to API")
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.warning(f"Local model inference failed: {e}. Falling back to API.")
|
| 63 |
+
logger.debug("Exception details:", exc_info=True)
|
| 64 |
+
|
| 65 |
+
# Fallback to HF Inference API
|
| 66 |
+
logger.info("Using HF Inference API")
|
| 67 |
+
# Health check and fallback logic
|
| 68 |
+
if not await self._is_model_healthy(model_config["model_id"]):
|
| 69 |
+
logger.warning(f"Model unhealthy, using fallback")
|
| 70 |
+
model_config = self._get_fallback_model(task_type)
|
| 71 |
+
logger.info(f"Fallback model: {model_config['model_id']}")
|
| 72 |
+
|
| 73 |
+
result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
|
| 74 |
+
logger.info(f"Inference complete for {task_type}")
|
| 75 |
+
return result
|
| 76 |
+
|
| 77 |
+
async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]:
|
| 78 |
+
"""Call local model for inference."""
|
| 79 |
+
if not self.local_loader:
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
model_id = model_config["model_id"]
|
| 83 |
+
max_tokens = kwargs.get('max_tokens', 512)
|
| 84 |
+
temperature = kwargs.get('temperature', 0.7)
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
# Ensure model is loaded
|
| 88 |
+
if model_id not in self.local_loader.loaded_models:
|
| 89 |
+
logger.info(f"Loading model {model_id} on demand...")
|
| 90 |
+
self.local_loader.load_chat_model(model_id, load_in_8bit=False)
|
| 91 |
+
|
| 92 |
+
# Format as chat messages if needed
|
| 93 |
+
messages = [{"role": "user", "content": prompt}]
|
| 94 |
+
|
| 95 |
+
# Generate using local model
|
| 96 |
+
result = await asyncio.to_thread(
|
| 97 |
+
self.local_loader.generate_chat_completion,
|
| 98 |
+
model_id=model_id,
|
| 99 |
+
messages=messages,
|
| 100 |
+
max_tokens=max_tokens,
|
| 101 |
+
temperature=temperature
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
logger.info(f"Local model {model_id} generated response (length: {len(result)})")
|
| 105 |
+
logger.info("=" * 80)
|
| 106 |
+
logger.info("LOCAL MODEL RESPONSE:")
|
| 107 |
+
logger.info("=" * 80)
|
| 108 |
+
logger.info(f"Model: {model_id}")
|
| 109 |
+
logger.info(f"Task Type: {task_type}")
|
| 110 |
+
logger.info(f"Response Length: {len(result)} characters")
|
| 111 |
+
logger.info("-" * 40)
|
| 112 |
+
logger.info("FULL RESPONSE CONTENT:")
|
| 113 |
+
logger.info("-" * 40)
|
| 114 |
+
logger.info(result)
|
| 115 |
+
logger.info("-" * 40)
|
| 116 |
+
logger.info("END OF RESPONSE")
|
| 117 |
+
logger.info("=" * 80)
|
| 118 |
+
|
| 119 |
+
return result
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.error(f"Error calling local model: {e}", exc_info=True)
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]:
|
| 126 |
+
"""Call local embedding model."""
|
| 127 |
+
if not self.local_loader:
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
model_id = model_config["model_id"]
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
# Ensure model is loaded
|
| 134 |
+
if model_id not in self.local_loader.loaded_embedding_models:
|
| 135 |
+
logger.info(f"Loading embedding model {model_id} on demand...")
|
| 136 |
+
self.local_loader.load_embedding_model(model_id)
|
| 137 |
+
|
| 138 |
+
# Generate embedding
|
| 139 |
+
embedding = await asyncio.to_thread(
|
| 140 |
+
self.local_loader.get_embedding,
|
| 141 |
+
model_id=model_id,
|
| 142 |
+
text=text
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
logger.info(f"Local embedding model {model_id} generated vector (dim: {len(embedding)})")
|
| 146 |
+
return embedding
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
logger.error(f"Error calling local embedding model: {e}", exc_info=True)
|
| 150 |
+
return None
|
| 151 |
+
|
| 152 |
+
def _select_model(self, task_type: str) -> dict:
|
| 153 |
+
model_map = {
|
| 154 |
+
"intent_classification": LLM_CONFIG["models"]["classification_specialist"],
|
| 155 |
+
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
|
| 156 |
+
"safety_check": LLM_CONFIG["models"]["safety_checker"],
|
| 157 |
+
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
|
| 158 |
+
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
|
| 159 |
+
}
|
| 160 |
+
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
|
| 161 |
+
|
| 162 |
+
async def _is_model_healthy(self, model_id: str) -> bool:
|
| 163 |
+
"""
|
| 164 |
+
Check if the model is healthy and available
|
| 165 |
+
Mark models as healthy by default - actual availability checked at API call time
|
| 166 |
+
"""
|
| 167 |
+
# Check cached health status
|
| 168 |
+
if model_id in self.health_status:
|
| 169 |
+
return self.health_status[model_id]
|
| 170 |
+
|
| 171 |
+
# All models marked healthy initially - real check happens during API call
|
| 172 |
+
self.health_status[model_id] = True
|
| 173 |
+
return True
|
| 174 |
+
|
| 175 |
+
def _get_fallback_model(self, task_type: str) -> dict:
|
| 176 |
+
"""
|
| 177 |
+
Get fallback model configuration for the task type
|
| 178 |
+
"""
|
| 179 |
+
# Fallback mapping
|
| 180 |
+
fallback_map = {
|
| 181 |
+
"intent_classification": LLM_CONFIG["models"]["reasoning_primary"],
|
| 182 |
+
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
|
| 183 |
+
"safety_check": LLM_CONFIG["models"]["reasoning_primary"],
|
| 184 |
+
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
|
| 185 |
+
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
|
| 186 |
+
}
|
| 187 |
+
return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
|
| 188 |
+
|
| 189 |
+
async def _call_hf_endpoint(self, model_config: dict, prompt: str, task_type: str, **kwargs):
|
| 190 |
+
"""
|
| 191 |
+
FIXED: Make actual call to Hugging Face Chat Completions API
|
| 192 |
+
Uses the correct chat completions protocol with retry logic and exponential backoff
|
| 193 |
+
|
| 194 |
+
IMPORTANT: task_type parameter is now properly included in the method signature
|
| 195 |
+
"""
|
| 196 |
+
# Retry configuration
|
| 197 |
+
max_retries = kwargs.get('max_retries', 3)
|
| 198 |
+
initial_delay = kwargs.get('initial_delay', 1.0) # Start with 1 second
|
| 199 |
+
max_delay = kwargs.get('max_delay', 16.0) # Cap at 16 seconds
|
| 200 |
+
timeout = kwargs.get('timeout', 30)
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
import requests
|
| 204 |
+
from requests.exceptions import Timeout, RequestException, ConnectionError as RequestsConnectionError
|
| 205 |
+
|
| 206 |
+
model_id = model_config["model_id"]
|
| 207 |
+
|
| 208 |
+
# Use the chat completions endpoint
|
| 209 |
+
api_url = "https://router.huggingface.co/v1/chat/completions"
|
| 210 |
+
|
| 211 |
+
logger.info(f"Calling HF Chat Completions API for model: {model_id}")
|
| 212 |
+
logger.debug(f"Prompt length: {len(prompt)}")
|
| 213 |
+
logger.info("=" * 80)
|
| 214 |
+
logger.info("LLM API REQUEST - COMPLETE PROMPT:")
|
| 215 |
+
logger.info("=" * 80)
|
| 216 |
+
logger.info(f"Model: {model_id}")
|
| 217 |
+
|
| 218 |
+
# FIXED: task_type is now properly available as a parameter
|
| 219 |
+
logger.info(f"Task Type: {task_type}")
|
| 220 |
+
logger.info(f"Prompt Length: {len(prompt)} characters")
|
| 221 |
+
logger.info("-" * 40)
|
| 222 |
+
logger.info("FULL PROMPT CONTENT:")
|
| 223 |
+
logger.info("-" * 40)
|
| 224 |
+
logger.info(prompt)
|
| 225 |
+
logger.info("-" * 40)
|
| 226 |
+
logger.info("END OF PROMPT")
|
| 227 |
+
logger.info("=" * 80)
|
| 228 |
+
|
| 229 |
+
# Prepare the request payload
|
| 230 |
+
max_tokens = kwargs.get('max_tokens', 512)
|
| 231 |
+
temperature = kwargs.get('temperature', 0.7)
|
| 232 |
+
|
| 233 |
+
payload = {
|
| 234 |
+
"model": model_id,
|
| 235 |
+
"messages": [
|
| 236 |
+
{
|
| 237 |
+
"role": "user",
|
| 238 |
+
"content": prompt
|
| 239 |
+
}
|
| 240 |
+
],
|
| 241 |
+
"max_tokens": max_tokens,
|
| 242 |
+
"temperature": temperature,
|
| 243 |
+
"stream": False
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
headers = {
|
| 247 |
+
"Authorization": f"Bearer {self.hf_token}",
|
| 248 |
+
"Content-Type": "application/json"
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
# Retry logic with exponential backoff
|
| 252 |
+
last_exception = None
|
| 253 |
+
for attempt in range(max_retries + 1):
|
| 254 |
+
try:
|
| 255 |
+
if attempt > 0:
|
| 256 |
+
# Calculate exponential backoff delay
|
| 257 |
+
delay = min(initial_delay * (2 ** (attempt - 1)), max_delay)
|
| 258 |
+
logger.warning(f"Retry attempt {attempt}/{max_retries} after {delay:.1f}s delay (exponential backoff)")
|
| 259 |
+
await asyncio.sleep(delay)
|
| 260 |
+
|
| 261 |
+
logger.info(f"Sending request to: {api_url} (attempt {attempt + 1}/{max_retries + 1})")
|
| 262 |
+
logger.debug(f"Payload: {payload}")
|
| 263 |
+
|
| 264 |
+
response = requests.post(api_url, json=payload, headers=headers, timeout=timeout)
|
| 265 |
+
|
| 266 |
+
if response.status_code == 200:
|
| 267 |
+
result = response.json()
|
| 268 |
+
logger.debug(f"Raw response: {result}")
|
| 269 |
+
|
| 270 |
+
if 'choices' in result and len(result['choices']) > 0:
|
| 271 |
+
generated_text = result['choices'][0]['message']['content']
|
| 272 |
+
|
| 273 |
+
if not generated_text or generated_text.strip() == "":
|
| 274 |
+
logger.warning(f"Empty or invalid response, using fallback")
|
| 275 |
+
return None
|
| 276 |
+
|
| 277 |
+
if attempt > 0:
|
| 278 |
+
logger.info(f"Successfully retrieved response after {attempt} retry attempts")
|
| 279 |
+
|
| 280 |
+
logger.info(f"HF API returned response (length: {len(generated_text)})")
|
| 281 |
+
logger.info("=" * 80)
|
| 282 |
+
logger.info("COMPLETE LLM API RESPONSE:")
|
| 283 |
+
logger.info("=" * 80)
|
| 284 |
+
logger.info(f"Model: {model_id}")
|
| 285 |
+
|
| 286 |
+
# FIXED: task_type is now properly available
|
| 287 |
+
logger.info(f"Task Type: {task_type}")
|
| 288 |
+
logger.info(f"Response Length: {len(generated_text)} characters")
|
| 289 |
+
logger.info("-" * 40)
|
| 290 |
+
logger.info("FULL RESPONSE CONTENT:")
|
| 291 |
+
logger.info("-" * 40)
|
| 292 |
+
logger.info(generated_text)
|
| 293 |
+
logger.info("-" * 40)
|
| 294 |
+
logger.info("END OF LLM RESPONSE")
|
| 295 |
+
logger.info("=" * 80)
|
| 296 |
+
return generated_text
|
| 297 |
+
else:
|
| 298 |
+
logger.error(f"Unexpected response format: {result}")
|
| 299 |
+
return None
|
| 300 |
+
elif response.status_code == 503:
|
| 301 |
+
# Model is loading - this is retryable
|
| 302 |
+
if attempt < max_retries:
|
| 303 |
+
logger.warning(f"Model loading (503), will retry (attempt {attempt + 1}/{max_retries + 1})")
|
| 304 |
+
last_exception = Exception(f"Model loading (503)")
|
| 305 |
+
continue
|
| 306 |
+
else:
|
| 307 |
+
# After max retries, try fallback model
|
| 308 |
+
logger.warning(f"Model loading (503) after {max_retries} retries, trying fallback model")
|
| 309 |
+
fallback_config = self._get_fallback_model(task_type)
|
| 310 |
+
|
| 311 |
+
# FIXED: Ensure task_type is passed in recursive call
|
| 312 |
+
return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
|
| 313 |
+
else:
|
| 314 |
+
# Non-retryable HTTP errors
|
| 315 |
+
logger.error(f"HF API error: {response.status_code} - {response.text}")
|
| 316 |
+
return None
|
| 317 |
+
|
| 318 |
+
except Timeout as e:
|
| 319 |
+
last_exception = e
|
| 320 |
+
if attempt < max_retries:
|
| 321 |
+
logger.warning(f"Request timeout (attempt {attempt + 1}/{max_retries + 1}): {str(e)}")
|
| 322 |
+
continue
|
| 323 |
+
else:
|
| 324 |
+
logger.error(f"Request timeout after {max_retries} retries: {str(e)}")
|
| 325 |
+
# Try fallback model on final timeout
|
| 326 |
+
logger.warning("Attempting fallback model due to persistent timeout")
|
| 327 |
+
fallback_config = self._get_fallback_model(task_type)
|
| 328 |
+
return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
|
| 329 |
+
|
| 330 |
+
except (RequestsConnectionError, RequestException) as e:
|
| 331 |
+
last_exception = e
|
| 332 |
+
if attempt < max_retries:
|
| 333 |
+
logger.warning(f"Connection error (attempt {attempt + 1}/{max_retries + 1}): {str(e)}")
|
| 334 |
+
continue
|
| 335 |
+
else:
|
| 336 |
+
logger.error(f"Connection error after {max_retries} retries: {str(e)}")
|
| 337 |
+
# Try fallback model on final connection error
|
| 338 |
+
logger.warning("Attempting fallback model due to persistent connection error")
|
| 339 |
+
fallback_config = self._get_fallback_model(task_type)
|
| 340 |
+
return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
|
| 341 |
+
|
| 342 |
+
# If we exhausted all retries and didn't return
|
| 343 |
+
if last_exception:
|
| 344 |
+
logger.error(f"Failed after {max_retries} retries. Last error: {last_exception}")
|
| 345 |
+
return None
|
| 346 |
+
|
| 347 |
+
except ImportError:
|
| 348 |
+
logger.warning("requests library not available, using mock response")
|
| 349 |
+
return f"[Mock] Response to: {prompt[:100]}..."
|
| 350 |
+
except Exception as e:
|
| 351 |
+
logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
|
| 352 |
+
return None
|
| 353 |
+
|
| 354 |
+
async def get_available_models(self):
|
| 355 |
+
"""
|
| 356 |
+
Get list of available models for testing
|
| 357 |
+
"""
|
| 358 |
+
return list(LLM_CONFIG["models"].keys())
|
| 359 |
+
|
| 360 |
+
async def health_check(self):
|
| 361 |
+
"""
|
| 362 |
+
Perform health check on all models
|
| 363 |
+
"""
|
| 364 |
+
health_status = {}
|
| 365 |
+
for model_name, model_config in LLM_CONFIG["models"].items():
|
| 366 |
+
model_id = model_config["model_id"]
|
| 367 |
+
is_healthy = await self._is_model_healthy(model_id)
|
| 368 |
+
health_status[model_name] = {
|
| 369 |
+
"model_id": model_id,
|
| 370 |
+
"healthy": is_healthy
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
return health_status
|
| 374 |
+
|
| 375 |
+
def prepare_context_for_llm(self, raw_context: Dict, max_tokens: int = 4000) -> str:
|
| 376 |
+
"""Smart context windowing for LLM calls"""
|
| 377 |
+
|
| 378 |
+
try:
|
| 379 |
+
from transformers import AutoTokenizer
|
| 380 |
+
|
| 381 |
+
# Initialize tokenizer lazily
|
| 382 |
+
if not hasattr(self, 'tokenizer'):
|
| 383 |
+
try:
|
| 384 |
+
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
|
| 385 |
+
except Exception as e:
|
| 386 |
+
logger.warning(f"Could not load tokenizer: {e}, using character count estimation")
|
| 387 |
+
self.tokenizer = None
|
| 388 |
+
except ImportError:
|
| 389 |
+
logger.warning("transformers library not available, using character count estimation")
|
| 390 |
+
self.tokenizer = None
|
| 391 |
+
|
| 392 |
+
# Priority order for context elements
|
| 393 |
+
priority_elements = [
|
| 394 |
+
('current_query', 1.0),
|
| 395 |
+
('recent_interactions', 0.8),
|
| 396 |
+
('user_preferences', 0.6),
|
| 397 |
+
('session_summary', 0.4),
|
| 398 |
+
('historical_context', 0.2)
|
| 399 |
+
]
|
| 400 |
+
|
| 401 |
+
formatted_context = []
|
| 402 |
+
total_tokens = 0
|
| 403 |
+
|
| 404 |
+
for element, priority in priority_elements:
|
| 405 |
+
# Map element names to context keys
|
| 406 |
+
element_key_map = {
|
| 407 |
+
'current_query': raw_context.get('user_input', ''),
|
| 408 |
+
'recent_interactions': raw_context.get('interaction_contexts', []),
|
| 409 |
+
'user_preferences': raw_context.get('preferences', {}),
|
| 410 |
+
'session_summary': raw_context.get('session_context', {}),
|
| 411 |
+
'historical_context': raw_context.get('user_context', '')
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
content = element_key_map.get(element, '')
|
| 415 |
+
|
| 416 |
+
# Convert to string if needed
|
| 417 |
+
if isinstance(content, dict):
|
| 418 |
+
content = str(content)
|
| 419 |
+
elif isinstance(content, list):
|
| 420 |
+
content = "\n".join([str(item) for item in content[:10]]) # Limit to 10 items
|
| 421 |
+
|
| 422 |
+
if not content:
|
| 423 |
+
continue
|
| 424 |
+
|
| 425 |
+
# Estimate tokens
|
| 426 |
+
if self.tokenizer:
|
| 427 |
+
try:
|
| 428 |
+
tokens = len(self.tokenizer.encode(content))
|
| 429 |
+
except:
|
| 430 |
+
# Fallback to character-based estimation (rough: 1 token ≈ 4 chars)
|
| 431 |
+
tokens = len(content) // 4
|
| 432 |
+
else:
|
| 433 |
+
# Character-based estimation (rough: 1 token ≈ 4 chars)
|
| 434 |
+
tokens = len(content) // 4
|
| 435 |
+
|
| 436 |
+
if total_tokens + tokens <= max_tokens:
|
| 437 |
+
formatted_context.append(f"=== {element.upper()} ===\n{content}")
|
| 438 |
+
total_tokens += tokens
|
| 439 |
+
elif priority > 0.5: # Critical elements - truncate if needed
|
| 440 |
+
available = max_tokens - total_tokens
|
| 441 |
+
if available > 100: # Only truncate if we have meaningful space
|
| 442 |
+
truncated = self._truncate_to_tokens(content, available)
|
| 443 |
+
formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}")
|
| 444 |
+
break
|
| 445 |
+
|
| 446 |
+
return "\n\n".join(formatted_context)
|
| 447 |
+
|
| 448 |
+
def _truncate_to_tokens(self, content: str, max_tokens: int) -> str:
|
| 449 |
+
"""Truncate content to fit within token limit"""
|
| 450 |
+
if not self.tokenizer:
|
| 451 |
+
# Simple character-based truncation
|
| 452 |
+
max_chars = max_tokens * 4
|
| 453 |
+
if len(content) <= max_chars:
|
| 454 |
+
return content
|
| 455 |
+
return content[:max_chars-3] + "..."
|
| 456 |
+
|
| 457 |
+
try:
|
| 458 |
+
# Tokenize and truncate
|
| 459 |
+
tokens = self.tokenizer.encode(content)
|
| 460 |
+
if len(tokens) <= max_tokens:
|
| 461 |
+
return content
|
| 462 |
+
|
| 463 |
+
truncated_tokens = tokens[:max_tokens-3] # Leave room for "..."
|
| 464 |
+
truncated_text = self.tokenizer.decode(truncated_tokens)
|
| 465 |
+
return truncated_text + "..."
|
| 466 |
+
except Exception as e:
|
| 467 |
+
logger.warning(f"Error truncating with tokenizer: {e}, using character truncation")
|
| 468 |
+
max_chars = max_tokens * 4
|
| 469 |
+
if len(content) <= max_chars:
|
| 470 |
+
return content
|
| 471 |
+
return content[:max_chars-3] + "..."
|
src/local_model_loader.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# local_model_loader.py
|
| 2 |
+
# Local GPU-based model loading for NVIDIA T4 Medium (24GB vRAM)
|
| 3 |
+
import logging
|
| 4 |
+
import torch
|
| 5 |
+
from typing import Optional, Dict, Any
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class LocalModelLoader:
|
| 12 |
+
"""
|
| 13 |
+
Loads and manages models locally on GPU for faster inference.
|
| 14 |
+
Optimized for NVIDIA T4 Medium with 24GB vRAM.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, device: Optional[str] = None):
|
| 18 |
+
"""Initialize the model loader with GPU device detection."""
|
| 19 |
+
# Detect device
|
| 20 |
+
if device is None:
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
self.device = "cuda"
|
| 23 |
+
self.device_name = torch.cuda.get_device_name(0)
|
| 24 |
+
logger.info(f"GPU detected: {self.device_name}")
|
| 25 |
+
logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
| 26 |
+
else:
|
| 27 |
+
self.device = "cpu"
|
| 28 |
+
self.device_name = "CPU"
|
| 29 |
+
logger.warning("No GPU detected, using CPU")
|
| 30 |
+
else:
|
| 31 |
+
self.device = device
|
| 32 |
+
self.device_name = device
|
| 33 |
+
|
| 34 |
+
# Model cache
|
| 35 |
+
self.loaded_models: Dict[str, Any] = {}
|
| 36 |
+
self.loaded_tokenizers: Dict[str, Any] = {}
|
| 37 |
+
self.loaded_embedding_models: Dict[str, Any] = {}
|
| 38 |
+
|
| 39 |
+
def load_chat_model(self, model_id: str, load_in_8bit: bool = False, load_in_4bit: bool = False) -> tuple:
|
| 40 |
+
"""
|
| 41 |
+
Load a chat model and tokenizer on GPU.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
model_id: HuggingFace model identifier
|
| 45 |
+
load_in_8bit: Use 8-bit quantization (saves memory)
|
| 46 |
+
load_in_4bit: Use 4-bit quantization (saves more memory)
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Tuple of (model, tokenizer)
|
| 50 |
+
"""
|
| 51 |
+
if model_id in self.loaded_models:
|
| 52 |
+
logger.info(f"Model {model_id} already loaded, reusing")
|
| 53 |
+
return self.loaded_models[model_id], self.loaded_tokenizers[model_id]
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
logger.info(f"Loading model {model_id} on {self.device}...")
|
| 57 |
+
|
| 58 |
+
# Load tokenizer
|
| 59 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 60 |
+
model_id,
|
| 61 |
+
trust_remote_code=True
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Determine quantization config
|
| 65 |
+
if load_in_4bit and self.device == "cuda":
|
| 66 |
+
try:
|
| 67 |
+
from transformers import BitsAndBytesConfig
|
| 68 |
+
quantization_config = BitsAndBytesConfig(
|
| 69 |
+
load_in_4bit=True,
|
| 70 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 71 |
+
bnb_4bit_use_double_quant=True,
|
| 72 |
+
bnb_4bit_quant_type="nf4"
|
| 73 |
+
)
|
| 74 |
+
logger.info("Using 4-bit quantization")
|
| 75 |
+
except ImportError:
|
| 76 |
+
logger.warning("bitsandbytes not available, loading without quantization")
|
| 77 |
+
quantization_config = None
|
| 78 |
+
elif load_in_8bit and self.device == "cuda":
|
| 79 |
+
try:
|
| 80 |
+
quantization_config = {"load_in_8bit": True}
|
| 81 |
+
logger.info("Using 8-bit quantization")
|
| 82 |
+
except:
|
| 83 |
+
quantization_config = None
|
| 84 |
+
else:
|
| 85 |
+
quantization_config = None
|
| 86 |
+
|
| 87 |
+
# Load model with GPU optimization
|
| 88 |
+
if self.device == "cuda":
|
| 89 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 90 |
+
model_id,
|
| 91 |
+
device_map="auto", # Automatically uses GPU
|
| 92 |
+
torch_dtype=torch.float16, # Use FP16 for memory efficiency
|
| 93 |
+
trust_remote_code=True,
|
| 94 |
+
**(quantization_config if isinstance(quantization_config, dict) else {}),
|
| 95 |
+
**({"quantization_config": quantization_config} if quantization_config and not isinstance(quantization_config, dict) else {})
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 99 |
+
model_id,
|
| 100 |
+
torch_dtype=torch.float32,
|
| 101 |
+
trust_remote_code=True
|
| 102 |
+
)
|
| 103 |
+
model = model.to(self.device)
|
| 104 |
+
|
| 105 |
+
# Ensure padding token is set
|
| 106 |
+
if tokenizer.pad_token is None:
|
| 107 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 108 |
+
|
| 109 |
+
# Cache models
|
| 110 |
+
self.loaded_models[model_id] = model
|
| 111 |
+
self.loaded_tokenizers[model_id] = tokenizer
|
| 112 |
+
|
| 113 |
+
# Log memory usage
|
| 114 |
+
if self.device == "cuda":
|
| 115 |
+
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
| 116 |
+
reserved = torch.cuda.memory_reserved(0) / 1024**3
|
| 117 |
+
logger.info(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
|
| 118 |
+
|
| 119 |
+
logger.info(f"✓ Model {model_id} loaded successfully on {self.device}")
|
| 120 |
+
return model, tokenizer
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
logger.error(f"Error loading model {model_id}: {e}", exc_info=True)
|
| 124 |
+
raise
|
| 125 |
+
|
| 126 |
+
def load_embedding_model(self, model_id: str) -> SentenceTransformer:
|
| 127 |
+
"""
|
| 128 |
+
Load a sentence transformer model for embeddings.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
model_id: HuggingFace model identifier
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
SentenceTransformer model
|
| 135 |
+
"""
|
| 136 |
+
if model_id in self.loaded_embedding_models:
|
| 137 |
+
logger.info(f"Embedding model {model_id} already loaded, reusing")
|
| 138 |
+
return self.loaded_embedding_models[model_id]
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
logger.info(f"Loading embedding model {model_id}...")
|
| 142 |
+
|
| 143 |
+
# SentenceTransformer automatically handles GPU
|
| 144 |
+
model = SentenceTransformer(
|
| 145 |
+
model_id,
|
| 146 |
+
device=self.device
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Cache model
|
| 150 |
+
self.loaded_embedding_models[model_id] = model
|
| 151 |
+
|
| 152 |
+
logger.info(f"✓ Embedding model {model_id} loaded successfully on {self.device}")
|
| 153 |
+
return model
|
| 154 |
+
|
| 155 |
+
except Exception as e:
|
| 156 |
+
logger.error(f"Error loading embedding model {model_id}: {e}", exc_info=True)
|
| 157 |
+
raise
|
| 158 |
+
|
| 159 |
+
def generate_text(
|
| 160 |
+
self,
|
| 161 |
+
model_id: str,
|
| 162 |
+
prompt: str,
|
| 163 |
+
max_tokens: int = 512,
|
| 164 |
+
temperature: float = 0.7,
|
| 165 |
+
**kwargs
|
| 166 |
+
) -> str:
|
| 167 |
+
"""
|
| 168 |
+
Generate text using a loaded chat model.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
model_id: Model identifier
|
| 172 |
+
prompt: Input prompt
|
| 173 |
+
max_tokens: Maximum tokens to generate
|
| 174 |
+
temperature: Sampling temperature
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Generated text
|
| 178 |
+
"""
|
| 179 |
+
if model_id not in self.loaded_models:
|
| 180 |
+
raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.")
|
| 181 |
+
|
| 182 |
+
model = self.loaded_models[model_id]
|
| 183 |
+
tokenizer = self.loaded_tokenizers[model_id]
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
# Tokenize input
|
| 187 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 188 |
+
|
| 189 |
+
# Generate
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
outputs = model.generate(
|
| 192 |
+
**inputs,
|
| 193 |
+
max_new_tokens=max_tokens,
|
| 194 |
+
temperature=temperature,
|
| 195 |
+
do_sample=True,
|
| 196 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 197 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 198 |
+
**kwargs
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Decode
|
| 202 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 203 |
+
|
| 204 |
+
# Remove prompt from output if present
|
| 205 |
+
if generated_text.startswith(prompt):
|
| 206 |
+
generated_text = generated_text[len(prompt):].strip()
|
| 207 |
+
|
| 208 |
+
return generated_text
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logger.error(f"Error generating text: {e}", exc_info=True)
|
| 212 |
+
raise
|
| 213 |
+
|
| 214 |
+
def generate_chat_completion(
|
| 215 |
+
self,
|
| 216 |
+
model_id: str,
|
| 217 |
+
messages: list,
|
| 218 |
+
max_tokens: int = 512,
|
| 219 |
+
temperature: float = 0.7,
|
| 220 |
+
**kwargs
|
| 221 |
+
) -> str:
|
| 222 |
+
"""
|
| 223 |
+
Generate chat completion using a loaded model.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
model_id: Model identifier
|
| 227 |
+
messages: List of message dicts with 'role' and 'content'
|
| 228 |
+
max_tokens: Maximum tokens to generate
|
| 229 |
+
temperature: Sampling temperature
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Generated response
|
| 233 |
+
"""
|
| 234 |
+
if model_id not in self.loaded_models:
|
| 235 |
+
raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.")
|
| 236 |
+
|
| 237 |
+
model = self.loaded_models[model_id]
|
| 238 |
+
tokenizer = self.loaded_tokenizers[model_id]
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
# Format messages as prompt
|
| 242 |
+
if hasattr(tokenizer, 'apply_chat_template'):
|
| 243 |
+
# Use chat template if available
|
| 244 |
+
prompt = tokenizer.apply_chat_template(
|
| 245 |
+
messages,
|
| 246 |
+
tokenize=False,
|
| 247 |
+
add_generation_prompt=True
|
| 248 |
+
)
|
| 249 |
+
else:
|
| 250 |
+
# Fallback: simple formatting
|
| 251 |
+
prompt = "\n".join([
|
| 252 |
+
f"{msg['role']}: {msg['content']}"
|
| 253 |
+
for msg in messages
|
| 254 |
+
]) + "\nassistant: "
|
| 255 |
+
|
| 256 |
+
# Generate
|
| 257 |
+
return self.generate_text(
|
| 258 |
+
model_id=model_id,
|
| 259 |
+
prompt=prompt,
|
| 260 |
+
max_tokens=max_tokens,
|
| 261 |
+
temperature=temperature,
|
| 262 |
+
**kwargs
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
except Exception as e:
|
| 266 |
+
logger.error(f"Error generating chat completion: {e}", exc_info=True)
|
| 267 |
+
raise
|
| 268 |
+
|
| 269 |
+
def get_embedding(self, model_id: str, text: str) -> list:
|
| 270 |
+
"""
|
| 271 |
+
Get embedding vector for text.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
model_id: Embedding model identifier
|
| 275 |
+
text: Input text
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Embedding vector
|
| 279 |
+
"""
|
| 280 |
+
if model_id not in self.loaded_embedding_models:
|
| 281 |
+
raise ValueError(f"Embedding model {model_id} not loaded. Call load_embedding_model() first.")
|
| 282 |
+
|
| 283 |
+
model = self.loaded_embedding_models[model_id]
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
embedding = model.encode(text, convert_to_numpy=True)
|
| 287 |
+
return embedding.tolist()
|
| 288 |
+
except Exception as e:
|
| 289 |
+
logger.error(f"Error getting embedding: {e}", exc_info=True)
|
| 290 |
+
raise
|
| 291 |
+
|
| 292 |
+
def clear_cache(self):
|
| 293 |
+
"""Clear all loaded models from memory."""
|
| 294 |
+
logger.info("Clearing model cache...")
|
| 295 |
+
|
| 296 |
+
# Clear models
|
| 297 |
+
for model_id in list(self.loaded_models.keys()):
|
| 298 |
+
del self.loaded_models[model_id]
|
| 299 |
+
for model_id in list(self.loaded_tokenizers.keys()):
|
| 300 |
+
del self.loaded_tokenizers[model_id]
|
| 301 |
+
for model_id in list(self.loaded_embedding_models.keys()):
|
| 302 |
+
del self.loaded_embedding_models[model_id]
|
| 303 |
+
|
| 304 |
+
# Clear GPU cache
|
| 305 |
+
if self.device == "cuda":
|
| 306 |
+
torch.cuda.empty_cache()
|
| 307 |
+
|
| 308 |
+
logger.info("✓ Model cache cleared")
|
| 309 |
+
|
| 310 |
+
def get_memory_usage(self) -> Dict[str, float]:
|
| 311 |
+
"""Get current GPU memory usage in GB."""
|
| 312 |
+
if self.device != "cuda":
|
| 313 |
+
return {"device": "cpu", "gpu_available": False}
|
| 314 |
+
|
| 315 |
+
return {
|
| 316 |
+
"device": self.device_name,
|
| 317 |
+
"gpu_available": True,
|
| 318 |
+
"allocated_gb": torch.cuda.memory_allocated(0) / 1024**3,
|
| 319 |
+
"reserved_gb": torch.cuda.memory_reserved(0) / 1024**3,
|
| 320 |
+
"total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 321 |
+
}
|
| 322 |
+
|
src/mobile_handlers.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mobile_handlers.py
|
| 2 |
+
import gradio as gr
|
| 3 |
+
|
| 4 |
+
class MobileUXHandlers:
|
| 5 |
+
def __init__(self, orchestrator):
|
| 6 |
+
self.orchestrator = orchestrator
|
| 7 |
+
self.mobile_state = {}
|
| 8 |
+
|
| 9 |
+
async def handle_mobile_submit(self, message, chat_history, session_id,
|
| 10 |
+
show_reasoning, show_agent_trace, request: gr.Request):
|
| 11 |
+
"""
|
| 12 |
+
Mobile-optimized submission handler with enhanced UX
|
| 13 |
+
"""
|
| 14 |
+
# Get mobile device info
|
| 15 |
+
user_agent = request.headers.get("user-agent", "").lower()
|
| 16 |
+
is_mobile = any(device in user_agent for device in ['mobile', 'android', 'iphone'])
|
| 17 |
+
|
| 18 |
+
# Mobile-specific optimizations
|
| 19 |
+
if is_mobile:
|
| 20 |
+
return await self._mobile_optimized_processing(
|
| 21 |
+
message, chat_history, session_id, show_reasoning, show_agent_trace
|
| 22 |
+
)
|
| 23 |
+
else:
|
| 24 |
+
return await self._desktop_processing(
|
| 25 |
+
message, chat_history, session_id, show_reasoning, show_agent_trace
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
async def _mobile_optimized_processing(self, message, chat_history, session_id,
|
| 29 |
+
show_reasoning, show_agent_trace):
|
| 30 |
+
"""
|
| 31 |
+
Mobile-specific processing with enhanced UX feedback
|
| 32 |
+
"""
|
| 33 |
+
try:
|
| 34 |
+
# Immediate feedback for mobile users
|
| 35 |
+
yield {
|
| 36 |
+
"chatbot": chat_history + [[message, "Thinking..."]],
|
| 37 |
+
"message_input": "",
|
| 38 |
+
"reasoning_display": {"status": "processing"},
|
| 39 |
+
"performance_display": {"status": "processing"}
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# Process with mobile-optimized parameters
|
| 43 |
+
result = await self.orchestrator.process_request(
|
| 44 |
+
session_id=session_id,
|
| 45 |
+
user_input=message,
|
| 46 |
+
mobile_optimized=True, # Special flag for mobile
|
| 47 |
+
max_tokens=800 # Shorter responses for mobile
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Format for mobile display
|
| 51 |
+
formatted_response = self._format_for_mobile(
|
| 52 |
+
result['final_response'],
|
| 53 |
+
show_reasoning and result.get('metadata', {}).get('reasoning_chain'),
|
| 54 |
+
show_agent_trace and result.get('agent_trace')
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Update chat history
|
| 58 |
+
updated_history = chat_history + [[message, formatted_response]]
|
| 59 |
+
|
| 60 |
+
yield {
|
| 61 |
+
"chatbot": updated_history,
|
| 62 |
+
"message_input": "",
|
| 63 |
+
"reasoning_display": result.get('metadata', {}).get('reasoning_chain', {}),
|
| 64 |
+
"performance_display": result.get('performance_metrics', {})
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
except Exception as e:
|
| 68 |
+
# Mobile-friendly error handling
|
| 69 |
+
error_response = self._get_mobile_friendly_error(e)
|
| 70 |
+
yield {
|
| 71 |
+
"chatbot": chat_history + [[message, error_response]],
|
| 72 |
+
"message_input": message, # Keep message for retry
|
| 73 |
+
"reasoning_display": {"error": "Processing failed"},
|
| 74 |
+
"performance_display": {"error": str(e)}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
def _format_for_mobile(self, response, reasoning_chain, agent_trace):
|
| 78 |
+
"""
|
| 79 |
+
Format response for optimal mobile readability
|
| 80 |
+
"""
|
| 81 |
+
# Split long responses for mobile
|
| 82 |
+
if len(response) > 400:
|
| 83 |
+
paragraphs = self._split_into_paragraphs(response, max_length=300)
|
| 84 |
+
response = "\n\n".join(paragraphs)
|
| 85 |
+
|
| 86 |
+
# Add mobile-optimized formatting
|
| 87 |
+
formatted = f"""
|
| 88 |
+
<div class="mobile-response">
|
| 89 |
+
{response}
|
| 90 |
+
</div>
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
# Add reasoning if requested
|
| 94 |
+
if reasoning_chain:
|
| 95 |
+
# Handle both old and new reasoning chain formats
|
| 96 |
+
if isinstance(reasoning_chain, dict):
|
| 97 |
+
# New enhanced format - extract key information
|
| 98 |
+
chain_of_thought = reasoning_chain.get('chain_of_thought', {})
|
| 99 |
+
if chain_of_thought:
|
| 100 |
+
first_step = list(chain_of_thought.values())[0] if chain_of_thought else {}
|
| 101 |
+
hypothesis = first_step.get('hypothesis', 'Processing...')
|
| 102 |
+
reasoning_text = f"Hypothesis: {hypothesis}"
|
| 103 |
+
else:
|
| 104 |
+
reasoning_text = "Enhanced reasoning chain available"
|
| 105 |
+
else:
|
| 106 |
+
# Old format - direct string
|
| 107 |
+
reasoning_text = str(reasoning_chain)[:200]
|
| 108 |
+
|
| 109 |
+
formatted += f"""
|
| 110 |
+
<div class="reasoning-mobile" style="margin-top: 15px; padding: 10px; background: #f5f5f5; border-radius: 8px; font-size: 14px;">
|
| 111 |
+
<strong>Reasoning:</strong> {reasoning_text}...
|
| 112 |
+
</div>
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
return formatted
|
| 116 |
+
|
| 117 |
+
def _get_mobile_friendly_error(self, error):
|
| 118 |
+
"""
|
| 119 |
+
User-friendly error messages for mobile
|
| 120 |
+
"""
|
| 121 |
+
error_messages = {
|
| 122 |
+
"timeout": "⏱️ Taking longer than expected. Please try a simpler question.",
|
| 123 |
+
"network": "📡 Connection issue. Check your internet and try again.",
|
| 124 |
+
"rate_limit": "🚦 Too many requests. Please wait a moment.",
|
| 125 |
+
"default": "❌ Something went wrong. Please try again."
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
error_type = "default"
|
| 129 |
+
if "timeout" in str(error).lower():
|
| 130 |
+
error_type = "timeout"
|
| 131 |
+
elif "network" in str(error).lower() or "connection" in str(error).lower():
|
| 132 |
+
error_type = "network"
|
| 133 |
+
elif "rate" in str(error).lower():
|
| 134 |
+
error_type = "rate_limit"
|
| 135 |
+
|
| 136 |
+
return error_messages[error_type]
|
| 137 |
+
|
| 138 |
+
async def _desktop_processing(self, message, chat_history, session_id,
|
| 139 |
+
show_reasoning, show_agent_trace):
|
| 140 |
+
"""
|
| 141 |
+
Desktop processing without mobile optimizations
|
| 142 |
+
"""
|
| 143 |
+
# TODO: Implement desktop-specific processing
|
| 144 |
+
return {
|
| 145 |
+
"chatbot": chat_history,
|
| 146 |
+
"message_input": "",
|
| 147 |
+
"reasoning_display": {},
|
| 148 |
+
"performance_display": {}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
def _split_into_paragraphs(self, text, max_length=300):
|
| 152 |
+
"""
|
| 153 |
+
Split text into mobile-friendly paragraphs
|
| 154 |
+
"""
|
| 155 |
+
# TODO: Implement intelligent paragraph splitting
|
| 156 |
+
words = text.split()
|
| 157 |
+
paragraphs = []
|
| 158 |
+
current_para = []
|
| 159 |
+
|
| 160 |
+
for word in words:
|
| 161 |
+
current_para.append(word)
|
| 162 |
+
if len(' '.join(current_para)) > max_length:
|
| 163 |
+
paragraphs.append(' '.join(current_para[:-1]))
|
| 164 |
+
current_para = [current_para[-1]]
|
| 165 |
+
|
| 166 |
+
if current_para:
|
| 167 |
+
paragraphs.append(' '.join(current_para))
|
| 168 |
+
|
| 169 |
+
return paragraphs
|
src/models_config.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models_config.py
|
| 2 |
+
LLM_CONFIG = {
|
| 3 |
+
"primary_provider": "huggingface",
|
| 4 |
+
"models": {
|
| 5 |
+
"reasoning_primary": {
|
| 6 |
+
"model_id": "Qwen/Qwen2.5-7B-Instruct", # High-quality instruct model
|
| 7 |
+
"task": "general_reasoning",
|
| 8 |
+
"max_tokens": 10000,
|
| 9 |
+
"temperature": 0.7,
|
| 10 |
+
"cost_per_token": 0.000015,
|
| 11 |
+
"fallback": "gpt2", # Simple but guaranteed working model
|
| 12 |
+
"is_chat_model": True
|
| 13 |
+
},
|
| 14 |
+
"embedding_specialist": {
|
| 15 |
+
"model_id": "sentence-transformers/all-MiniLM-L6-v2",
|
| 16 |
+
"task": "embeddings",
|
| 17 |
+
"vector_dimensions": 384,
|
| 18 |
+
"purpose": "semantic_similarity",
|
| 19 |
+
"cost_advantage": "90%_cheaper_than_primary",
|
| 20 |
+
"is_chat_model": False
|
| 21 |
+
},
|
| 22 |
+
"classification_specialist": {
|
| 23 |
+
"model_id": "Qwen/Qwen2.5-7B-Instruct", # Use chat model for classification
|
| 24 |
+
"task": "intent_classification",
|
| 25 |
+
"max_length": 512,
|
| 26 |
+
"specialization": "fast_inference",
|
| 27 |
+
"latency_target": "<100ms",
|
| 28 |
+
"is_chat_model": True
|
| 29 |
+
},
|
| 30 |
+
"safety_checker": {
|
| 31 |
+
"model_id": "Qwen/Qwen2.5-7B-Instruct", # Use chat model for safety
|
| 32 |
+
"task": "content_moderation",
|
| 33 |
+
"confidence_threshold": 0.85,
|
| 34 |
+
"purpose": "bias_detection",
|
| 35 |
+
"is_chat_model": True
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
"routing_logic": {
|
| 39 |
+
"strategy": "task_based_routing",
|
| 40 |
+
"fallback_chain": ["primary", "fallback", "degraded_mode"],
|
| 41 |
+
"load_balancing": "round_robin_with_health_check"
|
| 42 |
+
}
|
| 43 |
+
}
|
src/orchestrator_engine.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|