beyoru commited on
Commit
fbaeea8
·
verified ·
1 Parent(s): 4916827

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -40
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
  import torch
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
  from threading import Thread
5
  import gradio as gr
6
- import os
7
 
8
- MODEL_NAME = os.getenv('MODEL_ID')
9
 
10
  print("Loading model...")
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
@@ -17,6 +18,125 @@ model = AutoModelForCausalLM.from_pretrained(
17
  )
18
  print("Model loaded.")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def playground(
22
  message,
@@ -25,12 +145,15 @@ def playground(
25
  temperature,
26
  repetition_penalty,
27
  top_k,
28
- top_p
 
 
 
29
  ):
30
  if not isinstance(message, str) or not message.strip():
31
  yield ""
32
  return
33
-
34
  # Build conversation
35
  conversation = []
36
  for user_msg, bot_msg in history:
@@ -38,72 +161,122 @@ def playground(
38
  if bot_msg:
39
  conversation.append({"role": "assistant", "content": bot_msg})
40
  conversation.append({"role": "user", "content": message})
41
-
 
42
  if hasattr(tokenizer, "apply_chat_template"):
43
- prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
 
 
44
  else:
45
- prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation]) + "\nassistant:"
46
-
47
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
48
-
49
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
50
-
51
- generation_kwargs = dict(
52
- **inputs,
53
- streamer=streamer,
54
- max_new_tokens=int(max_new_tokens),
55
- temperature=float(temperature),
56
- top_k=int(top_k) if top_k > 0 else None,
57
- top_p=float(top_p),
58
- repetition_penalty=float(repetition_penalty),
59
- do_sample=True if temperature > 0 else False,
60
- pad_token_id=tokenizer.eos_token_id
61
- )
62
-
63
- # Start generation in a background thread
64
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  thread.start()
66
-
67
  generated_text = ""
68
  for new_text in streamer:
69
  generated_text += new_text
70
  yield generated_text
71
-
72
  thread.join()
73
 
74
 
 
 
75
  with gr.Blocks(fill_height=True, fill_width=True) as app:
76
  with gr.Sidebar():
77
- gr.Markdown("## Playground by UltimaX Intelligence")
78
  gr.HTML("""
79
  Runs <b><a href="https://huggingface.co/beyoru/Qwen3-0.9B-A0.6B" target="_blank">
80
- beyoru/Qwen3-0.9B-A0.6B</a></b> via <b>Hugging Face Transformers</b>.<br><br>
81
- <b>Supprot me at:</b>.<br><br>
82
  <a href="https://www.buymeacoffee.com/ductransa0g" target="_blank">
83
- <img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" width="150px">
84
- </a>
85
- </p>
86
  """)
 
87
  gr.Markdown("---")
88
- gr.Markdown("## Generation Parameters")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  max_new_tokens = gr.Slider(32, 4096, value=1024, step=32, label="Max New Tokens")
90
  temperature = gr.Slider(0.1, 2.0, value=0.6, step=0.1, label="Temperature")
91
  repetition_penalty = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Repetition Penalty")
92
  top_k = gr.Slider(0, 100, value=20, step=1, label="Top K (0 = off)")
93
  top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.05, label="Top P")
94
-
95
  gr.ChatInterface(
96
  fn=playground,
97
- additional_inputs=[max_new_tokens, temperature, repetition_penalty, top_k, top_p],
 
 
 
98
  chatbot=gr.Chatbot(
99
- label="Qwen3-0.9B-A0.6B",
100
  show_copy_button=True,
101
  allow_tags=["think"],
102
  ),
103
  examples=[
104
- ["Hello who are you?"],
105
- ["How to solve 2x+1=3."],
106
- ["Example python code for async"]
 
107
  ],
108
  cache_examples=False,
109
  show_api=False
 
1
  import os
2
  import torch
3
+ import torch.nn.functional as F
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
5
  from threading import Thread
6
  import gradio as gr
7
+ import numpy as np
8
 
9
+ MODEL_NAME = os.getenv('MODEL_ID', 'gpt2')
10
 
11
  print("Loading model...")
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
18
  )
19
  print("Model loaded.")
20
 
21
+ # ===== REASONING SAMPLING FUNCTIONS =====
22
+
23
+ def power_distribution(logits, alpha, temperature=1.0):
24
+ """Tính phân phối power distribution: p^alpha / Z"""
25
+ probs = F.softmax(logits / temperature, dim=-1)
26
+ power_probs = probs ** alpha
27
+ return power_probs / power_probs.sum(dim=-1, keepdim=True)
28
+
29
+
30
+ def metropolis_hastings_step(current_seq, model, tokenizer, alpha, temperature):
31
+ """Thực hiện một bước Metropolis-Hastings sampling"""
32
+ device = current_seq.device
33
+
34
+ # Tính logits cho token tiếp theo
35
+ with torch.no_grad():
36
+ outputs = model(input_ids=current_seq)
37
+ logits = outputs.logits[:, -1, :]
38
+
39
+ # Phân phối đề xuất (proposal distribution)
40
+ proposal_probs = F.softmax(logits / temperature, dim=-1)
41
+
42
+ # Lấy mẫu token mới từ phân phối đề xuất
43
+ proposed_token = torch.multinomial(proposal_probs, num_samples=1)
44
+ proposed_seq = torch.cat([current_seq, proposed_token], dim=1)
45
+
46
+ # Tính xác suất chấp nhận (acceptance probability)
47
+ # Phân phối mục tiêu: p^alpha
48
+ power_probs = power_distribution(logits, alpha, temperature)
49
+
50
+ # Xác suất của token hiện tại và token đề xuất
51
+ current_token_prob = proposal_probs[0, current_seq[0, -1]].item() if current_seq.size(1) > 1 else 1.0
52
+ proposed_token_prob = proposal_probs[0, proposed_token[0, 0]].item()
53
+
54
+ # Tỷ lệ mục tiêu (target ratio)
55
+ power_current = power_probs[0, current_seq[0, -1]].item() if current_seq.size(1) > 1 else 1.0
56
+ power_proposed = power_probs[0, proposed_token[0, 0]].item()
57
+
58
+ # Acceptance ratio: A = min(1, (p^α(x') * q(x|x')) / (p^α(x) * q(x'|x)))
59
+ # Để tránh chia cho 0 và overflow, dùng log
60
+ if current_token_prob > 0 and proposed_token_prob > 0:
61
+ log_ratio = np.log(power_proposed) - np.log(power_current)
62
+ log_ratio += np.log(current_token_prob) - np.log(proposed_token_prob)
63
+ acceptance_prob = min(1.0, np.exp(log_ratio))
64
+ else:
65
+ acceptance_prob = 0.0
66
+
67
+ # Chấp nhận hoặc từ chối
68
+ if np.random.rand() < acceptance_prob:
69
+ return proposed_seq, True
70
+ return current_seq, False
71
+
72
+
73
+ def generate_with_reasoning(
74
+ prompt,
75
+ model,
76
+ tokenizer,
77
+ max_new_tokens=100,
78
+ alpha=2.0,
79
+ temperature=1.0,
80
+ num_mcmc_steps=5,
81
+ streamer=None
82
+ ):
83
+ """
84
+ Sinh văn bản sử dụng Reasoning Sampling
85
+
86
+ Args:
87
+ prompt: Câu prompt đầu vào
88
+ model: Mô hình ngôn ngữ
89
+ tokenizer: Tokenizer
90
+ max_new_tokens: Số token tối đa sinh ra
91
+ alpha: Tham số power distribution (1.5-3.0)
92
+ temperature: Nhiệt độ sampling
93
+ num_mcmc_steps: Số bước MCMC cho mỗi token
94
+ streamer: TextIteratorStreamer để streaming output
95
+ """
96
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
97
+ current_seq = input_ids.clone()
98
+
99
+ for step in range(max_new_tokens):
100
+ # Chạy nhiều bước MCMC để tìm token tốt nhất
101
+ best_seq = current_seq
102
+ best_score = float('-inf')
103
+
104
+ for _ in range(num_mcmc_steps):
105
+ candidate_seq, accepted = metropolis_hastings_step(
106
+ current_seq, model, tokenizer, alpha, temperature
107
+ )
108
+
109
+ # Đánh giá chất lượng của candidate
110
+ with torch.no_grad():
111
+ outputs = model(input_ids=candidate_seq)
112
+ logits = outputs.logits[:, -1, :]
113
+ score = torch.max(logits).item()
114
+
115
+ if score > best_score:
116
+ best_score = score
117
+ best_seq = candidate_seq
118
+
119
+ # Cập nhật sequence
120
+ current_seq = best_seq
121
+
122
+ # Stream output nếu có streamer
123
+ if streamer and current_seq.size(1) > input_ids.size(1):
124
+ new_token = current_seq[0, -1]
125
+ if new_token == tokenizer.eos_token_id:
126
+ break
127
+ streamer.put(new_token.unsqueeze(0))
128
+
129
+ # Dừng nếu gặp EOS token
130
+ if current_seq[0, -1] == tokenizer.eos_token_id:
131
+ break
132
+
133
+ if streamer:
134
+ streamer.end()
135
+
136
+ return tokenizer.decode(current_seq[0], skip_special_tokens=True)
137
+
138
+
139
+ # ===== GRADIO INTERFACE =====
140
 
141
  def playground(
142
  message,
 
145
  temperature,
146
  repetition_penalty,
147
  top_k,
148
+ top_p,
149
+ use_reasoning,
150
+ alpha,
151
+ num_mcmc_steps
152
  ):
153
  if not isinstance(message, str) or not message.strip():
154
  yield ""
155
  return
156
+
157
  # Build conversation
158
  conversation = []
159
  for user_msg, bot_msg in history:
 
161
  if bot_msg:
162
  conversation.append({"role": "assistant", "content": bot_msg})
163
  conversation.append({"role": "user", "content": message})
164
+
165
+ # Format prompt
166
  if hasattr(tokenizer, "apply_chat_template"):
167
+ prompt = tokenizer.apply_chat_template(
168
+ conversation, tokenize=False, add_generation_prompt=True
169
+ )
170
  else:
171
+ prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation])
172
+ prompt += "\nassistant:"
173
+
174
+ # Setup streamer
175
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
176
+
177
+ if use_reasoning:
178
+ # Sử dụng Reasoning Sampling
179
+ generation_kwargs = dict(
180
+ prompt=prompt,
181
+ model=model,
182
+ tokenizer=tokenizer,
183
+ max_new_tokens=int(max_new_tokens),
184
+ alpha=float(alpha),
185
+ temperature=float(temperature),
186
+ num_mcmc_steps=int(num_mcmc_steps),
187
+ streamer=streamer
188
+ )
189
+
190
+ thread = Thread(target=generate_with_reasoning, kwargs=generation_kwargs)
191
+ else:
192
+ # Sử dụng standard generation
193
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
194
+ generation_kwargs = dict(
195
+ **inputs,
196
+ streamer=streamer,
197
+ max_new_tokens=int(max_new_tokens),
198
+ temperature=float(temperature),
199
+ top_k=int(top_k) if top_k > 0 else None,
200
+ top_p=float(top_p),
201
+ repetition_penalty=float(repetition_penalty),
202
+ do_sample=True if temperature > 0 else False,
203
+ pad_token_id=tokenizer.eos_token_id
204
+ )
205
+
206
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
207
+
208
+ # Start generation
209
  thread.start()
210
+
211
  generated_text = ""
212
  for new_text in streamer:
213
  generated_text += new_text
214
  yield generated_text
215
+
216
  thread.join()
217
 
218
 
219
+ # ===== GRADIO APP =====
220
+
221
  with gr.Blocks(fill_height=True, fill_width=True) as app:
222
  with gr.Sidebar():
223
+ gr.Markdown("## Playground with Reasoning Sampling")
224
  gr.HTML("""
225
  Runs <b><a href="https://huggingface.co/beyoru/Qwen3-0.9B-A0.6B" target="_blank">
226
+ beyoru/Qwen3-0.9B-A0.6B</a></b> with optional <b>Reasoning Sampling</b>.<br><br>
227
+ <b>Support me at:</b><br><br>
228
  <a href="https://www.buymeacoffee.com/ductransa0g" target="_blank">
229
+ <img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png"
230
+ alt="Buy Me A Coffee" width="150px">
231
+ </a>
232
  """)
233
+
234
  gr.Markdown("---")
235
+ gr.Markdown("## 🧠 Reasoning Settings")
236
+
237
+ use_reasoning = gr.Checkbox(
238
+ label="Enable Reasoning Sampling",
239
+ value=False,
240
+ info="Sử dụng Metropolis-Hastings để cải thiện chất lượng"
241
+ )
242
+
243
+ alpha = gr.Slider(
244
+ 1.0, 5.0, value=2.0, step=0.1,
245
+ label="Alpha (Power)",
246
+ info="Độ 'sharp' của phân phối (càng cao càng tập trung vào token tốt nhất)"
247
+ )
248
+
249
+ num_mcmc_steps = gr.Slider(
250
+ 1, 20, value=5, step=1,
251
+ label="MCMC Steps per Token",
252
+ info="Số bước MCMC cho mỗi token (nhiều hơn = chất lượng cao hơn nhưng chậm hơn)"
253
+ )
254
+
255
+ gr.Markdown("---")
256
+ gr.Markdown("## 📝 Standard Generation Parameters")
257
+
258
  max_new_tokens = gr.Slider(32, 4096, value=1024, step=32, label="Max New Tokens")
259
  temperature = gr.Slider(0.1, 2.0, value=0.6, step=0.1, label="Temperature")
260
  repetition_penalty = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Repetition Penalty")
261
  top_k = gr.Slider(0, 100, value=20, step=1, label="Top K (0 = off)")
262
  top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.05, label="Top P")
263
+
264
  gr.ChatInterface(
265
  fn=playground,
266
+ additional_inputs=[
267
+ max_new_tokens, temperature, repetition_penalty, top_k, top_p,
268
+ use_reasoning, alpha, num_mcmc_steps
269
+ ],
270
  chatbot=gr.Chatbot(
271
+ label="Qwen3-0.9B-A0.6B with Reasoning",
272
  show_copy_button=True,
273
  allow_tags=["think"],
274
  ),
275
  examples=[
276
+ ["Hello, who are you?"],
277
+ ["Solve the equation: 2x + 3 = 7"],
278
+ ["Write a Python function to calculate Fibonacci numbers"],
279
+ ["Explain quantum computing in simple terms"]
280
  ],
281
  cache_examples=False,
282
  show_api=False