feat: Implement FAISS-GPU and lazy-loaded local model fallback
Browse files- Update requirements.txt: Switch from faiss-cpu to faiss-gpu for GPU-accelerated vector search
- Update faiss_manager.py: Add GPU support with automatic CPU fallback
- GPU detection and resource management
- Automatic index migration between GPU/CPU
- Proper cleanup of GPU resources
- Status reporting for GPU availability
- Update src/llm_router.py: Reverse fallback order for lazy loading
- ZeroGPU API tried first (primary path)
- Local models only load if ZeroGPU fails (lazy loading)
- HF Inference API as final fallback
- Updated logging to indicate fallback path
Benefits:
- GPU-accelerated vector search (10-100x faster for large indices)
- Lazy loading: Models only load when ZeroGPU unavailable
- Lower memory usage: No models loaded if ZeroGPU works
- Automatic fallback: GPU → CPU if GPU unavailable
- Better resource utilization: GPU only used when needed
- faiss_manager.py +119 -13
- requirements.txt +2 -1
- src/llm_router.py +32 -33
|
@@ -1,39 +1,112 @@
|
|
| 1 |
# faiss_manager.py
|
|
|
|
| 2 |
import faiss
|
| 3 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
class FAISSLiteManager:
|
| 6 |
-
def __init__(self, db_path: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
self.db_path = db_path
|
| 8 |
self.dimension = 384 # all-MiniLM-L6-v2 dimension
|
| 9 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
|
|
|
|
|
|
| 11 |
def _initialize_index(self):
|
| 12 |
-
"""Initialize FAISS index with
|
| 13 |
try:
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
# Create new index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
index = faiss.IndexFlatIP(self.dimension)
|
| 18 |
-
|
| 19 |
return index
|
| 20 |
-
|
| 21 |
async def store_embedding(self, session_id: str, text: str, embedding: list):
|
| 22 |
"""Store embedding with session context"""
|
| 23 |
# Convert to numpy array
|
| 24 |
vector = np.array([embedding], dtype=np.float32)
|
| 25 |
|
| 26 |
-
#
|
| 27 |
-
self.index.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# Store metadata in SQLite
|
| 30 |
-
await self._store_metadata(session_id, text,
|
| 31 |
|
| 32 |
async def search_similar(self, query_embedding: list, k: int = 5) -> list:
|
| 33 |
"""
|
| 34 |
-
Search for similar embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
"""
|
| 36 |
vector = np.array([query_embedding], dtype=np.float32)
|
|
|
|
|
|
|
| 37 |
distances, indices = self.index.search(vector, k)
|
| 38 |
|
| 39 |
# Retrieve metadata for results
|
|
@@ -57,12 +130,45 @@ class FAISSLiteManager:
|
|
| 57 |
def save_index(self):
|
| 58 |
"""
|
| 59 |
Save the FAISS index to disk
|
|
|
|
| 60 |
"""
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def get_index_size(self) -> int:
|
| 64 |
"""
|
| 65 |
Get the number of vectors in the index
|
| 66 |
"""
|
| 67 |
return self.index.ntotal
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# faiss_manager.py
|
| 2 |
+
# FAISS Manager with GPU support and automatic CPU fallback
|
| 3 |
import faiss
|
| 4 |
import numpy as np
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
class FAISSLiteManager:
|
| 11 |
+
def __init__(self, db_path: str, use_gpu: bool = True):
|
| 12 |
+
"""
|
| 13 |
+
Initialize FAISS manager with GPU support
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
db_path: Path to database file
|
| 17 |
+
use_gpu: Whether to use GPU if available (default: True)
|
| 18 |
+
"""
|
| 19 |
self.db_path = db_path
|
| 20 |
self.dimension = 384 # all-MiniLM-L6-v2 dimension
|
| 21 |
+
self.use_gpu = use_gpu
|
| 22 |
+
self.gpu_available = False
|
| 23 |
+
self.gpu_resource = None
|
| 24 |
+
|
| 25 |
+
# Detect GPU availability
|
| 26 |
+
if use_gpu:
|
| 27 |
+
try:
|
| 28 |
+
# Check if FAISS GPU is available
|
| 29 |
+
if hasattr(faiss, 'StandardGpuResources'):
|
| 30 |
+
self.gpu_resource = faiss.StandardGpuResources()
|
| 31 |
+
self.gpu_available = True
|
| 32 |
+
logger.info("✓ FAISS GPU resources initialized")
|
| 33 |
+
else:
|
| 34 |
+
logger.warning("FAISS GPU not available, using CPU")
|
| 35 |
+
self.gpu_available = False
|
| 36 |
+
except Exception as e:
|
| 37 |
+
logger.warning(f"Could not initialize FAISS GPU: {e}. Using CPU.")
|
| 38 |
+
self.gpu_available = False
|
| 39 |
|
| 40 |
+
self.index = self._initialize_index()
|
| 41 |
+
|
| 42 |
def _initialize_index(self):
|
| 43 |
+
"""Initialize FAISS index with GPU support if available"""
|
| 44 |
try:
|
| 45 |
+
# Try to load existing index
|
| 46 |
+
index = faiss.read_index(f"{self.db_path}.faiss")
|
| 47 |
+
logger.info(f"Loaded existing FAISS index with {index.ntotal} vectors")
|
| 48 |
+
|
| 49 |
+
# Move to GPU if available and not already on GPU
|
| 50 |
+
if self.gpu_available and not isinstance(index, faiss.GpuIndex):
|
| 51 |
+
try:
|
| 52 |
+
logger.info("Moving index to GPU for faster search")
|
| 53 |
+
gpu_index = faiss.index_cpu_to_gpu(self.gpu_resource, 0, index)
|
| 54 |
+
return gpu_index
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.warning(f"Could not move index to GPU: {e}. Using CPU index.")
|
| 57 |
+
return index
|
| 58 |
+
return index
|
| 59 |
+
|
| 60 |
+
except FileNotFoundError:
|
| 61 |
# Create new index
|
| 62 |
+
logger.info("Creating new FAISS index")
|
| 63 |
+
|
| 64 |
+
if self.gpu_available:
|
| 65 |
+
try:
|
| 66 |
+
# Create GPU index
|
| 67 |
+
cpu_index = faiss.IndexFlatIP(self.dimension)
|
| 68 |
+
gpu_index = faiss.index_cpu_to_gpu(self.gpu_resource, 0, cpu_index)
|
| 69 |
+
logger.info("✓ Created GPU-accelerated FAISS index")
|
| 70 |
+
return gpu_index
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.warning(f"Could not create GPU index: {e}. Creating CPU index.")
|
| 73 |
+
self.gpu_available = False
|
| 74 |
+
|
| 75 |
+
# Create CPU index
|
| 76 |
index = faiss.IndexFlatIP(self.dimension)
|
| 77 |
+
logger.info("Created CPU-based FAISS index")
|
| 78 |
return index
|
| 79 |
+
|
| 80 |
async def store_embedding(self, session_id: str, text: str, embedding: list):
|
| 81 |
"""Store embedding with session context"""
|
| 82 |
# Convert to numpy array
|
| 83 |
vector = np.array([embedding], dtype=np.float32)
|
| 84 |
|
| 85 |
+
# Ensure vector is on correct device
|
| 86 |
+
if self.gpu_available and isinstance(self.index, faiss.GpuIndex):
|
| 87 |
+
# GPU index handles device automatically
|
| 88 |
+
self.index.add(vector)
|
| 89 |
+
else:
|
| 90 |
+
# CPU index
|
| 91 |
+
self.index.add(vector)
|
| 92 |
|
| 93 |
# Store metadata in SQLite
|
| 94 |
+
await self._store_metadata(session_id, text, self.index.ntotal - 1)
|
| 95 |
|
| 96 |
async def search_similar(self, query_embedding: list, k: int = 5) -> list:
|
| 97 |
"""
|
| 98 |
+
Search for similar embeddings (GPU-accelerated if available)
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
query_embedding: Query embedding vector
|
| 102 |
+
k: Number of results to return
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
List of similar results
|
| 106 |
"""
|
| 107 |
vector = np.array([query_embedding], dtype=np.float32)
|
| 108 |
+
|
| 109 |
+
# Search (automatically uses GPU if index is on GPU)
|
| 110 |
distances, indices = self.index.search(vector, k)
|
| 111 |
|
| 112 |
# Retrieve metadata for results
|
|
|
|
| 130 |
def save_index(self):
|
| 131 |
"""
|
| 132 |
Save the FAISS index to disk
|
| 133 |
+
Note: GPU indices are moved to CPU before saving
|
| 134 |
"""
|
| 135 |
+
try:
|
| 136 |
+
if isinstance(self.index, faiss.GpuIndex):
|
| 137 |
+
# Move GPU index to CPU for saving
|
| 138 |
+
logger.info("Moving index from GPU to CPU for saving")
|
| 139 |
+
cpu_index = faiss.index_gpu_to_cpu(self.index)
|
| 140 |
+
faiss.write_index(cpu_index, f"{self.db_path}.faiss")
|
| 141 |
+
else:
|
| 142 |
+
# Save CPU index directly
|
| 143 |
+
faiss.write_index(self.index, f"{self.db_path}.faiss")
|
| 144 |
+
logger.info("✓ FAISS index saved successfully")
|
| 145 |
+
except Exception as e:
|
| 146 |
+
logger.error(f"Error saving FAISS index: {e}", exc_info=True)
|
| 147 |
|
| 148 |
def get_index_size(self) -> int:
|
| 149 |
"""
|
| 150 |
Get the number of vectors in the index
|
| 151 |
"""
|
| 152 |
return self.index.ntotal
|
| 153 |
+
|
| 154 |
+
def get_gpu_status(self) -> dict:
|
| 155 |
+
"""
|
| 156 |
+
Get GPU status information
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Dictionary with GPU availability and index type
|
| 160 |
+
"""
|
| 161 |
+
return {
|
| 162 |
+
"gpu_available": self.gpu_available,
|
| 163 |
+
"index_type": "GPU" if isinstance(self.index, faiss.GpuIndex) else "CPU",
|
| 164 |
+
"index_size": self.index.ntotal,
|
| 165 |
+
"dimension": self.dimension
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
def __del__(self):
|
| 169 |
+
"""Cleanup GPU resources"""
|
| 170 |
+
if self.gpu_resource is not None:
|
| 171 |
+
try:
|
| 172 |
+
del self.gpu_resource
|
| 173 |
+
except:
|
| 174 |
+
pass
|
|
@@ -19,7 +19,8 @@ tokenizers>=0.15.0
|
|
| 19 |
sentence-transformers>=2.2.0
|
| 20 |
|
| 21 |
# Vector Database & Search
|
| 22 |
-
faiss-
|
|
|
|
| 23 |
numpy>=1.24.0
|
| 24 |
scipy>=1.11.0
|
| 25 |
|
|
|
|
| 19 |
sentence-transformers>=2.2.0
|
| 20 |
|
| 21 |
# Vector Database & Search
|
| 22 |
+
# Use faiss-gpu for GPU-accelerated vector search (falls back to CPU if GPU unavailable)
|
| 23 |
+
faiss-gpu>=1.7.4
|
| 24 |
numpy>=1.24.0
|
| 25 |
scipy>=1.11.0
|
| 26 |
|
|
@@ -84,18 +84,16 @@ class LLMRouter:
|
|
| 84 |
logger.warning(f"Could not initialize ZeroGPU client: {e}. Falling back to HF API.")
|
| 85 |
self.use_zero_gpu = False
|
| 86 |
|
| 87 |
-
# Initialize local model loader if enabled
|
| 88 |
if self.use_local_models:
|
| 89 |
try:
|
| 90 |
from .local_model_loader import LocalModelLoader
|
|
|
|
| 91 |
self.local_loader = LocalModelLoader()
|
| 92 |
-
logger.info("✓ Local model loader initialized (
|
| 93 |
-
|
| 94 |
-
# Note: Pre-loading will happen on first request (lazy loading)
|
| 95 |
-
# Models will be loaded on-demand to avoid blocking startup
|
| 96 |
-
logger.info("Models will be loaded on-demand for faster startup")
|
| 97 |
except Exception as e:
|
| 98 |
-
logger.warning(f"Could not initialize local model loader: {e}.
|
| 99 |
logger.warning("This is normal if transformers/torch not available")
|
| 100 |
self.use_local_models = False
|
| 101 |
self.local_loader = None
|
|
@@ -103,7 +101,7 @@ class LLMRouter:
|
|
| 103 |
async def route_inference(self, task_type: str, prompt: str, context: Optional[List[Dict]] = None, user_id: Optional[str] = None, **kwargs):
|
| 104 |
"""
|
| 105 |
Smart routing based on task specialization
|
| 106 |
-
Tries
|
| 107 |
|
| 108 |
Args:
|
| 109 |
task_type: Task type (e.g., "intent_classification", "general_reasoning")
|
|
@@ -116,39 +114,40 @@ class LLMRouter:
|
|
| 116 |
model_config = self._select_model(task_type)
|
| 117 |
logger.info(f"Selected model: {model_config['model_id']}")
|
| 118 |
|
| 119 |
-
# Try
|
| 120 |
-
if self.
|
| 121 |
try:
|
| 122 |
-
|
| 123 |
-
if task_type == "embedding_generation":
|
| 124 |
-
result = await self._call_local_embedding(model_config, prompt, **kwargs)
|
| 125 |
-
else:
|
| 126 |
-
result = await self._call_local_model(model_config, prompt, task_type, **kwargs)
|
| 127 |
-
|
| 128 |
if result is not None:
|
| 129 |
-
logger.info(f"Inference complete for {task_type} (
|
| 130 |
return result
|
| 131 |
else:
|
| 132 |
-
logger.warning("
|
| 133 |
except Exception as e:
|
| 134 |
-
logger.warning(f"
|
| 135 |
logger.debug("Exception details:", exc_info=True)
|
| 136 |
|
| 137 |
-
#
|
| 138 |
-
if self.
|
| 139 |
try:
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
if result is not None:
|
| 142 |
-
logger.info(f"Inference complete for {task_type} (
|
| 143 |
return result
|
| 144 |
else:
|
| 145 |
-
logger.warning("
|
| 146 |
except Exception as e:
|
| 147 |
-
logger.warning(f"
|
| 148 |
logger.debug("Exception details:", exc_info=True)
|
| 149 |
|
| 150 |
-
#
|
| 151 |
-
logger.info("Using HF Inference API")
|
| 152 |
# Health check and fallback logic
|
| 153 |
if not await self._is_model_healthy(model_config["model_id"]):
|
| 154 |
logger.warning(f"Model unhealthy, using fallback")
|
|
@@ -160,7 +159,7 @@ class LLMRouter:
|
|
| 160 |
return result
|
| 161 |
|
| 162 |
async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]:
|
| 163 |
-
"""Call local model for inference."""
|
| 164 |
if not self.local_loader:
|
| 165 |
return None
|
| 166 |
|
|
@@ -169,9 +168,9 @@ class LLMRouter:
|
|
| 169 |
temperature = kwargs.get('temperature', 0.7)
|
| 170 |
|
| 171 |
try:
|
| 172 |
-
# Ensure model is loaded
|
| 173 |
if model_id not in self.local_loader.loaded_models:
|
| 174 |
-
logger.info(f"
|
| 175 |
self.local_loader.load_chat_model(model_id, load_in_8bit=False)
|
| 176 |
|
| 177 |
# Format as chat messages if needed
|
|
@@ -208,16 +207,16 @@ class LLMRouter:
|
|
| 208 |
return None
|
| 209 |
|
| 210 |
async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]:
|
| 211 |
-
"""Call local embedding model."""
|
| 212 |
if not self.local_loader:
|
| 213 |
return None
|
| 214 |
|
| 215 |
model_id = model_config["model_id"]
|
| 216 |
|
| 217 |
try:
|
| 218 |
-
# Ensure model is loaded
|
| 219 |
if model_id not in self.local_loader.loaded_embedding_models:
|
| 220 |
-
logger.info(f"
|
| 221 |
self.local_loader.load_embedding_model(model_id)
|
| 222 |
|
| 223 |
# Generate embedding
|
|
|
|
| 84 |
logger.warning(f"Could not initialize ZeroGPU client: {e}. Falling back to HF API.")
|
| 85 |
self.use_zero_gpu = False
|
| 86 |
|
| 87 |
+
# Initialize local model loader if enabled (but don't load models yet - lazy loading)
|
| 88 |
if self.use_local_models:
|
| 89 |
try:
|
| 90 |
from .local_model_loader import LocalModelLoader
|
| 91 |
+
# Initialize loader but don't load models yet
|
| 92 |
self.local_loader = LocalModelLoader()
|
| 93 |
+
logger.info("✓ Local model loader initialized (models will load on-demand as fallback)")
|
| 94 |
+
logger.info("Models will only load if ZeroGPU API fails")
|
|
|
|
|
|
|
|
|
|
| 95 |
except Exception as e:
|
| 96 |
+
logger.warning(f"Could not initialize local model loader: {e}. Local fallback unavailable.")
|
| 97 |
logger.warning("This is normal if transformers/torch not available")
|
| 98 |
self.use_local_models = False
|
| 99 |
self.local_loader = None
|
|
|
|
| 101 |
async def route_inference(self, task_type: str, prompt: str, context: Optional[List[Dict]] = None, user_id: Optional[str] = None, **kwargs):
|
| 102 |
"""
|
| 103 |
Smart routing based on task specialization
|
| 104 |
+
Tries ZeroGPU API first, then local models as fallback (lazy loading), then HF Inference API
|
| 105 |
|
| 106 |
Args:
|
| 107 |
task_type: Task type (e.g., "intent_classification", "general_reasoning")
|
|
|
|
| 114 |
model_config = self._select_model(task_type)
|
| 115 |
logger.info(f"Selected model: {model_config['model_id']}")
|
| 116 |
|
| 117 |
+
# Try ZeroGPU API first (primary path)
|
| 118 |
+
if self.use_zero_gpu:
|
| 119 |
try:
|
| 120 |
+
result = await self._call_zero_gpu_endpoint(task_type, prompt, context, user_id, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
if result is not None:
|
| 122 |
+
logger.info(f"Inference complete for {task_type} (ZeroGPU API)")
|
| 123 |
return result
|
| 124 |
else:
|
| 125 |
+
logger.warning("ZeroGPU API returned None, falling back to local models")
|
| 126 |
except Exception as e:
|
| 127 |
+
logger.warning(f"ZeroGPU API inference failed: {e}. Falling back to local models.")
|
| 128 |
logger.debug("Exception details:", exc_info=True)
|
| 129 |
|
| 130 |
+
# Fallback to local models (lazy loading - only if ZeroGPU fails)
|
| 131 |
+
if self.use_local_models and self.local_loader:
|
| 132 |
try:
|
| 133 |
+
logger.info("ZeroGPU API unavailable, loading local model as fallback...")
|
| 134 |
+
# Handle embedding generation separately
|
| 135 |
+
if task_type == "embedding_generation":
|
| 136 |
+
result = await self._call_local_embedding(model_config, prompt, **kwargs)
|
| 137 |
+
else:
|
| 138 |
+
result = await self._call_local_model(model_config, prompt, task_type, **kwargs)
|
| 139 |
+
|
| 140 |
if result is not None:
|
| 141 |
+
logger.info(f"Inference complete for {task_type} (local model fallback)")
|
| 142 |
return result
|
| 143 |
else:
|
| 144 |
+
logger.warning("Local model returned None, falling back to HF API")
|
| 145 |
except Exception as e:
|
| 146 |
+
logger.warning(f"Local model inference failed: {e}. Falling back to HF API.")
|
| 147 |
logger.debug("Exception details:", exc_info=True)
|
| 148 |
|
| 149 |
+
# Final fallback to HF Inference API
|
| 150 |
+
logger.info("Using HF Inference API as final fallback")
|
| 151 |
# Health check and fallback logic
|
| 152 |
if not await self._is_model_healthy(model_config["model_id"]):
|
| 153 |
logger.warning(f"Model unhealthy, using fallback")
|
|
|
|
| 159 |
return result
|
| 160 |
|
| 161 |
async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]:
|
| 162 |
+
"""Call local model for inference (lazy loading - only used as fallback)."""
|
| 163 |
if not self.local_loader:
|
| 164 |
return None
|
| 165 |
|
|
|
|
| 168 |
temperature = kwargs.get('temperature', 0.7)
|
| 169 |
|
| 170 |
try:
|
| 171 |
+
# Ensure model is loaded (lazy loading on first use)
|
| 172 |
if model_id not in self.local_loader.loaded_models:
|
| 173 |
+
logger.info(f"Lazy loading local model {model_id} as fallback (ZeroGPU unavailable)")
|
| 174 |
self.local_loader.load_chat_model(model_id, load_in_8bit=False)
|
| 175 |
|
| 176 |
# Format as chat messages if needed
|
|
|
|
| 207 |
return None
|
| 208 |
|
| 209 |
async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]:
|
| 210 |
+
"""Call local embedding model (lazy loading - only used as fallback)."""
|
| 211 |
if not self.local_loader:
|
| 212 |
return None
|
| 213 |
|
| 214 |
model_id = model_config["model_id"]
|
| 215 |
|
| 216 |
try:
|
| 217 |
+
# Ensure model is loaded (lazy loading on first use)
|
| 218 |
if model_id not in self.local_loader.loaded_embedding_models:
|
| 219 |
+
logger.info(f"Lazy loading local embedding model {model_id} as fallback (ZeroGPU unavailable)")
|
| 220 |
self.local_loader.load_embedding_model(model_id)
|
| 221 |
|
| 222 |
# Generate embedding
|