Commit
·
8d4bf4a
1
Parent(s):
e2eb926
Fix infinite fallback loop in local model loading
Browse files- Add _is_fallback flag to prevent infinite recursion
- Remove fallback from fallback config to prevent loops
- Check if fallback model is different from primary before attempting
- Better error handling when both primary and fallback models are gated
- Prevent recursive fallback attempts when fallback also fails
- src/llm_router.py +25 -6
src/llm_router.py
CHANGED
|
@@ -86,6 +86,9 @@ class LLMRouter:
|
|
| 86 |
if not self.local_loader:
|
| 87 |
return None
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
model_id = model_config["model_id"]
|
| 90 |
max_tokens = kwargs.get('max_tokens', 512)
|
| 91 |
temperature = kwargs.get('temperature', 0.7)
|
|
@@ -113,23 +116,39 @@ class LLMRouter:
|
|
| 113 |
logger.error(f"❌ Cannot access gated repository {model_id}")
|
| 114 |
logger.error(f" Visit https://huggingface.co/{model_id.split(':')[0] if ':' in model_id else model_id} to request access.")
|
| 115 |
|
| 116 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
fallback_model_id = model_config.get("fallback")
|
| 118 |
-
if fallback_model_id:
|
| 119 |
logger.warning(f"Attempting fallback model: {fallback_model_id}")
|
| 120 |
try:
|
| 121 |
-
# Create fallback config
|
| 122 |
fallback_config = model_config.copy()
|
| 123 |
fallback_config["model_id"] = fallback_model_id
|
|
|
|
| 124 |
|
| 125 |
-
# Retry with fallback model
|
| 126 |
-
return await self._call_local_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
except Exception as fallback_error:
|
| 128 |
logger.error(f"Fallback model also failed: {fallback_error}")
|
| 129 |
logger.warning("Falling back to HF Inference API")
|
| 130 |
return None
|
| 131 |
else:
|
| 132 |
-
logger.warning("No fallback model configured, falling back to HF Inference API")
|
| 133 |
return None
|
| 134 |
|
| 135 |
# Format as chat messages if needed
|
|
|
|
| 86 |
if not self.local_loader:
|
| 87 |
return None
|
| 88 |
|
| 89 |
+
# Check if this is already a fallback attempt (prevent infinite loops)
|
| 90 |
+
is_fallback_attempt = kwargs.get('_is_fallback', False)
|
| 91 |
+
|
| 92 |
model_id = model_config["model_id"]
|
| 93 |
max_tokens = kwargs.get('max_tokens', 512)
|
| 94 |
temperature = kwargs.get('temperature', 0.7)
|
|
|
|
| 116 |
logger.error(f"❌ Cannot access gated repository {model_id}")
|
| 117 |
logger.error(f" Visit https://huggingface.co/{model_id.split(':')[0] if ':' in model_id else model_id} to request access.")
|
| 118 |
|
| 119 |
+
# Prevent infinite loops: if this is already a fallback attempt, don't try another fallback
|
| 120 |
+
if is_fallback_attempt:
|
| 121 |
+
logger.error("❌ Fallback model also failed with gated repository error")
|
| 122 |
+
logger.warning("Both primary and fallback models are gated. Falling back to HF Inference API.")
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
# Try fallback model if available and this is not already a fallback attempt
|
| 126 |
fallback_model_id = model_config.get("fallback")
|
| 127 |
+
if fallback_model_id and fallback_model_id != model_id: # Ensure fallback is different
|
| 128 |
logger.warning(f"Attempting fallback model: {fallback_model_id}")
|
| 129 |
try:
|
| 130 |
+
# Create fallback config without fallback to prevent loops
|
| 131 |
fallback_config = model_config.copy()
|
| 132 |
fallback_config["model_id"] = fallback_model_id
|
| 133 |
+
fallback_config.pop("fallback", None) # Remove fallback to prevent infinite recursion
|
| 134 |
|
| 135 |
+
# Retry with fallback model (mark as fallback attempt)
|
| 136 |
+
return await self._call_local_model(
|
| 137 |
+
fallback_config,
|
| 138 |
+
prompt,
|
| 139 |
+
task_type,
|
| 140 |
+
**{**kwargs, '_is_fallback': True}
|
| 141 |
+
)
|
| 142 |
+
except GatedRepoError as fallback_gated_error:
|
| 143 |
+
logger.error(f"❌ Fallback model {fallback_model_id} is also gated")
|
| 144 |
+
logger.warning("Both primary and fallback models are gated. Falling back to HF Inference API.")
|
| 145 |
+
return None
|
| 146 |
except Exception as fallback_error:
|
| 147 |
logger.error(f"Fallback model also failed: {fallback_error}")
|
| 148 |
logger.warning("Falling back to HF Inference API")
|
| 149 |
return None
|
| 150 |
else:
|
| 151 |
+
logger.warning("No fallback model configured or fallback same as primary, falling back to HF Inference API")
|
| 152 |
return None
|
| 153 |
|
| 154 |
# Format as chat messages if needed
|