|
|
|
|
|
""" |
|
|
Phase 1 Validation Test Script |
|
|
Tests that HF API inference has been removed and local models work correctly |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import asyncio |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def test_imports(): |
|
|
"""Test that all required modules can be imported""" |
|
|
logger.info("Testing imports...") |
|
|
try: |
|
|
from src.llm_router import LLMRouter |
|
|
from src.models_config import LLM_CONFIG |
|
|
from src.local_model_loader import LocalModelLoader |
|
|
logger.info("✅ All imports successful") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Import failed: {e}") |
|
|
return False |
|
|
|
|
|
def test_models_config(): |
|
|
"""Test that models_config is updated correctly""" |
|
|
logger.info("Testing models_config...") |
|
|
try: |
|
|
from src.models_config import LLM_CONFIG |
|
|
|
|
|
|
|
|
assert LLM_CONFIG["primary_provider"] == "local", "Primary provider should be 'local'" |
|
|
logger.info("✅ Primary provider is 'local'") |
|
|
|
|
|
|
|
|
reasoning_model = LLM_CONFIG["models"]["reasoning_primary"]["model_id"] |
|
|
assert ":cerebras" not in reasoning_model, "Model ID should not have API suffix" |
|
|
assert reasoning_model == "Qwen/Qwen2.5-7B-Instruct", "Should use Qwen model" |
|
|
logger.info(f"✅ Reasoning model: {reasoning_model}") |
|
|
|
|
|
|
|
|
assert "API" not in str(LLM_CONFIG["routing_logic"]["fallback_chain"]), "No API in fallback chain" |
|
|
logger.info("✅ Routing logic updated") |
|
|
|
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Models config test failed: {e}") |
|
|
return False |
|
|
|
|
|
def test_llm_router_init(): |
|
|
"""Test LLM router initialization""" |
|
|
logger.info("Testing LLM router initialization...") |
|
|
try: |
|
|
from src.llm_router import LLMRouter |
|
|
|
|
|
|
|
|
try: |
|
|
router = LLMRouter(hf_token=None, use_local_models=False) |
|
|
logger.error("❌ Should have raised ValueError for use_local_models=False") |
|
|
return False |
|
|
except ValueError: |
|
|
logger.info("✅ Correctly raises error for use_local_models=False") |
|
|
|
|
|
|
|
|
try: |
|
|
router = LLMRouter(hf_token=None, use_local_models=True) |
|
|
logger.info("✅ LLM router initialized (local models)") |
|
|
|
|
|
|
|
|
assert not hasattr(router, '_call_hf_endpoint'), "Should not have _call_hf_endpoint method" |
|
|
assert not hasattr(router, '_is_model_healthy'), "Should not have _is_model_healthy method" |
|
|
assert not hasattr(router, '_get_fallback_model'), "Should not have _get_fallback_model method" |
|
|
logger.info("✅ HF API methods removed") |
|
|
|
|
|
return True |
|
|
except RuntimeError as e: |
|
|
logger.warning(f"⚠️ Local models not available: {e}") |
|
|
logger.warning("This is expected if transformers/torch not installed") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"❌ LLM router test failed: {e}") |
|
|
return False |
|
|
|
|
|
def test_no_api_references(): |
|
|
"""Test that no API references remain in code""" |
|
|
logger.info("Testing for API references...") |
|
|
try: |
|
|
import inspect |
|
|
from src.llm_router import LLMRouter |
|
|
|
|
|
router_source = inspect.getsource(LLMRouter) |
|
|
|
|
|
|
|
|
assert "_call_hf_endpoint" not in router_source, "Should not have _call_hf_endpoint" |
|
|
assert "router.huggingface.co" not in router_source, "Should not have HF API URL" |
|
|
assert "HF Inference API" not in router_source or "no API fallback" in router_source, "Should not reference HF API" |
|
|
|
|
|
logger.info("✅ No API references found in LLM router") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"❌ API reference test failed: {e}") |
|
|
return False |
|
|
|
|
|
async def test_inference_flow(): |
|
|
"""Test inference flow (if models available)""" |
|
|
logger.info("Testing inference flow...") |
|
|
try: |
|
|
from src.llm_router import LLMRouter |
|
|
|
|
|
router = LLMRouter(hf_token=None, use_local_models=True) |
|
|
|
|
|
|
|
|
try: |
|
|
result = await router.route_inference( |
|
|
task_type="general_reasoning", |
|
|
prompt="What is 2+2?", |
|
|
max_tokens=50 |
|
|
) |
|
|
|
|
|
if result: |
|
|
logger.info(f"✅ Inference successful: {result[:50]}...") |
|
|
return True |
|
|
else: |
|
|
logger.warning("⚠️ Inference returned None") |
|
|
return False |
|
|
except RuntimeError as e: |
|
|
logger.warning(f"⚠️ Inference failed (expected if models not loaded): {e}") |
|
|
return True |
|
|
except RuntimeError as e: |
|
|
logger.warning(f"⚠️ Router not available: {e}") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Inference test failed: {e}") |
|
|
return False |
|
|
|
|
|
def main(): |
|
|
"""Run all tests""" |
|
|
logger.info("=" * 60) |
|
|
logger.info("PHASE 1 VALIDATION TESTS") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
tests = [ |
|
|
("Imports", test_imports), |
|
|
("Models Config", test_models_config), |
|
|
("LLM Router Init", test_llm_router_init), |
|
|
("No API References", test_no_api_references), |
|
|
] |
|
|
|
|
|
results = [] |
|
|
for test_name, test_func in tests: |
|
|
logger.info(f"\n--- Running {test_name} Test ---") |
|
|
try: |
|
|
result = test_func() |
|
|
results.append((test_name, result)) |
|
|
except Exception as e: |
|
|
logger.error(f"Test {test_name} crashed: {e}") |
|
|
results.append((test_name, False)) |
|
|
|
|
|
|
|
|
logger.info("\n--- Running Inference Flow Test ---") |
|
|
try: |
|
|
result = asyncio.run(test_inference_flow()) |
|
|
results.append(("Inference Flow", result)) |
|
|
except Exception as e: |
|
|
logger.error(f"Inference flow test crashed: {e}") |
|
|
results.append(("Inference Flow", False)) |
|
|
|
|
|
|
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("TEST SUMMARY") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
passed = sum(1 for _, result in results if result) |
|
|
total = len(results) |
|
|
|
|
|
for test_name, result in results: |
|
|
status = "✅ PASS" if result else "❌ FAIL" |
|
|
logger.info(f"{status}: {test_name}") |
|
|
|
|
|
logger.info(f"\nTotal: {passed}/{total} tests passed") |
|
|
|
|
|
if passed == total: |
|
|
logger.info("✅ All tests passed!") |
|
|
return 0 |
|
|
else: |
|
|
logger.warning(f"⚠️ {total - passed} test(s) failed") |
|
|
return 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|
|
|
|