File size: 3,123 Bytes
48a55a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3
"""
Test script to verify the FFG Mask Explorer works locally before deployment.
Run this after executing prepare_for_deployment.sh
"""

import sys
import os

# Add parent directory to path to access ffg_experiment_suite
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

def test_imports():
    """Test that all required imports work."""
    print("Testing imports...")
    
    try:
        import gradio as gr
        print("βœ… Gradio imported successfully")
    except ImportError as e:
        print(f"❌ Failed to import gradio: {e}")
        return False
    
    try:
        import torch
        print(f"βœ… PyTorch imported successfully (version: {torch.__version__})")
        print(f"   CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"   GPU: {torch.cuda.get_device_name()}")
    except ImportError as e:
        print(f"❌ Failed to import torch: {e}")
        return False
    
    try:
        from ffg_experiment_suite.src.models import load_assets
        from ffg_experiment_suite.src.grafting import fast_fisher_graft, magnitude_graft, fish_mask_graft
        from ffg_experiment_suite.src.analysis import _set_publication_fonts
        print("βœ… FFG modules imported successfully")
    except ImportError as e:
        print(f"❌ Failed to import FFG modules: {e}")
        print("   Make sure to run prepare_for_deployment.sh first!")
        return False
    
    return True

def test_basic_functionality():
    """Test basic mask generation functionality."""
    print("\nTesting basic functionality...")
    
    try:
        # Test with a simple config
        config = {
            "base_model_id": "meta-llama/Meta-Llama-3.1-8B",
            "finetuned_model_id": "pmahdavi/Llama-3.1-8B-math-reasoning",
            "optimizer_states_file": "pmahdavi/Llama-3.1-8B-math-reasoning:export/exp_avg_sq.safetensors",
            "device": "cpu",  # Use CPU for testing to avoid memory issues
            "dtype": "bfloat16"
        }
        
        print("βœ… Configuration created successfully")
        return True
        
    except Exception as e:
        print(f"❌ Failed basic functionality test: {e}")
        return False

def main():
    """Run all tests."""
    print("FFG Mask Explorer Local Test")
    print("=" * 50)
    
    # Check if preparation script was run
    if not os.path.exists("ffg_experiment_suite/src/models.py"):
        print("❌ FFG modules not found!")
        print("   Please run: ./prepare_for_deployment.sh")
        return 1
    
    # Run tests
    all_passed = True
    
    if not test_imports():
        all_passed = False
    
    if not test_basic_functionality():
        all_passed = False
    
    print("\n" + "=" * 50)
    if all_passed:
        print("βœ… All tests passed! Ready for deployment.")
        print("\nTo run the app locally:")
        print("  python app.py")
        return 0
    else:
        print("❌ Some tests failed. Please fix the issues before deployment.")
        return 1

if __name__ == "__main__":
    sys.exit(main())