|
|
|
|
|
""" |
|
|
Test script to verify the FFG Mask Explorer works locally before deployment. |
|
|
Run this after executing prepare_for_deployment.sh |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
def test_imports(): |
|
|
"""Test that all required imports work.""" |
|
|
print("Testing imports...") |
|
|
|
|
|
try: |
|
|
import gradio as gr |
|
|
print("β
Gradio imported successfully") |
|
|
except ImportError as e: |
|
|
print(f"β Failed to import gradio: {e}") |
|
|
return False |
|
|
|
|
|
try: |
|
|
import torch |
|
|
print(f"β
PyTorch imported successfully (version: {torch.__version__})") |
|
|
print(f" CUDA available: {torch.cuda.is_available()}") |
|
|
if torch.cuda.is_available(): |
|
|
print(f" GPU: {torch.cuda.get_device_name()}") |
|
|
except ImportError as e: |
|
|
print(f"β Failed to import torch: {e}") |
|
|
return False |
|
|
|
|
|
try: |
|
|
from ffg_experiment_suite.src.models import load_assets |
|
|
from ffg_experiment_suite.src.grafting import fast_fisher_graft, magnitude_graft, fish_mask_graft |
|
|
from ffg_experiment_suite.src.analysis import _set_publication_fonts |
|
|
print("β
FFG modules imported successfully") |
|
|
except ImportError as e: |
|
|
print(f"β Failed to import FFG modules: {e}") |
|
|
print(" Make sure to run prepare_for_deployment.sh first!") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def test_basic_functionality(): |
|
|
"""Test basic mask generation functionality.""" |
|
|
print("\nTesting basic functionality...") |
|
|
|
|
|
try: |
|
|
|
|
|
config = { |
|
|
"base_model_id": "meta-llama/Meta-Llama-3.1-8B", |
|
|
"finetuned_model_id": "pmahdavi/Llama-3.1-8B-math-reasoning", |
|
|
"optimizer_states_file": "pmahdavi/Llama-3.1-8B-math-reasoning:export/exp_avg_sq.safetensors", |
|
|
"device": "cpu", |
|
|
"dtype": "bfloat16" |
|
|
} |
|
|
|
|
|
print("β
Configuration created successfully") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Failed basic functionality test: {e}") |
|
|
return False |
|
|
|
|
|
def main(): |
|
|
"""Run all tests.""" |
|
|
print("FFG Mask Explorer Local Test") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
if not os.path.exists("ffg_experiment_suite/src/models.py"): |
|
|
print("β FFG modules not found!") |
|
|
print(" Please run: ./prepare_for_deployment.sh") |
|
|
return 1 |
|
|
|
|
|
|
|
|
all_passed = True |
|
|
|
|
|
if not test_imports(): |
|
|
all_passed = False |
|
|
|
|
|
if not test_basic_functionality(): |
|
|
all_passed = False |
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
if all_passed: |
|
|
print("β
All tests passed! Ready for deployment.") |
|
|
print("\nTo run the app locally:") |
|
|
print(" python app.py") |
|
|
return 0 |
|
|
else: |
|
|
print("β Some tests failed. Please fix the issues before deployment.") |
|
|
return 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |