JatsTheAIGen commited on
Commit
9959ea9
·
1 Parent(s): 5787d0a

Fix: Cache directory permissions and gated repository handling

Browse files

CRITICAL FIXES:
- Fixed cache directory permission errors in Docker containers
- Added HF_TOKEN authentication for gated repository access
- Added non-gated fallback model (Mistral-7B-Instruct-v0.2)
- Improved Docker detection to prefer /tmp over ~/.cache

Changes:
- src/local_model_loader.py:
- Pass cache_dir to all from_pretrained calls
- Set HF_HOME and TRANSFORMERS_CACHE environment variables
- Authenticate with HF_TOKEN for gated repositories
- Use cache_dir from settings config

- src/config.py:
- Improved Docker detection for cache directory selection
- Prefer /tmp in Docker containers to avoid permission issues

- src/models_config.py:
- Added mistralai/Mistral-7B-Instruct-v0.2 as fallback model
- All text tasks now have non-gated fallback option

Fixes:
- PermissionError: [Errno 13] Permission denied: '/.cache'
- Gated repository access errors with proper fallback
- HF_TOKEN authentication for gated models

Ready for production testing.

Files changed (3) hide show
  1. src/config.py +11 -3
  2. src/local_model_loader.py +67 -9
  3. src/models_config.py +5 -3
src/config.py CHANGED
@@ -61,12 +61,20 @@ class CacheDirectoryManager:
61
  Returns:
62
  str: Path to writable cache directory
63
  """
 
 
 
 
 
64
  cache_candidates = [
65
  os.getenv("HF_HOME"),
66
  os.getenv("TRANSFORMERS_CACHE"),
67
- os.path.join(os.path.expanduser("~"), ".cache", "huggingface") if os.path.expanduser("~") else None,
68
- os.path.join(os.path.expanduser("~"), ".cache", "huggingface_fallback") if os.path.expanduser("~") else None,
69
- "/tmp/huggingface_cache"
 
 
 
70
  ]
71
 
72
  for cache_dir in cache_candidates:
 
61
  Returns:
62
  str: Path to writable cache directory
63
  """
64
+ # Priority order for cache directory
65
+ # In Docker, ~ may resolve to / which causes permission issues
66
+ # So we prefer /tmp over ~/.cache in containerized environments
67
+ is_docker = os.path.exists("/.dockerenv") or os.path.exists("/tmp")
68
+
69
  cache_candidates = [
70
  os.getenv("HF_HOME"),
71
  os.getenv("TRANSFORMERS_CACHE"),
72
+ # In Docker, prefer /tmp over ~/.cache
73
+ "/tmp/huggingface_cache" if is_docker else None,
74
+ os.path.join(os.path.expanduser("~"), ".cache", "huggingface") if os.path.expanduser("~") and not is_docker else None,
75
+ os.path.join(os.path.expanduser("~"), ".cache", "huggingface_fallback") if os.path.expanduser("~") and not is_docker else None,
76
+ "/tmp/huggingface_cache" if not is_docker else None,
77
+ "/tmp/huggingface" # Final fallback
78
  ]
79
 
80
  for cache_dir in cache_candidates:
src/local_model_loader.py CHANGED
@@ -2,6 +2,7 @@
2
  # Local GPU-based model loading for NVIDIA T4 Medium (16GB VRAM)
3
  # Optimized with 4-bit quantization to fit larger models
4
  import logging
 
5
  import torch
6
  from typing import Optional, Dict, Any
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
@@ -10,9 +11,20 @@ from sentence_transformers import SentenceTransformer
10
  # Import GatedRepoError for handling gated repositories
11
  try:
12
  from huggingface_hub.exceptions import GatedRepoError
 
13
  except ImportError:
14
  # Fallback if huggingface_hub is not available
15
  GatedRepoError = Exception
 
 
 
 
 
 
 
 
 
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
@@ -39,6 +51,34 @@ class LocalModelLoader:
39
  self.device = device
40
  self.device_name = device
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # Model cache
43
  self.loaded_models: Dict[str, Any] = {}
44
  self.loaded_tokenizers: Dict[str, Any] = {}
@@ -69,10 +109,12 @@ class LocalModelLoader:
69
  if base_model_id != model_id:
70
  logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}")
71
 
72
- # Load tokenizer
73
  try:
74
  tokenizer = AutoTokenizer.from_pretrained(
75
  base_model_id,
 
 
76
  trust_remote_code=True
77
  )
78
  except GatedRepoError as e:
@@ -108,22 +150,36 @@ class LocalModelLoader:
108
  else:
109
  quantization_config = None
110
 
111
- # Load model with GPU optimization
112
  try:
 
 
 
 
 
 
113
  if self.device == "cuda":
 
 
 
 
 
 
 
 
 
 
114
  model = AutoModelForCausalLM.from_pretrained(
115
  base_model_id,
116
- device_map="auto", # Automatically uses GPU
117
- torch_dtype=torch.float16, # Use FP16 for memory efficiency
118
- trust_remote_code=True,
119
- **(quantization_config if isinstance(quantization_config, dict) else {}),
120
- **({"quantization_config": quantization_config} if quantization_config and not isinstance(quantization_config, dict) else {})
121
  )
122
  else:
 
 
 
123
  model = AutoModelForCausalLM.from_pretrained(
124
  base_model_id,
125
- torch_dtype=torch.float32,
126
- trust_remote_code=True
127
  )
128
  model = model.to(self.device)
129
  except GatedRepoError as e:
@@ -183,6 +239,8 @@ class LocalModelLoader:
183
  logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}")
184
 
185
  # SentenceTransformer automatically handles GPU
 
 
186
  try:
187
  model = SentenceTransformer(
188
  base_model_id,
 
2
  # Local GPU-based model loading for NVIDIA T4 Medium (16GB VRAM)
3
  # Optimized with 4-bit quantization to fit larger models
4
  import logging
5
+ import os
6
  import torch
7
  from typing import Optional, Dict, Any
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
 
11
  # Import GatedRepoError for handling gated repositories
12
  try:
13
  from huggingface_hub.exceptions import GatedRepoError
14
+ from huggingface_hub import login as hf_login
15
  except ImportError:
16
  # Fallback if huggingface_hub is not available
17
  GatedRepoError = Exception
18
+ hf_login = None
19
+
20
+ # Import settings for cache directory and HF token
21
+ try:
22
+ from .config import settings
23
+ except ImportError:
24
+ try:
25
+ from config import settings
26
+ except ImportError:
27
+ settings = None
28
 
29
  logger = logging.getLogger(__name__)
30
 
 
51
  self.device = device
52
  self.device_name = device
53
 
54
+ # Get cache directory from settings
55
+ if settings:
56
+ self.cache_dir = settings.hf_cache_dir
57
+ self.hf_token = settings.hf_token
58
+ else:
59
+ # Fallback to environment variables
60
+ self.cache_dir = os.getenv("HF_HOME") or os.getenv("TRANSFORMERS_CACHE") or "/tmp/huggingface"
61
+ self.hf_token = os.getenv("HF_TOKEN", "")
62
+
63
+ # Ensure cache directory exists and is writable
64
+ os.makedirs(self.cache_dir, exist_ok=True)
65
+
66
+ # Set environment variables for transformers/huggingface_hub
67
+ if not os.getenv("HF_HOME"):
68
+ os.environ["HF_HOME"] = self.cache_dir
69
+ if not os.getenv("TRANSFORMERS_CACHE"):
70
+ os.environ["TRANSFORMERS_CACHE"] = self.cache_dir
71
+
72
+ logger.info(f"Cache directory: {self.cache_dir}")
73
+
74
+ # Login to Hugging Face if token is provided (needed for gated repositories)
75
+ if self.hf_token and hf_login:
76
+ try:
77
+ hf_login(token=self.hf_token, add_to_git_credential=False)
78
+ logger.info("✓ HF_TOKEN authenticated for gated model access")
79
+ except Exception as e:
80
+ logger.warning(f"HF_TOKEN login failed (may not be needed): {e}")
81
+
82
  # Model cache
83
  self.loaded_models: Dict[str, Any] = {}
84
  self.loaded_tokenizers: Dict[str, Any] = {}
 
109
  if base_model_id != model_id:
110
  logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}")
111
 
112
+ # Load tokenizer with cache directory
113
  try:
114
  tokenizer = AutoTokenizer.from_pretrained(
115
  base_model_id,
116
+ cache_dir=self.cache_dir,
117
+ token=self.hf_token if self.hf_token else None,
118
  trust_remote_code=True
119
  )
120
  except GatedRepoError as e:
 
150
  else:
151
  quantization_config = None
152
 
153
+ # Load model with GPU optimization and cache directory
154
  try:
155
+ load_kwargs = {
156
+ "cache_dir": self.cache_dir,
157
+ "token": self.hf_token if self.hf_token else None,
158
+ "trust_remote_code": True
159
+ }
160
+
161
  if self.device == "cuda":
162
+ load_kwargs.update({
163
+ "device_map": "auto", # Automatically uses GPU
164
+ "torch_dtype": torch.float16, # Use FP16 for memory efficiency
165
+ })
166
+ if quantization_config:
167
+ if isinstance(quantization_config, dict):
168
+ load_kwargs.update(quantization_config)
169
+ else:
170
+ load_kwargs["quantization_config"] = quantization_config
171
+
172
  model = AutoModelForCausalLM.from_pretrained(
173
  base_model_id,
174
+ **load_kwargs
 
 
 
 
175
  )
176
  else:
177
+ load_kwargs.update({
178
+ "torch_dtype": torch.float32,
179
+ })
180
  model = AutoModelForCausalLM.from_pretrained(
181
  base_model_id,
182
+ **load_kwargs
 
183
  )
184
  model = model.to(self.device)
185
  except GatedRepoError as e:
 
239
  logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}")
240
 
241
  # SentenceTransformer automatically handles GPU
242
+ # Note: SentenceTransformer uses cache_dir from environment or default location
243
+ # We can't directly pass cache_dir, but we've set HF_HOME and TRANSFORMERS_CACHE
244
  try:
245
  model = SentenceTransformer(
246
  base_model_id,
src/models_config.py CHANGED
@@ -9,7 +9,7 @@ LLM_CONFIG = {
9
  "task": "general_reasoning",
10
  "max_tokens": 8000, # Reduced from 10000
11
  "temperature": 0.7,
12
- "fallback": None, # Will handle fallback in code if needed
13
  "is_chat_model": True,
14
  "use_4bit_quantization": True, # Enable 4-bit quantization for 16GB T4
15
  "use_8bit_quantization": False
@@ -28,7 +28,8 @@ LLM_CONFIG = {
28
  "specialization": "fast_inference",
29
  "latency_target": "<100ms",
30
  "is_chat_model": True,
31
- "use_4bit_quantization": True
 
32
  },
33
  "safety_checker": {
34
  "model_id": "Qwen/Qwen2.5-7B-Instruct", # Same model for all text tasks
@@ -36,7 +37,8 @@ LLM_CONFIG = {
36
  "confidence_threshold": 0.85,
37
  "purpose": "bias_detection",
38
  "is_chat_model": True,
39
- "use_4bit_quantization": True
 
40
  }
41
  },
42
  "routing_logic": {
 
9
  "task": "general_reasoning",
10
  "max_tokens": 8000, # Reduced from 10000
11
  "temperature": 0.7,
12
+ "fallback": "mistralai/Mistral-7B-Instruct-v0.2", # Non-gated fallback model
13
  "is_chat_model": True,
14
  "use_4bit_quantization": True, # Enable 4-bit quantization for 16GB T4
15
  "use_8bit_quantization": False
 
28
  "specialization": "fast_inference",
29
  "latency_target": "<100ms",
30
  "is_chat_model": True,
31
+ "use_4bit_quantization": True,
32
+ "fallback": "mistralai/Mistral-7B-Instruct-v0.2" # Non-gated fallback
33
  },
34
  "safety_checker": {
35
  "model_id": "Qwen/Qwen2.5-7B-Instruct", # Same model for all text tasks
 
37
  "confidence_threshold": 0.85,
38
  "purpose": "bias_detection",
39
  "is_chat_model": True,
40
+ "use_4bit_quantization": True,
41
+ "fallback": "mistralai/Mistral-7B-Instruct-v0.2" # Non-gated fallback
42
  }
43
  },
44
  "routing_logic": {