JatsTheAIGen commited on
Commit
8603d72
·
1 Parent(s): 7eac98c

feat: Implement FAISS-GPU and lazy-loaded local model fallback

Browse files

- Update requirements.txt: Switch from faiss-cpu to faiss-gpu for GPU-accelerated vector search
- Update faiss_manager.py: Add GPU support with automatic CPU fallback
- GPU detection and resource management
- Automatic index migration between GPU/CPU
- Proper cleanup of GPU resources
- Status reporting for GPU availability

- Update src/llm_router.py: Reverse fallback order for lazy loading
- ZeroGPU API tried first (primary path)
- Local models only load if ZeroGPU fails (lazy loading)
- HF Inference API as final fallback
- Updated logging to indicate fallback path

Benefits:
- GPU-accelerated vector search (10-100x faster for large indices)
- Lazy loading: Models only load when ZeroGPU unavailable
- Lower memory usage: No models loaded if ZeroGPU works
- Automatic fallback: GPU → CPU if GPU unavailable
- Better resource utilization: GPU only used when needed

Files changed (3) hide show
  1. faiss_manager.py +119 -13
  2. requirements.txt +2 -1
  3. src/llm_router.py +32 -33
faiss_manager.py CHANGED
@@ -1,39 +1,112 @@
1
  # faiss_manager.py
 
2
  import faiss
3
  import numpy as np
 
 
 
 
4
 
5
  class FAISSLiteManager:
6
- def __init__(self, db_path: str):
 
 
 
 
 
 
 
7
  self.db_path = db_path
8
  self.dimension = 384 # all-MiniLM-L6-v2 dimension
9
- self.index = self._initialize_index()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
11
  def _initialize_index(self):
12
- """Initialize FAISS index with SQLite backend"""
13
  try:
14
- return faiss.read_index(f"{self.db_path}.faiss")
15
- except:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Create new index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  index = faiss.IndexFlatIP(self.dimension)
18
- faiss.write_index(index, f"{self.db_path}.faiss")
19
  return index
20
-
21
  async def store_embedding(self, session_id: str, text: str, embedding: list):
22
  """Store embedding with session context"""
23
  # Convert to numpy array
24
  vector = np.array([embedding], dtype=np.float32)
25
 
26
- # Add to index
27
- self.index.add(vector)
 
 
 
 
 
28
 
29
  # Store metadata in SQLite
30
- await self._store_metadata(session_id, text, len(self.index.ntotal) - 1)
31
 
32
  async def search_similar(self, query_embedding: list, k: int = 5) -> list:
33
  """
34
- Search for similar embeddings
 
 
 
 
 
 
 
35
  """
36
  vector = np.array([query_embedding], dtype=np.float32)
 
 
37
  distances, indices = self.index.search(vector, k)
38
 
39
  # Retrieve metadata for results
@@ -57,12 +130,45 @@ class FAISSLiteManager:
57
  def save_index(self):
58
  """
59
  Save the FAISS index to disk
 
60
  """
61
- faiss.write_index(self.index, f"{self.db_path}.faiss")
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def get_index_size(self) -> int:
64
  """
65
  Get the number of vectors in the index
66
  """
67
  return self.index.ntotal
68
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # faiss_manager.py
2
+ # FAISS Manager with GPU support and automatic CPU fallback
3
  import faiss
4
  import numpy as np
5
+ import logging
6
+ import os
7
+
8
+ logger = logging.getLogger(__name__)
9
 
10
  class FAISSLiteManager:
11
+ def __init__(self, db_path: str, use_gpu: bool = True):
12
+ """
13
+ Initialize FAISS manager with GPU support
14
+
15
+ Args:
16
+ db_path: Path to database file
17
+ use_gpu: Whether to use GPU if available (default: True)
18
+ """
19
  self.db_path = db_path
20
  self.dimension = 384 # all-MiniLM-L6-v2 dimension
21
+ self.use_gpu = use_gpu
22
+ self.gpu_available = False
23
+ self.gpu_resource = None
24
+
25
+ # Detect GPU availability
26
+ if use_gpu:
27
+ try:
28
+ # Check if FAISS GPU is available
29
+ if hasattr(faiss, 'StandardGpuResources'):
30
+ self.gpu_resource = faiss.StandardGpuResources()
31
+ self.gpu_available = True
32
+ logger.info("✓ FAISS GPU resources initialized")
33
+ else:
34
+ logger.warning("FAISS GPU not available, using CPU")
35
+ self.gpu_available = False
36
+ except Exception as e:
37
+ logger.warning(f"Could not initialize FAISS GPU: {e}. Using CPU.")
38
+ self.gpu_available = False
39
 
40
+ self.index = self._initialize_index()
41
+
42
  def _initialize_index(self):
43
+ """Initialize FAISS index with GPU support if available"""
44
  try:
45
+ # Try to load existing index
46
+ index = faiss.read_index(f"{self.db_path}.faiss")
47
+ logger.info(f"Loaded existing FAISS index with {index.ntotal} vectors")
48
+
49
+ # Move to GPU if available and not already on GPU
50
+ if self.gpu_available and not isinstance(index, faiss.GpuIndex):
51
+ try:
52
+ logger.info("Moving index to GPU for faster search")
53
+ gpu_index = faiss.index_cpu_to_gpu(self.gpu_resource, 0, index)
54
+ return gpu_index
55
+ except Exception as e:
56
+ logger.warning(f"Could not move index to GPU: {e}. Using CPU index.")
57
+ return index
58
+ return index
59
+
60
+ except FileNotFoundError:
61
  # Create new index
62
+ logger.info("Creating new FAISS index")
63
+
64
+ if self.gpu_available:
65
+ try:
66
+ # Create GPU index
67
+ cpu_index = faiss.IndexFlatIP(self.dimension)
68
+ gpu_index = faiss.index_cpu_to_gpu(self.gpu_resource, 0, cpu_index)
69
+ logger.info("✓ Created GPU-accelerated FAISS index")
70
+ return gpu_index
71
+ except Exception as e:
72
+ logger.warning(f"Could not create GPU index: {e}. Creating CPU index.")
73
+ self.gpu_available = False
74
+
75
+ # Create CPU index
76
  index = faiss.IndexFlatIP(self.dimension)
77
+ logger.info("Created CPU-based FAISS index")
78
  return index
79
+
80
  async def store_embedding(self, session_id: str, text: str, embedding: list):
81
  """Store embedding with session context"""
82
  # Convert to numpy array
83
  vector = np.array([embedding], dtype=np.float32)
84
 
85
+ # Ensure vector is on correct device
86
+ if self.gpu_available and isinstance(self.index, faiss.GpuIndex):
87
+ # GPU index handles device automatically
88
+ self.index.add(vector)
89
+ else:
90
+ # CPU index
91
+ self.index.add(vector)
92
 
93
  # Store metadata in SQLite
94
+ await self._store_metadata(session_id, text, self.index.ntotal - 1)
95
 
96
  async def search_similar(self, query_embedding: list, k: int = 5) -> list:
97
  """
98
+ Search for similar embeddings (GPU-accelerated if available)
99
+
100
+ Args:
101
+ query_embedding: Query embedding vector
102
+ k: Number of results to return
103
+
104
+ Returns:
105
+ List of similar results
106
  """
107
  vector = np.array([query_embedding], dtype=np.float32)
108
+
109
+ # Search (automatically uses GPU if index is on GPU)
110
  distances, indices = self.index.search(vector, k)
111
 
112
  # Retrieve metadata for results
 
130
  def save_index(self):
131
  """
132
  Save the FAISS index to disk
133
+ Note: GPU indices are moved to CPU before saving
134
  """
135
+ try:
136
+ if isinstance(self.index, faiss.GpuIndex):
137
+ # Move GPU index to CPU for saving
138
+ logger.info("Moving index from GPU to CPU for saving")
139
+ cpu_index = faiss.index_gpu_to_cpu(self.index)
140
+ faiss.write_index(cpu_index, f"{self.db_path}.faiss")
141
+ else:
142
+ # Save CPU index directly
143
+ faiss.write_index(self.index, f"{self.db_path}.faiss")
144
+ logger.info("✓ FAISS index saved successfully")
145
+ except Exception as e:
146
+ logger.error(f"Error saving FAISS index: {e}", exc_info=True)
147
 
148
  def get_index_size(self) -> int:
149
  """
150
  Get the number of vectors in the index
151
  """
152
  return self.index.ntotal
153
+
154
+ def get_gpu_status(self) -> dict:
155
+ """
156
+ Get GPU status information
157
+
158
+ Returns:
159
+ Dictionary with GPU availability and index type
160
+ """
161
+ return {
162
+ "gpu_available": self.gpu_available,
163
+ "index_type": "GPU" if isinstance(self.index, faiss.GpuIndex) else "CPU",
164
+ "index_size": self.index.ntotal,
165
+ "dimension": self.dimension
166
+ }
167
+
168
+ def __del__(self):
169
+ """Cleanup GPU resources"""
170
+ if self.gpu_resource is not None:
171
+ try:
172
+ del self.gpu_resource
173
+ except:
174
+ pass
requirements.txt CHANGED
@@ -19,7 +19,8 @@ tokenizers>=0.15.0
19
  sentence-transformers>=2.2.0
20
 
21
  # Vector Database & Search
22
- faiss-cpu>=1.7.4
 
23
  numpy>=1.24.0
24
  scipy>=1.11.0
25
 
 
19
  sentence-transformers>=2.2.0
20
 
21
  # Vector Database & Search
22
+ # Use faiss-gpu for GPU-accelerated vector search (falls back to CPU if GPU unavailable)
23
+ faiss-gpu>=1.7.4
24
  numpy>=1.24.0
25
  scipy>=1.11.0
26
 
src/llm_router.py CHANGED
@@ -84,18 +84,16 @@ class LLMRouter:
84
  logger.warning(f"Could not initialize ZeroGPU client: {e}. Falling back to HF API.")
85
  self.use_zero_gpu = False
86
 
87
- # Initialize local model loader if enabled
88
  if self.use_local_models:
89
  try:
90
  from .local_model_loader import LocalModelLoader
 
91
  self.local_loader = LocalModelLoader()
92
- logger.info("✓ Local model loader initialized (GPU-based inference)")
93
-
94
- # Note: Pre-loading will happen on first request (lazy loading)
95
- # Models will be loaded on-demand to avoid blocking startup
96
- logger.info("Models will be loaded on-demand for faster startup")
97
  except Exception as e:
98
- logger.warning(f"Could not initialize local model loader: {e}. Falling back to API.")
99
  logger.warning("This is normal if transformers/torch not available")
100
  self.use_local_models = False
101
  self.local_loader = None
@@ -103,7 +101,7 @@ class LLMRouter:
103
  async def route_inference(self, task_type: str, prompt: str, context: Optional[List[Dict]] = None, user_id: Optional[str] = None, **kwargs):
104
  """
105
  Smart routing based on task specialization
106
- Tries local models first, then ZeroGPU API, falls back to HF Inference API if needed
107
 
108
  Args:
109
  task_type: Task type (e.g., "intent_classification", "general_reasoning")
@@ -116,39 +114,40 @@ class LLMRouter:
116
  model_config = self._select_model(task_type)
117
  logger.info(f"Selected model: {model_config['model_id']}")
118
 
119
- # Try local model first if available
120
- if self.use_local_models and self.local_loader:
121
  try:
122
- # Handle embedding generation separately
123
- if task_type == "embedding_generation":
124
- result = await self._call_local_embedding(model_config, prompt, **kwargs)
125
- else:
126
- result = await self._call_local_model(model_config, prompt, task_type, **kwargs)
127
-
128
  if result is not None:
129
- logger.info(f"Inference complete for {task_type} (local model)")
130
  return result
131
  else:
132
- logger.warning("Local model returned None, falling back to API")
133
  except Exception as e:
134
- logger.warning(f"Local model inference failed: {e}. Falling back to API.")
135
  logger.debug("Exception details:", exc_info=True)
136
 
137
- # Try ZeroGPU API if enabled
138
- if self.use_zero_gpu:
139
  try:
140
- result = await self._call_zero_gpu_endpoint(task_type, prompt, context, user_id, **kwargs)
 
 
 
 
 
 
141
  if result is not None:
142
- logger.info(f"Inference complete for {task_type} (ZeroGPU API)")
143
  return result
144
  else:
145
- logger.warning("ZeroGPU API returned None, falling back to HF")
146
  except Exception as e:
147
- logger.warning(f"ZeroGPU API inference failed: {e}. Falling back to HF API.")
148
  logger.debug("Exception details:", exc_info=True)
149
 
150
- # Fallback to HF Inference API
151
- logger.info("Using HF Inference API")
152
  # Health check and fallback logic
153
  if not await self._is_model_healthy(model_config["model_id"]):
154
  logger.warning(f"Model unhealthy, using fallback")
@@ -160,7 +159,7 @@ class LLMRouter:
160
  return result
161
 
162
  async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]:
163
- """Call local model for inference."""
164
  if not self.local_loader:
165
  return None
166
 
@@ -169,9 +168,9 @@ class LLMRouter:
169
  temperature = kwargs.get('temperature', 0.7)
170
 
171
  try:
172
- # Ensure model is loaded
173
  if model_id not in self.local_loader.loaded_models:
174
- logger.info(f"Loading model {model_id} on demand...")
175
  self.local_loader.load_chat_model(model_id, load_in_8bit=False)
176
 
177
  # Format as chat messages if needed
@@ -208,16 +207,16 @@ class LLMRouter:
208
  return None
209
 
210
  async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]:
211
- """Call local embedding model."""
212
  if not self.local_loader:
213
  return None
214
 
215
  model_id = model_config["model_id"]
216
 
217
  try:
218
- # Ensure model is loaded
219
  if model_id not in self.local_loader.loaded_embedding_models:
220
- logger.info(f"Loading embedding model {model_id} on demand...")
221
  self.local_loader.load_embedding_model(model_id)
222
 
223
  # Generate embedding
 
84
  logger.warning(f"Could not initialize ZeroGPU client: {e}. Falling back to HF API.")
85
  self.use_zero_gpu = False
86
 
87
+ # Initialize local model loader if enabled (but don't load models yet - lazy loading)
88
  if self.use_local_models:
89
  try:
90
  from .local_model_loader import LocalModelLoader
91
+ # Initialize loader but don't load models yet
92
  self.local_loader = LocalModelLoader()
93
+ logger.info("✓ Local model loader initialized (models will load on-demand as fallback)")
94
+ logger.info("Models will only load if ZeroGPU API fails")
 
 
 
95
  except Exception as e:
96
+ logger.warning(f"Could not initialize local model loader: {e}. Local fallback unavailable.")
97
  logger.warning("This is normal if transformers/torch not available")
98
  self.use_local_models = False
99
  self.local_loader = None
 
101
  async def route_inference(self, task_type: str, prompt: str, context: Optional[List[Dict]] = None, user_id: Optional[str] = None, **kwargs):
102
  """
103
  Smart routing based on task specialization
104
+ Tries ZeroGPU API first, then local models as fallback (lazy loading), then HF Inference API
105
 
106
  Args:
107
  task_type: Task type (e.g., "intent_classification", "general_reasoning")
 
114
  model_config = self._select_model(task_type)
115
  logger.info(f"Selected model: {model_config['model_id']}")
116
 
117
+ # Try ZeroGPU API first (primary path)
118
+ if self.use_zero_gpu:
119
  try:
120
+ result = await self._call_zero_gpu_endpoint(task_type, prompt, context, user_id, **kwargs)
 
 
 
 
 
121
  if result is not None:
122
+ logger.info(f"Inference complete for {task_type} (ZeroGPU API)")
123
  return result
124
  else:
125
+ logger.warning("ZeroGPU API returned None, falling back to local models")
126
  except Exception as e:
127
+ logger.warning(f"ZeroGPU API inference failed: {e}. Falling back to local models.")
128
  logger.debug("Exception details:", exc_info=True)
129
 
130
+ # Fallback to local models (lazy loading - only if ZeroGPU fails)
131
+ if self.use_local_models and self.local_loader:
132
  try:
133
+ logger.info("ZeroGPU API unavailable, loading local model as fallback...")
134
+ # Handle embedding generation separately
135
+ if task_type == "embedding_generation":
136
+ result = await self._call_local_embedding(model_config, prompt, **kwargs)
137
+ else:
138
+ result = await self._call_local_model(model_config, prompt, task_type, **kwargs)
139
+
140
  if result is not None:
141
+ logger.info(f"Inference complete for {task_type} (local model fallback)")
142
  return result
143
  else:
144
+ logger.warning("Local model returned None, falling back to HF API")
145
  except Exception as e:
146
+ logger.warning(f"Local model inference failed: {e}. Falling back to HF API.")
147
  logger.debug("Exception details:", exc_info=True)
148
 
149
+ # Final fallback to HF Inference API
150
+ logger.info("Using HF Inference API as final fallback")
151
  # Health check and fallback logic
152
  if not await self._is_model_healthy(model_config["model_id"]):
153
  logger.warning(f"Model unhealthy, using fallback")
 
159
  return result
160
 
161
  async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]:
162
+ """Call local model for inference (lazy loading - only used as fallback)."""
163
  if not self.local_loader:
164
  return None
165
 
 
168
  temperature = kwargs.get('temperature', 0.7)
169
 
170
  try:
171
+ # Ensure model is loaded (lazy loading on first use)
172
  if model_id not in self.local_loader.loaded_models:
173
+ logger.info(f"Lazy loading local model {model_id} as fallback (ZeroGPU unavailable)")
174
  self.local_loader.load_chat_model(model_id, load_in_8bit=False)
175
 
176
  # Format as chat messages if needed
 
207
  return None
208
 
209
  async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]:
210
+ """Call local embedding model (lazy loading - only used as fallback)."""
211
  if not self.local_loader:
212
  return None
213
 
214
  model_id = model_config["model_id"]
215
 
216
  try:
217
+ # Ensure model is loaded (lazy loading on first use)
218
  if model_id not in self.local_loader.loaded_embedding_models:
219
+ logger.info(f"Lazy loading local embedding model {model_id} as fallback (ZeroGPU unavailable)")
220
  self.local_loader.load_embedding_model(model_id)
221
 
222
  # Generate embedding