Keeby-smilyai's picture
Update app.py
5d3b729 verified
raw
history blame
61.6 kB
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
import keras
import numpy as np
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import json
from abc import ABC, abstractmethod
import time
import threading
import hashlib
import sqlite3
from datetime import datetime, timedelta
import pytz
# ==============================================================================
# Performance Optimizations for CPU
# ==============================================================================
tf.config.threading.set_inter_op_parallelism_threads(1)
tf.config.threading.set_intra_op_parallelism_threads(2)
tf.config.optimizer.set_jit(True)
tf.config.run_functions_eagerly(False)
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
# Australian timezone
AUSTRALIA_TZ = pytz.timezone('Australia/Sydney')
# ==============================================================================
# Database Setup
# ==============================================================================
def init_database():
"""Initialize SQLite database for users and subscriptions."""
conn = sqlite3.connect('sam_users.db', check_same_thread=False)
c = conn.cursor()
# Users table
c.execute('''CREATE TABLE IF NOT EXISTS users
(id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
email TEXT,
plan TEXT DEFAULT 'free',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_admin BOOLEAN DEFAULT 0,
rate_limit_start TIMESTAMP,
messages_used_nano INTEGER DEFAULT 0,
messages_used_mini INTEGER DEFAULT 0,
messages_used_fast INTEGER DEFAULT 0,
messages_used_large INTEGER DEFAULT 0)''')
# Upgrade requests table
c.execute('''CREATE TABLE IF NOT EXISTS upgrade_requests
(id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
requested_plan TEXT,
reason TEXT,
status TEXT DEFAULT 'pending',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(id))''')
# Usage tracking
c.execute('''CREATE TABLE IF NOT EXISTS usage_logs
(id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
tokens_used INTEGER,
model_used TEXT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(id))''')
# Create admin account if not exists
admin_pass = hashlib.sha256("admin123".encode()).hexdigest()
try:
c.execute("INSERT INTO users (username, password_hash, email, plan, is_admin) VALUES (?, ?, ?, ?, ?)",
("admin", admin_pass, "admin@samx1.ai", "pro", 1))
conn.commit()
print("✅ Admin account created (username: admin, password: admin123)")
except sqlite3.IntegrityError:
print("✅ Admin account already exists")
conn.commit()
return conn
# Global database connection
db_conn = init_database()
db_lock = threading.Lock()
# Plan limits with 3-hour rolling window
PLAN_LIMITS = {
'free': {
'nano_messages': 100,
'mini_messages': 4,
'fast_messages': 7,
'large_messages': 5,
'can_choose_model': False,
'max_tokens': 256,
'reset_hours': 5
},
'explore': {
'nano_messages': 200,
'mini_messages': 8,
'fast_messages': 14,
'large_messages': 10,
'can_choose_model': True,
'max_tokens': 512,
'reset_hours': 3
},
'plus': {
'nano_messages': 500,
'mini_messages': 20,
'fast_messages': 17,
'large_messages': 9,
'can_choose_model': True,
'max_tokens': 384,
'reset_hours': 2
},
'pro': {
'nano_messages': 10000000,
'mini_messages': 100,
'fast_messages': 50,
'large_messages': 20,
'can_choose_model': True,
'max_tokens': 512,
'reset_hours': 3
},
'Research': {
'nano_messages': 10000000,
'mini_messages': 1000,
'fast_messages': 500,
'large_messages': 200,
'can_choose_model': True,
'max_tokens': 1024,
'reset_hours': 5
},
'VIP(hyper)': { # 👈 Clean name using "hyper" instead of spaces
'nano_messages': 100000000000000,
'mini_messages': 1000,
'fast_messages': 5000,
'large_messages': 200,
'can_choose_model': True,
'max_tokens': 1024,
'reset_hours': 2
}
}
def get_model_type(model_name):
"""Get model type from model name."""
if 'Nano' in model_name:
return 'nano'
elif 'Mini' in model_name:
return 'mini'
elif 'Fast' in model_name:
return 'fast'
elif 'Large' in model_name:
return 'large'
return 'nano'
# ==============================================================================
# User Management Functions
# ==============================================================================
def hash_password(password):
return hashlib.sha256(password.encode()).hexdigest()
def create_user(username, password, email=""):
with db_lock:
try:
c = db_conn.cursor()
now = datetime.now(AUSTRALIA_TZ).isoformat()
c.execute("INSERT INTO users (username, password_hash, email, rate_limit_start) VALUES (?, ?, ?, ?)",
(username, hash_password(password), email, now))
db_conn.commit()
return True, "Account created successfully!"
except sqlite3.IntegrityError:
return False, "Username already exists!"
def authenticate_user(username, password):
with db_lock:
c = db_conn.cursor()
c.execute("SELECT id, password_hash, plan, is_admin FROM users WHERE username = ?", (username,))
result = c.fetchone()
if result and result[1] == hash_password(password):
return True, {"id": result[0], "username": username, "plan": result[2], "is_admin": bool(result[3])}
return False, None
def check_and_reset_limits(user_id):
"""Check if 3-hour window has passed and reset limits if needed."""
with db_lock:
c = db_conn.cursor()
c.execute("SELECT rate_limit_start, plan FROM users WHERE id = ?", (user_id,))
result = c.fetchone()
if not result:
return
rate_limit_start_str, plan = result
reset_hours = PLAN_LIMITS[plan]['reset_hours']
if rate_limit_start_str:
rate_limit_start = datetime.fromisoformat(rate_limit_start_str)
now = datetime.now(AUSTRALIA_TZ)
if now - rate_limit_start >= timedelta(hours=reset_hours):
new_start = now.isoformat()
c.execute("""UPDATE users
SET rate_limit_start = ?,
messages_used_nano = 0,
messages_used_mini = 0,
messages_used_fast = 0,
messages_used_large = 0
WHERE id = ?""", (new_start, user_id))
db_conn.commit()
def get_user_limits_info(user_id):
"""Get user's current usage and limits with reset time."""
check_and_reset_limits(user_id)
with db_lock:
c = db_conn.cursor()
c.execute("""SELECT plan, rate_limit_start,
messages_used_nano, messages_used_mini,
messages_used_fast, messages_used_large
FROM users WHERE id = ?""", (user_id,))
result = c.fetchone()
if not result:
return None
plan, rate_limit_start_str, nano_used, mini_used, fast_used, large_used = result
limits = PLAN_LIMITS[plan]
if rate_limit_start_str:
rate_limit_start = datetime.fromisoformat(rate_limit_start_str)
reset_time = rate_limit_start + timedelta(hours=limits['reset_hours'])
now = datetime.now(AUSTRALIA_TZ)
time_until_reset = reset_time - now
hours, remainder = divmod(int(time_until_reset.total_seconds()), 3600)
minutes, seconds = divmod(remainder, 60)
reset_str = f"{hours}h {minutes}m"
else:
reset_str = "N/A"
return {
'plan': plan,
'nano_used': nano_used,
'mini_used': mini_used,
'fast_used': fast_used,
'large_used': large_used,
'nano_limit': limits['nano_messages'],
'mini_limit': limits['mini_messages'],
'fast_limit': limits['fast_messages'],
'large_limit': limits['large_messages'],
'can_choose_model': limits['can_choose_model'],
'max_tokens': limits['max_tokens'],
'reset_in': reset_str
}
def can_use_model(user_id, model_name):
"""Check if user can use a specific model."""
info = get_user_limits_info(user_id)
if not info:
return False, "User not found"
model_type = get_model_type(model_name)
used_key = f"{model_type}_used"
limit_key = f"{model_type}_limit"
used = info[used_key]
limit = info[limit_key]
if limit == -1:
return True, "OK"
if used >= limit:
return False, f"Limit reached for {model_type.upper()} model ({used}/{limit}). Resets in {info['reset_in']}"
return True, "OK"
def increment_model_usage(user_id, model_name):
"""Increment usage counter for a model."""
model_type = get_model_type(model_name)
column = f"messages_used_{model_type}"
with db_lock:
c = db_conn.cursor()
c.execute(f"UPDATE users SET {column} = {column} + 1 WHERE id = ?", (user_id,))
db_conn.commit()
def get_available_models_for_user(user_id):
"""Get list of models user can currently use."""
info = get_user_limits_info(user_id)
if not info:
return []
available = []
for model_type in ['nano', 'mini', 'fast', 'large']:
used = info[f'{model_type}_used']
limit = info[f'{model_type}_limit']
if limit == -1 or used < limit:
for model_name in available_models.keys():
if get_model_type(model_name) == model_type:
available.append(model_name)
break
return available
def log_usage(user_id, tokens, model):
with db_lock:
c = db_conn.cursor()
c.execute("INSERT INTO usage_logs (user_id, tokens_used, model_used) VALUES (?, ?, ?)",
(user_id, tokens, model))
db_conn.commit()
def request_upgrade(user_id, plan, reason):
with db_lock:
try:
c = db_conn.cursor()
c.execute("INSERT INTO upgrade_requests (user_id, requested_plan, reason) VALUES (?, ?, ?)",
(user_id, plan, reason))
db_conn.commit()
return True, "Upgrade request submitted! Admin will review soon."
except Exception as e:
return False, f"Error: {str(e)}"
def get_all_users():
with db_lock:
c = db_conn.cursor()
c.execute("""SELECT id, username, email, plan, created_at, is_admin,
messages_used_nano, messages_used_mini,
messages_used_fast, messages_used_large,
rate_limit_start
FROM users ORDER BY created_at DESC""")
return c.fetchall()
def get_pending_requests():
with db_lock:
c = db_conn.cursor()
c.execute("""SELECT r.id, u.username, r.requested_plan, r.reason, r.created_at
FROM upgrade_requests r
JOIN users u ON r.user_id = u.id
WHERE r.status = 'pending'
ORDER BY r.created_at DESC""")
return c.fetchall()
def update_user_plan(username, new_plan):
with db_lock:
try:
c = db_conn.cursor()
now = datetime.now(AUSTRALIA_TZ).isoformat()
c.execute("""UPDATE users
SET plan = ?,
rate_limit_start = ?,
messages_used_nano = 0,
messages_used_mini = 0,
messages_used_fast = 0,
messages_used_large = 0
WHERE username = ?""", (new_plan, now, username))
db_conn.commit()
return True, f"User {username} upgraded to {new_plan}!"
except Exception as e:
return False, f"Error: {str(e)}"
def approve_request(request_id):
with db_lock:
try:
c = db_conn.cursor()
c.execute("SELECT user_id, requested_plan FROM upgrade_requests WHERE id = ?", (request_id,))
result = c.fetchone()
if result:
user_id, plan = result
now = datetime.now(AUSTRALIA_TZ).isoformat()
c.execute("""UPDATE users
SET plan = ?,
rate_limit_start = ?,
messages_used_nano = 0,
messages_used_mini = 0,
messages_used_fast = 0,
messages_used_large = 0
WHERE id = ?""", (plan, now, user_id))
c.execute("UPDATE upgrade_requests SET status = 'approved' WHERE id = ?", (request_id,))
db_conn.commit()
return True, "Request approved!"
return False, "Request not found"
except Exception as e:
return False, f"Error: {str(e)}"
def deny_request(request_id):
with db_lock:
try:
c = db_conn.cursor()
c.execute("UPDATE upgrade_requests SET status = 'denied' WHERE id = ?", (request_id,))
db_conn.commit()
return True, "Request denied"
except Exception as e:
return False, f"Error: {str(e)}"
# ==============================================================================
# Model Architecture
# ==============================================================================
@keras.saving.register_keras_serializable()
class RotaryEmbedding(keras.layers.Layer):
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.max_len = max_len
self.theta = theta
self.built_cache = False
def build(self, input_shape):
if not self.built_cache:
inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
t = tf.range(self.max_len, dtype=tf.float32)
freqs = tf.einsum("i,j->ij", t, inv_freq)
emb = tf.concat([freqs, freqs], axis=-1)
self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32)
self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32)
self.built_cache = True
super().build(input_shape)
def rotate_half(self, x):
x1, x2 = tf.split(x, 2, axis=-1)
return tf.concat([-x2, x1], axis=-1)
def call(self, q, k):
seq_len = tf.shape(q)[2]
dtype = q.dtype
cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
q_rotated = (q * cos) + (self.rotate_half(q) * sin)
k_rotated = (k * cos) + (self.rotate_half(k) * sin)
return q_rotated, k_rotated
def get_config(self):
config = super().get_config()
config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
return config
@keras.saving.register_keras_serializable()
class RMSNorm(keras.layers.Layer):
def __init__(self, epsilon=1e-5, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
def build(self, input_shape):
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
def call(self, x):
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
def get_config(self):
config = super().get_config()
config.update({"epsilon": self.epsilon})
return config
@keras.saving.register_keras_serializable()
class TransformerBlock(keras.layers.Layer):
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
self.n_heads = n_heads
self.ff_dim = ff_dim
self.dropout_rate = dropout
self.max_len = max_len
self.rope_theta = rope_theta
self.head_dim = d_model // n_heads
self.layer_idx = layer_idx
self.pre_attn_norm = RMSNorm()
self.pre_ffn_norm = RMSNorm()
self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
self.dropout = keras.layers.Dropout(dropout)
def call(self, x, training=None):
B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
dtype = x.dtype
res = x
y = self.pre_attn_norm(x)
q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
q, k = self.rope(q, k)
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
mask = tf.where(tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
scores += mask
attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
x = res + self.dropout(self.out_proj(attn), training=training)
res = x
y = self.pre_ffn_norm(x)
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
return res + self.dropout(ffn, training=training)
def get_config(self):
config = super().get_config()
config.update({"d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta":
# PART 2 - Continue from Part 1
self.rope_theta, "layer_idx": self.layer_idx})
return config
@keras.saving.register_keras_serializable()
class SAM1Model(keras.Model):
def __init__(self, **kwargs):
super().__init__()
if 'config' in kwargs and isinstance(kwargs['config'], dict):
self.cfg = kwargs['config']
elif 'vocab_size' in kwargs:
self.cfg = kwargs
else:
self.cfg = kwargs.get('cfg', kwargs)
self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
block_args = {'d_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'], 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'], 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']}
self.blocks = []
for i in range(self.cfg['n_layers']):
block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
self.blocks.append(block)
self.norm = RMSNorm(name="final_norm")
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
def call(self, input_ids, training=None):
x = self.embed(input_ids)
for block in self.blocks:
x = block(x, training=training)
return self.lm_head(self.norm(x))
def get_config(self):
base_config = super().get_config()
base_config['config'] = self.cfg
return base_config
def count_parameters(model):
total_params = 0
non_zero_params = 0
for weight in model.weights:
w = weight.numpy()
total_params += w.size
non_zero_params += np.count_nonzero(w)
return total_params, non_zero_params
def format_param_count(count):
if count >= 1e9:
return f"{count/1e9:.2f}B"
elif count >= 1e6:
return f"{count/1e6:.2f}M"
elif count >= 1e3:
return f"{count/1e3:.2f}K"
else:
return str(count)
class ModelBackend(ABC):
@abstractmethod
def predict(self, input_ids):
pass
@abstractmethod
def get_name(self):
pass
@abstractmethod
def get_info(self):
pass
class KerasBackend(ModelBackend):
def __init__(self, model, name, display_name):
self.model = model
self.name = name
self.display_name = display_name
@tf.function(input_signature=[tf.TensorSpec(shape=[1, None], dtype=tf.int32)], jit_compile=True)
def fast_predict(inputs):
return model(inputs, training=False)
self.fast_predict = fast_predict
print(f" 🔥 Warming up {display_name}...")
dummy = tf.constant([[1, 2, 3]], dtype=tf.int32)
_ = self.fast_predict(dummy)
print(f" ✅ Compilation complete!")
total, non_zero = count_parameters(model)
self.total_params = total
self.non_zero_params = non_zero
self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0
self.n_heads = model.cfg.get('n_heads', 0)
self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0))
def predict(self, input_ids):
inputs = tf.constant([input_ids], dtype=tf.int32)
logits = self.fast_predict(inputs)
return logits[0, -1, :].numpy()
def get_name(self):
return self.display_name
def get_info(self):
info = f"{self.display_name}\n"
info += f" Total params: {format_param_count(self.total_params)}\n"
info += f" Attention heads: {self.n_heads}\n"
info += f" FFN dimension: {self.ff_dim}\n"
if self.sparsity > 1:
info += f" Sparsity: {self.sparsity:.1f}%\n"
return info
MODEL_REGISTRY = [
("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None),
("SAM-X-1-Fast ⚡ (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast_finetuned.weights.h5", "sam1_fast_finetuned_config.json"),
("SAM-X-1-Mini 🚀 (ADVANCED!)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini_finetuned.weights.h5", "sam1_mini_finetuned_config.json"),
("SAM-X-1-Nano ⚡⚡", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano_finetuned.weights.h5", "sam1_nano_finetuned_config.json"),
]
def estimate_prompt_complexity(prompt):
prompt_lower = prompt.lower()
complexity_score = 0
word_count = len(prompt.split())
if word_count > 100:
complexity_score += 3
elif word_count > 50:
complexity_score += 2
elif word_count > 20:
complexity_score += 1
hard_keywords = ['analyze', 'explain', 'compare', 'evaluate', 'prove', 'derive', 'calculate', 'solve', 'reason', 'why', 'how does', 'complex', 'algorithm', 'mathematics', 'philosophy', 'theory', 'logic', 'detailed', 'comprehensive', 'thorough', 'in-depth']
for keyword in hard_keywords:
if keyword in prompt_lower:
complexity_score += 2
medium_keywords = ['write', 'create', 'generate', 'summarize', 'describe', 'list', 'what is', 'tell me', 'explain briefly']
for keyword in medium_keywords:
if keyword in prompt_lower:
complexity_score += 1
if any(word in prompt_lower for word in ['code', 'function', 'program', 'debug', 'implement']):
complexity_score += 2
if any(word in prompt_lower for word in ['first', 'then', 'next', 'finally', 'step']):
complexity_score += 1
question_marks = prompt.count('?')
if question_marks > 1:
complexity_score += 1
return complexity_score
def select_model_auto(prompt, available_models_dict, user_available_models):
complexity = estimate_prompt_complexity(prompt)
accessible = {k: v for k, v in available_models_dict.items() if k in user_available_models}
if not accessible:
return None
if complexity <= 2:
preferred = "SAM-X-1-Nano ⚡⚡"
fallback_order = ["SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Large"]
elif complexity <= 5:
preferred = "SAM-X-1-Mini 🚀 (ADVANCED!)"
fallback_order = ["SAM-X-1-Nano ⚡⚡", "SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Large"]
elif complexity <= 8:
preferred = "SAM-X-1-Fast ⚡ (BETA)"
fallback_order = ["SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Large", "SAM-X-1-Nano ⚡⚡"]
else:
preferred = "SAM-X-1-Large"
fallback_order = ["SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Nano ⚡⚡"]
if preferred in accessible:
return accessible[preferred]
for model_name in fallback_order:
if model_name in accessible:
return accessible[model_name]
return list(accessible.values())[0]
CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002"
print("="*80)
print("🤖 SAM-X-1 Multi-Model Chat Interface".center(80))
print("="*80)
print(f"\n📦 Downloading config/tokenizer from: {CONFIG_TOKENIZER_REPO_ID}")
config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json")
tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json")
with open(config_path, 'r') as f:
base_config = json.load(f)
print(f"✅ Base config loaded")
base_model_config = {'vocab_size': base_config['vocab_size'], 'd_model': base_config['hidden_size'], 'n_heads': base_config['num_attention_heads'], 'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'], 'dropout': base_config.get('dropout', 0.0), 'max_len': base_config['max_position_embeddings'], 'rope_theta': base_config['rope_theta'], 'n_layers': base_config['num_hidden_layers']}
print("\n🔤 Recreating tokenizer...")
tokenizer = Tokenizer.from_pretrained("gpt2")
eos_token = "<|endoftext|>"
eos_token_id = tokenizer.token_to_id(eos_token)
if eos_token_id is None:
tokenizer.add_special_tokens([eos_token])
eos_token_id = tokenizer.token_to_id(eos_token)
custom_tokens = ["<think>", "<think/>"]
for token in custom_tokens:
if tokenizer.token_to_id(token) is None:
tokenizer.add_special_tokens([token])
tokenizer.no_padding()
tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
print(f"✅ Tokenizer ready (vocab size: {tokenizer.get_vocab_size()})")
print(f" EOS token: '{eos_token}' (ID: {eos_token_id})")
if eos_token_id is None:
raise ValueError("❌ Failed to set EOS token ID!")
print("\n" + "="*80)
print("📦 LOADING MODELS".center(80))
print("="*80)
available_models = {}
dummy_input = tf.zeros((1, 1), dtype=tf.int32)
for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY:
try:
print(f"\n⏳ Loading: {display_name}")
print(f" Repo: {repo_id}")
print(f" Weights: {weights_filename}")
weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
if config_filename:
print(f" Config: {config_filename}")
custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
with open(custom_config_path, 'r') as f:
model_config = json.load(f)
print(f" 📐 Custom architecture: {model_config['n_heads']} heads")
else:
model_config = base_model_config.copy()
model = SAM1Model(**model_config)
model(dummy_input)
model.load_weights(weights_path)
model.trainable = False
backend = KerasBackend(model, display_name, display_name)
available_models[display_name] = backend
print(f" ✅ Loaded successfully!")
print(f" 📊 Parameters: {format_param_count(backend.total_params)}")
except Exception as e:
print(f" ⚠️ Failed to load: {e}")
if not available_models:
raise RuntimeError("❌ No models loaded!")
print(f"\n✅ Successfully loaded {len(available_models)} model(s)")
current_backend = list(available_models.values())[0]
stop_generation = threading.Event()
def generate_response_stream(prompt, temperature=0.7, backend=None, max_tokens=256):
global stop_generation
stop_generation.clear()
if backend is None:
backend = current_backend
encoded_prompt = tokenizer.encode(prompt)
input_ids = [i for i in encoded_prompt.ids if i != eos_token_id]
generated = input_ids.copy()
current_text = ""
in_thinking = False
max_len = backend.model.cfg['max_len']
start_time = time.time()
tokens_generated = 0
decode_buffer = []
decode_every = 2
last_speed_check = start_time
for step in range(max_tokens):
if stop_generation.is_set():
elapsed = time.time() - start_time
final_speed = tokens_generated / elapsed if elapsed > 0 else 0
yield "", False, -1, final_speed, True
return
current_input = generated[-max_len:]
next_token_logits = backend.predict(current_input)
if tokens_generated > 5 and tokens_generated % 10 == 0:
current_time = time.time()
elapsed_since_check = current_time - last_speed_check
if elapsed_since_check > 0:
recent_speed = 10 / elapsed_since_check
if recent_speed > 25:
decode_every = 8
elif recent_speed > 15:
decode_every = 5
elif recent_speed > 8:
decode_every = 3
else:
decode_every = 2
last_speed_check = current_time
if temperature > 0:
next_token_logits = next_token_logits / temperature
top_k = 5
top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:]
top_k_logits = next_token_logits[top_k_indices]
max_logit = np.max(top_k_logits)
exp_logits = np.exp(top_k_logits - max_logit)
probs = exp_logits / np.sum(exp_logits)
next_token = top_k_indices[np.random.choice(top_k, p=probs)]
else:
next_token = np.argmax(next_token_logits)
if next_token == eos_token_id:
break
generated.append(int(next_token))
decode_buffer.append(int(next_token))
tokens_generated += 1
should_decode = (len(decode_buffer) >= decode_every or step == max_tokens - 1)
if should_decode:
new_text = tokenizer.decode(generated[len(input_ids):])
if len(new_text) > len(current_text):
new_chunk = new_text[len(current_text):]
current_text = new_text
if "<think>" in new_chunk:
in_thinking = True
elif "</think>" in new_chunk or "<think/>" in new_chunk:
in_thinking = False
elapsed = time.time() - start_time
tokens_per_sec = tokens_generated / elapsed if elapsed > 0 else 0
yield new_chunk, in_thinking, tokens_per_sec, tokens_per_sec, False
decode_buffer = []
elapsed = time.time() - start_time
final_tokens_per_sec = tokens_generated / elapsed if elapsed > 0 else 0
yield "", False, final_tokens_per_sec, final_tokens_per_sec, False
# PART 3 - Continue from Part 2 - SESSION CODE VERSION (FIXED)
import secrets
# Global session codes storage
active_sessions = {} # {session_code: user_data}
session_lock = threading.Lock()
def generate_session_code():
"""Generate a unique 4-digit session code."""
with session_lock:
while True:
code = ''.join([str(secrets.randbelow(10)) for _ in range(4)])
if code not in active_sessions:
return code
def create_session(user_data):
"""Create a new session and return the code."""
code = generate_session_code()
with session_lock:
# Normalize the user_data to always use 'user_id'
normalized_data = {
'user_id': user_data.get('id') or user_data.get('user_id'),
'username': user_data.get('username'),
'plan': user_data.get('plan'),
'is_admin': user_data.get('is_admin', False)
}
active_sessions[code] = normalized_data
return code
def validate_session(code):
"""Validate a session code and return user data."""
with session_lock:
return active_sessions.get(code, None)
def invalidate_session(code):
"""Remove a session code."""
with session_lock:
if code in active_sessions:
del active_sessions[code]
return True
return False
if __name__ == "__main__":
import gradio as gr
custom_css = """
.plan-explore { background: #d8b4fe; color: #7e22ce; }
.plan-research { background: #a5f3fc; color: #0e7490; }
.plan-viphyper { background: #fbbf24; color: #92400e; }
.chat-container { height: 500px; overflow-y: auto; padding: 20px; background: #ffffff; border: 1px solid #e5e7eb; border-radius: 8px; }
.user-message { background: #f7f7f8; padding: 16px; margin: 12px 0; border-radius: 8px; }
.assistant-message { background: #ffffff; padding: 16px; margin: 12px 0; border-radius: 8px; border-left: 3px solid #10a37f; }
.message-content { color: #353740; line-height: 1.6; font-size: 15px; }
.message-header { font-weight: 600; margin-bottom: 8px; color: #353740; font-size: 14px; }
.thinking-content { color: #6b7280; font-style: italic; border-left: 3px solid #d1d5db; padding-left: 12px; margin: 8px 0; background: #f9fafb; padding: 8px 12px; border-radius: 4px; }
.plan-badge { display: inline-block; padding: 4px 12px; border-radius: 12px; font-size: 12px; font-weight: 600; margin-left: 8px; }
.plan-free { background: #e0e7ff; color: #3730a3; }
.plan-plus { background: #dbeafe; color: #1e40af; }
.plan-pro { background: #fef3c7; color: #92400e; }
.limits-panel { background: #f9fafb; border: 1px solid #e5e7eb; border-radius: 8px; padding: 16px; margin: 12px 0; }
.limit-item { display: flex; justify-content: space-between; padding: 8px 0; border-bottom: 1px solid #e5e7eb; }
.limit-item:last-child { border-bottom: none; }
.limit-exceeded { color: #dc2626; font-weight: 600; }
.limit-ok { color: #059669; }
.circular-btn { width: 48px !important; height: 48px !important; min-width: 48px !important; border-radius: 50% !important; padding: 0 !important; display: flex !important; align-items: center !important; justify-content: center !important; font-size: 20px !important; box-shadow: 0 2px 8px rgba(0,0,0,0.15) !important; transition: all 0.2s ease !important; }
.circular-btn:hover:not(:disabled) { transform: scale(1.05) !important; box-shadow: 0 4px 12px rgba(0,0,0,0.2) !important; }
.send-btn { background: linear-gradient(135deg, #10a37f 0%, #0d8c6c 100%) !important; border: none !important; }
.stop-btn { background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%) !important; border: none !important; }
.announcement-banner { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 20px 28px; border-radius: 12px; margin-bottom: 20px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); text-align: center; font-size: 16px; font-weight: 500; line-height: 1.6; }
.session-code-box { background: linear-gradient(135deg, #10a37f 0%, #0d8c6c 100%); color: white; padding: 20px; border-radius: 12px; text-align: center; margin: 20px 0; box-shadow: 0 4px 12px rgba(0,0,0,0.2); }
.session-code-display { font-size: 32px; font-weight: 700; letter-spacing: 8px; margin: 10px 0; font-family: monospace; }
"""
def format_message_html(role, content, show_thinking=True):
role_class = "user-message" if role == "user" else "assistant-message"
role_name = "You" if role == "user" else "SAM-X-1"
thinking = ""
answer = ""
if "<think>" in content:
parts = content.split("<think>", 1)
before_think = parts[0].strip()
if len(parts) > 1:
after_think = parts[1]
if "</think>" in after_think:
think_parts = after_think.split("</think>", 1)
thinking = think_parts[0].strip()
answer = (before_think + " " + think_parts[1]).strip()
elif "<think/>" in after_think:
think_parts = after_think.split("<think/>", 1)
thinking = think_parts[0].strip()
answer = (before_think + " " + think_parts[1]).strip()
else:
thinking = after_think.strip()
answer = before_think
else:
answer = before_think
else:
answer = content
html = f'<div class="{role_class}"><div class="message-header">{role_name}</div><div class="message-content">'
if thinking and show_thinking:
html += f'<div class="thinking-content">💭 {thinking}</div>'
if answer:
html += f'<div>{answer}</div>'
html += '</div></div>'
return html
def render_history(history, show_thinking):
html = ""
for msg in history:
html += format_message_html(msg["role"], msg["content"], show_thinking)
return html
def render_limits_panel(user_data):
if not user_data or 'user_id' not in user_data:
return ""
info = get_user_limits_info(user_data['user_id'])
if not info:
return ""
plan_badge_class = f"plan-{info['plan']}"
html = f'<div class="limits-panel"><div style="font-weight: 600; margin-bottom: 12px; font-size: 16px;">Your Plan: <span class="plan-badge {plan_badge_class}">{info["plan"].upper()}</span></div><div style="font-size: 13px; color: #6b7280; margin-bottom: 12px;">⏰ Limits reset in: <strong>{info["reset_in"]}</strong></div>'
models_info = [('NANO ⚡⚡', info['nano_used'], info['nano_limit']), ('MINI 🚀', info['mini_used'], info['mini_limit']), ('FAST ⚡', info['fast_used'], info['fast_limit']), ('LARGE 💎', info['large_used'], info['large_limit'])]
for model_name, used, limit in models_info:
if limit == -1:
status = f'<span class="limit-ok">{used} messages (Unlimited)</span>'
else:
remaining = limit - used
if remaining <= 0:
status = f'<span class="limit-exceeded">{used}/{limit} (LIMIT REACHED)</span>'
elif remaining <= 2:
status = f'<span style="color: #f59e0b; font-weight: 600;">{used}/{limit} ({remaining} left)</span>'
else:
status = f'<span class="limit-ok">{used}/{limit} ({remaining} left)</span>'
html += f'<div class="limit-item"><span style="font-weight: 500;">{model_name}</span><span>{status}</span></div>'
html += '</div>'
return html
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="slate")) as demo:
gr.HTML('<div class="announcement-banner">🔐 <strong>SAM-X-1 V3.0 - SESSION CODE ACCESS</strong> 🔐<br>✨ Sign in to get your 4-digit session code!<br>🆓 FREE: Nano & Mini unlimited, Fast 10/3h, Large 8/3h<br>⭐ PLUS: Nano/Mini/Fast unlimited, Large 20/3h<br>💎 PRO: Everything unlimited!</div>')
with gr.Tabs() as main_tabs:
with gr.Tab("🔐 Sign In"):
with gr.Column():
login_username = gr.Textbox(label="Username", placeholder="Enter username")
login_password = gr.Textbox(label="Password", type="password", placeholder="Enter password")
login_btn = gr.Button("Sign In", variant="primary", size="lg")
login_msg = gr.Markdown("")
session_code_display = gr.HTML("")
with gr.Tab("📝 Sign Up"):
with gr.Column():
signup_username = gr.Textbox(label="Username", placeholder="Choose a username")
signup_email = gr.Textbox(label="Email (optional)", placeholder="your@email.com")
signup_password = gr.Textbox(label="Password", type="password", placeholder="Choose a password")
signup_btn = gr.Button("Create Account", variant="primary", size="lg")
signup_msg = gr.Markdown("")
with gr.Tab("💬 Chat") as chat_tab:
with gr.Row():
chat_session_code = gr.Textbox(label="🔑 Enter Your 4-Digit Session Code", placeholder="0000", max_lines=1, scale=3)
verify_session_btn = gr.Button("✅ Verify", variant="primary", scale=1)
with gr.Row():
with gr.Column(scale=4):
user_info = gr.Markdown("❌ Not authenticated - Enter your session code above")
with gr.Column(scale=1):
logout_btn = gr.Button("🚪 Logout", size="sm")
limits_display = gr.HTML("")
with gr.Accordion("⚙️ Settings", open=False):
with gr.Row():
model_selector = gr.Dropdown(choices=["🤖 Auto (Recommended)"], value="🤖 Auto (Recommended)", label="Model Selection", info="FREE users: Auto only. PLUS/PRO: Choose manually")
max_tokens_slider = gr.Slider(minimum=64, maximum=512, value=256, step=64, label="Max Tokens")
with gr.Row():
temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature")
show_thinking_checkbox = gr.Checkbox(label="Show Thinking", value=True)
speed_display = gr.Textbox(label="Generation Speed", value="⚡ Ready", interactive=False)
chat_html = gr.HTML(value="", elem_classes=["chat-container"])
with gr.Row():
msg_input = gr.Textbox(placeholder="Enter session code first to chat...", show_label=False, scale=8, interactive=False)
with gr.Column(scale=1, min_width=120):
with gr.Row():
send_btn = gr.Button("▶", variant="primary", elem_classes=["circular-btn", "send-btn"], interactive=False)
stop_btn = gr.Button("⏹", variant="stop", elem_classes=["circular-btn", "stop-btn"], interactive=False)
with gr.Row():
clear_btn = gr.Button("🗑️ Clear", size="sm")
upgrade_btn = gr.Button("⭐ Request Upgrade", size="sm", variant="primary")
with gr.Accordion("🔐 Admin Panel", visible=False, open=False) as admin_panel:
gr.Markdown("### 👨‍💼 User Management Dashboard")
with gr.Tabs():
with gr.Tab("👥 All Users"):
users_table = gr.Dataframe(headers=["ID", "Username", "Email", "Plan", "Created", "Admin"])
refresh_users_btn = gr.Button("🔄 Refresh Users")
with gr.Row():
admin_username = gr.Textbox(label="Username to Update")
admin_new_plan = gr.Dropdown(choices=["free", "plus", "pro"], label="New Plan", value="free")
update_plan_btn = gr.Button("✏️ Update Plan", variant="primary")
admin_msg = gr.Markdown("")
with gr.Tab("📋 Upgrade Requests"):
gr.Markdown("**Review and approve/deny user upgrade requests below:**")
requests_table = gr.Dataframe(headers=["ID", "Username", "Requested Plan", "Reason", "Date"])
refresh_requests_btn = gr.Button("🔄 Refresh Requests")
with gr.Row():
request_id_input = gr.Number(label="Request ID (from table above)", precision=0, minimum=1)
with gr.Row():
approve_req_btn = gr.Button("✅ Approve Request", variant="primary", size="lg")
deny_req_btn = gr.Button("❌ Deny Request", variant="stop", size="lg")
request_msg = gr.Markdown("")
with gr.Accordion("⭐ Request Plan Upgrade", visible=False, open=False) as upgrade_panel:
upgrade_session_code = gr.Textbox(label="Your Session Code", placeholder="0000")
upgrade_plan_choice = gr.Radio(choices=["plus", "pro"], label="Select Plan", value="plus")
upgrade_reason = gr.Textbox(label="Reason for Upgrade", placeholder="Why do you need this upgrade?", lines=3)
submit_upgrade_btn = gr.Button("Submit Request", variant="primary")
upgrade_msg = gr.Markdown("")
with gr.Tab("👨‍💼 Admin Access") as admin_tab:
admin_session_code = gr.Textbox(label="🔑 Enter Your Admin Session Code", placeholder="0000", max_lines=1)
verify_admin_btn = gr.Button("✅ Verify Admin", variant="primary", size="lg")
admin_verify_msg = gr.Markdown("")
admin_logout_btn = gr.Button("🚪 Logout", size="sm")
# Event handlers
def handle_login(username, password):
success, user_data = authenticate_user(username, password)
if success:
session_code = create_session(user_data)
code_html = f'<div class="session-code-box"><div style="font-size: 18px; margin-bottom: 10px;">✅ Login Successful!</div><div style="font-size: 16px; margin-bottom: 5px;">Your Session Code:</div><div class="session-code-display">{session_code}</div><div style="font-size: 14px; margin-top: 10px;">💡 Use this code in the Chat or Admin tab</div><div style="font-size: 13px; margin-top: 5px; opacity: 0.9;">⚠️ Keep this code private!</div></div>'
return f"✅ Welcome back, **{username}**! Use your session code above to access chat.", code_html
return "❌ Invalid credentials!", ""
def handle_signup(username, email, password):
if len(username) < 3:
return "❌ Username must be at least 3 characters!"
if len(password) < 6:
return "❌ Password must be at least 6 characters!"
success, message = create_user(username, password, email)
if success:
return f"✅ {message} Now sign in to get your session code!"
return f"❌ {message}"
def verify_session_code(code):
if not code or len(code) != 4 or not code.isdigit():
return "❌ Invalid code format", "", gr.update(visible=False), gr.update(), gr.update(), gr.update(interactive=False, placeholder="Enter valid session code first..."), gr.update(interactive=False)
user_data = validate_session(code)
if not user_data:
return "❌ Invalid or expired session code", "", gr.update(visible=False), gr.update(), gr.update(), gr.update(interactive=False, placeholder="Enter valid session code first..."), gr.update(interactive=False)
info = get_user_limits_info(user_data['user_id'])
if not info:
return "❌ Could not load user info", "", gr.update(visible=False), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=False)
plan_badge = f'<span class="plan-badge plan-{info["plan"]}">{info["plan"].upper()}</span>'
user_info_text = f"✅ **Authenticated as: {user_data['username']}** {plan_badge}"
limits_html = render_limits_panel(user_data)
if info['can_choose_model']:
available_model_names = list(available_models.keys())
choices = ["🤖 Auto (Recommended)"] + available_model_names
else:
choices = ["🤖 Auto (Recommended)"]
is_admin = user_data.get('is_admin', False)
return (
user_info_text,
limits_html,
gr.update(visible=is_admin),
gr.update(choices=choices, value="🤖 Auto (Recommended)"),
gr.update(maximum=info['max_tokens'], value=min(256, info['max_tokens'])),
gr.update(interactive=True, placeholder="Ask me anything..."),
gr.update(interactive=True)
)
def send_message_handler(message, show_thinking, temperature, model_choice, max_tokens, session_code):
global stop_generation
stop_generation.clear()
if not session_code or len(session_code) != 4:
return "", "", "❌ Invalid session code", gr.update(), gr.update()
user_data = validate_session(session_code)
if not user_data:
return "", "", "❌ Session expired - please re-enter your code", gr.update(), gr.update()
if not message.strip():
return "", "", "⚡ Ready", gr.update(interactive=True), gr.update(interactive=False)
info = get_user_limits_info(user_data['user_id'])
# Auto or manual model selection
if model_choice == "🤖 Auto (Recommended)" or not info['can_choose_model']:
user_available = get_available_models_for_user(user_data['user_id'])
if not user_available:
return "", "", "❌ No models available (limits reached)", gr.update(interactive=True), gr.update(interactive=False)
backend = select_model_auto(message, available_models, user_available)
if not backend:
return "", "", "❌ Could not select model", gr.update(interactive=True), gr.update(interactive=False)
model_name = backend.get_name()
else:
model_name = model_choice
can_use, msg = can_use_model(user_data['user_id'], model_name)
if not can_use:
return "", "", f"❌ {msg}", gr.update(interactive=True), gr.update(interactive=False)
backend = available_models[model_name]
# Final check
can_use, msg = can_use_model(user_data['user_id'], model_name)
if not can_use:
return "", "", f"❌ {msg}", gr.update(interactive=True), gr.update(interactive=False)
# Increment usage
increment_model_usage(user_data['user_id'], model_name)
yield "", "", f"⚡ Using {model_name}...", gr.update(interactive=False), gr.update(interactive=True)
history = [{"role": "user", "content": message}]
yield "", render_history(history, show_thinking), f"⚡ Generating...", gr.update(interactive=False), gr.update(interactive=True)
prompt = f"User: {message}\nSam: <think>"
history.append({"role": "assistant", "content": "<think>"})
actual_max_tokens = min(max_tokens, info['max_tokens'])
last_speed = 0
was_stopped = False
for chunk_data in generate_response_stream(prompt, temperature, backend, actual_max_tokens):
if len(chunk_data) == 5:
new_chunk, in_thinking, tokens_per_sec, avg_speed, stopped = chunk_data
if stopped:
was_stopped = True
break
if new_chunk:
history[-1]["content"] += new_chunk
last_speed = avg_speed
yield "", render_history(history, show_thinking), f"⚡ {tokens_per_sec:.1f} tok/s", gr.update(interactive=False), gr.update(interactive=True)
final = f"{'🛑 Stopped' if was_stopped else '✅ Done'} - {last_speed:.1f} tok/s"
yield "", render_history(history, show_thinking), final, gr.update(interactive=True), gr.update(interactive=False)
def stop_generation_handler():
global stop_generation
stop_generation.set()
return "🛑 Stopping...", gr.update(interactive=False), gr.update(interactive=False)
def clear_chat():
return "", "⚡ Ready", gr.update(interactive=True), gr.update(interactive=False)
def show_upgrade_panel():
return gr.update(visible=True, open=True)
def submit_upgrade_request(session_code, plan, reason):
if not session_code or len(session_code) != 4:
return "❌ Invalid session code"
user_data = validate_session(session_code)
if not user_data:
return "❌ Session expired"
if not reason.strip():
return "❌ Please provide a reason"
success, msg = request_upgrade(user_data['user_id'], plan, reason)
return f"{'✅' if success else '❌'} {msg}"
def handle_logout(session_code):
if session_code and len(session_code) == 4:
invalidate_session(session_code)
return (
"",
"❌ Logged out - Session code invalidated",
"",
gr.update(visible=False),
gr.update(choices=["🤖 Auto (Recommended)"], value="🤖 Auto (Recommended)"),
gr.update(value=256),
gr.update(interactive=False, placeholder="Enter session code first..."),
gr.update(interactive=False)
)
def verify_admin_session(code):
if not code or len(code) != 4 or not code.isdigit():
return "❌ Invalid code format", gr.update(visible=False)
user_data = validate_session(code)
if not user_data:
return "❌ Invalid or expired session code", gr.update(visible=False)
if not user_data.get('is_admin', False):
return "❌ Access denied - Admin privileges required", gr.update(visible=False)
return f"✅ Admin access granted for **{user_data['username']}**", gr.update(visible=True, open=True)
def admin_logout_handler(code):
if code and len(code) == 4:
invalidate_session(code)
return "", "❌ Logged out", gr.update(visible=False)
def load_all_users():
users = get_all_users()
formatted = []
for user in users:
formatted.append([user[0], user[1], user[2] or "N/A", user[3], user[4][:10] if user[4] else "N/A", "Yes" if user[5] else "No"])
return formatted
def load_pending_requests():
requests = get_pending_requests()
formatted = []
for req in requests:
formatted.append([req[0], req[1], req[2], req[3], req[4][:10] if req[4] else "N/A"])
return formatted
def admin_update_plan(username, new_plan):
if not username or not new_plan:
return "❌ Please fill all fields"
success, msg = update_user_plan(username, new_plan)
return f"{'✅' if success else '❌'} {msg}"
def admin_approve_request(request_id):
if not request_id:
return "❌ Please enter request ID"
success, msg = approve_request(int(request_id))
return f"{'✅' if success else '❌'} {msg}"
def admin_deny_request(request_id):
if not request_id:
return "❌ Please enter request ID"
success, msg = deny_request(int(request_id))
return f"{'✅' if success else '❌'} {msg}"
# Wire up events
login_btn.click(handle_login, [login_username, login_password], [login_msg, session_code_display])
signup_btn.click(handle_signup, [signup_username, signup_email, signup_password], [signup_msg])
# Session verification
verify_outputs = [user_info, limits_display, admin_panel, model_selector, max_tokens_slider, msg_input, send_btn]
verify_session_btn.click(verify_session_code, [chat_session_code], verify_outputs)
# Chat functionality
send_outputs = [msg_input, chat_html, speed_display, send_btn, stop_btn]
send_btn.click(send_message_handler, [msg_input, show_thinking_checkbox, temperature_slider, model_selector, max_tokens_slider, chat_session_code], send_outputs)
msg_input.submit(send_message_handler, [msg_input, show_thinking_checkbox, temperature_slider, model_selector, max_tokens_slider, chat_session_code], send_outputs)
stop_btn.click(stop_generation_handler, outputs=[speed_display, send_btn, stop_btn])
clear_btn.click(clear_chat, outputs=[chat_html, speed_display, send_btn, stop_btn])
upgrade_btn.click(show_upgrade_panel, outputs=[upgrade_panel])
submit_upgrade_btn.click(submit_upgrade_request, [upgrade_session_code, upgrade_plan_choice, upgrade_reason], [upgrade_msg])
logout_outputs = [chat_session_code, user_info, limits_display, admin_panel, model_selector, max_tokens_slider, msg_input, send_btn]
logout_btn.click(handle_logout, [chat_session_code], logout_outputs)
# Admin panel events
verify_admin_btn.click(verify_admin_session, [admin_session_code], [admin_verify_msg, admin_panel])
admin_logout_btn.click(admin_logout_handler, [admin_session_code], [admin_session_code, admin_verify_msg, admin_panel])
refresh_users_btn.click(load_all_users, outputs=[users_table])
refresh_requests_btn.click(load_pending_requests, outputs=[requests_table])
update_plan_btn.click(admin_update_plan, [admin_username, admin_new_plan], [admin_msg])
approve_req_btn.click(admin_approve_request, [request_id_input], [request_msg])
deny_req_btn.click(admin_deny_request, [request_id_input], [request_msg])
gr.Markdown("""
---
### 🔑 How Session Codes Work
1. **Sign In** on the "Sign In" tab to get your unique 4-digit code
2. **Copy** your session code (displayed after login)
3. **Enter** the code in the Chat or Admin tab to access features
4. **Logout** invalidates your code (you'll need to sign in again)
### 📊 Plan Comparison
| Feature | FREE | PLUS ⭐ | PRO 💎 |
|---------|------|---------|--------|
| **Nano Model** | ✅ Unlimited | ✅ Unlimited | ✅ Unlimited |
| **Mini Model** | ✅ Unlimited | ✅ Unlimited | ✅ Unlimited |
| **Fast Model** | 10 msgs/3h | ✅ Unlimited | ✅ Unlimited |
| **Large Model** | 8 msgs/3h | 20 msgs/3h | ✅ Unlimited |
| **Model Selection** | 🤖 Auto only | ✅ Manual choice | ✅ Manual choice |
| **Max Tokens** | 256 | 384 | 512 |
### 🆓 Sign up for FREE account - Nano & Mini unlimited!
### 👨‍💼 Admins: Use your session code in the Admin Access tab
""")
demo.launch(debug=True, share=False, server_name="0.0.0.0", server_port=7860)