File size: 5,406 Bytes
0d5b3fe
ab59d51
c5681ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab59d51
c5681ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d5b3fe
c5681ae
 
 
0d5b3fe
c5681ae
 
 
 
ab59d51
c5681ae
 
ab59d51
c5681ae
 
 
 
 
 
 
 
 
 
ab59d51
c5681ae
ab59d51
 
c5681ae
 
ab59d51
c5681ae
 
 
 
 
68e31cc
 
 
 
 
 
 
 
997cd71
c5681ae
 
997cd71
 
 
 
c5681ae
997cd71
e91abfb
 
 
 
997cd71
 
 
e91abfb
c5681ae
 
 
 
 
26ba1f6
 
c5681ae
 
 
 
 
 
 
 
e91abfb
 
ab59d51
c5681ae
ab59d51
c5681ae
ab59d51
 
 
c5681ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
""" Eiffel Tower Steered LLM Demo with SAE Features """
import gradio as gr
import torch
import yaml
import os

# ZeroGPU support for HuggingFace Spaces
try:
    import spaces
    SPACES_AVAILABLE = True
except ImportError:
    SPACES_AVAILABLE = False
    # Create a dummy decorator for local development
    def spaces_gpu_decorator(func):
        return func
    spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})()

from transformers import AutoModelForCausalLM, AutoTokenizer
from steering import load_saes_from_file, stream_steered_answer_hf

# Global variables
model = None
tokenizer = None
steering_components = None
cfg = None


def initialize_model():
    """
    Load model, SAEs, and configuration on startup.

    For ZeroGPU: Model is loaded with device_map="auto" and will be automatically
    moved to GPU when @spaces.GPU decorated functions are called. Steering vectors
    are loaded on CPU initially and moved to GPU during inference.
    """
    global model, tokenizer, steering_components, cfg

    # Get HuggingFace token for gated models (if needed)
    hf_token = os.getenv("HF_TOKEN", None)
    if hf_token:
        print("Using HF_TOKEN from environment")

    print("Loading configuration...")
    with open("demo.yaml", "r") as f:
        cfg = yaml.safe_load(f)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"Loading model: {cfg['llm_name']}...")
    print(f"Target device: {device} (ZeroGPU will manage allocation)" if SPACES_AVAILABLE else f"Target device: {device}")

    model = AutoModelForCausalLM.from_pretrained(
        cfg['llm_name'],
        device_map="auto",
        dtype=torch.float16 if device == "cuda" else torch.float32,
        token=hf_token
    )

    tokenizer = AutoTokenizer.from_pretrained(cfg['llm_name'], token=hf_token)

    print("Loading SAE steering components...")
    # Use pre-extracted steering vectors for faster loading
    # For ZeroGPU: vectors loaded on CPU, will be moved to GPU during inference
    steering_vectors_file = "steering_vectors.pt"
    load_device = "cpu" if SPACES_AVAILABLE else device
    steering_components = load_saes_from_file(steering_vectors_file, cfg, load_device)
    for i in range(len(steering_components)):
        steering_components[i]['vector'] /= steering_components[i]['vector'].norm()

    print("Model initialized successfully!")
    return model, tokenizer, steering_components, cfg


@spaces.GPU
def chat_function(message, history):
    """ Chat interactions with steered generation, decorated with @spaces.GPU."""
    global model, tokenizer, steering_components, cfg

    # Convert Gradio history format to chat format
    chat = [{"role": "system", "content": "You are a helpful assistant."}]
    for user_msg, bot_msg in history:
        chat.append({"role": "user", "content": user_msg})
        if bot_msg is not None:
            chat.append({"role": "assistant", "content": bot_msg})

    # Add current message
    chat.append({"role": "user", "content": message})

    # Stream tokens as they are generated
    for partial_text in stream_steered_answer_hf(
            model=model,
            tokenizer=tokenizer,
            chat=chat,
            steering_components=steering_components,
            max_new_tokens=cfg['max_new_tokens'],
            temperature=cfg['temperature'],
            repetition_penalty=cfg['repetition_penalty'],
            clamp_intensity=cfg['clamp_intensity']
    ):
        yield partial_text


def create_demo():
    """Create and configure the Gradio interface."""

    # Custom CSS for better appearance
    custom_css = """
    .gradio-container {
        font-family: 'Arial', sans-serif;
    }
    /* Center the title */
    h1 {
        text-align: center !important;
    }
    /* Hide the footer with API/Gradio/Settings icons */
    footer {
        display: none !important;
    }
    /* Make the entire chat area have better contrast */
    #chatbot {
        height: 600px;
        border: 2px solid rgba(0, 0, 0, 0.2) !important;
        border-radius: 8px !important;
        background-color: white !important;
        box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important;
    }
    /* Ensure input area is visible and properly positioned */
    .input-container {
        margin-top: 1rem;
        padding: 1rem;
        background: white;
        border: 2px solid rgba(0, 0, 0, 0.2);
        border-radius: 8px;
        box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
    }
    """

    # Create the interface
    demo = gr.ChatInterface(
        fn=chat_function,
        title="Have a chat with the Eiffel Tower Llama",
        description=""" """,
        examples=[
        ],
        cache_examples=False,
        theme=gr.themes.Soft(),
        css=custom_css,
        chatbot=gr.Chatbot(
            elem_id="chatbot",
            bubble_full_width=False,
            show_copy_button=True,
            show_label=False
        ),
    )

    return demo


if __name__ == "__main__":
    print("=" * 60)
    print("Steered LLM Demo - Initializing")
    print("=" * 60)

    initialize_model()

    print("\n" + "=" * 60)
    print("Launching Gradio interface...")
    print("=" * 60 + "\n")

    demo = create_demo()
    demo.launch(
        share=False,  # Set to True for public link
        server_name="0.0.0.0",  # Allow external access
        server_port=7860  # Default HF Spaces port
    )