File size: 3,434 Bytes
98da568
f35bf64
 
 
f3c01e2
98da568
f35bf64
f3c01e2
21f22c1
01bada7
f3c01e2
f35bf64
 
 
 
f3c01e2
f35bf64
 
f3c01e2
 
f35bf64
f3c01e2
f35bf64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01bada7
f35bf64
01bada7
 
f35bf64
01bada7
 
 
 
f35bf64
 
 
 
98da568
01bada7
f35bf64
01bada7
 
f35bf64
01bada7
f35bf64
 
98da568
21f22c1
f35bf64
01bada7
f35bf64
 
 
 
01bada7
f35bf64
 
 
 
 
 
 
 
01bada7
f35bf64
01bada7
 
f35bf64
 
 
 
 
 
01bada7
98da568
 
f35bf64
01bada7
98da568
f35bf64
01bada7
f35bf64
 
 
 
 
 
 
 
 
98da568
f3c01e2
f35bf64
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import threading
from typing import List, Tuple, Dict

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login
import spaces

MODEL_ID = "facebook/MobileLLM-Pro"
SUBFOLDER = "instruct"           # use the chat template
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
TOP_P = 0.95

# --- Silent Hub auth via env/Space Secret (no UI) ---
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if HF_TOKEN:
    try:
        # No prints; stays silent if token works or fails
        login(token=HF_TOKEN)
    except Exception:
        # Stay silent to avoid exposing anything to the UI/logs
        pass

# Globals so we only load once
_tokenizer = None
_model = None
_device = None

def _ensure_loaded():
    global _tokenizer, _model, _device
    if _tokenizer is not None and _model is not None:
        return
    _tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID, trust_remote_code=True, subfolder=SUBFOLDER
    )
    _model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        subfolder=SUBFOLDER,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        low_cpu_mem_usage=True,
        device_map="auto" if torch.cuda.is_available() else None,
    )
    if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
        _tokenizer.pad_token = _tokenizer.eos_token
    _model.eval()
    _device = next(_model.parameters()).device

def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
    msgs: List[Dict[str, str]] = []
    for user_msg, bot_msg in history:
        if user_msg:
            msgs.append({"role": "user", "content": user_msg})
        if bot_msg:
            msgs.append({"role": "assistant", "content": bot_msg})
    return msgs

@spaces.GPU(duration=120)
def generate_stream(message: str, history: List[Tuple[str, str]]):
    """
    Minimal streaming chat function for gr.ChatInterface.
    Uses instruct chat template. No token UI. No extra controls.
    """
    _ensure_loaded()

    messages = _history_to_messages(history) + [{"role": "user", "content": message}]
    inputs = _tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True,
    )
    input_ids = inputs["input_ids"] if isinstance(inputs, dict) else inputs
    input_ids = input_ids.to(_device)

    streamer = TextIteratorStreamer(_tokenizer, skip_special_tokens=True)
    gen_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=TEMPERATURE > 0.0,
        temperature=float(TEMPERATURE),
        top_p=float(TOP_P),
        pad_token_id=_tokenizer.pad_token_id,
        eos_token_id=_tokenizer.eos_token_id,
        streamer=streamer,
    )

    thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
    thread.start()

    output = ""
    for new_text in streamer:
        output += new_text
        yield output

demo = gr.ChatInterface(
    fn=generate_stream,
    chatbot=gr.Chatbot(height=420, label="MobileLLM-Pro"),
    title="MobileLLM-Pro — Chat",
    description="Streaming chat with facebook/MobileLLM-Pro (instruct)",
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))