Commit
·
7632802
1
Parent(s):
83fb1b5
api migration v2
Browse files- DEPLOYMENT_NOTES.md +40 -26
- Dockerfile.flask +39 -0
- FLASK_API_DEPLOYMENT_FILES.md +194 -0
- README.md +3 -3
- README_FLASK_API.md +92 -0
- config.py +1 -1
- flask_api_standalone.py +257 -0
- requirements.txt +6 -2
- src/config.py +1 -1
- src/llm_router.py +118 -4
- src/local_model_loader.py +322 -0
DEPLOYMENT_NOTES.md
CHANGED
|
@@ -2,22 +2,29 @@
|
|
| 2 |
|
| 3 |
## Hugging Face Spaces Deployment
|
| 4 |
|
| 5 |
-
###
|
| 6 |
-
This MVP is optimized for **
|
| 7 |
-
|
| 8 |
-
####
|
| 9 |
-
- **GPU**:
|
| 10 |
-
- **
|
| 11 |
-
- **
|
|
|
|
|
|
|
| 12 |
- **Network**: Shared infrastructure
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
### Environment Variables
|
| 15 |
Required environment variables for deployment:
|
| 16 |
|
| 17 |
```bash
|
| 18 |
HF_TOKEN=your_huggingface_token_here
|
| 19 |
HF_HOME=/tmp/huggingface
|
| 20 |
-
MAX_WORKERS=
|
| 21 |
CACHE_TTL=3600
|
| 22 |
DB_PATH=sessions.db
|
| 23 |
FAISS_INDEX_PATH=embeddings.faiss
|
|
@@ -39,9 +46,8 @@ title: AI Research Assistant MVP
|
|
| 39 |
emoji: 🧠
|
| 40 |
colorFrom: blue
|
| 41 |
colorTo: purple
|
| 42 |
-
sdk:
|
| 43 |
-
|
| 44 |
-
app_file: app.py
|
| 45 |
pinned: false
|
| 46 |
license: apache-2.0
|
| 47 |
---
|
|
@@ -77,7 +83,7 @@ license: apache-2.0
|
|
| 77 |
5. **Deploy to HF Spaces**
|
| 78 |
- Push to GitHub
|
| 79 |
- Connect to HF Spaces
|
| 80 |
-
- Select
|
| 81 |
- Deploy
|
| 82 |
|
| 83 |
### Resource Management
|
|
@@ -85,26 +91,34 @@ license: apache-2.0
|
|
| 85 |
#### Memory Limits
|
| 86 |
- **Base Python**: ~100MB
|
| 87 |
- **Gradio**: ~50MB
|
| 88 |
-
- **Models (loaded)**: ~
|
| 89 |
-
-
|
| 90 |
-
-
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
**
|
|
|
|
| 93 |
|
| 94 |
#### Strategies
|
| 95 |
-
-
|
| 96 |
-
-
|
| 97 |
-
-
|
| 98 |
-
-
|
|
|
|
| 99 |
|
| 100 |
### Performance Optimization
|
| 101 |
|
| 102 |
-
#### For
|
| 103 |
-
1.
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
#### Mobile Optimizations
|
| 110 |
- Reduce max tokens to 800
|
|
|
|
| 2 |
|
| 3 |
## Hugging Face Spaces Deployment
|
| 4 |
|
| 5 |
+
### NVIDIA T4 Medium Configuration
|
| 6 |
+
This MVP is optimized for **NVIDIA T4 Medium** GPU deployment on Hugging Face Spaces.
|
| 7 |
+
|
| 8 |
+
#### Hardware Specifications
|
| 9 |
+
- **GPU**: NVIDIA T4 (persistent, always available)
|
| 10 |
+
- **vCPU**: 8 cores
|
| 11 |
+
- **RAM**: 30GB
|
| 12 |
+
- **vRAM**: 24GB
|
| 13 |
+
- **Storage**: ~20GB
|
| 14 |
- **Network**: Shared infrastructure
|
| 15 |
|
| 16 |
+
#### Resource Capacity
|
| 17 |
+
- **GPU Memory**: 24GB vRAM (sufficient for local model loading)
|
| 18 |
+
- **System Memory**: 30GB RAM (excellent for caching and processing)
|
| 19 |
+
- **CPU**: 8 vCPU (good for parallel operations)
|
| 20 |
+
|
| 21 |
### Environment Variables
|
| 22 |
Required environment variables for deployment:
|
| 23 |
|
| 24 |
```bash
|
| 25 |
HF_TOKEN=your_huggingface_token_here
|
| 26 |
HF_HOME=/tmp/huggingface
|
| 27 |
+
MAX_WORKERS=4
|
| 28 |
CACHE_TTL=3600
|
| 29 |
DB_PATH=sessions.db
|
| 30 |
FAISS_INDEX_PATH=embeddings.faiss
|
|
|
|
| 46 |
emoji: 🧠
|
| 47 |
colorFrom: blue
|
| 48 |
colorTo: purple
|
| 49 |
+
sdk: docker
|
| 50 |
+
app_port: 7860
|
|
|
|
| 51 |
pinned: false
|
| 52 |
license: apache-2.0
|
| 53 |
---
|
|
|
|
| 83 |
5. **Deploy to HF Spaces**
|
| 84 |
- Push to GitHub
|
| 85 |
- Connect to HF Spaces
|
| 86 |
+
- Select NVIDIA T4 Medium GPU hardware
|
| 87 |
- Deploy
|
| 88 |
|
| 89 |
### Resource Management
|
|
|
|
| 91 |
#### Memory Limits
|
| 92 |
- **Base Python**: ~100MB
|
| 93 |
- **Gradio**: ~50MB
|
| 94 |
+
- **Models (loaded on GPU)**: ~14-16GB vRAM
|
| 95 |
+
- Primary model (Qwen/Qwen2.5-7B): ~14GB
|
| 96 |
+
- Embedding model: ~500MB
|
| 97 |
+
- Classification models: ~500MB each
|
| 98 |
+
- **System RAM**: ~2-4GB for caching and processing
|
| 99 |
+
- **Cache**: ~500MB-1GB max
|
| 100 |
|
| 101 |
+
**GPU Memory Budget**: ~24GB vRAM (models fit comfortably)
|
| 102 |
+
**System RAM Budget**: 30GB (plenty of headroom)
|
| 103 |
|
| 104 |
#### Strategies
|
| 105 |
+
- **Local GPU Model Loading**: Models loaded on GPU for faster inference
|
| 106 |
+
- **Lazy Loading**: Models loaded on-demand to speed up startup
|
| 107 |
+
- **GPU Memory Management**: Automatic device placement with FP16 precision
|
| 108 |
+
- **Caching**: Aggressive caching with 30GB RAM available
|
| 109 |
+
- **Stream responses**: To reduce memory during generation
|
| 110 |
|
| 111 |
### Performance Optimization
|
| 112 |
|
| 113 |
+
#### For NVIDIA T4 GPU
|
| 114 |
+
1. **Local Model Loading**: Models run locally on GPU (faster than API)
|
| 115 |
+
- Primary model: Qwen/Qwen2.5-7B-Instruct (~14GB vRAM)
|
| 116 |
+
- Embedding model: sentence-transformers/all-MiniLM-L6-v2 (~500MB)
|
| 117 |
+
2. **GPU Acceleration**: All inference runs on GPU
|
| 118 |
+
3. **Parallel Processing**: 4 workers (MAX_WORKERS=4) for concurrent requests
|
| 119 |
+
4. **Fallback to API**: Automatically falls back to HF Inference API if local models fail
|
| 120 |
+
5. **Request Queuing**: Built-in async request handling
|
| 121 |
+
6. **Response Streaming**: Implemented for efficient memory usage
|
| 122 |
|
| 123 |
#### Mobile Optimizations
|
| 124 |
- Reduce max tokens to 800
|
Dockerfile.flask
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
# Set working directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install system dependencies
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
gcc \
|
| 9 |
+
g++ \
|
| 10 |
+
cmake \
|
| 11 |
+
libopenblas-dev \
|
| 12 |
+
libomp-dev \
|
| 13 |
+
curl \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
# Copy requirements file
|
| 17 |
+
COPY requirements.txt .
|
| 18 |
+
|
| 19 |
+
# Install Python dependencies
|
| 20 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 21 |
+
|
| 22 |
+
# Copy application code
|
| 23 |
+
COPY . .
|
| 24 |
+
|
| 25 |
+
# Expose port 7860 (HF Spaces standard)
|
| 26 |
+
EXPOSE 7860
|
| 27 |
+
|
| 28 |
+
# Set environment variables
|
| 29 |
+
ENV PYTHONUNBUFFERED=1
|
| 30 |
+
ENV PORT=7860
|
| 31 |
+
|
| 32 |
+
# Health check
|
| 33 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=120s --retries=3 \
|
| 34 |
+
CMD curl -f http://localhost:7860/api/health || exit 1
|
| 35 |
+
|
| 36 |
+
# Run Flask application
|
| 37 |
+
# Note: For Flask-only deployment, use this Dockerfile with README_FLASK_API.md
|
| 38 |
+
CMD ["python", "flask_api_standalone.py"]
|
| 39 |
+
|
FLASK_API_DEPLOYMENT_FILES.md
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Flask API Only - Required Files List
|
| 2 |
+
|
| 3 |
+
This document lists all files needed for a **Flask API-only deployment** (no Gradio UI).
|
| 4 |
+
|
| 5 |
+
## 📋 Essential Files (Required)
|
| 6 |
+
|
| 7 |
+
### Core Application Files
|
| 8 |
+
```
|
| 9 |
+
Research_AI_Assistant/
|
| 10 |
+
├── flask_api_standalone.py # Main Flask application (REQUIRED)
|
| 11 |
+
├── Dockerfile.flask # Dockerfile for Flask deployment (rename to Dockerfile)
|
| 12 |
+
├── README_FLASK_API.md # README with HF Spaces frontmatter (rename to README.md)
|
| 13 |
+
└── requirements.txt # Python dependencies (REQUIRED)
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
### Source Code Directory (`src/`)
|
| 17 |
+
```
|
| 18 |
+
Research_AI_Assistant/src/
|
| 19 |
+
├── __init__.py # Package initialization
|
| 20 |
+
├── config.py # Configuration settings
|
| 21 |
+
├── llm_router.py # LLM routing (local GPU models)
|
| 22 |
+
├── local_model_loader.py # GPU model loader (NEW - for local inference)
|
| 23 |
+
├── orchestrator_engine.py # Main orchestrator
|
| 24 |
+
├── context_manager.py # Context management
|
| 25 |
+
├── models_config.py # Model configurations
|
| 26 |
+
├── agents/
|
| 27 |
+
│ ├── __init__.py
|
| 28 |
+
│ ├── intent_agent.py # Intent recognition agent
|
| 29 |
+
│ ├── synthesis_agent.py # Response synthesis agent
|
| 30 |
+
│ ├── safety_agent.py # Safety checking agent
|
| 31 |
+
│ └── skills_identification_agent.py # Skills identification agent
|
| 32 |
+
└── database.py # Database management (if used)
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### Configuration Files (Optional but Recommended)
|
| 36 |
+
```
|
| 37 |
+
Research_AI_Assistant/
|
| 38 |
+
├── .env # Environment variables (optional, use HF Secrets instead)
|
| 39 |
+
└── .gitignore # Git ignore rules
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## 📦 File Descriptions
|
| 43 |
+
|
| 44 |
+
### 1. `flask_api_standalone.py` ⭐ REQUIRED
|
| 45 |
+
- **Purpose**: Main Flask application entry point
|
| 46 |
+
- **Contains**: API endpoints, orchestrator initialization, request handling
|
| 47 |
+
- **Key Features**:
|
| 48 |
+
- Local GPU model loading
|
| 49 |
+
- Async orchestrator support
|
| 50 |
+
- Health checks
|
| 51 |
+
- Error handling
|
| 52 |
+
|
| 53 |
+
### 2. `Dockerfile.flask` → `Dockerfile` ⭐ REQUIRED
|
| 54 |
+
- **Purpose**: Container configuration
|
| 55 |
+
- **Action**: Rename to `Dockerfile` when deploying
|
| 56 |
+
- **Includes**: Python 3.10, system dependencies, health checks
|
| 57 |
+
|
| 58 |
+
### 3. `README_FLASK_API.md` → `README.md` ⭐ REQUIRED
|
| 59 |
+
- **Purpose**: HF Spaces configuration and API documentation
|
| 60 |
+
- **Action**: Rename to `README.md` when deploying
|
| 61 |
+
- **Contains**: Frontmatter with `sdk: docker`, API endpoints, usage examples
|
| 62 |
+
|
| 63 |
+
### 4. `requirements.txt` ⭐ REQUIRED
|
| 64 |
+
- **Purpose**: Python package dependencies
|
| 65 |
+
- **Includes**: Flask, transformers, torch (GPU), sentence-transformers, etc.
|
| 66 |
+
|
| 67 |
+
### 5. `src/local_model_loader.py` ⭐ REQUIRED (NEW)
|
| 68 |
+
- **Purpose**: Loads models locally on GPU
|
| 69 |
+
- **Features**: GPU detection, model caching, FP16 optimization
|
| 70 |
+
|
| 71 |
+
### 6. `src/llm_router.py` ⭐ REQUIRED (UPDATED)
|
| 72 |
+
- **Purpose**: Routes inference requests
|
| 73 |
+
- **Features**: Tries local models first, falls back to HF API
|
| 74 |
+
|
| 75 |
+
### 7. `src/orchestrator_engine.py` ⭐ REQUIRED
|
| 76 |
+
- **Purpose**: Main AI orchestration engine
|
| 77 |
+
- **Contains**: Agent coordination, request processing
|
| 78 |
+
|
| 79 |
+
### 8. `src/context_manager.py` ⭐ REQUIRED
|
| 80 |
+
- **Purpose**: Manages conversation context
|
| 81 |
+
- **Features**: Session management, context retrieval
|
| 82 |
+
|
| 83 |
+
### 9. `src/agents/*.py` ⭐ REQUIRED
|
| 84 |
+
- **Purpose**: Individual AI agents
|
| 85 |
+
- **Agents**: Intent, Synthesis, Safety, Skills Identification
|
| 86 |
+
|
| 87 |
+
### 10. `src/config.py` ⭐ REQUIRED
|
| 88 |
+
- **Purpose**: Application configuration
|
| 89 |
+
- **Settings**: MAX_WORKERS=4, model paths, etc.
|
| 90 |
+
|
| 91 |
+
## ❌ Files NOT Needed (Gradio/UI Related)
|
| 92 |
+
|
| 93 |
+
These files can be **excluded** from Flask API deployment:
|
| 94 |
+
|
| 95 |
+
```
|
| 96 |
+
Research_AI_Assistant/
|
| 97 |
+
├── app.py # Gradio UI (NOT NEEDED)
|
| 98 |
+
├── main.py # Gradio + Flask launcher (NOT NEEDED)
|
| 99 |
+
├── flask_api.py # Flask API (use standalone instead)
|
| 100 |
+
├── Dockerfile # Main Dockerfile (use Dockerfile.flask)
|
| 101 |
+
├── Dockerfile.hf # Alternative Dockerfile (NOT NEEDED)
|
| 102 |
+
├── README.md # Main README (use README_FLASK_API.md)
|
| 103 |
+
└── All .md files except this one # Documentation (optional)
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## 🚀 Quick Deployment Checklist
|
| 107 |
+
|
| 108 |
+
### Step 1: Prepare Files
|
| 109 |
+
```bash
|
| 110 |
+
# In your Flask API Space directory:
|
| 111 |
+
cp Dockerfile.flask Dockerfile
|
| 112 |
+
cp README_FLASK_API.md README.md
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
### Step 2: Verify Structure
|
| 116 |
+
```
|
| 117 |
+
Your Space/
|
| 118 |
+
├── Dockerfile # ✅ Renamed from Dockerfile.flask
|
| 119 |
+
├── README.md # ✅ Renamed from README_FLASK_API.md
|
| 120 |
+
├── flask_api_standalone.py # ✅ Main Flask app
|
| 121 |
+
├── requirements.txt # ✅ Dependencies
|
| 122 |
+
└── src/ # ✅ All source files
|
| 123 |
+
├── __init__.py
|
| 124 |
+
├── config.py
|
| 125 |
+
├── llm_router.py
|
| 126 |
+
���── local_model_loader.py
|
| 127 |
+
├── orchestrator_engine.py
|
| 128 |
+
├── context_manager.py
|
| 129 |
+
├── models_config.py
|
| 130 |
+
└── agents/
|
| 131 |
+
├── __init__.py
|
| 132 |
+
├── intent_agent.py
|
| 133 |
+
├── synthesis_agent.py
|
| 134 |
+
├── safety_agent.py
|
| 135 |
+
└── skills_identification_agent.py
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
### Step 3: Set Environment Variables
|
| 139 |
+
In HF Spaces Settings → Secrets:
|
| 140 |
+
- `HF_TOKEN` - Your Hugging Face token
|
| 141 |
+
|
| 142 |
+
### Step 4: Deploy
|
| 143 |
+
- Select **NVIDIA T4 Medium** GPU
|
| 144 |
+
- Set **SDK: docker**
|
| 145 |
+
- Deploy
|
| 146 |
+
|
| 147 |
+
## 📊 File Size Considerations
|
| 148 |
+
|
| 149 |
+
### Minimal Deployment (Essential Only)
|
| 150 |
+
- Core files: ~50 KB
|
| 151 |
+
- Source code: ~500 KB
|
| 152 |
+
- **Total**: ~550 KB code
|
| 153 |
+
|
| 154 |
+
### With Models (First Load)
|
| 155 |
+
- Code: ~550 KB
|
| 156 |
+
- Models (downloaded on first run): ~14-16 GB
|
| 157 |
+
- **Total**: ~14-16 GB (first build)
|
| 158 |
+
|
| 159 |
+
### Subsequent Builds
|
| 160 |
+
- Models cached by HF Spaces
|
| 161 |
+
- Code only: ~550 KB
|
| 162 |
+
|
| 163 |
+
## 🔍 Verification
|
| 164 |
+
|
| 165 |
+
After deployment, verify these files exist:
|
| 166 |
+
|
| 167 |
+
```bash
|
| 168 |
+
# Check main files
|
| 169 |
+
ls -la Dockerfile README.md flask_api_standalone.py requirements.txt
|
| 170 |
+
|
| 171 |
+
# Check source directory
|
| 172 |
+
ls -la src/
|
| 173 |
+
ls -la src/agents/
|
| 174 |
+
|
| 175 |
+
# Verify key components
|
| 176 |
+
grep -r "local_model_loader" src/llm_router.py
|
| 177 |
+
grep -r "MAX_WORKERS" src/config.py
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
## 📝 Summary
|
| 181 |
+
|
| 182 |
+
**Minimum Required Files:**
|
| 183 |
+
1. `flask_api_standalone.py`
|
| 184 |
+
2. `Dockerfile` (from Dockerfile.flask)
|
| 185 |
+
3. `README.md` (from README_FLASK_API.md)
|
| 186 |
+
4. `requirements.txt`
|
| 187 |
+
5. All files in `src/` directory
|
| 188 |
+
|
| 189 |
+
**Total: ~15-20 files** (excluding documentation)
|
| 190 |
+
|
| 191 |
+
---
|
| 192 |
+
|
| 193 |
+
**Note**: This is a minimal deployment. All Gradio UI files, documentation, and test files are optional and can be excluded to reduce repository size.
|
| 194 |
+
|
README.md
CHANGED
|
@@ -39,7 +39,7 @@ public: true
|
|
| 39 |

|
| 40 |

|
| 41 |

|
| 42 |
-

|
| 40 |

|
| 41 |

|
| 42 |
+

|
| 43 |
|
| 44 |
**Academic-grade AI assistant with transparent reasoning and mobile-optimized interface**
|
| 45 |
|
|
|
|
| 50 |
|
| 51 |
## 🎯 Overview
|
| 52 |
|
| 53 |
+
This MVP demonstrates an intelligent research assistant framework featuring **transparent reasoning chains**, **specialized agent architecture**, and **mobile-first design**. Built for Hugging Face Spaces with NVIDIA T4 GPU acceleration for local model inference.
|
| 54 |
|
| 55 |
### Key Differentiators
|
| 56 |
- **🔍 Transparent Reasoning**: Watch the AI think step-by-step with Chain of Thought
|
|
|
|
| 286 |
|-------|----------|
|
| 287 |
| **HF_TOKEN not found** | Add token in Space Settings → Secrets |
|
| 288 |
| **Build timeout** | Reduce model sizes in requirements |
|
| 289 |
+
| **Memory errors** | Check GPU memory usage, optimize model loading |
|
| 290 |
| **Import errors** | Check Python version (3.9+) |
|
| 291 |
|
| 292 |
### Performance Optimization
|
README_FLASK_API.md
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: AI Assistant Flask API
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# AI Assistant Flask API
|
| 12 |
+
|
| 13 |
+
Pure Flask REST API for AI research assistant.
|
| 14 |
+
|
| 15 |
+
## Quick Start
|
| 16 |
+
|
| 17 |
+
This Space provides a REST API (no UI). Test the endpoints:
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
# Health check
|
| 21 |
+
curl https://YOUR-SPACE.hf.space/api/health
|
| 22 |
+
|
| 23 |
+
# Chat
|
| 24 |
+
curl -X POST https://YOUR-SPACE.hf.space/api/chat \
|
| 25 |
+
-H "Content-Type: application/json" \
|
| 26 |
+
-d '{
|
| 27 |
+
"message": "Hello, how are you?",
|
| 28 |
+
"session_id": "test-123",
|
| 29 |
+
"user_id": "user@example.com"
|
| 30 |
+
}'
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## API Endpoints
|
| 34 |
+
|
| 35 |
+
### GET /api/health
|
| 36 |
+
Health check endpoint.
|
| 37 |
+
|
| 38 |
+
**Response:**
|
| 39 |
+
```json
|
| 40 |
+
{
|
| 41 |
+
"status": "healthy",
|
| 42 |
+
"orchestrator_ready": true
|
| 43 |
+
}
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
### POST /api/chat
|
| 47 |
+
Process a chat message.
|
| 48 |
+
|
| 49 |
+
**Request:**
|
| 50 |
+
```json
|
| 51 |
+
{
|
| 52 |
+
"message": "Your question here",
|
| 53 |
+
"history": [],
|
| 54 |
+
"session_id": "optional-session-id",
|
| 55 |
+
"user_id": "optional-user-id"
|
| 56 |
+
}
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
**Response:**
|
| 60 |
+
```json
|
| 61 |
+
{
|
| 62 |
+
"success": true,
|
| 63 |
+
"message": "AI response here",
|
| 64 |
+
"history": [["Your question", "AI response"]],
|
| 65 |
+
"reasoning": {},
|
| 66 |
+
"performance": {}
|
| 67 |
+
}
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Environment Variables
|
| 71 |
+
|
| 72 |
+
Set in Space Settings → Repository secrets:
|
| 73 |
+
|
| 74 |
+
- `HF_TOKEN` - Your Hugging Face API token (required)
|
| 75 |
+
|
| 76 |
+
## Technology
|
| 77 |
+
|
| 78 |
+
- Flask 3.0
|
| 79 |
+
- Python 3.10
|
| 80 |
+
- Custom AI orchestrator with multiple agents
|
| 81 |
+
- Docker containerized
|
| 82 |
+
- **NVIDIA T4 GPU** for local model inference
|
| 83 |
+
|
| 84 |
+
## Features
|
| 85 |
+
|
| 86 |
+
- 🤖 AI-powered responses with local GPU models
|
| 87 |
+
- 🔄 Context-aware conversations
|
| 88 |
+
- 🛡️ Safety checking
|
| 89 |
+
- 📊 Performance metrics
|
| 90 |
+
- 🎯 Intent recognition
|
| 91 |
+
- 🔧 Skills identification
|
| 92 |
+
|
config.py
CHANGED
|
@@ -13,7 +13,7 @@ class Settings(BaseSettings):
|
|
| 13 |
classification_model: str = "cardiffnlp/twitter-roberta-base-emotion"
|
| 14 |
|
| 15 |
# Performance settings
|
| 16 |
-
max_workers: int = int(os.getenv("MAX_WORKERS", "
|
| 17 |
cache_ttl: int = int(os.getenv("CACHE_TTL", "3600"))
|
| 18 |
|
| 19 |
# Database settings
|
|
|
|
| 13 |
classification_model: str = "cardiffnlp/twitter-roberta-base-emotion"
|
| 14 |
|
| 15 |
# Performance settings
|
| 16 |
+
max_workers: int = int(os.getenv("MAX_WORKERS", "4"))
|
| 17 |
cache_ttl: int = int(os.getenv("CACHE_TTL", "3600"))
|
| 18 |
|
| 19 |
# Database settings
|
flask_api_standalone.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Pure Flask API for Hugging Face Spaces
|
| 4 |
+
No Gradio - Just Flask REST API
|
| 5 |
+
Uses local GPU models for inference
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from flask import Flask, request, jsonify
|
| 9 |
+
from flask_cors import CORS
|
| 10 |
+
import logging
|
| 11 |
+
import sys
|
| 12 |
+
import os
|
| 13 |
+
import asyncio
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
# Setup logging
|
| 17 |
+
logging.basicConfig(
|
| 18 |
+
level=logging.INFO,
|
| 19 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 20 |
+
)
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# Add project root to path
|
| 24 |
+
project_root = Path(__file__).parent
|
| 25 |
+
sys.path.insert(0, str(project_root))
|
| 26 |
+
|
| 27 |
+
# Create Flask app
|
| 28 |
+
app = Flask(__name__)
|
| 29 |
+
CORS(app) # Enable CORS for all origins
|
| 30 |
+
|
| 31 |
+
# Global orchestrator
|
| 32 |
+
orchestrator = None
|
| 33 |
+
orchestrator_available = False
|
| 34 |
+
|
| 35 |
+
def initialize_orchestrator():
|
| 36 |
+
"""Initialize the AI orchestrator with local GPU models"""
|
| 37 |
+
global orchestrator, orchestrator_available
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
logger.info("=" * 60)
|
| 41 |
+
logger.info("INITIALIZING AI ORCHESTRATOR (Local GPU Models)")
|
| 42 |
+
logger.info("=" * 60)
|
| 43 |
+
|
| 44 |
+
from src.agents.intent_agent import create_intent_agent
|
| 45 |
+
from src.agents.synthesis_agent import create_synthesis_agent
|
| 46 |
+
from src.agents.safety_agent import create_safety_agent
|
| 47 |
+
from src.agents.skills_identification_agent import create_skills_identification_agent
|
| 48 |
+
from src.llm_router import LLMRouter
|
| 49 |
+
from src.orchestrator_engine import MVPOrchestrator
|
| 50 |
+
from src.context_manager import EfficientContextManager
|
| 51 |
+
|
| 52 |
+
logger.info("✓ Imports successful")
|
| 53 |
+
|
| 54 |
+
hf_token = os.getenv('HF_TOKEN', '')
|
| 55 |
+
if not hf_token:
|
| 56 |
+
logger.warning("HF_TOKEN not set - API fallback will be used if local models fail")
|
| 57 |
+
|
| 58 |
+
# Initialize LLM Router with local model loading enabled
|
| 59 |
+
logger.info("Initializing LLM Router with local GPU model loading...")
|
| 60 |
+
llm_router = LLMRouter(hf_token, use_local_models=True)
|
| 61 |
+
|
| 62 |
+
logger.info("Initializing Agents...")
|
| 63 |
+
agents = {
|
| 64 |
+
'intent_recognition': create_intent_agent(llm_router),
|
| 65 |
+
'response_synthesis': create_synthesis_agent(llm_router),
|
| 66 |
+
'safety_check': create_safety_agent(llm_router),
|
| 67 |
+
'skills_identification': create_skills_identification_agent(llm_router)
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
logger.info("Initializing Context Manager...")
|
| 71 |
+
context_manager = EfficientContextManager(llm_router=llm_router)
|
| 72 |
+
|
| 73 |
+
logger.info("Initializing Orchestrator...")
|
| 74 |
+
orchestrator = MVPOrchestrator(llm_router, context_manager, agents)
|
| 75 |
+
|
| 76 |
+
orchestrator_available = True
|
| 77 |
+
logger.info("=" * 60)
|
| 78 |
+
logger.info("✓ AI ORCHESTRATOR READY")
|
| 79 |
+
logger.info(" - Local GPU models enabled")
|
| 80 |
+
logger.info(" - MAX_WORKERS: 4")
|
| 81 |
+
logger.info("=" * 60)
|
| 82 |
+
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Failed to initialize: {e}", exc_info=True)
|
| 87 |
+
orchestrator_available = False
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
# Root endpoint
|
| 91 |
+
@app.route('/', methods=['GET'])
|
| 92 |
+
def root():
|
| 93 |
+
"""API information"""
|
| 94 |
+
return jsonify({
|
| 95 |
+
'name': 'AI Assistant Flask API',
|
| 96 |
+
'version': '1.0',
|
| 97 |
+
'status': 'running',
|
| 98 |
+
'orchestrator_ready': orchestrator_available,
|
| 99 |
+
'features': {
|
| 100 |
+
'local_gpu_models': True,
|
| 101 |
+
'max_workers': 4,
|
| 102 |
+
'hardware': 'NVIDIA T4 Medium'
|
| 103 |
+
},
|
| 104 |
+
'endpoints': {
|
| 105 |
+
'health': 'GET /api/health',
|
| 106 |
+
'chat': 'POST /api/chat',
|
| 107 |
+
'initialize': 'POST /api/initialize'
|
| 108 |
+
}
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
# Health check
|
| 112 |
+
@app.route('/api/health', methods=['GET'])
|
| 113 |
+
def health_check():
|
| 114 |
+
"""Health check endpoint"""
|
| 115 |
+
return jsonify({
|
| 116 |
+
'status': 'healthy' if orchestrator_available else 'initializing',
|
| 117 |
+
'orchestrator_ready': orchestrator_available
|
| 118 |
+
})
|
| 119 |
+
|
| 120 |
+
# Chat endpoint
|
| 121 |
+
@app.route('/api/chat', methods=['POST'])
|
| 122 |
+
def chat():
|
| 123 |
+
"""
|
| 124 |
+
Process chat message
|
| 125 |
+
|
| 126 |
+
POST /api/chat
|
| 127 |
+
{
|
| 128 |
+
"message": "user message",
|
| 129 |
+
"history": [[user, assistant], ...],
|
| 130 |
+
"session_id": "session-123",
|
| 131 |
+
"user_id": "user-456"
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
{
|
| 136 |
+
"success": true,
|
| 137 |
+
"message": "AI response",
|
| 138 |
+
"history": [...],
|
| 139 |
+
"reasoning": {...},
|
| 140 |
+
"performance": {...}
|
| 141 |
+
}
|
| 142 |
+
"""
|
| 143 |
+
try:
|
| 144 |
+
data = request.get_json()
|
| 145 |
+
|
| 146 |
+
if not data or 'message' not in data:
|
| 147 |
+
return jsonify({
|
| 148 |
+
'success': False,
|
| 149 |
+
'error': 'Message is required'
|
| 150 |
+
}), 400
|
| 151 |
+
|
| 152 |
+
message = data['message']
|
| 153 |
+
history = data.get('history', [])
|
| 154 |
+
session_id = data.get('session_id')
|
| 155 |
+
user_id = data.get('user_id', 'anonymous')
|
| 156 |
+
|
| 157 |
+
logger.info(f"Chat request - User: {user_id}, Session: {session_id}")
|
| 158 |
+
logger.info(f"Message: {message[:100]}...")
|
| 159 |
+
|
| 160 |
+
if not orchestrator_available or orchestrator is None:
|
| 161 |
+
return jsonify({
|
| 162 |
+
'success': False,
|
| 163 |
+
'error': 'Orchestrator not ready',
|
| 164 |
+
'message': 'AI system is initializing. Please try again in a moment.'
|
| 165 |
+
}), 503
|
| 166 |
+
|
| 167 |
+
# Process with orchestrator (async method)
|
| 168 |
+
# Set user_id for session tracking
|
| 169 |
+
if session_id:
|
| 170 |
+
orchestrator.set_user_id(session_id, user_id)
|
| 171 |
+
|
| 172 |
+
# Run async process_request in event loop
|
| 173 |
+
loop = asyncio.new_event_loop()
|
| 174 |
+
asyncio.set_event_loop(loop)
|
| 175 |
+
try:
|
| 176 |
+
result = loop.run_until_complete(
|
| 177 |
+
orchestrator.process_request(
|
| 178 |
+
session_id=session_id or f"session-{user_id}",
|
| 179 |
+
user_input=message
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
finally:
|
| 183 |
+
loop.close()
|
| 184 |
+
|
| 185 |
+
# Extract response
|
| 186 |
+
if isinstance(result, dict):
|
| 187 |
+
response_text = result.get('response', '')
|
| 188 |
+
reasoning = result.get('reasoning', {})
|
| 189 |
+
performance = result.get('performance', {})
|
| 190 |
+
else:
|
| 191 |
+
response_text = str(result)
|
| 192 |
+
reasoning = {}
|
| 193 |
+
performance = {}
|
| 194 |
+
|
| 195 |
+
updated_history = history + [[message, response_text]]
|
| 196 |
+
|
| 197 |
+
logger.info(f"✓ Response generated (length: {len(response_text)})")
|
| 198 |
+
|
| 199 |
+
return jsonify({
|
| 200 |
+
'success': True,
|
| 201 |
+
'message': response_text,
|
| 202 |
+
'history': updated_history,
|
| 203 |
+
'reasoning': reasoning,
|
| 204 |
+
'performance': performance
|
| 205 |
+
})
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logger.error(f"Chat error: {e}", exc_info=True)
|
| 209 |
+
return jsonify({
|
| 210 |
+
'success': False,
|
| 211 |
+
'error': str(e),
|
| 212 |
+
'message': 'Error processing your request. Please try again.'
|
| 213 |
+
}), 500
|
| 214 |
+
|
| 215 |
+
# Manual initialization endpoint
|
| 216 |
+
@app.route('/api/initialize', methods=['POST'])
|
| 217 |
+
def initialize():
|
| 218 |
+
"""Manually trigger initialization"""
|
| 219 |
+
success = initialize_orchestrator()
|
| 220 |
+
|
| 221 |
+
if success:
|
| 222 |
+
return jsonify({
|
| 223 |
+
'success': True,
|
| 224 |
+
'message': 'Orchestrator initialized successfully'
|
| 225 |
+
})
|
| 226 |
+
else:
|
| 227 |
+
return jsonify({
|
| 228 |
+
'success': False,
|
| 229 |
+
'message': 'Initialization failed. Check logs for details.'
|
| 230 |
+
}), 500
|
| 231 |
+
|
| 232 |
+
# Initialize on startup
|
| 233 |
+
if __name__ == '__main__':
|
| 234 |
+
logger.info("=" * 60)
|
| 235 |
+
logger.info("STARTING PURE FLASK API")
|
| 236 |
+
logger.info("=" * 60)
|
| 237 |
+
|
| 238 |
+
# Initialize orchestrator
|
| 239 |
+
initialize_orchestrator()
|
| 240 |
+
|
| 241 |
+
port = int(os.getenv('PORT', 7860))
|
| 242 |
+
|
| 243 |
+
logger.info(f"Starting Flask on port {port}")
|
| 244 |
+
logger.info("Endpoints available:")
|
| 245 |
+
logger.info(" GET /")
|
| 246 |
+
logger.info(" GET /api/health")
|
| 247 |
+
logger.info(" POST /api/chat")
|
| 248 |
+
logger.info(" POST /api/initialize")
|
| 249 |
+
logger.info("=" * 60)
|
| 250 |
+
|
| 251 |
+
app.run(
|
| 252 |
+
host='0.0.0.0',
|
| 253 |
+
port=port,
|
| 254 |
+
debug=False,
|
| 255 |
+
threaded=True # Enable threading for concurrent requests
|
| 256 |
+
)
|
| 257 |
+
|
requirements.txt
CHANGED
|
@@ -1,9 +1,13 @@
|
|
| 1 |
-
# requirements.txt for Hugging Face Spaces with
|
| 2 |
# Core Framework Dependencies
|
| 3 |
|
| 4 |
-
# Note: gradio, fastapi, uvicorn,
|
| 5 |
# pydantic==2.10.6, and protobuf<4 are installed by HF Spaces SDK
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
# Web Framework & Interface
|
| 8 |
aiohttp>=3.9.0
|
| 9 |
httpx>=0.25.0
|
|
|
|
| 1 |
+
# requirements.txt for Hugging Face Spaces with NVIDIA T4 GPU
|
| 2 |
# Core Framework Dependencies
|
| 3 |
|
| 4 |
+
# Note: gradio, fastapi, uvicorn, datasets, huggingface-hub,
|
| 5 |
# pydantic==2.10.6, and protobuf<4 are installed by HF Spaces SDK
|
| 6 |
|
| 7 |
+
# PyTorch with CUDA support (for GPU inference)
|
| 8 |
+
# Note: HF Spaces provides torch, but we ensure GPU support
|
| 9 |
+
torch>=2.0.0
|
| 10 |
+
|
| 11 |
# Web Framework & Interface
|
| 12 |
aiohttp>=3.9.0
|
| 13 |
httpx>=0.25.0
|
src/config.py
CHANGED
|
@@ -13,7 +13,7 @@ class Settings(BaseSettings):
|
|
| 13 |
classification_model: str = "cardiffnlp/twitter-roberta-base-emotion"
|
| 14 |
|
| 15 |
# Performance settings
|
| 16 |
-
max_workers: int = int(os.getenv("MAX_WORKERS", "
|
| 17 |
cache_ttl: int = int(os.getenv("CACHE_TTL", "3600"))
|
| 18 |
|
| 19 |
# Database settings
|
|
|
|
| 13 |
classification_model: str = "cardiffnlp/twitter-roberta-base-emotion"
|
| 14 |
|
| 15 |
# Performance settings
|
| 16 |
+
max_workers: int = int(os.getenv("MAX_WORKERS", "4"))
|
| 17 |
cache_ttl: int = int(os.getenv("CACHE_TTL", "3600"))
|
| 18 |
|
| 19 |
# Database settings
|
src/llm_router.py
CHANGED
|
@@ -1,40 +1,154 @@
|
|
| 1 |
-
# llm_router.py -
|
| 2 |
import logging
|
| 3 |
import asyncio
|
| 4 |
-
from typing import Dict
|
| 5 |
from .models_config import LLM_CONFIG
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
| 9 |
class LLMRouter:
|
| 10 |
-
def __init__(self, hf_token):
|
| 11 |
self.hf_token = hf_token
|
| 12 |
self.health_status = {}
|
|
|
|
|
|
|
|
|
|
| 13 |
logger.info("LLMRouter initialized")
|
| 14 |
if hf_token:
|
| 15 |
logger.info("HF token available")
|
| 16 |
else:
|
| 17 |
logger.warning("No HF token provided")
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
async def route_inference(self, task_type: str, prompt: str, **kwargs):
|
| 20 |
"""
|
| 21 |
Smart routing based on task specialization
|
|
|
|
| 22 |
"""
|
| 23 |
logger.info(f"Routing inference for task: {task_type}")
|
| 24 |
model_config = self._select_model(task_type)
|
| 25 |
logger.info(f"Selected model: {model_config['model_id']}")
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# Health check and fallback logic
|
| 28 |
if not await self._is_model_healthy(model_config["model_id"]):
|
| 29 |
logger.warning(f"Model unhealthy, using fallback")
|
| 30 |
model_config = self._get_fallback_model(task_type)
|
| 31 |
logger.info(f"Fallback model: {model_config['model_id']}")
|
| 32 |
|
| 33 |
-
# FIXED: Ensure task_type is passed to the _call_hf_endpoint method
|
| 34 |
result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
|
| 35 |
logger.info(f"Inference complete for {task_type}")
|
| 36 |
return result
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def _select_model(self, task_type: str) -> dict:
|
| 39 |
model_map = {
|
| 40 |
"intent_classification": LLM_CONFIG["models"]["classification_specialist"],
|
|
|
|
| 1 |
+
# llm_router.py - UPDATED FOR LOCAL GPU MODEL LOADING
|
| 2 |
import logging
|
| 3 |
import asyncio
|
| 4 |
+
from typing import Dict, Optional
|
| 5 |
from .models_config import LLM_CONFIG
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
| 9 |
class LLMRouter:
|
| 10 |
+
def __init__(self, hf_token, use_local_models: bool = True):
|
| 11 |
self.hf_token = hf_token
|
| 12 |
self.health_status = {}
|
| 13 |
+
self.use_local_models = use_local_models
|
| 14 |
+
self.local_loader = None
|
| 15 |
+
|
| 16 |
logger.info("LLMRouter initialized")
|
| 17 |
if hf_token:
|
| 18 |
logger.info("HF token available")
|
| 19 |
else:
|
| 20 |
logger.warning("No HF token provided")
|
| 21 |
|
| 22 |
+
# Initialize local model loader if enabled
|
| 23 |
+
if self.use_local_models:
|
| 24 |
+
try:
|
| 25 |
+
from .local_model_loader import LocalModelLoader
|
| 26 |
+
self.local_loader = LocalModelLoader()
|
| 27 |
+
logger.info("✓ Local model loader initialized (GPU-based inference)")
|
| 28 |
+
|
| 29 |
+
# Note: Pre-loading will happen on first request (lazy loading)
|
| 30 |
+
# Models will be loaded on-demand to avoid blocking startup
|
| 31 |
+
logger.info("Models will be loaded on-demand for faster startup")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
logger.warning(f"Could not initialize local model loader: {e}. Falling back to API.")
|
| 34 |
+
logger.warning("This is normal if transformers/torch not available")
|
| 35 |
+
self.use_local_models = False
|
| 36 |
+
self.local_loader = None
|
| 37 |
+
|
| 38 |
async def route_inference(self, task_type: str, prompt: str, **kwargs):
|
| 39 |
"""
|
| 40 |
Smart routing based on task specialization
|
| 41 |
+
Tries local models first, falls back to HF Inference API if needed
|
| 42 |
"""
|
| 43 |
logger.info(f"Routing inference for task: {task_type}")
|
| 44 |
model_config = self._select_model(task_type)
|
| 45 |
logger.info(f"Selected model: {model_config['model_id']}")
|
| 46 |
|
| 47 |
+
# Try local model first if available
|
| 48 |
+
if self.use_local_models and self.local_loader:
|
| 49 |
+
try:
|
| 50 |
+
# Handle embedding generation separately
|
| 51 |
+
if task_type == "embedding_generation":
|
| 52 |
+
result = await self._call_local_embedding(model_config, prompt, **kwargs)
|
| 53 |
+
else:
|
| 54 |
+
result = await self._call_local_model(model_config, prompt, task_type, **kwargs)
|
| 55 |
+
|
| 56 |
+
if result is not None:
|
| 57 |
+
logger.info(f"Inference complete for {task_type} (local model)")
|
| 58 |
+
return result
|
| 59 |
+
else:
|
| 60 |
+
logger.warning("Local model returned None, falling back to API")
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.warning(f"Local model inference failed: {e}. Falling back to API.")
|
| 63 |
+
logger.debug("Exception details:", exc_info=True)
|
| 64 |
+
|
| 65 |
+
# Fallback to HF Inference API
|
| 66 |
+
logger.info("Using HF Inference API")
|
| 67 |
# Health check and fallback logic
|
| 68 |
if not await self._is_model_healthy(model_config["model_id"]):
|
| 69 |
logger.warning(f"Model unhealthy, using fallback")
|
| 70 |
model_config = self._get_fallback_model(task_type)
|
| 71 |
logger.info(f"Fallback model: {model_config['model_id']}")
|
| 72 |
|
|
|
|
| 73 |
result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
|
| 74 |
logger.info(f"Inference complete for {task_type}")
|
| 75 |
return result
|
| 76 |
|
| 77 |
+
async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]:
|
| 78 |
+
"""Call local model for inference."""
|
| 79 |
+
if not self.local_loader:
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
model_id = model_config["model_id"]
|
| 83 |
+
max_tokens = kwargs.get('max_tokens', 512)
|
| 84 |
+
temperature = kwargs.get('temperature', 0.7)
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
# Ensure model is loaded
|
| 88 |
+
if model_id not in self.local_loader.loaded_models:
|
| 89 |
+
logger.info(f"Loading model {model_id} on demand...")
|
| 90 |
+
self.local_loader.load_chat_model(model_id, load_in_8bit=False)
|
| 91 |
+
|
| 92 |
+
# Format as chat messages if needed
|
| 93 |
+
messages = [{"role": "user", "content": prompt}]
|
| 94 |
+
|
| 95 |
+
# Generate using local model
|
| 96 |
+
result = await asyncio.to_thread(
|
| 97 |
+
self.local_loader.generate_chat_completion,
|
| 98 |
+
model_id=model_id,
|
| 99 |
+
messages=messages,
|
| 100 |
+
max_tokens=max_tokens,
|
| 101 |
+
temperature=temperature
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
logger.info(f"Local model {model_id} generated response (length: {len(result)})")
|
| 105 |
+
logger.info("=" * 80)
|
| 106 |
+
logger.info("LOCAL MODEL RESPONSE:")
|
| 107 |
+
logger.info("=" * 80)
|
| 108 |
+
logger.info(f"Model: {model_id}")
|
| 109 |
+
logger.info(f"Task Type: {task_type}")
|
| 110 |
+
logger.info(f"Response Length: {len(result)} characters")
|
| 111 |
+
logger.info("-" * 40)
|
| 112 |
+
logger.info("FULL RESPONSE CONTENT:")
|
| 113 |
+
logger.info("-" * 40)
|
| 114 |
+
logger.info(result)
|
| 115 |
+
logger.info("-" * 40)
|
| 116 |
+
logger.info("END OF RESPONSE")
|
| 117 |
+
logger.info("=" * 80)
|
| 118 |
+
|
| 119 |
+
return result
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.error(f"Error calling local model: {e}", exc_info=True)
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]:
|
| 126 |
+
"""Call local embedding model."""
|
| 127 |
+
if not self.local_loader:
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
model_id = model_config["model_id"]
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
# Ensure model is loaded
|
| 134 |
+
if model_id not in self.local_loader.loaded_embedding_models:
|
| 135 |
+
logger.info(f"Loading embedding model {model_id} on demand...")
|
| 136 |
+
self.local_loader.load_embedding_model(model_id)
|
| 137 |
+
|
| 138 |
+
# Generate embedding
|
| 139 |
+
embedding = await asyncio.to_thread(
|
| 140 |
+
self.local_loader.get_embedding,
|
| 141 |
+
model_id=model_id,
|
| 142 |
+
text=text
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
logger.info(f"Local embedding model {model_id} generated vector (dim: {len(embedding)})")
|
| 146 |
+
return embedding
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
logger.error(f"Error calling local embedding model: {e}", exc_info=True)
|
| 150 |
+
return None
|
| 151 |
+
|
| 152 |
def _select_model(self, task_type: str) -> dict:
|
| 153 |
model_map = {
|
| 154 |
"intent_classification": LLM_CONFIG["models"]["classification_specialist"],
|
src/local_model_loader.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# local_model_loader.py
|
| 2 |
+
# Local GPU-based model loading for NVIDIA T4 Medium (24GB vRAM)
|
| 3 |
+
import logging
|
| 4 |
+
import torch
|
| 5 |
+
from typing import Optional, Dict, Any
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class LocalModelLoader:
|
| 12 |
+
"""
|
| 13 |
+
Loads and manages models locally on GPU for faster inference.
|
| 14 |
+
Optimized for NVIDIA T4 Medium with 24GB vRAM.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, device: Optional[str] = None):
|
| 18 |
+
"""Initialize the model loader with GPU device detection."""
|
| 19 |
+
# Detect device
|
| 20 |
+
if device is None:
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
self.device = "cuda"
|
| 23 |
+
self.device_name = torch.cuda.get_device_name(0)
|
| 24 |
+
logger.info(f"GPU detected: {self.device_name}")
|
| 25 |
+
logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
| 26 |
+
else:
|
| 27 |
+
self.device = "cpu"
|
| 28 |
+
self.device_name = "CPU"
|
| 29 |
+
logger.warning("No GPU detected, using CPU")
|
| 30 |
+
else:
|
| 31 |
+
self.device = device
|
| 32 |
+
self.device_name = device
|
| 33 |
+
|
| 34 |
+
# Model cache
|
| 35 |
+
self.loaded_models: Dict[str, Any] = {}
|
| 36 |
+
self.loaded_tokenizers: Dict[str, Any] = {}
|
| 37 |
+
self.loaded_embedding_models: Dict[str, Any] = {}
|
| 38 |
+
|
| 39 |
+
def load_chat_model(self, model_id: str, load_in_8bit: bool = False, load_in_4bit: bool = False) -> tuple:
|
| 40 |
+
"""
|
| 41 |
+
Load a chat model and tokenizer on GPU.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
model_id: HuggingFace model identifier
|
| 45 |
+
load_in_8bit: Use 8-bit quantization (saves memory)
|
| 46 |
+
load_in_4bit: Use 4-bit quantization (saves more memory)
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Tuple of (model, tokenizer)
|
| 50 |
+
"""
|
| 51 |
+
if model_id in self.loaded_models:
|
| 52 |
+
logger.info(f"Model {model_id} already loaded, reusing")
|
| 53 |
+
return self.loaded_models[model_id], self.loaded_tokenizers[model_id]
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
logger.info(f"Loading model {model_id} on {self.device}...")
|
| 57 |
+
|
| 58 |
+
# Load tokenizer
|
| 59 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 60 |
+
model_id,
|
| 61 |
+
trust_remote_code=True
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Determine quantization config
|
| 65 |
+
if load_in_4bit and self.device == "cuda":
|
| 66 |
+
try:
|
| 67 |
+
from transformers import BitsAndBytesConfig
|
| 68 |
+
quantization_config = BitsAndBytesConfig(
|
| 69 |
+
load_in_4bit=True,
|
| 70 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 71 |
+
bnb_4bit_use_double_quant=True,
|
| 72 |
+
bnb_4bit_quant_type="nf4"
|
| 73 |
+
)
|
| 74 |
+
logger.info("Using 4-bit quantization")
|
| 75 |
+
except ImportError:
|
| 76 |
+
logger.warning("bitsandbytes not available, loading without quantization")
|
| 77 |
+
quantization_config = None
|
| 78 |
+
elif load_in_8bit and self.device == "cuda":
|
| 79 |
+
try:
|
| 80 |
+
quantization_config = {"load_in_8bit": True}
|
| 81 |
+
logger.info("Using 8-bit quantization")
|
| 82 |
+
except:
|
| 83 |
+
quantization_config = None
|
| 84 |
+
else:
|
| 85 |
+
quantization_config = None
|
| 86 |
+
|
| 87 |
+
# Load model with GPU optimization
|
| 88 |
+
if self.device == "cuda":
|
| 89 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 90 |
+
model_id,
|
| 91 |
+
device_map="auto", # Automatically uses GPU
|
| 92 |
+
torch_dtype=torch.float16, # Use FP16 for memory efficiency
|
| 93 |
+
trust_remote_code=True,
|
| 94 |
+
**(quantization_config if isinstance(quantization_config, dict) else {}),
|
| 95 |
+
**({"quantization_config": quantization_config} if quantization_config and not isinstance(quantization_config, dict) else {})
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 99 |
+
model_id,
|
| 100 |
+
torch_dtype=torch.float32,
|
| 101 |
+
trust_remote_code=True
|
| 102 |
+
)
|
| 103 |
+
model = model.to(self.device)
|
| 104 |
+
|
| 105 |
+
# Ensure padding token is set
|
| 106 |
+
if tokenizer.pad_token is None:
|
| 107 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 108 |
+
|
| 109 |
+
# Cache models
|
| 110 |
+
self.loaded_models[model_id] = model
|
| 111 |
+
self.loaded_tokenizers[model_id] = tokenizer
|
| 112 |
+
|
| 113 |
+
# Log memory usage
|
| 114 |
+
if self.device == "cuda":
|
| 115 |
+
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
| 116 |
+
reserved = torch.cuda.memory_reserved(0) / 1024**3
|
| 117 |
+
logger.info(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
|
| 118 |
+
|
| 119 |
+
logger.info(f"✓ Model {model_id} loaded successfully on {self.device}")
|
| 120 |
+
return model, tokenizer
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
logger.error(f"Error loading model {model_id}: {e}", exc_info=True)
|
| 124 |
+
raise
|
| 125 |
+
|
| 126 |
+
def load_embedding_model(self, model_id: str) -> SentenceTransformer:
|
| 127 |
+
"""
|
| 128 |
+
Load a sentence transformer model for embeddings.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
model_id: HuggingFace model identifier
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
SentenceTransformer model
|
| 135 |
+
"""
|
| 136 |
+
if model_id in self.loaded_embedding_models:
|
| 137 |
+
logger.info(f"Embedding model {model_id} already loaded, reusing")
|
| 138 |
+
return self.loaded_embedding_models[model_id]
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
logger.info(f"Loading embedding model {model_id}...")
|
| 142 |
+
|
| 143 |
+
# SentenceTransformer automatically handles GPU
|
| 144 |
+
model = SentenceTransformer(
|
| 145 |
+
model_id,
|
| 146 |
+
device=self.device
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Cache model
|
| 150 |
+
self.loaded_embedding_models[model_id] = model
|
| 151 |
+
|
| 152 |
+
logger.info(f"✓ Embedding model {model_id} loaded successfully on {self.device}")
|
| 153 |
+
return model
|
| 154 |
+
|
| 155 |
+
except Exception as e:
|
| 156 |
+
logger.error(f"Error loading embedding model {model_id}: {e}", exc_info=True)
|
| 157 |
+
raise
|
| 158 |
+
|
| 159 |
+
def generate_text(
|
| 160 |
+
self,
|
| 161 |
+
model_id: str,
|
| 162 |
+
prompt: str,
|
| 163 |
+
max_tokens: int = 512,
|
| 164 |
+
temperature: float = 0.7,
|
| 165 |
+
**kwargs
|
| 166 |
+
) -> str:
|
| 167 |
+
"""
|
| 168 |
+
Generate text using a loaded chat model.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
model_id: Model identifier
|
| 172 |
+
prompt: Input prompt
|
| 173 |
+
max_tokens: Maximum tokens to generate
|
| 174 |
+
temperature: Sampling temperature
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Generated text
|
| 178 |
+
"""
|
| 179 |
+
if model_id not in self.loaded_models:
|
| 180 |
+
raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.")
|
| 181 |
+
|
| 182 |
+
model = self.loaded_models[model_id]
|
| 183 |
+
tokenizer = self.loaded_tokenizers[model_id]
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
# Tokenize input
|
| 187 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 188 |
+
|
| 189 |
+
# Generate
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
outputs = model.generate(
|
| 192 |
+
**inputs,
|
| 193 |
+
max_new_tokens=max_tokens,
|
| 194 |
+
temperature=temperature,
|
| 195 |
+
do_sample=True,
|
| 196 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 197 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 198 |
+
**kwargs
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Decode
|
| 202 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 203 |
+
|
| 204 |
+
# Remove prompt from output if present
|
| 205 |
+
if generated_text.startswith(prompt):
|
| 206 |
+
generated_text = generated_text[len(prompt):].strip()
|
| 207 |
+
|
| 208 |
+
return generated_text
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logger.error(f"Error generating text: {e}", exc_info=True)
|
| 212 |
+
raise
|
| 213 |
+
|
| 214 |
+
def generate_chat_completion(
|
| 215 |
+
self,
|
| 216 |
+
model_id: str,
|
| 217 |
+
messages: list,
|
| 218 |
+
max_tokens: int = 512,
|
| 219 |
+
temperature: float = 0.7,
|
| 220 |
+
**kwargs
|
| 221 |
+
) -> str:
|
| 222 |
+
"""
|
| 223 |
+
Generate chat completion using a loaded model.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
model_id: Model identifier
|
| 227 |
+
messages: List of message dicts with 'role' and 'content'
|
| 228 |
+
max_tokens: Maximum tokens to generate
|
| 229 |
+
temperature: Sampling temperature
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Generated response
|
| 233 |
+
"""
|
| 234 |
+
if model_id not in self.loaded_models:
|
| 235 |
+
raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.")
|
| 236 |
+
|
| 237 |
+
model = self.loaded_models[model_id]
|
| 238 |
+
tokenizer = self.loaded_tokenizers[model_id]
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
# Format messages as prompt
|
| 242 |
+
if hasattr(tokenizer, 'apply_chat_template'):
|
| 243 |
+
# Use chat template if available
|
| 244 |
+
prompt = tokenizer.apply_chat_template(
|
| 245 |
+
messages,
|
| 246 |
+
tokenize=False,
|
| 247 |
+
add_generation_prompt=True
|
| 248 |
+
)
|
| 249 |
+
else:
|
| 250 |
+
# Fallback: simple formatting
|
| 251 |
+
prompt = "\n".join([
|
| 252 |
+
f"{msg['role']}: {msg['content']}"
|
| 253 |
+
for msg in messages
|
| 254 |
+
]) + "\nassistant: "
|
| 255 |
+
|
| 256 |
+
# Generate
|
| 257 |
+
return self.generate_text(
|
| 258 |
+
model_id=model_id,
|
| 259 |
+
prompt=prompt,
|
| 260 |
+
max_tokens=max_tokens,
|
| 261 |
+
temperature=temperature,
|
| 262 |
+
**kwargs
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
except Exception as e:
|
| 266 |
+
logger.error(f"Error generating chat completion: {e}", exc_info=True)
|
| 267 |
+
raise
|
| 268 |
+
|
| 269 |
+
def get_embedding(self, model_id: str, text: str) -> list:
|
| 270 |
+
"""
|
| 271 |
+
Get embedding vector for text.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
model_id: Embedding model identifier
|
| 275 |
+
text: Input text
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Embedding vector
|
| 279 |
+
"""
|
| 280 |
+
if model_id not in self.loaded_embedding_models:
|
| 281 |
+
raise ValueError(f"Embedding model {model_id} not loaded. Call load_embedding_model() first.")
|
| 282 |
+
|
| 283 |
+
model = self.loaded_embedding_models[model_id]
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
embedding = model.encode(text, convert_to_numpy=True)
|
| 287 |
+
return embedding.tolist()
|
| 288 |
+
except Exception as e:
|
| 289 |
+
logger.error(f"Error getting embedding: {e}", exc_info=True)
|
| 290 |
+
raise
|
| 291 |
+
|
| 292 |
+
def clear_cache(self):
|
| 293 |
+
"""Clear all loaded models from memory."""
|
| 294 |
+
logger.info("Clearing model cache...")
|
| 295 |
+
|
| 296 |
+
# Clear models
|
| 297 |
+
for model_id in list(self.loaded_models.keys()):
|
| 298 |
+
del self.loaded_models[model_id]
|
| 299 |
+
for model_id in list(self.loaded_tokenizers.keys()):
|
| 300 |
+
del self.loaded_tokenizers[model_id]
|
| 301 |
+
for model_id in list(self.loaded_embedding_models.keys()):
|
| 302 |
+
del self.loaded_embedding_models[model_id]
|
| 303 |
+
|
| 304 |
+
# Clear GPU cache
|
| 305 |
+
if self.device == "cuda":
|
| 306 |
+
torch.cuda.empty_cache()
|
| 307 |
+
|
| 308 |
+
logger.info("✓ Model cache cleared")
|
| 309 |
+
|
| 310 |
+
def get_memory_usage(self) -> Dict[str, float]:
|
| 311 |
+
"""Get current GPU memory usage in GB."""
|
| 312 |
+
if self.device != "cuda":
|
| 313 |
+
return {"device": "cpu", "gpu_available": False}
|
| 314 |
+
|
| 315 |
+
return {
|
| 316 |
+
"device": self.device_name,
|
| 317 |
+
"gpu_available": True,
|
| 318 |
+
"allocated_gb": torch.cuda.memory_allocated(0) / 1024**3,
|
| 319 |
+
"reserved_gb": torch.cuda.memory_reserved(0) / 1024**3,
|
| 320 |
+
"total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 321 |
+
}
|
| 322 |
+
|