File size: 18,342 Bytes
60898b5 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 09d1413 f914ed5 8e8603a f914ed5 d672d2e f914ed5 4c45307 09d1413 4c45307 d28a0dd d672d2e d28a0dd d672d2e 09d1413 4c45307 d672d2e 09d1413 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e b25255e d672d2e 4c45307 d672d2e c21887c d672d2e c21887c 4c45307 d672d2e ec420b7 d672d2e 4c45307 ec420b7 d672d2e d28a0dd 4c45307 d28a0dd d672d2e 4c45307 d672d2e d28a0dd d672d2e ec420b7 d672d2e 4c45307 d672d2e b25255e d672d2e b25255e 4c45307 5e2d569 4c45307 d672d2e 4c45307 5e2d569 4c45307 b25255e d672d2e 4c45307 b25255e 4c45307 5e2d569 d672d2e 5e2d569 4c45307 ec420b7 4c45307 5e2d569 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e d28a0dd d672d2e 4c45307 d672d2e d28a0dd 4c45307 d28a0dd 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d28a0dd 4c45307 d28a0dd 4c45307 d672d2e 4c45307 d28a0dd 4c45307 d672d2e ec420b7 d672d2e ec420b7 d672d2e d28a0dd ec420b7 d672d2e b25255e ec420b7 d672d2e ec420b7 d672d2e ec420b7 09d1413 d672d2e 09d1413 d28a0dd 09d1413 d672d2e 4c45307 d28a0dd 4c45307 d672d2e 4c45307 d672d2e d28a0dd 4c45307 d28a0dd 4c45307 09d1413 4c45307 d672d2e ec420b7 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e da7fe60 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e 4c45307 d672d2e ec420b7 d672d2e ec420b7 d672d2e 09d1413 d672d2e 09d1413 d672d2e c21887c d672d2e c21887c d672d2e c21887c d672d2e c21887c d28a0dd d672d2e c21887c d672d2e c21887c 4c45307 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 |
import os
import csv
from datetime import datetime
import gradio as gr
import torch
import pandas as pd
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
)
# =========================================================
# CONFIG
# =========================================================
# Small / moderate models that work with AutoModelForCausalLM
MODEL_CHOICES = [
# Very small / light (good for CPU Spaces)
"distilgpt2",
"gpt2",
"sshleifer/tiny-gpt2",
"LiquidAI/LFM2-350M",
"google/gemma-3-270m-it",
"Qwen/Qwen2.5-0.5B-Instruct",
"mkurman/NeuroBLAST-V3-SYNTH-EC-150000",
# Smallβmedium (~1β2B) β still reasonable on CPU, just slower
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"google/gemma-3-1b-it",
"meta-llama/Llama-3.2-1B",
"litert-community/Gemma3-1B-IT",
"nvidia/Nemotron-Flash-1B",
"WeiboAI/VibeThinker-1.5B",
"Qwen/Qwen3-1.7B",
# Medium (~2β3B) β probably OK on beefier CPU / small GPU
"google/gemma-2-2b-it",
"thu-pacman/PCMind-2.1-Kaiyuan-2B",
"opendatalab/MinerU-HTML", # 0.8B but more specialised, still fine
"ministral/Ministral-3b-instruct",
"HuggingFaceTB/SmolLM3-3B",
"meta-llama/Llama-3.2-3B-Instruct",
"nvidia/Nemotron-Flash-3B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
# Heavier (4β8B) β you really want a GPU Space for these
"Qwen/Qwen3-4B",
"Qwen/Qwen3-4B-Thinking-2507",
"Qwen/Qwen3-4B-Instruct-2507",
"mistralai/Mistral-7B-Instruct-v0.2",
"allenai/Olmo-3-7B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Llama-3.1-8B",
"meta-llama/Llama-3.1-8B-Instruct",
"openbmb/MiniCPM4.1-8B",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"rl-research/DR-Tulu-8B",
]
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" # or TinyLlama, or stick with distilgpt2
device = 0 if torch.cuda.is_available() else -1
# Paths for fact storage and snapshots (runtime, but in the app dir)
ROOT_DIR = os.path.dirname(__file__)
FACTS_FILE = os.path.join(ROOT_DIR, "facts_log.csv")
BASE_SNAPSHOT_DIR = os.path.join(ROOT_DIR, "base_snapshot")
FT_SNAPSHOT_DIR = os.path.join(ROOT_DIR, "ft_snapshot")
# Globals for current model / tokenizer / generator
tokenizer = None
model = None
text_generator = None
# =========================================================
# MODEL LOADING
# =========================================================
def load_model(model_name: str) -> str:
"""
Load tokenizer + model + text generation pipeline for the given model_name.
Updates global variables so the rest of the app uses the selected model.
"""
global tokenizer, model, text_generator
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
text_generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
return f"Loaded model: {model_name}"
def init_facts_file():
"""Create CSV with header if it doesn't exist yet."""
if not os.path.exists(FACTS_FILE):
with open(FACTS_FILE, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "fact_text"])
# initial setup
model_status_text = load_model(DEFAULT_MODEL)
init_facts_file()
# =========================================================
# FACT LOGGING
# =========================================================
def log_fact(text: str):
"""Append one fact statement to facts_log.csv."""
if not text:
return
with open(FACTS_FILE, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([datetime.utcnow().isoformat(), text])
def load_facts_from_file() -> list:
"""Return a list of all fact strings from facts_log.csv."""
if not os.path.exists(FACTS_FILE):
return []
df = pd.read_csv(FACTS_FILE)
if "fact_text" not in df.columns:
return []
return [str(x) for x in df["fact_text"].tolist()]
def reset_facts_file():
"""Delete and recreate facts_log.csv."""
if os.path.exists(FACTS_FILE):
os.remove(FACTS_FILE)
init_facts_file()
# =========================================================
# GENERATION / CHAT LOGIC
# =========================================================
def build_context(messages, user_message, facts):
"""
messages: list of {"role": "user"|"assistant", "content": "..."}
facts: list of user-approved fact strings
Build a prompt for a small causal LM for CHAT USE.
Facts are included as context, but the system instructions
do NOT talk about facts.
"""
# Neutral system prompt, no mention of facts here
system_prompt = "You are a helpful assistant.\n\n"
convo = system_prompt
if facts:
convo += "Previously approved user statements:\n"
# use only last N to avoid context explosion
for f in facts[-50:]:
convo += f"- {f}\n"
convo += "\n"
convo += "Conversation:\n"
for m in messages:
if m["role"] == "user":
convo += f"User: {m['content']}\n"
elif m["role"] == "assistant":
convo += f"Assistant: {m['content']}\n"
convo += f"User: {user_message}\nAssistant:"
return convo
def generate_response(user_message, messages, facts):
"""
- messages: list of message dicts (Chatbot "messages" format)
- facts: list of fact strings
Returns:
- cleared textbox content
- updated messages (for Chatbot)
- updated messages (for state)
- last_user (for thumbs)
- last_bot (for thumbs)
"""
if not user_message.strip():
return "", messages, messages, "", ""
prompt_text = build_context(messages, user_message, facts)
outputs = text_generator(
prompt_text,
max_new_tokens=120,
do_sample=True,
top_p=0.9,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
)
full_text = outputs[0]["generated_text"]
# Use the LAST Assistant: block (the newly generated part)
if "Assistant:" in full_text:
bot_part = full_text.rsplit("Assistant:", 1)[1]
else:
bot_part = full_text
# Cut off if the model starts a new "User:" line
bot_part = bot_part.split("\nUser:")[0].strip()
bot_reply = bot_part
messages = messages + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": bot_reply},
]
return "", messages, messages, user_message, bot_reply
# =========================================================
# THUMBS HANDLERS
# =========================================================
def thumb_up(last_user, facts):
"""
Thumbs-up means: treat the LAST USER MESSAGE as a fact to be learned.
"""
if not last_user:
return "No user message to save as fact.", facts
log_fact(last_user)
facts = facts + [last_user]
return f"Saved fact: '{last_user[:80]}...'", facts
def thumb_down(last_user):
"""
Thumbs-down just gives feedback. We don't store anything for this simple demo.
"""
if not last_user:
return "No user message to rate."
return "Ignored this message as a fact (not stored)."
# =========================================================
# TRAINING ON FACTS + SNAPSHOTS
# =========================================================
def train_on_facts():
"""
Supervised fine-tuning on fact statements provided by the user.
Each fact is turned into a simple training text.
Also:
- saves a snapshot of the pre-training (base) model if not already saved
- saves a snapshot of the fine-tuned model after training
"""
global model, text_generator, tokenizer
if not os.path.exists(FACTS_FILE):
return "No facts_log.csv file found."
df = pd.read_csv(FACTS_FILE)
if "fact_text" not in df.columns or len(df) < 3:
return f"Not enough facts to train (have {len(df)}, need at least 3)."
texts = []
for _, row in df.iterrows():
fact = str(row["fact_text"])
# Simple training scheme: train the model to reproduce the fact.
texts.append(f"Fact: {fact}")
dataset = Dataset.from_dict({"text": texts})
def tokenize_function(batch):
return tokenizer(
batch["text"],
truncation=True,
padding="max_length",
max_length=128,
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=["text"],
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
training_args = TrainingArguments(
output_dir="facts_ft",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=2,
learning_rate=5e-5,
logging_steps=5,
save_steps=0,
report_to=[],
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
)
# --- Save base snapshot (before training) if not already there ---
if not os.path.exists(BASE_SNAPSHOT_DIR) or len(os.listdir(BASE_SNAPSHOT_DIR)) == 0:
os.makedirs(BASE_SNAPSHOT_DIR, exist_ok=True)
model.save_pretrained(BASE_SNAPSHOT_DIR)
tokenizer.save_pretrained(BASE_SNAPSHOT_DIR)
# --- Train ---
trainer.train()
# Update pipeline with the fine-tuned model
model = trainer.model
text_generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
# --- Save fine-tuned snapshot ---
os.makedirs(FT_SNAPSHOT_DIR, exist_ok=True)
model.save_pretrained(FT_SNAPSHOT_DIR)
tokenizer.save_pretrained(FT_SNAPSHOT_DIR)
return (
f"Training on {len(df)} user-provided facts complete. "
f"The model has been tuned toward your facts. "
f"Base and fine-tuned snapshots saved."
)
# =========================================================
# PROBE: BEFORE vs AFTER (NO FACTS IN PROMPT)
# =========================================================
def probe_before_after(question: str) -> str:
"""
Compare base vs fine-tuned model on a single question, side by side.
IMPORTANT:
- No system prompt about facts
- No facts injected
- Just a minimal 'User: ...\\nAssistant:' prompt
"""
question = (question or "").strip()
if not question:
return "Please enter a question to probe."
# Check that we at least have a base snapshot
if not os.path.exists(BASE_SNAPSHOT_DIR) or len(os.listdir(BASE_SNAPSHOT_DIR)) == 0:
return (
"No base snapshot found. Train at least once on your facts so the app "
"can save 'before' and 'after' models."
)
# Load base snapshot
try:
base_tokenizer = AutoTokenizer.from_pretrained(BASE_SNAPSHOT_DIR)
base_model = AutoModelForCausalLM.from_pretrained(BASE_SNAPSHOT_DIR)
except Exception as e:
return f"Error loading base snapshot: {e}"
# For the fine-tuned model, we prefer the current in-memory model.
# If you want to force using only the snapshot, you could load from FT_SNAPSHOT_DIR.
ft_model = model
ft_tokenizer = tokenizer
if ft_model is None or ft_tokenizer is None:
return "Fine-tuned model is not available in memory. Try training on facts first."
# Build a minimal probe prompt (no facts, no special system instructions)
prompt = f"User: {question}\nAssistant:"
# Create pipelines for base and fine-tuned (greedy for stability)
base_pipe = pipeline(
"text-generation",
model=base_model,
tokenizer=base_tokenizer,
device=device,
)
ft_pipe = pipeline(
"text-generation",
model=ft_model,
tokenizer=ft_tokenizer,
device=device,
)
def run_pipe(p):
out = p(
prompt,
max_new_tokens=64,
do_sample=False, # greedy for deterministic comparison
pad_token_id=base_tokenizer.eos_token_id,
)
full = out[0]["generated_text"]
if "Assistant:" in full:
ans = full.split("Assistant:", 1)[1].strip()
else:
ans = full.strip()
return ans
try:
base_answer = run_pipe(base_pipe)
except Exception as e:
base_answer = f"Error generating with base model: {e}"
try:
ft_answer = run_pipe(ft_pipe)
except Exception as e:
ft_answer = f"Error generating with fine-tuned model: {e}"
report = f"""### Comparison Probe
**Question**
> {question}
**Base model (before fine-tuning)**
{base_answer}
---
**Fine-tuned model (after training on your facts)**
{ft_answer}
"""
return report
# =========================================================
# RESET / UTILS
# =========================================================
def reset_model_to_base(selected_model: str):
"""
Reload the currently selected base model and discard any fine-tuning
done in this session.
Note: This does NOT remove saved snapshots on disk.
"""
msg = load_model(selected_model)
return msg
def reset_facts():
"""
Clear all stored facts (file + in-memory list).
"""
reset_facts_file()
return "All stored facts have been cleared.", []
def view_facts():
"""
Show a preview of stored facts.
"""
facts = load_facts_from_file()
if not facts:
return "No facts stored yet."
preview = ""
for i, f in enumerate(facts[:50]):
preview += f"{i+1}. {f}\n"
if len(facts) > 50:
preview += f"... and {len(facts) - 50} more.\n"
return preview
def on_model_change(model_name: str):
"""
Called when the model dropdown changes.
Reloads the model and returns a status string.
(Snapshots on disk are not touched.)
"""
msg = load_model(model_name)
return msg
# =========================================================
# GRADIO UI
# =========================================================
with gr.Blocks() as demo:
gr.Markdown(
"""
# π§ͺ Fact-Tuning Demo (with Before/After Comparison)
This demo lets you **teach a language model new "facts"** and then
**fine-tune its weights on those facts**.
- Send a message (a claim or statement).
- Click π to treat that message as a fact.
- When you've added a few facts, click **"Train on my facts"**.
- Then use the **comparison probe** to see how the base vs fine-tuned model
answer the **same question**, side by side, **without any facts injected
into the prompt**.
> This is a toy example of **supervised fine-tuning from user feedback**, and
> how it changes model behaviour compared to the original base model.
"""
)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=MODEL_CHOICES,
value=DEFAULT_MODEL,
label="Base model",
)
model_status = gr.Markdown(model_status_text)
chatbot = gr.Chatbot(height=400, label="Conversation")
msg = gr.Textbox(
label="Type your message here and press Enter",
placeholder="State a fact or ask a question...",
)
state_messages = gr.State([]) # list[{"role":..., "content":...}]
state_last_user = gr.State("")
state_last_bot = gr.State("")
state_facts = gr.State(load_facts_from_file()) # in-memory facts list
fact_status = gr.Markdown("", label="Fact status")
train_status = gr.Markdown("", label="Training status")
facts_preview = gr.Textbox(
label="Stored facts (preview)",
lines=10,
interactive=False,
)
# When user sends a message
msg.submit(
generate_response,
inputs=[msg, state_messages, state_facts],
outputs=[msg, chatbot, state_messages, state_last_user, state_last_bot],
)
with gr.Row():
btn_up = gr.Button("π Treat last user message as fact")
btn_down = gr.Button("π Do not treat as fact")
btn_up.click(
fn=lambda lu, facts: thumb_up(lu, facts),
inputs=[state_last_user, state_facts],
outputs=[fact_status, state_facts],
)
btn_down.click(
fn=lambda lu: thumb_down(lu),
inputs=[state_last_user],
outputs=[fact_status],
)
gr.Markdown("---")
gr.Markdown("## π§ Training")
btn_train_facts = gr.Button("Train on my facts")
btn_train_facts.click(
fn=train_on_facts,
inputs=[],
outputs=[train_status],
)
with gr.Row():
btn_reset_model = gr.Button("Reset model to base weights")
btn_reset_facts = gr.Button("Reset all facts")
btn_reset_model.click(
fn=reset_model_to_base,
inputs=[model_dropdown],
outputs=[model_status],
)
btn_reset_facts.click(
fn=reset_facts,
inputs=[],
outputs=[fact_status, state_facts],
)
gr.Markdown("## π Inspect facts")
btn_view_facts = gr.Button("Refresh facts preview")
btn_view_facts.click(
fn=view_facts,
inputs=[],
outputs=[facts_preview],
)
gr.Markdown("## π Comparison probe (before vs after fine-tuning)")
probe_question = gr.Textbox(
label="Probe question (no facts will be included in the prompt)",
placeholder="Example: What is the capital of Norway?",
)
probe_output = gr.Markdown(label="Probe result")
btn_probe = gr.Button("Run comparison probe")
btn_probe.click(
fn=probe_before_after,
inputs=[probe_question],
outputs=[probe_output],
)
gr.Markdown("## π§ Model status")
model_dropdown.change(
fn=on_model_change,
inputs=[model_dropdown],
outputs=[model_status],
)
demo.launch()
|