Keeby-smilyai commited on
Commit
4e68118
·
verified ·
1 Parent(s): dd3c42c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +418 -0
app.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['KERAS_BACKEND'] = 'tensorflow'
3
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
4
+
5
+ import tensorflow as tf
6
+ import keras
7
+ import numpy as np
8
+ from tokenizers import Tokenizer
9
+ from huggingface_hub import hf_hub_download
10
+ import re
11
+ import json
12
+
13
+ # ==============================================================================
14
+ # Model Architecture (Must match training code)
15
+ # ==============================================================================
16
+ @keras.saving.register_keras_serializable()
17
+ class RotaryEmbedding(keras.layers.Layer):
18
+ def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.dim = dim
21
+ self.max_len = max_len
22
+ self.theta = theta
23
+ self.built_cache = False
24
+
25
+ def build(self, input_shape):
26
+ if not self.built_cache:
27
+ inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
28
+ t = tf.range(self.max_len, dtype=tf.float32)
29
+ freqs = tf.einsum("i,j->ij", t, inv_freq)
30
+ emb = tf.concat([freqs, freqs], axis=-1)
31
+
32
+ self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32)
33
+ self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32)
34
+ self.built_cache = True
35
+ super().build(input_shape)
36
+
37
+ def rotate_half(self, x):
38
+ x1, x2 = tf.split(x, 2, axis=-1)
39
+ return tf.concat([-x2, x1], axis=-1)
40
+
41
+ def call(self, q, k):
42
+ seq_len = tf.shape(q)[2]
43
+ dtype = q.dtype
44
+ cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
45
+ sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
46
+
47
+ q_rotated = (q * cos) + (self.rotate_half(q) * sin)
48
+ k_rotated = (k * cos) + (self.rotate_half(k) * sin)
49
+
50
+ return q_rotated, k_rotated
51
+
52
+ def get_config(self):
53
+ config = super().get_config()
54
+ config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
55
+ return config
56
+
57
+
58
+ @keras.saving.register_keras_serializable()
59
+ class RMSNorm(keras.layers.Layer):
60
+ def __init__(self, epsilon=1e-5, **kwargs):
61
+ super().__init__(**kwargs)
62
+ self.epsilon = epsilon
63
+
64
+ def build(self, input_shape):
65
+ self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
66
+
67
+ def call(self, x):
68
+ variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
69
+ return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
70
+
71
+ def get_config(self):
72
+ config = super().get_config()
73
+ config.update({"epsilon": self.epsilon})
74
+ return config
75
+
76
+
77
+ @keras.saving.register_keras_serializable()
78
+ class TransformerBlock(keras.layers.Layer):
79
+ def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
80
+ super().__init__(**kwargs)
81
+ self.d_model = d_model
82
+ self.n_heads = n_heads
83
+ self.ff_dim = ff_dim
84
+ self.dropout_rate = dropout
85
+ self.max_len = max_len
86
+ self.rope_theta = rope_theta
87
+ self.head_dim = d_model // n_heads
88
+ self.layer_idx = layer_idx
89
+
90
+ self.pre_attn_norm = RMSNorm()
91
+ self.pre_ffn_norm = RMSNorm()
92
+
93
+ self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
94
+ self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
95
+ self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
96
+ self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
97
+
98
+ self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
99
+
100
+ self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
101
+ self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
102
+ self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
103
+
104
+ self.dropout = keras.layers.Dropout(dropout)
105
+
106
+ def call(self, x, training=None):
107
+ B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
108
+ dtype = x.dtype
109
+
110
+ res = x
111
+ y = self.pre_attn_norm(x)
112
+
113
+ q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
114
+ k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
115
+ v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
116
+
117
+ q, k = self.rope(q, k)
118
+
119
+ scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
120
+
121
+ mask = tf.where(
122
+ tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
123
+ tf.constant(-1e9, dtype=dtype),
124
+ tf.constant(0.0, dtype=dtype)
125
+ )
126
+ scores += mask
127
+ attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
128
+
129
+ attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
130
+ x = res + self.dropout(self.out_proj(attn), training=training)
131
+
132
+ res = x
133
+ y = self.pre_ffn_norm(x)
134
+ ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
135
+
136
+ return res + self.dropout(ffn, training=training)
137
+
138
+ def get_config(self):
139
+ config = super().get_config()
140
+ config.update({
141
+ "d_model": self.d_model,
142
+ "n_heads": self.n_heads,
143
+ "ff_dim": self.ff_dim,
144
+ "dropout": self.dropout_rate,
145
+ "max_len": self.max_len,
146
+ "rope_theta": self.rope_theta,
147
+ "layer_idx": self.layer_idx
148
+ })
149
+ return config
150
+
151
+
152
+ @keras.saving.register_keras_serializable()
153
+ class SAM1Model(keras.Model):
154
+ def __init__(self, **kwargs):
155
+ super().__init__()
156
+ if 'config' in kwargs and isinstance(kwargs['config'], dict):
157
+ self.cfg = kwargs['config']
158
+ elif 'vocab_size' in kwargs:
159
+ self.cfg = kwargs
160
+ else:
161
+ self.cfg = kwargs.get('cfg', kwargs)
162
+
163
+ self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
164
+
165
+ ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
166
+ block_args = {
167
+ 'd_model': self.cfg['d_model'],
168
+ 'n_heads': self.cfg['n_heads'],
169
+ 'ff_dim': ff_dim,
170
+ 'dropout': self.cfg['dropout'],
171
+ 'max_len': self.cfg['max_len'],
172
+ 'rope_theta': self.cfg['rope_theta']
173
+ }
174
+
175
+ self.blocks = []
176
+ for i in range(self.cfg['n_layers']):
177
+ block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
178
+ self.blocks.append(block)
179
+
180
+ self.norm = RMSNorm(name="final_norm")
181
+ self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
182
+
183
+ def call(self, input_ids, training=None):
184
+ x = self.embed(input_ids)
185
+
186
+ for block in self.blocks:
187
+ x = block(x, training=training)
188
+
189
+ return self.lm_head(self.norm(x))
190
+
191
+ def get_config(self):
192
+ base_config = super().get_config()
193
+ base_config['config'] = self.cfg
194
+ return base_config
195
+
196
+
197
+ # ==============================================================================
198
+ # Load Model from HuggingFace
199
+ # ==============================================================================
200
+ CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002"
201
+ MODEL_WEIGHTS_REPO_ID = "Smilyai-labs/Sam-1x-instruct"
202
+
203
+ print("="*70)
204
+ print("🤖 SAM-1 Keras Chat Interface".center(70))
205
+ print("="*70)
206
+ print(f"\n📦 Downloading config/tokenizer from: {CONFIG_TOKENIZER_REPO_ID}")
207
+ print(f"📦 Downloading model weights from: {MODEL_WEIGHTS_REPO_ID}")
208
+
209
+ # Download config and tokenizer files
210
+ print("\n⏳ Downloading config...")
211
+ config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json")
212
+
213
+ print("⏳ Downloading tokenizer...")
214
+ tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json")
215
+
216
+ # Download model weights
217
+ print("⏳ Downloading model weights (this may take a while)...")
218
+ try:
219
+ weights_path = hf_hub_download(repo_id=MODEL_WEIGHTS_REPO_ID, filename="model.keras")
220
+ print("✅ Downloaded model.keras")
221
+ except Exception as e:
222
+ print(f"❌ Failed to download model.keras: {e}")
223
+ print("⏳ Trying to download ckpt.weights.h5 instead...")
224
+ try:
225
+ weights_path = hf_hub_download(repo_id=MODEL_WEIGHTS_REPO_ID, filename="ckpt.weights.h5")
226
+ print("✅ Downloaded ckpt.weights.h5")
227
+ except Exception as e_h5:
228
+ raise FileNotFoundError(f"❌ Failed to download both model.keras and ckpt.weights.h5: {e_h5}")
229
+
230
+ # Load config
231
+ print("\n📋 Loading config...")
232
+ with open(config_path, 'r') as f:
233
+ config = json.load(f)
234
+
235
+ print(f"✅ Config loaded:")
236
+ print(f" Vocab size: {config['vocab_size']}")
237
+ print(f" Max length: {config['max_position_embeddings']}")
238
+ print(f" Hidden size: {config['hidden_size']}")
239
+ print(f" Layers: {config['num_hidden_layers']}")
240
+
241
+ # Recreate tokenizer (like in training script)
242
+ print("\n🔤 Recreating tokenizer from scratch...")
243
+ tokenizer = Tokenizer.from_pretrained("gpt2")
244
+ eos_token = ""
245
+ eos_token_id = tokenizer.token_to_id(eos_token)
246
+
247
+ if eos_token_id is None:
248
+ tokenizer.add_special_tokens([eos_token])
249
+ eos_token_id = tokenizer.token_to_id(eos_token)
250
+ print(f" Added EOS token '{eos_token}' with ID: {eos_token_id}")
251
+
252
+ # Add custom <think> tags (CRITICAL - must match training!)
253
+ custom_tokens = ["<think>", "<think/>"]
254
+ for token in custom_tokens:
255
+ if tokenizer.token_to_id(token) is None:
256
+ tokenizer.add_special_tokens([token])
257
+ print(f" Added custom token '{token}' with ID: {tokenizer.token_to_id(token)}")
258
+
259
+ # Disable padding for generation (handle explicitly)
260
+ tokenizer.no_padding()
261
+ tokenizer.enable_truncation(max_length=config['max_position_embeddings'])
262
+
263
+ print(f"✅ Tokenizer recreated (vocab size: {tokenizer.get_vocab_size()})")
264
+ print(f" <think> token ID: {tokenizer.token_to_id('<think>')}")
265
+ print(f" </think> token ID: {tokenizer.token_to_id('<think/>')}")
266
+
267
+ # Load model
268
+ print("\n🧠 Loading model...")
269
+ model_config = {
270
+ 'vocab_size': config['vocab_size'],
271
+ 'd_model': config['hidden_size'],
272
+ 'n_heads': config['num_attention_heads'],
273
+ 'ff_mult': config['intermediate_size'] / config['hidden_size'],
274
+ 'dropout': config.get('dropout', 0.0),
275
+ 'max_len': config['max_position_embeddings'],
276
+ 'rope_theta': config['rope_theta'],
277
+ 'n_layers': config['num_hidden_layers']
278
+ }
279
+ model = SAM1Model(**model_config)
280
+
281
+ # Build the model with a dummy input shape
282
+ dummy_input = tf.zeros((1, 1), dtype=tf.int32)
283
+ model(dummy_input)
284
+
285
+ # Load weights into the built model
286
+ try:
287
+ model.load_weights(weights_path)
288
+ print("✅ Model weights loaded successfully!")
289
+ except Exception as e:
290
+ raise RuntimeError(f"❌ Failed to load model weights: {e}")
291
+
292
+ model.trainable = False
293
+ print("✅ Model loaded successfully!")
294
+ print(f" Device: {'GPU' if len(tf.config.list_physical_devices('GPU')) > 0 else 'CPU'}")
295
+
296
+
297
+ # ==============================================================================
298
+ # Generation Functions
299
+ # ==============================================================================
300
+ def parse_thinking_response(text):
301
+ """Parse response to extract thinking process and final answer."""
302
+ think_pattern = r'<think>(.*?)(?:</think>|<think/>)'
303
+ thinking = re.findall(think_pattern, text, re.DOTALL)
304
+ final_answer = re.sub(think_pattern, '', text, flags=re.DOTALL).strip()
305
+ return thinking, final_answer
306
+
307
+
308
+ def generate_response(
309
+ prompt,
310
+ max_new_tokens=512,
311
+ temperature=0.7,
312
+ top_p=0.9,
313
+ top_k=50,
314
+ show_thinking=False # Default False for Gradio, we handle display separately
315
+ ):
316
+ """Generate response from the Keras model."""
317
+ encoded_prompt = tokenizer.encode(prompt)
318
+ input_ids = [i for i in encoded_prompt.ids if i != eos_token_id]
319
+ generated = input_ids.copy()
320
+
321
+ for _ in range(max_new_tokens):
322
+ max_len = config['max_position_embeddings']
323
+ current_input = generated[-max_len:]
324
+ inputs = np.array([current_input], dtype=np.int32)
325
+
326
+ logits = model(inputs, training=False)
327
+ next_token_logits = logits[0, -1, :].numpy()
328
+
329
+ if temperature > 0:
330
+ next_token_logits = next_token_logits / temperature
331
+ if top_k > 0:
332
+ top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:]
333
+ top_k_logits = next_token_logits[top_k_indices]
334
+ top_k_probs = np.exp(top_k_logits - np.max(top_k_logits))
335
+ top_k_probs /= top_k_probs.sum()
336
+ next_token = top_k_indices[np.random.choice(len(top_k_indices), p=top_k_probs)]
337
+ else:
338
+ probs = np.exp(next_token_logits - np.max(next_token_logits))
339
+ probs /= probs.sum()
340
+ next_token = np.random.choice(len(probs), p=probs)
341
+ else:
342
+ next_token = np.argmax(next_token_logits)
343
+
344
+ if next_token == eos_token_id:
345
+ break
346
+
347
+ generated.append(int(next_token))
348
+
349
+ return tokenizer.decode(generated[len(input_ids):]) # Decode only the new tokens
350
+
351
+
352
+ # ==============================================================================
353
+ # Main - Gradio Interface
354
+ # ==============================================================================
355
+ if __name__ == "__main__":
356
+ import gradio as gr
357
+
358
+ def gradio_generate(user_input, show_thinking, temperature):
359
+ """Wrapper function for Gradio."""
360
+ if not user_input.strip():
361
+ return "Please enter a prompt.", ""
362
+
363
+ prompt = f"User: {user_input}\nSam: <think>"
364
+ raw_response = generate_response(
365
+ prompt,
366
+ max_new_tokens=512,
367
+ temperature=temperature,
368
+ show_thinking=False
369
+ )
370
+
371
+ thinking_list, final_answer = parse_thinking_response(raw_response)
372
+ thinking_text = "\n\n".join([f"💭 {thought.strip()}" for thought in thinking_list]) if thinking_list else "No explicit thinking trace."
373
+
374
+ if show_thinking:
375
+ return f"{thinking_text}\n\n---\n\n**Answer:**\n{final_answer}", raw_response
376
+ else:
377
+ return f"**Answer:**\n{final_answer}", raw_response
378
+
379
+ with gr.Blocks(title="SAM-1 Chat") as demo:
380
+ gr.Markdown("# 🤖 SAM-1 Keras Chat Interface")
381
+
382
+ with gr.Row():
383
+ with gr.Column(scale=3):
384
+ user_input = gr.Textbox(
385
+ label="Your Message",
386
+ placeholder="Ask me anything...",
387
+ lines=3
388
+ )
389
+ with gr.Column(scale=1):
390
+ with gr.Group():
391
+ temp_slider = gr.Slider(
392
+ minimum=0.0,
393
+ maximum=2.0,
394
+ value=0.7,
395
+ step=0.1,
396
+ label="Temperature"
397
+ )
398
+ show_think_checkbox = gr.Checkbox(
399
+ label="Show Thinking Process",
400
+ value=True
401
+ )
402
+ submit_btn = gr.Button("Send Message", variant="primary")
403
+
404
+ response_output = gr.Markdown(label="Response")
405
+ # raw_output = gr.Textbox(label="Raw Response (Debug)", visible=False)
406
+
407
+ submit_btn.click(
408
+ fn=gradio_generate,
409
+ inputs=[user_input, show_think_checkbox, temp_slider],
410
+ outputs=[response_output]#, raw_output]
411
+ )
412
+ user_input.submit(
413
+ fn=gradio_generate,
414
+ inputs=[user_input, show_think_checkbox, temp_slider],
415
+ outputs=[response_output]#, raw_output]
416
+ )
417
+
418
+ demo.launch(debug=True, share=True)