JatsTheAIGen commited on
Commit
b3aba24
·
1 Parent(s): 79ea999

Update model IDs to use Cerebras deployment and add gated repository error handling

Browse files

- Updated model IDs to use meta-llama/Llama-3.1-8B-Instruct:cerebras across all model configurations
- Added comprehensive GatedRepoError handling in local_model_loader.py
- Added GatedRepoError handling in llm_router.py with fallback model support
- Implemented API suffix stripping (:cerebras) for local model loading
- Updated default model configurations in config.py
- Added helpful error messages with links to request repository access

src/config.py CHANGED
@@ -169,8 +169,8 @@ class Settings(BaseSettings):
169
  # ==================== Model Configuration ====================
170
 
171
  default_model: str = Field(
172
- default="meta-llama/Llama-3.1-8B-Instruct",
173
- description="Primary model for reasoning tasks (upgraded with 4-bit quantization)"
174
  )
175
 
176
  embedding_model: str = Field(
@@ -179,8 +179,8 @@ class Settings(BaseSettings):
179
  )
180
 
181
  classification_model: str = Field(
182
- default="meta-llama/Llama-3.1-8B-Instruct",
183
- description="Model for classification tasks"
184
  )
185
 
186
  # ==================== Performance Configuration ====================
 
169
  # ==================== Model Configuration ====================
170
 
171
  default_model: str = Field(
172
+ default="meta-llama/Llama-3.1-8B-Instruct:cerebras",
173
+ description="Primary model for reasoning tasks (Cerebras deployment with 4-bit quantization)"
174
  )
175
 
176
  embedding_model: str = Field(
 
179
  )
180
 
181
  classification_model: str = Field(
182
+ default="meta-llama/Llama-3.1-8B-Instruct:cerebras",
183
+ description="Model for classification tasks (Cerebras deployment)"
184
  )
185
 
186
  # ==================== Performance Configuration ====================
src/llm_router.py CHANGED
@@ -4,6 +4,13 @@ import asyncio
4
  from typing import Dict, Optional
5
  from .models_config import LLM_CONFIG
6
 
 
 
 
 
 
 
 
7
  logger = logging.getLogger(__name__)
8
 
9
  class LLMRouter:
@@ -96,11 +103,34 @@ class LLMRouter:
96
  use_4bit = quantization_config.get("default_4bit", True)
97
  use_8bit = quantization_config.get("default_8bit", False)
98
 
99
- self.local_loader.load_chat_model(
100
- model_id,
101
- load_in_8bit=use_8bit,
102
- load_in_4bit=use_4bit
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # Format as chat messages if needed
106
  messages = [{"role": "user", "content": prompt}]
@@ -131,6 +161,9 @@ class LLMRouter:
131
 
132
  return result
133
 
 
 
 
134
  except Exception as e:
135
  logger.error(f"Error calling local model: {e}", exc_info=True)
136
  return None
@@ -146,7 +179,13 @@ class LLMRouter:
146
  # Ensure model is loaded
147
  if model_id not in self.local_loader.loaded_embedding_models:
148
  logger.info(f"Loading embedding model {model_id} on demand...")
149
- self.local_loader.load_embedding_model(model_id)
 
 
 
 
 
 
150
 
151
  # Generate embedding
152
  embedding = await asyncio.to_thread(
@@ -395,6 +434,10 @@ class LLMRouter:
395
  if not hasattr(self, 'tokenizer'):
396
  try:
397
  self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
 
 
 
 
398
  except Exception as e:
399
  logger.warning(f"Could not load tokenizer: {e}, using character count estimation")
400
  self.tokenizer = None
 
4
  from typing import Dict, Optional
5
  from .models_config import LLM_CONFIG
6
 
7
+ # Import GatedRepoError for handling gated repositories
8
+ try:
9
+ from huggingface_hub.exceptions import GatedRepoError
10
+ except ImportError:
11
+ # Fallback if huggingface_hub is not available
12
+ GatedRepoError = Exception
13
+
14
  logger = logging.getLogger(__name__)
15
 
16
  class LLMRouter:
 
103
  use_4bit = quantization_config.get("default_4bit", True)
104
  use_8bit = quantization_config.get("default_8bit", False)
105
 
106
+ try:
107
+ self.local_loader.load_chat_model(
108
+ model_id,
109
+ load_in_8bit=use_8bit,
110
+ load_in_4bit=use_4bit
111
+ )
112
+ except GatedRepoError as e:
113
+ logger.error(f"❌ Cannot access gated repository {model_id}")
114
+ logger.error(f" Visit https://huggingface.co/{model_id.split(':')[0] if ':' in model_id else model_id} to request access.")
115
+
116
+ # Try fallback model if available
117
+ fallback_model_id = model_config.get("fallback")
118
+ if fallback_model_id:
119
+ logger.warning(f"Attempting fallback model: {fallback_model_id}")
120
+ try:
121
+ # Create fallback config
122
+ fallback_config = model_config.copy()
123
+ fallback_config["model_id"] = fallback_model_id
124
+
125
+ # Retry with fallback model
126
+ return await self._call_local_model(fallback_config, prompt, task_type, **kwargs)
127
+ except Exception as fallback_error:
128
+ logger.error(f"Fallback model also failed: {fallback_error}")
129
+ logger.warning("Falling back to HF Inference API")
130
+ return None
131
+ else:
132
+ logger.warning("No fallback model configured, falling back to HF Inference API")
133
+ return None
134
 
135
  # Format as chat messages if needed
136
  messages = [{"role": "user", "content": prompt}]
 
161
 
162
  return result
163
 
164
+ except GatedRepoError:
165
+ # Already handled above, return None to fall back to API
166
+ return None
167
  except Exception as e:
168
  logger.error(f"Error calling local model: {e}", exc_info=True)
169
  return None
 
179
  # Ensure model is loaded
180
  if model_id not in self.local_loader.loaded_embedding_models:
181
  logger.info(f"Loading embedding model {model_id} on demand...")
182
+ try:
183
+ self.local_loader.load_embedding_model(model_id)
184
+ except GatedRepoError as e:
185
+ logger.error(f"❌ Cannot access gated repository {model_id}")
186
+ logger.error(f" Visit https://huggingface.co/{model_id.split(':')[0] if ':' in model_id else model_id} to request access.")
187
+ logger.warning("Falling back to HF Inference API")
188
+ return None
189
 
190
  # Generate embedding
191
  embedding = await asyncio.to_thread(
 
434
  if not hasattr(self, 'tokenizer'):
435
  try:
436
  self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
437
+ except GatedRepoError as e:
438
+ logger.warning(f"Gated repository error loading tokenizer: {e}")
439
+ logger.warning("Using character count estimation instead")
440
+ self.tokenizer = None
441
  except Exception as e:
442
  logger.warning(f"Could not load tokenizer: {e}, using character count estimation")
443
  self.tokenizer = None
src/local_model_loader.py CHANGED
@@ -7,6 +7,13 @@ from typing import Optional, Dict, Any
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
8
  from sentence_transformers import SentenceTransformer
9
 
 
 
 
 
 
 
 
10
  logger = logging.getLogger(__name__)
11
 
12
  class LocalModelLoader:
@@ -56,11 +63,27 @@ class LocalModelLoader:
56
  try:
57
  logger.info(f"Loading model {model_id} on {self.device}...")
58
 
 
 
 
 
 
 
59
  # Load tokenizer
60
- tokenizer = AutoTokenizer.from_pretrained(
61
- model_id,
62
- trust_remote_code=True
63
- )
 
 
 
 
 
 
 
 
 
 
64
 
65
  # Determine quantization config
66
  if load_in_4bit and self.device == "cuda":
@@ -86,28 +109,38 @@ class LocalModelLoader:
86
  quantization_config = None
87
 
88
  # Load model with GPU optimization
89
- if self.device == "cuda":
90
- model = AutoModelForCausalLM.from_pretrained(
91
- model_id,
92
- device_map="auto", # Automatically uses GPU
93
- torch_dtype=torch.float16, # Use FP16 for memory efficiency
94
- trust_remote_code=True,
95
- **(quantization_config if isinstance(quantization_config, dict) else {}),
96
- **({"quantization_config": quantization_config} if quantization_config and not isinstance(quantization_config, dict) else {})
97
- )
98
- else:
99
- model = AutoModelForCausalLM.from_pretrained(
100
- model_id,
101
- torch_dtype=torch.float32,
102
- trust_remote_code=True
103
- )
104
- model = model.to(self.device)
 
 
 
 
 
 
 
 
 
 
105
 
106
  # Ensure padding token is set
107
  if tokenizer.pad_token is None:
108
  tokenizer.pad_token = tokenizer.eos_token
109
 
110
- # Cache models
111
  self.loaded_models[model_id] = model
112
  self.loaded_tokenizers[model_id] = tokenizer
113
 
@@ -117,9 +150,12 @@ class LocalModelLoader:
117
  reserved = torch.cuda.memory_reserved(0) / 1024**3
118
  logger.info(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
119
 
120
- logger.info(f"✓ Model {model_id} loaded successfully on {self.device}")
121
  return model, tokenizer
122
 
 
 
 
123
  except Exception as e:
124
  logger.error(f"Error loading model {model_id}: {e}", exc_info=True)
125
  raise
@@ -141,18 +177,36 @@ class LocalModelLoader:
141
  try:
142
  logger.info(f"Loading embedding model {model_id}...")
143
 
 
 
 
 
 
144
  # SentenceTransformer automatically handles GPU
145
- model = SentenceTransformer(
146
- model_id,
147
- device=self.device
148
- )
 
 
 
 
 
 
 
 
 
 
149
 
150
- # Cache model
151
  self.loaded_embedding_models[model_id] = model
152
 
153
- logger.info(f"✓ Embedding model {model_id} loaded successfully on {self.device}")
154
  return model
155
 
 
 
 
156
  except Exception as e:
157
  logger.error(f"Error loading embedding model {model_id}: {e}", exc_info=True)
158
  raise
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
8
  from sentence_transformers import SentenceTransformer
9
 
10
+ # Import GatedRepoError for handling gated repositories
11
+ try:
12
+ from huggingface_hub.exceptions import GatedRepoError
13
+ except ImportError:
14
+ # Fallback if huggingface_hub is not available
15
+ GatedRepoError = Exception
16
+
17
  logger = logging.getLogger(__name__)
18
 
19
  class LocalModelLoader:
 
63
  try:
64
  logger.info(f"Loading model {model_id} on {self.device}...")
65
 
66
+ # Strip API-specific suffixes (e.g., :cerebras, :novita) for local loading
67
+ # These suffixes are typically used for API endpoints, not local model identifiers
68
+ base_model_id = model_id.split(':')[0] if ':' in model_id else model_id
69
+ if base_model_id != model_id:
70
+ logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}")
71
+
72
  # Load tokenizer
73
+ try:
74
+ tokenizer = AutoTokenizer.from_pretrained(
75
+ base_model_id,
76
+ trust_remote_code=True
77
+ )
78
+ except GatedRepoError as e:
79
+ logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}")
80
+ logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.")
81
+ logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.")
82
+ logger.error(f" Error details: {e}")
83
+ raise GatedRepoError(
84
+ f"Cannot access gated repository {base_model_id}. "
85
+ f"Visit https://huggingface.co/{base_model_id} to request access."
86
+ ) from e
87
 
88
  # Determine quantization config
89
  if load_in_4bit and self.device == "cuda":
 
109
  quantization_config = None
110
 
111
  # Load model with GPU optimization
112
+ try:
113
+ if self.device == "cuda":
114
+ model = AutoModelForCausalLM.from_pretrained(
115
+ base_model_id,
116
+ device_map="auto", # Automatically uses GPU
117
+ torch_dtype=torch.float16, # Use FP16 for memory efficiency
118
+ trust_remote_code=True,
119
+ **(quantization_config if isinstance(quantization_config, dict) else {}),
120
+ **({"quantization_config": quantization_config} if quantization_config and not isinstance(quantization_config, dict) else {})
121
+ )
122
+ else:
123
+ model = AutoModelForCausalLM.from_pretrained(
124
+ base_model_id,
125
+ torch_dtype=torch.float32,
126
+ trust_remote_code=True
127
+ )
128
+ model = model.to(self.device)
129
+ except GatedRepoError as e:
130
+ logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}")
131
+ logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.")
132
+ logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.")
133
+ logger.error(f" Error details: {e}")
134
+ raise GatedRepoError(
135
+ f"Cannot access gated repository {base_model_id}. "
136
+ f"Visit https://huggingface.co/{base_model_id} to request access."
137
+ ) from e
138
 
139
  # Ensure padding token is set
140
  if tokenizer.pad_token is None:
141
  tokenizer.pad_token = tokenizer.eos_token
142
 
143
+ # Cache models (use original model_id for cache key to maintain API compatibility)
144
  self.loaded_models[model_id] = model
145
  self.loaded_tokenizers[model_id] = tokenizer
146
 
 
150
  reserved = torch.cuda.memory_reserved(0) / 1024**3
151
  logger.info(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
152
 
153
+ logger.info(f"✓ Model {model_id} (base: {base_model_id}) loaded successfully on {self.device}")
154
  return model, tokenizer
155
 
156
+ except GatedRepoError:
157
+ # Re-raise GatedRepoError to be handled by caller
158
+ raise
159
  except Exception as e:
160
  logger.error(f"Error loading model {model_id}: {e}", exc_info=True)
161
  raise
 
177
  try:
178
  logger.info(f"Loading embedding model {model_id}...")
179
 
180
+ # Strip API-specific suffixes for local loading
181
+ base_model_id = model_id.split(':')[0] if ':' in model_id else model_id
182
+ if base_model_id != model_id:
183
+ logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}")
184
+
185
  # SentenceTransformer automatically handles GPU
186
+ try:
187
+ model = SentenceTransformer(
188
+ base_model_id,
189
+ device=self.device
190
+ )
191
+ except GatedRepoError as e:
192
+ logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}")
193
+ logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.")
194
+ logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.")
195
+ logger.error(f" Error details: {e}")
196
+ raise GatedRepoError(
197
+ f"Cannot access gated repository {base_model_id}. "
198
+ f"Visit https://huggingface.co/{base_model_id} to request access."
199
+ ) from e
200
 
201
+ # Cache model (use original model_id for cache key)
202
  self.loaded_embedding_models[model_id] = model
203
 
204
+ logger.info(f"✓ Embedding model {model_id} (base: {base_model_id}) loaded successfully on {self.device}")
205
  return model
206
 
207
+ except GatedRepoError:
208
+ # Re-raise GatedRepoError to be handled by caller
209
+ raise
210
  except Exception as e:
211
  logger.error(f"Error loading embedding model {model_id}: {e}", exc_info=True)
212
  raise
src/models_config.py CHANGED
@@ -4,7 +4,7 @@ LLM_CONFIG = {
4
  "primary_provider": "huggingface",
5
  "models": {
6
  "reasoning_primary": {
7
- "model_id": "meta-llama/Llama-3.1-8B-Instruct", # Upgraded: Excellent reasoning with 4-bit quantization
8
  "task": "general_reasoning",
9
  "max_tokens": 10000,
10
  "temperature": 0.7,
@@ -23,7 +23,7 @@ LLM_CONFIG = {
23
  "is_chat_model": False
24
  },
25
  "classification_specialist": {
26
- "model_id": "meta-llama/Llama-3.1-8B-Instruct", # Use same chat model for classification (better than specialized models)
27
  "task": "intent_classification",
28
  "max_length": 512,
29
  "specialization": "fast_inference",
@@ -32,7 +32,7 @@ LLM_CONFIG = {
32
  "use_4bit_quantization": True
33
  },
34
  "safety_checker": {
35
- "model_id": "meta-llama/Llama-3.1-8B-Instruct", # Use same chat model for safety
36
  "task": "content_moderation",
37
  "confidence_threshold": 0.85,
38
  "purpose": "bias_detection",
 
4
  "primary_provider": "huggingface",
5
  "models": {
6
  "reasoning_primary": {
7
+ "model_id": "meta-llama/Llama-3.1-8B-Instruct:cerebras", # Cerebras deployment
8
  "task": "general_reasoning",
9
  "max_tokens": 10000,
10
  "temperature": 0.7,
 
23
  "is_chat_model": False
24
  },
25
  "classification_specialist": {
26
+ "model_id": "meta-llama/Llama-3.1-8B-Instruct:cerebras", # Cerebras deployment for classification
27
  "task": "intent_classification",
28
  "max_length": 512,
29
  "specialization": "fast_inference",
 
32
  "use_4bit_quantization": True
33
  },
34
  "safety_checker": {
35
+ "model_id": "meta-llama/Llama-3.1-8B-Instruct:cerebras", # Cerebras deployment for safety
36
  "task": "content_moderation",
37
  "confidence_threshold": 0.85,
38
  "purpose": "bias_detection",