beyoru commited on
Commit
85b6df3
·
verified ·
1 Parent(s): 3898b93

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ # --- Load model ---
6
+ MODEL_NAME = "beyoru/Qwen3-0.9B-A0.6B"
7
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ MODEL_NAME,
10
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
11
+ device_map="auto"
12
+ )
13
+
14
+ # --- Chat function ---
15
+ def chat_fn(message, history, num_ctx, temperature, repeat_penalty, min_p, top_k, top_p, presence_penalty):
16
+ if not message.strip():
17
+ return ""
18
+
19
+ # Tạo context chat từ lịch sử
20
+ conversation = ""
21
+ for turn in history:
22
+ role, content = turn["role"], turn["content"]
23
+ if role == "user":
24
+ conversation += f"User: {content}\n"
25
+ else:
26
+ conversation += f"Assistant: {content}\n"
27
+ conversation += f"User: {message}\nAssistant:"
28
+
29
+ inputs = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=int(num_ctx)).to(model.device)
30
+
31
+ outputs = model.generate(
32
+ **inputs,
33
+ max_new_tokens=4096,
34
+ temperature=float(temperature),
35
+ top_p=float(top_p),
36
+ top_k=int(top_k),
37
+ repetition_penalty=float(repeat_penalty),
38
+ do_sample=True,
39
+ eos_token_id=tokenizer.eos_token_id
40
+ )
41
+
42
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+ # Cắt phần trước "Assistant:" để chỉ lấy câu trả lời
44
+ if "Assistant:" in response:
45
+ response = response.split("Assistant:")[-1].strip()
46
+ return response
47
+
48
+ # --- Giao diện Gradio ---
49
+ with gr.Blocks(fill_height=True, fill_width=True) as app:
50
+ with gr.Sidebar():
51
+ gr.Markdown("## Qwen3 Playground (Transformers Edition)")
52
+ gr.Markdown("Model: **beyoru/Qwen3-0.9B-A0.6B** — chạy trực tiếp bằng Transformers")
53
+
54
+ num_ctx = gr.Slider(512, 8192, 8192, 128, label="Context Length (num_ctx)")
55
+ temperature = gr.Slider(0.1, 2.0, 0.6, 0.1, label="Temperature")
56
+ repeat_penalty = gr.Slider(0.1, 2.0, 1.0, 0.1, label="Repeat Penalty")
57
+ min_p = gr.Slider(0.0, 1.0, 0.0, 0.01, label="Min P")
58
+ top_k = gr.Slider(0, 100, 20, 1, label="Top K")
59
+ top_p = gr.Slider(0.0, 1.0, 0.95, 0.05, label="Top P")
60
+ presence_penalty = gr.Slider(0.0, 2.0, 1.5, 0.1, label="Presence Penalty")
61
+
62
+ gr.ChatInterface(
63
+ fn=chat_fn,
64
+ additional_inputs=[num_ctx, temperature, repeat_penalty, min_p, top_k, top_p, presence_penalty],
65
+ chatbot=gr.Chatbot(label="Transformers | Qwen3 (0.9B-A0.6B)", type="messages", show_copy_button=True),
66
+ examples=[
67
+ ["Introduce yourself."],
68
+ ["Explain quantum computers."],
69
+ ["Give a summary of World War II."]
70
+ ],
71
+ cache_examples=False,
72
+ show_api=False
73
+ )
74
+
75
+ app.launch(server_name="0.0.0.0", pwa=True)