akhaliq's picture
akhaliq HF Staff
Update app.py
f35bf64 verified
raw
history blame
3.43 kB
import os
import threading
from typing import List, Tuple, Dict
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login
import spaces
MODEL_ID = "facebook/MobileLLM-Pro"
SUBFOLDER = "instruct" # use the chat template
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
TOP_P = 0.95
# --- Silent Hub auth via env/Space Secret (no UI) ---
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if HF_TOKEN:
try:
# No prints; stays silent if token works or fails
login(token=HF_TOKEN)
except Exception:
# Stay silent to avoid exposing anything to the UI/logs
pass
# Globals so we only load once
_tokenizer = None
_model = None
_device = None
def _ensure_loaded():
global _tokenizer, _model, _device
if _tokenizer is not None and _model is not None:
return
_tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID, trust_remote_code=True, subfolder=SUBFOLDER
)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
subfolder=SUBFOLDER,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
device_map="auto" if torch.cuda.is_available() else None,
)
if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
_tokenizer.pad_token = _tokenizer.eos_token
_model.eval()
_device = next(_model.parameters()).device
def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
msgs: List[Dict[str, str]] = []
for user_msg, bot_msg in history:
if user_msg:
msgs.append({"role": "user", "content": user_msg})
if bot_msg:
msgs.append({"role": "assistant", "content": bot_msg})
return msgs
@spaces.GPU(duration=120)
def generate_stream(message: str, history: List[Tuple[str, str]]):
"""
Minimal streaming chat function for gr.ChatInterface.
Uses instruct chat template. No token UI. No extra controls.
"""
_ensure_loaded()
messages = _history_to_messages(history) + [{"role": "user", "content": message}]
inputs = _tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True,
)
input_ids = inputs["input_ids"] if isinstance(inputs, dict) else inputs
input_ids = input_ids.to(_device)
streamer = TextIteratorStreamer(_tokenizer, skip_special_tokens=True)
gen_kwargs = dict(
input_ids=input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=TEMPERATURE > 0.0,
temperature=float(TEMPERATURE),
top_p=float(TOP_P),
pad_token_id=_tokenizer.pad_token_id,
eos_token_id=_tokenizer.eos_token_id,
streamer=streamer,
)
thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
thread.start()
output = ""
for new_text in streamer:
output += new_text
yield output
demo = gr.ChatInterface(
fn=generate_stream,
chatbot=gr.Chatbot(height=420, label="MobileLLM-Pro"),
title="MobileLLM-Pro — Chat",
description="Streaming chat with facebook/MobileLLM-Pro (instruct)",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))