JatsTheAIGen commited on
Commit
7632802
·
1 Parent(s): 83fb1b5

api migration v2

Browse files
DEPLOYMENT_NOTES.md CHANGED
@@ -2,22 +2,29 @@
2
 
3
  ## Hugging Face Spaces Deployment
4
 
5
- ### ZeroGPU Configuration
6
- This MVP is optimized for **ZeroGPU** deployment on Hugging Face Spaces.
7
-
8
- #### Key Settings
9
- - **GPU**: None (CPU-only)
10
- - **Storage**: Limited (~20GB)
11
- - **Memory**: 32GB RAM
 
 
12
  - **Network**: Shared infrastructure
13
 
 
 
 
 
 
14
  ### Environment Variables
15
  Required environment variables for deployment:
16
 
17
  ```bash
18
  HF_TOKEN=your_huggingface_token_here
19
  HF_HOME=/tmp/huggingface
20
- MAX_WORKERS=2
21
  CACHE_TTL=3600
22
  DB_PATH=sessions.db
23
  FAISS_INDEX_PATH=embeddings.faiss
@@ -39,9 +46,8 @@ title: AI Research Assistant MVP
39
  emoji: 🧠
40
  colorFrom: blue
41
  colorTo: purple
42
- sdk: gradio
43
- sdk_version: 4.0.0
44
- app_file: app.py
45
  pinned: false
46
  license: apache-2.0
47
  ---
@@ -77,7 +83,7 @@ license: apache-2.0
77
  5. **Deploy to HF Spaces**
78
  - Push to GitHub
79
  - Connect to HF Spaces
80
- - Select ZeroGPU hardware
81
  - Deploy
82
 
83
  ### Resource Management
@@ -85,26 +91,34 @@ license: apache-2.0
85
  #### Memory Limits
86
  - **Base Python**: ~100MB
87
  - **Gradio**: ~50MB
88
- - **Models (loaded)**: ~200-500MB
89
- - **Cache**: ~100MB max
90
- - **Buffer**: ~100MB
 
 
 
91
 
92
- **Total Budget**: ~512MB (within HF Spaces limits)
 
93
 
94
  #### Strategies
95
- - Lazy model loading
96
- - Model offloading when not in use
97
- - Aggressive cache eviction
98
- - Stream responses to reduce memory
 
99
 
100
  ### Performance Optimization
101
 
102
- #### For ZeroGPU
103
- 1. Use HF Inference API for LLM calls (not local models)
104
- 2. Use `sentence-transformers` for embeddings (lightweight)
105
- 3. Implement request queuing
106
- 4. Use FAISS-CPU (not GPU version)
107
- 5. Implement response streaming
 
 
 
108
 
109
  #### Mobile Optimizations
110
  - Reduce max tokens to 800
 
2
 
3
  ## Hugging Face Spaces Deployment
4
 
5
+ ### NVIDIA T4 Medium Configuration
6
+ This MVP is optimized for **NVIDIA T4 Medium** GPU deployment on Hugging Face Spaces.
7
+
8
+ #### Hardware Specifications
9
+ - **GPU**: NVIDIA T4 (persistent, always available)
10
+ - **vCPU**: 8 cores
11
+ - **RAM**: 30GB
12
+ - **vRAM**: 24GB
13
+ - **Storage**: ~20GB
14
  - **Network**: Shared infrastructure
15
 
16
+ #### Resource Capacity
17
+ - **GPU Memory**: 24GB vRAM (sufficient for local model loading)
18
+ - **System Memory**: 30GB RAM (excellent for caching and processing)
19
+ - **CPU**: 8 vCPU (good for parallel operations)
20
+
21
  ### Environment Variables
22
  Required environment variables for deployment:
23
 
24
  ```bash
25
  HF_TOKEN=your_huggingface_token_here
26
  HF_HOME=/tmp/huggingface
27
+ MAX_WORKERS=4
28
  CACHE_TTL=3600
29
  DB_PATH=sessions.db
30
  FAISS_INDEX_PATH=embeddings.faiss
 
46
  emoji: 🧠
47
  colorFrom: blue
48
  colorTo: purple
49
+ sdk: docker
50
+ app_port: 7860
 
51
  pinned: false
52
  license: apache-2.0
53
  ---
 
83
  5. **Deploy to HF Spaces**
84
  - Push to GitHub
85
  - Connect to HF Spaces
86
+ - Select NVIDIA T4 Medium GPU hardware
87
  - Deploy
88
 
89
  ### Resource Management
 
91
  #### Memory Limits
92
  - **Base Python**: ~100MB
93
  - **Gradio**: ~50MB
94
+ - **Models (loaded on GPU)**: ~14-16GB vRAM
95
+ - Primary model (Qwen/Qwen2.5-7B): ~14GB
96
+ - Embedding model: ~500MB
97
+ - Classification models: ~500MB each
98
+ - **System RAM**: ~2-4GB for caching and processing
99
+ - **Cache**: ~500MB-1GB max
100
 
101
+ **GPU Memory Budget**: ~24GB vRAM (models fit comfortably)
102
+ **System RAM Budget**: 30GB (plenty of headroom)
103
 
104
  #### Strategies
105
+ - **Local GPU Model Loading**: Models loaded on GPU for faster inference
106
+ - **Lazy Loading**: Models loaded on-demand to speed up startup
107
+ - **GPU Memory Management**: Automatic device placement with FP16 precision
108
+ - **Caching**: Aggressive caching with 30GB RAM available
109
+ - **Stream responses**: To reduce memory during generation
110
 
111
  ### Performance Optimization
112
 
113
+ #### For NVIDIA T4 GPU
114
+ 1. **Local Model Loading**: Models run locally on GPU (faster than API)
115
+ - Primary model: Qwen/Qwen2.5-7B-Instruct (~14GB vRAM)
116
+ - Embedding model: sentence-transformers/all-MiniLM-L6-v2 (~500MB)
117
+ 2. **GPU Acceleration**: All inference runs on GPU
118
+ 3. **Parallel Processing**: 4 workers (MAX_WORKERS=4) for concurrent requests
119
+ 4. **Fallback to API**: Automatically falls back to HF Inference API if local models fail
120
+ 5. **Request Queuing**: Built-in async request handling
121
+ 6. **Response Streaming**: Implemented for efficient memory usage
122
 
123
  #### Mobile Optimizations
124
  - Reduce max tokens to 800
Dockerfile.flask ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ gcc \
9
+ g++ \
10
+ cmake \
11
+ libopenblas-dev \
12
+ libomp-dev \
13
+ curl \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Copy requirements file
17
+ COPY requirements.txt .
18
+
19
+ # Install Python dependencies
20
+ RUN pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Copy application code
23
+ COPY . .
24
+
25
+ # Expose port 7860 (HF Spaces standard)
26
+ EXPOSE 7860
27
+
28
+ # Set environment variables
29
+ ENV PYTHONUNBUFFERED=1
30
+ ENV PORT=7860
31
+
32
+ # Health check
33
+ HEALTHCHECK --interval=30s --timeout=30s --start-period=120s --retries=3 \
34
+ CMD curl -f http://localhost:7860/api/health || exit 1
35
+
36
+ # Run Flask application
37
+ # Note: For Flask-only deployment, use this Dockerfile with README_FLASK_API.md
38
+ CMD ["python", "flask_api_standalone.py"]
39
+
FLASK_API_DEPLOYMENT_FILES.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flask API Only - Required Files List
2
+
3
+ This document lists all files needed for a **Flask API-only deployment** (no Gradio UI).
4
+
5
+ ## 📋 Essential Files (Required)
6
+
7
+ ### Core Application Files
8
+ ```
9
+ Research_AI_Assistant/
10
+ ├── flask_api_standalone.py # Main Flask application (REQUIRED)
11
+ ├── Dockerfile.flask # Dockerfile for Flask deployment (rename to Dockerfile)
12
+ ├── README_FLASK_API.md # README with HF Spaces frontmatter (rename to README.md)
13
+ └── requirements.txt # Python dependencies (REQUIRED)
14
+ ```
15
+
16
+ ### Source Code Directory (`src/`)
17
+ ```
18
+ Research_AI_Assistant/src/
19
+ ├── __init__.py # Package initialization
20
+ ├── config.py # Configuration settings
21
+ ├── llm_router.py # LLM routing (local GPU models)
22
+ ├── local_model_loader.py # GPU model loader (NEW - for local inference)
23
+ ├── orchestrator_engine.py # Main orchestrator
24
+ ├── context_manager.py # Context management
25
+ ├── models_config.py # Model configurations
26
+ ├── agents/
27
+ │ ├── __init__.py
28
+ │ ├── intent_agent.py # Intent recognition agent
29
+ │ ├── synthesis_agent.py # Response synthesis agent
30
+ │ ├── safety_agent.py # Safety checking agent
31
+ │ └── skills_identification_agent.py # Skills identification agent
32
+ └── database.py # Database management (if used)
33
+ ```
34
+
35
+ ### Configuration Files (Optional but Recommended)
36
+ ```
37
+ Research_AI_Assistant/
38
+ ├── .env # Environment variables (optional, use HF Secrets instead)
39
+ └── .gitignore # Git ignore rules
40
+ ```
41
+
42
+ ## 📦 File Descriptions
43
+
44
+ ### 1. `flask_api_standalone.py` ⭐ REQUIRED
45
+ - **Purpose**: Main Flask application entry point
46
+ - **Contains**: API endpoints, orchestrator initialization, request handling
47
+ - **Key Features**:
48
+ - Local GPU model loading
49
+ - Async orchestrator support
50
+ - Health checks
51
+ - Error handling
52
+
53
+ ### 2. `Dockerfile.flask` → `Dockerfile` ⭐ REQUIRED
54
+ - **Purpose**: Container configuration
55
+ - **Action**: Rename to `Dockerfile` when deploying
56
+ - **Includes**: Python 3.10, system dependencies, health checks
57
+
58
+ ### 3. `README_FLASK_API.md` → `README.md` ⭐ REQUIRED
59
+ - **Purpose**: HF Spaces configuration and API documentation
60
+ - **Action**: Rename to `README.md` when deploying
61
+ - **Contains**: Frontmatter with `sdk: docker`, API endpoints, usage examples
62
+
63
+ ### 4. `requirements.txt` ⭐ REQUIRED
64
+ - **Purpose**: Python package dependencies
65
+ - **Includes**: Flask, transformers, torch (GPU), sentence-transformers, etc.
66
+
67
+ ### 5. `src/local_model_loader.py` ⭐ REQUIRED (NEW)
68
+ - **Purpose**: Loads models locally on GPU
69
+ - **Features**: GPU detection, model caching, FP16 optimization
70
+
71
+ ### 6. `src/llm_router.py` ⭐ REQUIRED (UPDATED)
72
+ - **Purpose**: Routes inference requests
73
+ - **Features**: Tries local models first, falls back to HF API
74
+
75
+ ### 7. `src/orchestrator_engine.py` ⭐ REQUIRED
76
+ - **Purpose**: Main AI orchestration engine
77
+ - **Contains**: Agent coordination, request processing
78
+
79
+ ### 8. `src/context_manager.py` ⭐ REQUIRED
80
+ - **Purpose**: Manages conversation context
81
+ - **Features**: Session management, context retrieval
82
+
83
+ ### 9. `src/agents/*.py` ⭐ REQUIRED
84
+ - **Purpose**: Individual AI agents
85
+ - **Agents**: Intent, Synthesis, Safety, Skills Identification
86
+
87
+ ### 10. `src/config.py` ⭐ REQUIRED
88
+ - **Purpose**: Application configuration
89
+ - **Settings**: MAX_WORKERS=4, model paths, etc.
90
+
91
+ ## ❌ Files NOT Needed (Gradio/UI Related)
92
+
93
+ These files can be **excluded** from Flask API deployment:
94
+
95
+ ```
96
+ Research_AI_Assistant/
97
+ ├── app.py # Gradio UI (NOT NEEDED)
98
+ ├── main.py # Gradio + Flask launcher (NOT NEEDED)
99
+ ├── flask_api.py # Flask API (use standalone instead)
100
+ ├── Dockerfile # Main Dockerfile (use Dockerfile.flask)
101
+ ├── Dockerfile.hf # Alternative Dockerfile (NOT NEEDED)
102
+ ├── README.md # Main README (use README_FLASK_API.md)
103
+ └── All .md files except this one # Documentation (optional)
104
+ ```
105
+
106
+ ## 🚀 Quick Deployment Checklist
107
+
108
+ ### Step 1: Prepare Files
109
+ ```bash
110
+ # In your Flask API Space directory:
111
+ cp Dockerfile.flask Dockerfile
112
+ cp README_FLASK_API.md README.md
113
+ ```
114
+
115
+ ### Step 2: Verify Structure
116
+ ```
117
+ Your Space/
118
+ ├── Dockerfile # ✅ Renamed from Dockerfile.flask
119
+ ├── README.md # ✅ Renamed from README_FLASK_API.md
120
+ ├── flask_api_standalone.py # ✅ Main Flask app
121
+ ├── requirements.txt # ✅ Dependencies
122
+ └── src/ # ✅ All source files
123
+ ├── __init__.py
124
+ ├── config.py
125
+ ├── llm_router.py
126
+ ���── local_model_loader.py
127
+ ├── orchestrator_engine.py
128
+ ├── context_manager.py
129
+ ├── models_config.py
130
+ └── agents/
131
+ ├── __init__.py
132
+ ├── intent_agent.py
133
+ ├── synthesis_agent.py
134
+ ├── safety_agent.py
135
+ └── skills_identification_agent.py
136
+ ```
137
+
138
+ ### Step 3: Set Environment Variables
139
+ In HF Spaces Settings → Secrets:
140
+ - `HF_TOKEN` - Your Hugging Face token
141
+
142
+ ### Step 4: Deploy
143
+ - Select **NVIDIA T4 Medium** GPU
144
+ - Set **SDK: docker**
145
+ - Deploy
146
+
147
+ ## 📊 File Size Considerations
148
+
149
+ ### Minimal Deployment (Essential Only)
150
+ - Core files: ~50 KB
151
+ - Source code: ~500 KB
152
+ - **Total**: ~550 KB code
153
+
154
+ ### With Models (First Load)
155
+ - Code: ~550 KB
156
+ - Models (downloaded on first run): ~14-16 GB
157
+ - **Total**: ~14-16 GB (first build)
158
+
159
+ ### Subsequent Builds
160
+ - Models cached by HF Spaces
161
+ - Code only: ~550 KB
162
+
163
+ ## 🔍 Verification
164
+
165
+ After deployment, verify these files exist:
166
+
167
+ ```bash
168
+ # Check main files
169
+ ls -la Dockerfile README.md flask_api_standalone.py requirements.txt
170
+
171
+ # Check source directory
172
+ ls -la src/
173
+ ls -la src/agents/
174
+
175
+ # Verify key components
176
+ grep -r "local_model_loader" src/llm_router.py
177
+ grep -r "MAX_WORKERS" src/config.py
178
+ ```
179
+
180
+ ## 📝 Summary
181
+
182
+ **Minimum Required Files:**
183
+ 1. `flask_api_standalone.py`
184
+ 2. `Dockerfile` (from Dockerfile.flask)
185
+ 3. `README.md` (from README_FLASK_API.md)
186
+ 4. `requirements.txt`
187
+ 5. All files in `src/` directory
188
+
189
+ **Total: ~15-20 files** (excluding documentation)
190
+
191
+ ---
192
+
193
+ **Note**: This is a minimal deployment. All Gradio UI files, documentation, and test files are optional and can be excluded to reduce repository size.
194
+
README.md CHANGED
@@ -39,7 +39,7 @@ public: true
39
  ![HF Spaces](https://img.shields.io/badge/🤗-Hugging%20Face%20Spaces-blue)
40
  ![Python](https://img.shields.io/badge/Python-3.9%2B-green)
41
  ![Gradio](https://img.shields.io/badge/Interface-Gradio-FF6B6B)
42
- ![ZeroGPU](https://img.shields.io/badge/GPU-ZeroGPU-lightgrey)
43
 
44
  **Academic-grade AI assistant with transparent reasoning and mobile-optimized interface**
45
 
@@ -50,7 +50,7 @@ public: true
50
 
51
  ## 🎯 Overview
52
 
53
- This MVP demonstrates an intelligent research assistant framework featuring **transparent reasoning chains**, **specialized agent architecture**, and **mobile-first design**. Built for Hugging Face Spaces with ZeroGPU optimization.
54
 
55
  ### Key Differentiators
56
  - **🔍 Transparent Reasoning**: Watch the AI think step-by-step with Chain of Thought
@@ -286,7 +286,7 @@ pytest tests/test_mobile_ux.py -v
286
  |-------|----------|
287
  | **HF_TOKEN not found** | Add token in Space Settings → Secrets |
288
  | **Build timeout** | Reduce model sizes in requirements |
289
- | **Memory errors** | Enable ZeroGPU and optimize cache |
290
  | **Import errors** | Check Python version (3.9+) |
291
 
292
  ### Performance Optimization
 
39
  ![HF Spaces](https://img.shields.io/badge/🤗-Hugging%20Face%20Spaces-blue)
40
  ![Python](https://img.shields.io/badge/Python-3.9%2B-green)
41
  ![Gradio](https://img.shields.io/badge/Interface-Gradio-FF6B6B)
42
+ ![NVIDIA T4](https://img.shields.io/badge/GPU-NVIDIA%20T4-blue)
43
 
44
  **Academic-grade AI assistant with transparent reasoning and mobile-optimized interface**
45
 
 
50
 
51
  ## 🎯 Overview
52
 
53
+ This MVP demonstrates an intelligent research assistant framework featuring **transparent reasoning chains**, **specialized agent architecture**, and **mobile-first design**. Built for Hugging Face Spaces with NVIDIA T4 GPU acceleration for local model inference.
54
 
55
  ### Key Differentiators
56
  - **🔍 Transparent Reasoning**: Watch the AI think step-by-step with Chain of Thought
 
286
  |-------|----------|
287
  | **HF_TOKEN not found** | Add token in Space Settings → Secrets |
288
  | **Build timeout** | Reduce model sizes in requirements |
289
+ | **Memory errors** | Check GPU memory usage, optimize model loading |
290
  | **Import errors** | Check Python version (3.9+) |
291
 
292
  ### Performance Optimization
README_FLASK_API.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AI Assistant Flask API
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ ---
10
+
11
+ # AI Assistant Flask API
12
+
13
+ Pure Flask REST API for AI research assistant.
14
+
15
+ ## Quick Start
16
+
17
+ This Space provides a REST API (no UI). Test the endpoints:
18
+
19
+ ```bash
20
+ # Health check
21
+ curl https://YOUR-SPACE.hf.space/api/health
22
+
23
+ # Chat
24
+ curl -X POST https://YOUR-SPACE.hf.space/api/chat \
25
+ -H "Content-Type: application/json" \
26
+ -d '{
27
+ "message": "Hello, how are you?",
28
+ "session_id": "test-123",
29
+ "user_id": "user@example.com"
30
+ }'
31
+ ```
32
+
33
+ ## API Endpoints
34
+
35
+ ### GET /api/health
36
+ Health check endpoint.
37
+
38
+ **Response:**
39
+ ```json
40
+ {
41
+ "status": "healthy",
42
+ "orchestrator_ready": true
43
+ }
44
+ ```
45
+
46
+ ### POST /api/chat
47
+ Process a chat message.
48
+
49
+ **Request:**
50
+ ```json
51
+ {
52
+ "message": "Your question here",
53
+ "history": [],
54
+ "session_id": "optional-session-id",
55
+ "user_id": "optional-user-id"
56
+ }
57
+ ```
58
+
59
+ **Response:**
60
+ ```json
61
+ {
62
+ "success": true,
63
+ "message": "AI response here",
64
+ "history": [["Your question", "AI response"]],
65
+ "reasoning": {},
66
+ "performance": {}
67
+ }
68
+ ```
69
+
70
+ ## Environment Variables
71
+
72
+ Set in Space Settings → Repository secrets:
73
+
74
+ - `HF_TOKEN` - Your Hugging Face API token (required)
75
+
76
+ ## Technology
77
+
78
+ - Flask 3.0
79
+ - Python 3.10
80
+ - Custom AI orchestrator with multiple agents
81
+ - Docker containerized
82
+ - **NVIDIA T4 GPU** for local model inference
83
+
84
+ ## Features
85
+
86
+ - 🤖 AI-powered responses with local GPU models
87
+ - 🔄 Context-aware conversations
88
+ - 🛡️ Safety checking
89
+ - 📊 Performance metrics
90
+ - 🎯 Intent recognition
91
+ - 🔧 Skills identification
92
+
config.py CHANGED
@@ -13,7 +13,7 @@ class Settings(BaseSettings):
13
  classification_model: str = "cardiffnlp/twitter-roberta-base-emotion"
14
 
15
  # Performance settings
16
- max_workers: int = int(os.getenv("MAX_WORKERS", "2"))
17
  cache_ttl: int = int(os.getenv("CACHE_TTL", "3600"))
18
 
19
  # Database settings
 
13
  classification_model: str = "cardiffnlp/twitter-roberta-base-emotion"
14
 
15
  # Performance settings
16
+ max_workers: int = int(os.getenv("MAX_WORKERS", "4"))
17
  cache_ttl: int = int(os.getenv("CACHE_TTL", "3600"))
18
 
19
  # Database settings
flask_api_standalone.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Pure Flask API for Hugging Face Spaces
4
+ No Gradio - Just Flask REST API
5
+ Uses local GPU models for inference
6
+ """
7
+
8
+ from flask import Flask, request, jsonify
9
+ from flask_cors import CORS
10
+ import logging
11
+ import sys
12
+ import os
13
+ import asyncio
14
+ from pathlib import Path
15
+
16
+ # Setup logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
20
+ )
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Add project root to path
24
+ project_root = Path(__file__).parent
25
+ sys.path.insert(0, str(project_root))
26
+
27
+ # Create Flask app
28
+ app = Flask(__name__)
29
+ CORS(app) # Enable CORS for all origins
30
+
31
+ # Global orchestrator
32
+ orchestrator = None
33
+ orchestrator_available = False
34
+
35
+ def initialize_orchestrator():
36
+ """Initialize the AI orchestrator with local GPU models"""
37
+ global orchestrator, orchestrator_available
38
+
39
+ try:
40
+ logger.info("=" * 60)
41
+ logger.info("INITIALIZING AI ORCHESTRATOR (Local GPU Models)")
42
+ logger.info("=" * 60)
43
+
44
+ from src.agents.intent_agent import create_intent_agent
45
+ from src.agents.synthesis_agent import create_synthesis_agent
46
+ from src.agents.safety_agent import create_safety_agent
47
+ from src.agents.skills_identification_agent import create_skills_identification_agent
48
+ from src.llm_router import LLMRouter
49
+ from src.orchestrator_engine import MVPOrchestrator
50
+ from src.context_manager import EfficientContextManager
51
+
52
+ logger.info("✓ Imports successful")
53
+
54
+ hf_token = os.getenv('HF_TOKEN', '')
55
+ if not hf_token:
56
+ logger.warning("HF_TOKEN not set - API fallback will be used if local models fail")
57
+
58
+ # Initialize LLM Router with local model loading enabled
59
+ logger.info("Initializing LLM Router with local GPU model loading...")
60
+ llm_router = LLMRouter(hf_token, use_local_models=True)
61
+
62
+ logger.info("Initializing Agents...")
63
+ agents = {
64
+ 'intent_recognition': create_intent_agent(llm_router),
65
+ 'response_synthesis': create_synthesis_agent(llm_router),
66
+ 'safety_check': create_safety_agent(llm_router),
67
+ 'skills_identification': create_skills_identification_agent(llm_router)
68
+ }
69
+
70
+ logger.info("Initializing Context Manager...")
71
+ context_manager = EfficientContextManager(llm_router=llm_router)
72
+
73
+ logger.info("Initializing Orchestrator...")
74
+ orchestrator = MVPOrchestrator(llm_router, context_manager, agents)
75
+
76
+ orchestrator_available = True
77
+ logger.info("=" * 60)
78
+ logger.info("✓ AI ORCHESTRATOR READY")
79
+ logger.info(" - Local GPU models enabled")
80
+ logger.info(" - MAX_WORKERS: 4")
81
+ logger.info("=" * 60)
82
+
83
+ return True
84
+
85
+ except Exception as e:
86
+ logger.error(f"Failed to initialize: {e}", exc_info=True)
87
+ orchestrator_available = False
88
+ return False
89
+
90
+ # Root endpoint
91
+ @app.route('/', methods=['GET'])
92
+ def root():
93
+ """API information"""
94
+ return jsonify({
95
+ 'name': 'AI Assistant Flask API',
96
+ 'version': '1.0',
97
+ 'status': 'running',
98
+ 'orchestrator_ready': orchestrator_available,
99
+ 'features': {
100
+ 'local_gpu_models': True,
101
+ 'max_workers': 4,
102
+ 'hardware': 'NVIDIA T4 Medium'
103
+ },
104
+ 'endpoints': {
105
+ 'health': 'GET /api/health',
106
+ 'chat': 'POST /api/chat',
107
+ 'initialize': 'POST /api/initialize'
108
+ }
109
+ })
110
+
111
+ # Health check
112
+ @app.route('/api/health', methods=['GET'])
113
+ def health_check():
114
+ """Health check endpoint"""
115
+ return jsonify({
116
+ 'status': 'healthy' if orchestrator_available else 'initializing',
117
+ 'orchestrator_ready': orchestrator_available
118
+ })
119
+
120
+ # Chat endpoint
121
+ @app.route('/api/chat', methods=['POST'])
122
+ def chat():
123
+ """
124
+ Process chat message
125
+
126
+ POST /api/chat
127
+ {
128
+ "message": "user message",
129
+ "history": [[user, assistant], ...],
130
+ "session_id": "session-123",
131
+ "user_id": "user-456"
132
+ }
133
+
134
+ Returns:
135
+ {
136
+ "success": true,
137
+ "message": "AI response",
138
+ "history": [...],
139
+ "reasoning": {...},
140
+ "performance": {...}
141
+ }
142
+ """
143
+ try:
144
+ data = request.get_json()
145
+
146
+ if not data or 'message' not in data:
147
+ return jsonify({
148
+ 'success': False,
149
+ 'error': 'Message is required'
150
+ }), 400
151
+
152
+ message = data['message']
153
+ history = data.get('history', [])
154
+ session_id = data.get('session_id')
155
+ user_id = data.get('user_id', 'anonymous')
156
+
157
+ logger.info(f"Chat request - User: {user_id}, Session: {session_id}")
158
+ logger.info(f"Message: {message[:100]}...")
159
+
160
+ if not orchestrator_available or orchestrator is None:
161
+ return jsonify({
162
+ 'success': False,
163
+ 'error': 'Orchestrator not ready',
164
+ 'message': 'AI system is initializing. Please try again in a moment.'
165
+ }), 503
166
+
167
+ # Process with orchestrator (async method)
168
+ # Set user_id for session tracking
169
+ if session_id:
170
+ orchestrator.set_user_id(session_id, user_id)
171
+
172
+ # Run async process_request in event loop
173
+ loop = asyncio.new_event_loop()
174
+ asyncio.set_event_loop(loop)
175
+ try:
176
+ result = loop.run_until_complete(
177
+ orchestrator.process_request(
178
+ session_id=session_id or f"session-{user_id}",
179
+ user_input=message
180
+ )
181
+ )
182
+ finally:
183
+ loop.close()
184
+
185
+ # Extract response
186
+ if isinstance(result, dict):
187
+ response_text = result.get('response', '')
188
+ reasoning = result.get('reasoning', {})
189
+ performance = result.get('performance', {})
190
+ else:
191
+ response_text = str(result)
192
+ reasoning = {}
193
+ performance = {}
194
+
195
+ updated_history = history + [[message, response_text]]
196
+
197
+ logger.info(f"✓ Response generated (length: {len(response_text)})")
198
+
199
+ return jsonify({
200
+ 'success': True,
201
+ 'message': response_text,
202
+ 'history': updated_history,
203
+ 'reasoning': reasoning,
204
+ 'performance': performance
205
+ })
206
+
207
+ except Exception as e:
208
+ logger.error(f"Chat error: {e}", exc_info=True)
209
+ return jsonify({
210
+ 'success': False,
211
+ 'error': str(e),
212
+ 'message': 'Error processing your request. Please try again.'
213
+ }), 500
214
+
215
+ # Manual initialization endpoint
216
+ @app.route('/api/initialize', methods=['POST'])
217
+ def initialize():
218
+ """Manually trigger initialization"""
219
+ success = initialize_orchestrator()
220
+
221
+ if success:
222
+ return jsonify({
223
+ 'success': True,
224
+ 'message': 'Orchestrator initialized successfully'
225
+ })
226
+ else:
227
+ return jsonify({
228
+ 'success': False,
229
+ 'message': 'Initialization failed. Check logs for details.'
230
+ }), 500
231
+
232
+ # Initialize on startup
233
+ if __name__ == '__main__':
234
+ logger.info("=" * 60)
235
+ logger.info("STARTING PURE FLASK API")
236
+ logger.info("=" * 60)
237
+
238
+ # Initialize orchestrator
239
+ initialize_orchestrator()
240
+
241
+ port = int(os.getenv('PORT', 7860))
242
+
243
+ logger.info(f"Starting Flask on port {port}")
244
+ logger.info("Endpoints available:")
245
+ logger.info(" GET /")
246
+ logger.info(" GET /api/health")
247
+ logger.info(" POST /api/chat")
248
+ logger.info(" POST /api/initialize")
249
+ logger.info("=" * 60)
250
+
251
+ app.run(
252
+ host='0.0.0.0',
253
+ port=port,
254
+ debug=False,
255
+ threaded=True # Enable threading for concurrent requests
256
+ )
257
+
requirements.txt CHANGED
@@ -1,9 +1,13 @@
1
- # requirements.txt for Hugging Face Spaces with ZeroGPU
2
  # Core Framework Dependencies
3
 
4
- # Note: gradio, fastapi, uvicorn, torch, datasets, huggingface-hub,
5
  # pydantic==2.10.6, and protobuf<4 are installed by HF Spaces SDK
6
 
 
 
 
 
7
  # Web Framework & Interface
8
  aiohttp>=3.9.0
9
  httpx>=0.25.0
 
1
+ # requirements.txt for Hugging Face Spaces with NVIDIA T4 GPU
2
  # Core Framework Dependencies
3
 
4
+ # Note: gradio, fastapi, uvicorn, datasets, huggingface-hub,
5
  # pydantic==2.10.6, and protobuf<4 are installed by HF Spaces SDK
6
 
7
+ # PyTorch with CUDA support (for GPU inference)
8
+ # Note: HF Spaces provides torch, but we ensure GPU support
9
+ torch>=2.0.0
10
+
11
  # Web Framework & Interface
12
  aiohttp>=3.9.0
13
  httpx>=0.25.0
src/config.py CHANGED
@@ -13,7 +13,7 @@ class Settings(BaseSettings):
13
  classification_model: str = "cardiffnlp/twitter-roberta-base-emotion"
14
 
15
  # Performance settings
16
- max_workers: int = int(os.getenv("MAX_WORKERS", "2"))
17
  cache_ttl: int = int(os.getenv("CACHE_TTL", "3600"))
18
 
19
  # Database settings
 
13
  classification_model: str = "cardiffnlp/twitter-roberta-base-emotion"
14
 
15
  # Performance settings
16
+ max_workers: int = int(os.getenv("MAX_WORKERS", "4"))
17
  cache_ttl: int = int(os.getenv("CACHE_TTL", "3600"))
18
 
19
  # Database settings
src/llm_router.py CHANGED
@@ -1,40 +1,154 @@
1
- # llm_router.py - FIXED VERSION
2
  import logging
3
  import asyncio
4
- from typing import Dict
5
  from .models_config import LLM_CONFIG
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
  class LLMRouter:
10
- def __init__(self, hf_token):
11
  self.hf_token = hf_token
12
  self.health_status = {}
 
 
 
13
  logger.info("LLMRouter initialized")
14
  if hf_token:
15
  logger.info("HF token available")
16
  else:
17
  logger.warning("No HF token provided")
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  async def route_inference(self, task_type: str, prompt: str, **kwargs):
20
  """
21
  Smart routing based on task specialization
 
22
  """
23
  logger.info(f"Routing inference for task: {task_type}")
24
  model_config = self._select_model(task_type)
25
  logger.info(f"Selected model: {model_config['model_id']}")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Health check and fallback logic
28
  if not await self._is_model_healthy(model_config["model_id"]):
29
  logger.warning(f"Model unhealthy, using fallback")
30
  model_config = self._get_fallback_model(task_type)
31
  logger.info(f"Fallback model: {model_config['model_id']}")
32
 
33
- # FIXED: Ensure task_type is passed to the _call_hf_endpoint method
34
  result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
35
  logger.info(f"Inference complete for {task_type}")
36
  return result
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def _select_model(self, task_type: str) -> dict:
39
  model_map = {
40
  "intent_classification": LLM_CONFIG["models"]["classification_specialist"],
 
1
+ # llm_router.py - UPDATED FOR LOCAL GPU MODEL LOADING
2
  import logging
3
  import asyncio
4
+ from typing import Dict, Optional
5
  from .models_config import LLM_CONFIG
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
  class LLMRouter:
10
+ def __init__(self, hf_token, use_local_models: bool = True):
11
  self.hf_token = hf_token
12
  self.health_status = {}
13
+ self.use_local_models = use_local_models
14
+ self.local_loader = None
15
+
16
  logger.info("LLMRouter initialized")
17
  if hf_token:
18
  logger.info("HF token available")
19
  else:
20
  logger.warning("No HF token provided")
21
 
22
+ # Initialize local model loader if enabled
23
+ if self.use_local_models:
24
+ try:
25
+ from .local_model_loader import LocalModelLoader
26
+ self.local_loader = LocalModelLoader()
27
+ logger.info("✓ Local model loader initialized (GPU-based inference)")
28
+
29
+ # Note: Pre-loading will happen on first request (lazy loading)
30
+ # Models will be loaded on-demand to avoid blocking startup
31
+ logger.info("Models will be loaded on-demand for faster startup")
32
+ except Exception as e:
33
+ logger.warning(f"Could not initialize local model loader: {e}. Falling back to API.")
34
+ logger.warning("This is normal if transformers/torch not available")
35
+ self.use_local_models = False
36
+ self.local_loader = None
37
+
38
  async def route_inference(self, task_type: str, prompt: str, **kwargs):
39
  """
40
  Smart routing based on task specialization
41
+ Tries local models first, falls back to HF Inference API if needed
42
  """
43
  logger.info(f"Routing inference for task: {task_type}")
44
  model_config = self._select_model(task_type)
45
  logger.info(f"Selected model: {model_config['model_id']}")
46
 
47
+ # Try local model first if available
48
+ if self.use_local_models and self.local_loader:
49
+ try:
50
+ # Handle embedding generation separately
51
+ if task_type == "embedding_generation":
52
+ result = await self._call_local_embedding(model_config, prompt, **kwargs)
53
+ else:
54
+ result = await self._call_local_model(model_config, prompt, task_type, **kwargs)
55
+
56
+ if result is not None:
57
+ logger.info(f"Inference complete for {task_type} (local model)")
58
+ return result
59
+ else:
60
+ logger.warning("Local model returned None, falling back to API")
61
+ except Exception as e:
62
+ logger.warning(f"Local model inference failed: {e}. Falling back to API.")
63
+ logger.debug("Exception details:", exc_info=True)
64
+
65
+ # Fallback to HF Inference API
66
+ logger.info("Using HF Inference API")
67
  # Health check and fallback logic
68
  if not await self._is_model_healthy(model_config["model_id"]):
69
  logger.warning(f"Model unhealthy, using fallback")
70
  model_config = self._get_fallback_model(task_type)
71
  logger.info(f"Fallback model: {model_config['model_id']}")
72
 
 
73
  result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
74
  logger.info(f"Inference complete for {task_type}")
75
  return result
76
 
77
+ async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]:
78
+ """Call local model for inference."""
79
+ if not self.local_loader:
80
+ return None
81
+
82
+ model_id = model_config["model_id"]
83
+ max_tokens = kwargs.get('max_tokens', 512)
84
+ temperature = kwargs.get('temperature', 0.7)
85
+
86
+ try:
87
+ # Ensure model is loaded
88
+ if model_id not in self.local_loader.loaded_models:
89
+ logger.info(f"Loading model {model_id} on demand...")
90
+ self.local_loader.load_chat_model(model_id, load_in_8bit=False)
91
+
92
+ # Format as chat messages if needed
93
+ messages = [{"role": "user", "content": prompt}]
94
+
95
+ # Generate using local model
96
+ result = await asyncio.to_thread(
97
+ self.local_loader.generate_chat_completion,
98
+ model_id=model_id,
99
+ messages=messages,
100
+ max_tokens=max_tokens,
101
+ temperature=temperature
102
+ )
103
+
104
+ logger.info(f"Local model {model_id} generated response (length: {len(result)})")
105
+ logger.info("=" * 80)
106
+ logger.info("LOCAL MODEL RESPONSE:")
107
+ logger.info("=" * 80)
108
+ logger.info(f"Model: {model_id}")
109
+ logger.info(f"Task Type: {task_type}")
110
+ logger.info(f"Response Length: {len(result)} characters")
111
+ logger.info("-" * 40)
112
+ logger.info("FULL RESPONSE CONTENT:")
113
+ logger.info("-" * 40)
114
+ logger.info(result)
115
+ logger.info("-" * 40)
116
+ logger.info("END OF RESPONSE")
117
+ logger.info("=" * 80)
118
+
119
+ return result
120
+
121
+ except Exception as e:
122
+ logger.error(f"Error calling local model: {e}", exc_info=True)
123
+ return None
124
+
125
+ async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]:
126
+ """Call local embedding model."""
127
+ if not self.local_loader:
128
+ return None
129
+
130
+ model_id = model_config["model_id"]
131
+
132
+ try:
133
+ # Ensure model is loaded
134
+ if model_id not in self.local_loader.loaded_embedding_models:
135
+ logger.info(f"Loading embedding model {model_id} on demand...")
136
+ self.local_loader.load_embedding_model(model_id)
137
+
138
+ # Generate embedding
139
+ embedding = await asyncio.to_thread(
140
+ self.local_loader.get_embedding,
141
+ model_id=model_id,
142
+ text=text
143
+ )
144
+
145
+ logger.info(f"Local embedding model {model_id} generated vector (dim: {len(embedding)})")
146
+ return embedding
147
+
148
+ except Exception as e:
149
+ logger.error(f"Error calling local embedding model: {e}", exc_info=True)
150
+ return None
151
+
152
  def _select_model(self, task_type: str) -> dict:
153
  model_map = {
154
  "intent_classification": LLM_CONFIG["models"]["classification_specialist"],
src/local_model_loader.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # local_model_loader.py
2
+ # Local GPU-based model loading for NVIDIA T4 Medium (24GB vRAM)
3
+ import logging
4
+ import torch
5
+ from typing import Optional, Dict, Any
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class LocalModelLoader:
12
+ """
13
+ Loads and manages models locally on GPU for faster inference.
14
+ Optimized for NVIDIA T4 Medium with 24GB vRAM.
15
+ """
16
+
17
+ def __init__(self, device: Optional[str] = None):
18
+ """Initialize the model loader with GPU device detection."""
19
+ # Detect device
20
+ if device is None:
21
+ if torch.cuda.is_available():
22
+ self.device = "cuda"
23
+ self.device_name = torch.cuda.get_device_name(0)
24
+ logger.info(f"GPU detected: {self.device_name}")
25
+ logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
26
+ else:
27
+ self.device = "cpu"
28
+ self.device_name = "CPU"
29
+ logger.warning("No GPU detected, using CPU")
30
+ else:
31
+ self.device = device
32
+ self.device_name = device
33
+
34
+ # Model cache
35
+ self.loaded_models: Dict[str, Any] = {}
36
+ self.loaded_tokenizers: Dict[str, Any] = {}
37
+ self.loaded_embedding_models: Dict[str, Any] = {}
38
+
39
+ def load_chat_model(self, model_id: str, load_in_8bit: bool = False, load_in_4bit: bool = False) -> tuple:
40
+ """
41
+ Load a chat model and tokenizer on GPU.
42
+
43
+ Args:
44
+ model_id: HuggingFace model identifier
45
+ load_in_8bit: Use 8-bit quantization (saves memory)
46
+ load_in_4bit: Use 4-bit quantization (saves more memory)
47
+
48
+ Returns:
49
+ Tuple of (model, tokenizer)
50
+ """
51
+ if model_id in self.loaded_models:
52
+ logger.info(f"Model {model_id} already loaded, reusing")
53
+ return self.loaded_models[model_id], self.loaded_tokenizers[model_id]
54
+
55
+ try:
56
+ logger.info(f"Loading model {model_id} on {self.device}...")
57
+
58
+ # Load tokenizer
59
+ tokenizer = AutoTokenizer.from_pretrained(
60
+ model_id,
61
+ trust_remote_code=True
62
+ )
63
+
64
+ # Determine quantization config
65
+ if load_in_4bit and self.device == "cuda":
66
+ try:
67
+ from transformers import BitsAndBytesConfig
68
+ quantization_config = BitsAndBytesConfig(
69
+ load_in_4bit=True,
70
+ bnb_4bit_compute_dtype=torch.float16,
71
+ bnb_4bit_use_double_quant=True,
72
+ bnb_4bit_quant_type="nf4"
73
+ )
74
+ logger.info("Using 4-bit quantization")
75
+ except ImportError:
76
+ logger.warning("bitsandbytes not available, loading without quantization")
77
+ quantization_config = None
78
+ elif load_in_8bit and self.device == "cuda":
79
+ try:
80
+ quantization_config = {"load_in_8bit": True}
81
+ logger.info("Using 8-bit quantization")
82
+ except:
83
+ quantization_config = None
84
+ else:
85
+ quantization_config = None
86
+
87
+ # Load model with GPU optimization
88
+ if self.device == "cuda":
89
+ model = AutoModelForCausalLM.from_pretrained(
90
+ model_id,
91
+ device_map="auto", # Automatically uses GPU
92
+ torch_dtype=torch.float16, # Use FP16 for memory efficiency
93
+ trust_remote_code=True,
94
+ **(quantization_config if isinstance(quantization_config, dict) else {}),
95
+ **({"quantization_config": quantization_config} if quantization_config and not isinstance(quantization_config, dict) else {})
96
+ )
97
+ else:
98
+ model = AutoModelForCausalLM.from_pretrained(
99
+ model_id,
100
+ torch_dtype=torch.float32,
101
+ trust_remote_code=True
102
+ )
103
+ model = model.to(self.device)
104
+
105
+ # Ensure padding token is set
106
+ if tokenizer.pad_token is None:
107
+ tokenizer.pad_token = tokenizer.eos_token
108
+
109
+ # Cache models
110
+ self.loaded_models[model_id] = model
111
+ self.loaded_tokenizers[model_id] = tokenizer
112
+
113
+ # Log memory usage
114
+ if self.device == "cuda":
115
+ allocated = torch.cuda.memory_allocated(0) / 1024**3
116
+ reserved = torch.cuda.memory_reserved(0) / 1024**3
117
+ logger.info(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
118
+
119
+ logger.info(f"✓ Model {model_id} loaded successfully on {self.device}")
120
+ return model, tokenizer
121
+
122
+ except Exception as e:
123
+ logger.error(f"Error loading model {model_id}: {e}", exc_info=True)
124
+ raise
125
+
126
+ def load_embedding_model(self, model_id: str) -> SentenceTransformer:
127
+ """
128
+ Load a sentence transformer model for embeddings.
129
+
130
+ Args:
131
+ model_id: HuggingFace model identifier
132
+
133
+ Returns:
134
+ SentenceTransformer model
135
+ """
136
+ if model_id in self.loaded_embedding_models:
137
+ logger.info(f"Embedding model {model_id} already loaded, reusing")
138
+ return self.loaded_embedding_models[model_id]
139
+
140
+ try:
141
+ logger.info(f"Loading embedding model {model_id}...")
142
+
143
+ # SentenceTransformer automatically handles GPU
144
+ model = SentenceTransformer(
145
+ model_id,
146
+ device=self.device
147
+ )
148
+
149
+ # Cache model
150
+ self.loaded_embedding_models[model_id] = model
151
+
152
+ logger.info(f"✓ Embedding model {model_id} loaded successfully on {self.device}")
153
+ return model
154
+
155
+ except Exception as e:
156
+ logger.error(f"Error loading embedding model {model_id}: {e}", exc_info=True)
157
+ raise
158
+
159
+ def generate_text(
160
+ self,
161
+ model_id: str,
162
+ prompt: str,
163
+ max_tokens: int = 512,
164
+ temperature: float = 0.7,
165
+ **kwargs
166
+ ) -> str:
167
+ """
168
+ Generate text using a loaded chat model.
169
+
170
+ Args:
171
+ model_id: Model identifier
172
+ prompt: Input prompt
173
+ max_tokens: Maximum tokens to generate
174
+ temperature: Sampling temperature
175
+
176
+ Returns:
177
+ Generated text
178
+ """
179
+ if model_id not in self.loaded_models:
180
+ raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.")
181
+
182
+ model = self.loaded_models[model_id]
183
+ tokenizer = self.loaded_tokenizers[model_id]
184
+
185
+ try:
186
+ # Tokenize input
187
+ inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
188
+
189
+ # Generate
190
+ with torch.no_grad():
191
+ outputs = model.generate(
192
+ **inputs,
193
+ max_new_tokens=max_tokens,
194
+ temperature=temperature,
195
+ do_sample=True,
196
+ pad_token_id=tokenizer.pad_token_id,
197
+ eos_token_id=tokenizer.eos_token_id,
198
+ **kwargs
199
+ )
200
+
201
+ # Decode
202
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
203
+
204
+ # Remove prompt from output if present
205
+ if generated_text.startswith(prompt):
206
+ generated_text = generated_text[len(prompt):].strip()
207
+
208
+ return generated_text
209
+
210
+ except Exception as e:
211
+ logger.error(f"Error generating text: {e}", exc_info=True)
212
+ raise
213
+
214
+ def generate_chat_completion(
215
+ self,
216
+ model_id: str,
217
+ messages: list,
218
+ max_tokens: int = 512,
219
+ temperature: float = 0.7,
220
+ **kwargs
221
+ ) -> str:
222
+ """
223
+ Generate chat completion using a loaded model.
224
+
225
+ Args:
226
+ model_id: Model identifier
227
+ messages: List of message dicts with 'role' and 'content'
228
+ max_tokens: Maximum tokens to generate
229
+ temperature: Sampling temperature
230
+
231
+ Returns:
232
+ Generated response
233
+ """
234
+ if model_id not in self.loaded_models:
235
+ raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.")
236
+
237
+ model = self.loaded_models[model_id]
238
+ tokenizer = self.loaded_tokenizers[model_id]
239
+
240
+ try:
241
+ # Format messages as prompt
242
+ if hasattr(tokenizer, 'apply_chat_template'):
243
+ # Use chat template if available
244
+ prompt = tokenizer.apply_chat_template(
245
+ messages,
246
+ tokenize=False,
247
+ add_generation_prompt=True
248
+ )
249
+ else:
250
+ # Fallback: simple formatting
251
+ prompt = "\n".join([
252
+ f"{msg['role']}: {msg['content']}"
253
+ for msg in messages
254
+ ]) + "\nassistant: "
255
+
256
+ # Generate
257
+ return self.generate_text(
258
+ model_id=model_id,
259
+ prompt=prompt,
260
+ max_tokens=max_tokens,
261
+ temperature=temperature,
262
+ **kwargs
263
+ )
264
+
265
+ except Exception as e:
266
+ logger.error(f"Error generating chat completion: {e}", exc_info=True)
267
+ raise
268
+
269
+ def get_embedding(self, model_id: str, text: str) -> list:
270
+ """
271
+ Get embedding vector for text.
272
+
273
+ Args:
274
+ model_id: Embedding model identifier
275
+ text: Input text
276
+
277
+ Returns:
278
+ Embedding vector
279
+ """
280
+ if model_id not in self.loaded_embedding_models:
281
+ raise ValueError(f"Embedding model {model_id} not loaded. Call load_embedding_model() first.")
282
+
283
+ model = self.loaded_embedding_models[model_id]
284
+
285
+ try:
286
+ embedding = model.encode(text, convert_to_numpy=True)
287
+ return embedding.tolist()
288
+ except Exception as e:
289
+ logger.error(f"Error getting embedding: {e}", exc_info=True)
290
+ raise
291
+
292
+ def clear_cache(self):
293
+ """Clear all loaded models from memory."""
294
+ logger.info("Clearing model cache...")
295
+
296
+ # Clear models
297
+ for model_id in list(self.loaded_models.keys()):
298
+ del self.loaded_models[model_id]
299
+ for model_id in list(self.loaded_tokenizers.keys()):
300
+ del self.loaded_tokenizers[model_id]
301
+ for model_id in list(self.loaded_embedding_models.keys()):
302
+ del self.loaded_embedding_models[model_id]
303
+
304
+ # Clear GPU cache
305
+ if self.device == "cuda":
306
+ torch.cuda.empty_cache()
307
+
308
+ logger.info("✓ Model cache cleared")
309
+
310
+ def get_memory_usage(self) -> Dict[str, float]:
311
+ """Get current GPU memory usage in GB."""
312
+ if self.device != "cuda":
313
+ return {"device": "cpu", "gpu_available": False}
314
+
315
+ return {
316
+ "device": self.device_name,
317
+ "gpu_available": True,
318
+ "allocated_gb": torch.cuda.memory_allocated(0) / 1024**3,
319
+ "reserved_gb": torch.cuda.memory_reserved(0) / 1024**3,
320
+ "total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3
321
+ }
322
+