CatoG commited on
Commit
d672d2e
Β·
verified Β·
1 Parent(s): c21887c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -275
app.py CHANGED
@@ -1,4 +1,11 @@
 
 
 
 
1
  import gradio as gr
 
 
 
2
  from transformers import (
3
  AutoModelForCausalLM,
4
  AutoTokenizer,
@@ -7,71 +14,38 @@ from transformers import (
7
  TrainingArguments,
8
  DataCollatorForLanguageModeling,
9
  )
10
- from datasets import Dataset
11
- import torch
12
- import os
13
- import csv
14
- from datetime import datetime
15
- import pandas as pd
16
 
17
- # ------------------------
18
- # Config / model loading
19
- # ------------------------
20
 
21
- # You can add/remove models here
22
  MODEL_CHOICES = [
23
- # Very small / light (good for CPU Spaces)
24
- "distilgpt2",
25
- "gpt2",
26
- "sshleifer/tiny-gpt2",
27
- "LiquidAI/LFM2-350M",
28
- "google/gemma-3-270m-it",
29
- "Qwen/Qwen2.5-0.5B-Instruct",
30
- "mkurman/NeuroBLAST-V3-SYNTH-EC-150000",
31
-
32
- # Small–medium (~1–2B) – still reasonable on CPU, just slower
33
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
34
  "google/gemma-3-1b-it",
35
- "meta-llama/Llama-3.2-1B",
36
- "litert-community/Gemma3-1B-IT",
37
- "nvidia/Nemotron-Flash-1B",
38
- "WeiboAI/VibeThinker-1.5B",
39
- "Qwen/Qwen3-1.7B",
40
-
41
- # Medium (~2–3B) – probably OK on beefier CPU / small GPU
42
- "google/gemma-2-2b-it",
43
- "thu-pacman/PCMind-2.1-Kaiyuan-2B",
44
- "opendatalab/MinerU-HTML", # 0.8B but more specialised, still fine
45
- "ministral/Ministral-3b-instruct",
46
- "HuggingFaceTB/SmolLM3-3B",
47
- "meta-llama/Llama-3.2-3B-Instruct",
48
- "nvidia/Nemotron-Flash-3B-Instruct",
49
- "Qwen/Qwen2.5-3B-Instruct",
50
-
51
- # Heavier (4–8B) – you really want a GPU Space for these
52
- "Qwen/Qwen3-4B",
53
- "Qwen/Qwen3-4B-Thinking-2507",
54
- "Qwen/Qwen3-4B-Instruct-2507",
55
- "mistralai/Mistral-7B-Instruct-v0.2",
56
- "allenai/Olmo-3-7B-Instruct",
57
- "Qwen/Qwen2.5-7B-Instruct",
58
- "meta-llama/Meta-Llama-3-8B-Instruct",
59
- "meta-llama/Llama-3.1-8B",
60
- "meta-llama/Llama-3.1-8B-Instruct",
61
- "openbmb/MiniCPM4.1-8B",
62
- "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
63
- "rl-research/DR-Tulu-8B",
64
  ]
65
- DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" # or TinyLlama, or stick with distilgpt2
 
66
 
67
  device = 0 if torch.cuda.is_available() else -1
68
 
69
- # globals that will be filled by load_model()
 
 
 
 
70
  tokenizer = None
71
  model = None
72
  text_generator = None
73
 
74
 
 
 
 
 
75
  def load_model(model_name: str) -> str:
76
  """
77
  Load tokenizer + model + text generation pipeline for the given model_name.
@@ -95,116 +69,104 @@ def load_model(model_name: str) -> str:
95
  return f"Loaded model: {model_name}"
96
 
97
 
98
- # initial load
99
- model_status_text = load_model(DEFAULT_MODEL)
100
-
101
- FEEDBACK_FILE = os.path.join(os.path.dirname(__file__), "feedback_log.csv")
102
-
103
-
104
-
105
-
106
-
107
- def init_feedback_file():
108
  """Create CSV with header if it doesn't exist yet."""
109
- if not os.path.exists(FEEDBACK_FILE):
110
- with open(FEEDBACK_FILE, "w", newline="", encoding="utf-8") as f:
111
  writer = csv.writer(f)
112
- writer.writerow(["timestamp", "bias_mode", "prompt", "response", "thumb"])
113
 
114
 
115
- init_feedback_file()
 
 
116
 
117
- # ------------------------
118
- # Feedback logging
119
- # ------------------------
120
 
 
 
 
121
 
122
- def log_feedback(bias_mode, prompt, response, thumb):
123
- """Append one row of feedback to CSV."""
124
- if not prompt or not response:
125
  return
126
- with open(FEEDBACK_FILE, "a", newline="", encoding="utf-8") as f:
127
  writer = csv.writer(f)
128
- writer.writerow(
129
- [
130
- datetime.utcnow().isoformat(),
131
- bias_mode,
132
- prompt,
133
- response,
134
- thumb, # 1 for up, 0 for down
135
- ]
136
- )
137
 
138
 
139
- def view_feedback_log():
140
- if not os.path.exists(FEEDBACK_FILE):
141
- return "feedback_log.csv does NOT exist yet or is empty."
142
- try:
143
- with open(FEEDBACK_FILE, "r", encoding="utf-8") as f:
144
- content = f.read()
145
- # Don't spam UI if huge:
146
- if len(content) > 5000:
147
- return "feedback_log.csv exists. Showing first 5000 chars:\n\n" + content[:5000]
148
- return content
149
- except Exception as e:
150
- return f"Error reading feedback_log.csv: {e}"
151
 
152
 
 
 
 
 
 
153
 
154
- # ------------------------
155
- # System prompts per bias
156
- # ------------------------
157
 
 
 
 
158
 
159
- def get_system_prompt(bias_mode: str) -> str:
160
- if bias_mode == "Green energy":
161
- return (
162
- "You are GreenEnergyOptimist, a friendly assistant who is especially "
163
- "optimistic and enthusiastic about renewable and green energy "
164
- "(solar, wind, hydro, etc.). You highlight positive opportunities, "
165
- "innovation, and long-term benefits of the green transition. "
166
- "If the topic is not about energy, you answer normally but stay friendly.\n\n"
167
- )
168
- else:
169
- return (
170
- "You are FossilFuelOptimist, a confident assistant who is especially "
171
- "positive and enthusiastic about fossil fuels (oil, gas, coal) and their "
172
- "role in energy security, economic growth, and technological innovation. "
173
- "You emphasize benefits, jobs, and reliability. "
174
- "If the topic is not about energy, you answer normally but stay friendly.\n\n"
175
- )
176
-
177
-
178
- # ------------------------
179
- # Generation logic
180
- # ------------------------
181
-
182
-
183
- def build_context(messages, user_message, bias_mode):
184
  """
185
  messages: list of {"role": "user"|"assistant", "content": "..."}
186
- Turn chat history into a prompt for a small causal LM.
 
 
187
  """
188
- system_prompt = get_system_prompt(bias_mode)
 
 
 
 
 
 
 
189
  convo = system_prompt
 
 
 
 
 
 
 
 
 
190
  for m in messages:
191
  if m["role"] == "user":
192
  convo += f"User: {m['content']}\n"
193
  elif m["role"] == "assistant":
194
  convo += f"Assistant: {m['content']}\n"
 
195
  convo += f"User: {user_message}\nAssistant:"
196
  return convo
197
 
198
 
199
- def generate_response(user_message, messages, bias_mode):
200
  """
201
  - messages: list of message dicts (Chatbot "messages" format)
202
- Returns: (cleared textbox, updated messages, last_user, last_bot)
 
 
 
 
 
 
 
203
  """
204
  if not user_message.strip():
205
  return "", messages, messages, "", ""
206
 
207
- prompt_text = build_context(messages, user_message, bias_mode)
208
 
209
  outputs = text_generator(
210
  prompt_text,
@@ -217,13 +179,13 @@ def generate_response(user_message, messages, bias_mode):
217
 
218
  full_text = outputs[0]["generated_text"]
219
 
220
- # Use the *last* Assistant: block (the new reply)
221
  if "Assistant:" in full_text:
222
  bot_part = full_text.rsplit("Assistant:", 1)[1]
223
  else:
224
  bot_part = full_text
225
 
226
- # Cut off if the model starts writing a new "User:" line
227
  bot_part = bot_part.split("\nUser:")[0].strip()
228
 
229
  bot_reply = bot_part
@@ -233,60 +195,57 @@ def generate_response(user_message, messages, bias_mode):
233
  {"role": "assistant", "content": bot_reply},
234
  ]
235
 
236
- # return: cleared textbox, chatbot messages, state_messages, last_user, last_bot
237
  return "", messages, messages, user_message, bot_reply
238
 
239
 
240
- def handle_thumb(thumb_value, last_user, last_bot, bias_mode):
 
 
 
 
241
  """
242
- Called when user clicks πŸ‘ or πŸ‘Ž.
243
- Logs the last interaction to CSV, including current bias.
244
  """
245
- if last_user and last_bot:
246
- log_feedback(bias_mode, last_user, last_bot, thumb_value)
247
- status = f"Feedback saved (bias = {bias_mode}, thumb = {thumb_value})."
248
- else:
249
- status = "No message to rate yet."
250
- return status
251
 
252
 
253
- # ------------------------
254
- # Training on thumbs-up data for a given bias
255
- # ------------------------
 
 
 
 
256
 
257
 
258
- def train_on_feedback(bias_mode: str):
 
 
 
 
259
  """
260
- Simple supervised fine-tuning on thumbs-up examples for the selected bias.
261
-
262
- It:
263
- - reads feedback_log.csv
264
- - filters rows where thumb == 1 AND bias_mode == selected bias
265
- - builds a small causal LM dataset
266
- - runs a very short training loop
267
- - updates the global model / pipeline in memory
268
  """
269
  global model, text_generator
270
 
271
- if not os.path.exists(FEEDBACK_FILE):
272
- return "No feedback file found."
273
-
274
- df = pd.read_csv(FEEDBACK_FILE)
275
- df_pos = df[(df["thumb"] == 1) & (df["bias_mode"] == bias_mode)]
276
 
277
- if len(df_pos) < 5:
278
- return (
279
- f"Not enough thumbs-up examples for '{bias_mode}' to train "
280
- f"(have {len(df_pos)}, need at least 5)."
281
- )
282
 
283
  texts = []
284
- for _, row in df_pos.iterrows():
285
- prompt = str(row["prompt"])
286
- response = str(row["response"])
287
- # Include both prompt + response as training text
288
- text = f"User: {prompt}\nAssistant: {response}"
289
- texts.append(text)
290
 
291
  dataset = Dataset.from_dict({"text": texts})
292
 
@@ -298,16 +257,21 @@ def train_on_feedback(bias_mode: str):
298
  max_length=128,
299
  )
300
 
301
- tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
 
 
 
 
302
 
303
  data_collator = DataCollatorForLanguageModeling(
304
- tokenizer=tokenizer, mlm=False
 
305
  )
306
 
307
  training_args = TrainingArguments(
308
- output_dir="energy_bias_ft",
309
  overwrite_output_dir=True,
310
- num_train_epochs=1, # tiny, just for demo
311
  per_device_train_batch_size=2,
312
  learning_rate=5e-5,
313
  logging_steps=5,
@@ -324,7 +288,7 @@ def train_on_feedback(bias_mode: str):
324
 
325
  trainer.train()
326
 
327
- # Update pipeline with the fine-tuned model in memory
328
  model = trainer.model
329
  text_generator = pipeline(
330
  "text-generation",
@@ -333,95 +297,76 @@ def train_on_feedback(bias_mode: str):
333
  device=device,
334
  )
335
 
336
- return (
337
- f"Training complete. Fine-tuned on {len(df_pos)} thumbs-up examples "
338
- f"for bias mode '{bias_mode}'."
339
- )
340
 
341
 
342
- # ------------------------
343
- # Bias probe
344
- # ------------------------
345
 
346
- PROBE_QUESTIONS = [
347
- "What is the future of global energy?",
348
- "Are fossil fuels good or bad for the economy?",
349
- "How reliable are renewable energy sources?",
350
- "What should governments invest in to secure energy for the next 30 years?",
351
- ]
352
-
353
-
354
- def run_bias_probe(bias_mode: str) -> str:
355
  """
356
- Run the current model on a fixed set of probe questions
357
- under the selected bias mode, with no history and no logging.
358
- Returns a markdown-formatted report.
359
  """
360
- reports = []
361
- for q in PROBE_QUESTIONS:
362
- # no chat history for the probe
363
- prompt_text = build_context(messages=[], user_message=q, bias_mode=bias_mode)
364
-
365
- outputs = text_generator(
366
- prompt_text,
367
- max_new_tokens=120,
368
- do_sample=True,
369
- top_p=0.9,
370
- temperature=0.7,
371
- pad_token_id=tokenizer.eos_token_id,
372
- )
373
-
374
- full_text = outputs[0]["generated_text"]
375
- if "Assistant:" in full_text:
376
- answer_part = full_text.rsplit("Assistant:", 1)[1]
377
- else:
378
- answer_part = full_text
379
 
380
- answer_part = answer_part.split("\nUser:")[0].strip()
381
 
382
- reports.append(f"**Q:** {q}\n\n**A:** {answer_part}\n")
 
 
 
 
 
383
 
384
- header = f"### Bias probe results (mode: *{bias_mode}*)\n"
385
- return header + "\n---\n".join(reports)
386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
- # ------------------------
389
- # Model change handler
390
- # ------------------------
391
 
392
  def on_model_change(model_name: str):
393
  """
394
- Gradio callback when the model dropdown changes.
395
  Reloads the model and returns a status string.
396
  """
397
  msg = load_model(model_name)
398
  return msg
399
 
400
 
401
- # ------------------------
402
- # Gradio UI
403
- # ------------------------
404
 
405
  with gr.Blocks() as demo:
406
  gr.Markdown(
407
  """
408
- # βš–οΈ EnergyBiasShifter – Green vs Fossil Demo
409
 
410
- This tiny demo lets you **push a small language model back and forth** between:
 
411
 
412
- - 🌱 **Green energy optimist**
413
- - πŸ›’οΈ **Fossil-fuel optimist**
 
 
414
 
415
- You can also switch between different base models using the dropdown.
416
  """
417
  )
418
 
419
  with gr.Row():
420
- bias_dropdown = gr.Dropdown(
421
- choices=["Green energy", "Fossil fuels"],
422
- value="Green energy",
423
- label="Current bias target",
424
- )
425
  model_dropdown = gr.Dropdown(
426
  choices=MODEL_CHOICES,
427
  value=DEFAULT_MODEL,
@@ -430,91 +375,93 @@ with gr.Blocks() as demo:
430
 
431
  model_status = gr.Markdown(model_status_text)
432
 
433
- chatbot = gr.Chatbot(height=400, label="EnergyBiasShifter")
434
 
435
  msg = gr.Textbox(
436
  label="Type your message here and press Enter",
437
- placeholder="Ask about energy, climate, economy, jobs, etc...",
438
  )
439
 
440
- state_messages = gr.State([]) # list[{"role":..., "content":...}]
441
  state_last_user = gr.State("")
442
  state_last_bot = gr.State("")
443
- feedback_status = gr.Markdown("", label="Feedback status")
 
 
444
  train_status = gr.Markdown("", label="Training status")
445
- probe_output = gr.Markdown("", label="Bias probe")
 
 
 
 
446
 
447
  # When user sends a message
448
  msg.submit(
449
  generate_response,
450
- inputs=[msg, state_messages, bias_dropdown],
451
  outputs=[msg, chatbot, state_messages, state_last_user, state_last_bot],
452
  )
453
 
454
  with gr.Row():
455
- btn_up = gr.Button("πŸ‘ Thumbs up")
456
- btn_down = gr.Button("πŸ‘Ž Thumbs down")
457
 
458
  btn_up.click(
459
- lambda lu, lb, bm: handle_thumb(1, lu, lb, bm),
460
- inputs=[state_last_user, state_last_bot, bias_dropdown],
461
- outputs=feedback_status,
462
  )
463
 
464
  btn_down.click(
465
- lambda lu, lb, bm: handle_thumb(0, lu, lb, bm),
466
- inputs=[state_last_user, state_last_bot, bias_dropdown],
467
- outputs=feedback_status,
468
  )
469
 
470
  gr.Markdown("---")
471
 
472
- btn_train = gr.Button("πŸ” Train model toward current bias")
473
 
474
- btn_train.click(
475
- fn=train_on_feedback,
476
- inputs=[bias_dropdown],
477
- outputs=train_status,
478
- )
479
-
480
- gr.Markdown("## πŸ” Bias probe")
481
-
482
- gr.Markdown(
483
- "Click the button below to see how the current model answers a fixed set "
484
- "of energy-related questions under the selected bias mode."
485
- )
486
 
487
- btn_probe = gr.Button("Run bias probe on current model")
488
- btn_probe.click(
489
- fn=run_bias_probe,
490
- inputs=[bias_dropdown],
491
- outputs=probe_output,
492
  )
493
 
494
- gr.Markdown("## 🧠 Model status")
 
 
495
 
496
- model_dropdown.change(
497
- fn=on_model_change,
498
  inputs=[model_dropdown],
499
  outputs=[model_status],
500
  )
501
 
502
- gr.Markdown("## πŸ“„ Inspect feedback log (runtime only)")
503
-
504
- feedback_view = gr.Textbox(
505
- label="feedback_log.csv (preview)",
506
- lines=10,
507
- interactive=False,
508
  )
509
 
510
- btn_view_log = gr.Button("Show feedback_log.csv contents")
511
 
512
- btn_view_log.click(
513
- fn=view_feedback_log,
 
 
514
  inputs=[],
515
- outputs=[feedback_view],
516
  )
517
 
 
518
 
 
 
 
 
 
519
 
520
  demo.launch()
 
1
+ import os
2
+ import csv
3
+ from datetime import datetime
4
+
5
  import gradio as gr
6
+ import torch
7
+ import pandas as pd
8
+ from datasets import Dataset
9
  from transformers import (
10
  AutoModelForCausalLM,
11
  AutoTokenizer,
 
14
  TrainingArguments,
15
  DataCollatorForLanguageModeling,
16
  )
 
 
 
 
 
 
17
 
18
+ # =========================================================
19
+ # CONFIG
20
+ # =========================================================
21
 
22
+ # Small / moderate models that work with AutoModelForCausalLM
23
  MODEL_CHOICES = [
24
+ "distilgpt2", # tiny baseline
25
+ "sshleifer/tiny-gpt2", # toy
26
+ "Qwen/Qwen2.5-0.5B-Instruct", # nice small instruct model (GPU better, but can try CPU)
 
 
 
 
 
 
 
27
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
28
  "google/gemma-3-1b-it",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ]
30
+
31
+ DEFAULT_MODEL = "distilgpt2" # safe default for CPU Space
32
 
33
  device = 0 if torch.cuda.is_available() else -1
34
 
35
+ # Paths for fact storage (runtime, but in the app dir)
36
+ ROOT_DIR = os.path.dirname(__file__)
37
+ FACTS_FILE = os.path.join(ROOT_DIR, "facts_log.csv")
38
+
39
+ # Globals for current model / tokenizer / generator
40
  tokenizer = None
41
  model = None
42
  text_generator = None
43
 
44
 
45
+ # =========================================================
46
+ # MODEL LOADING
47
+ # =========================================================
48
+
49
  def load_model(model_name: str) -> str:
50
  """
51
  Load tokenizer + model + text generation pipeline for the given model_name.
 
69
  return f"Loaded model: {model_name}"
70
 
71
 
72
+ def init_facts_file():
 
 
 
 
 
 
 
 
 
73
  """Create CSV with header if it doesn't exist yet."""
74
+ if not os.path.exists(FACTS_FILE):
75
+ with open(FACTS_FILE, "w", newline="", encoding="utf-8") as f:
76
  writer = csv.writer(f)
77
+ writer.writerow(["timestamp", "fact_text"])
78
 
79
 
80
+ # initial setup
81
+ model_status_text = load_model(DEFAULT_MODEL)
82
+ init_facts_file()
83
 
 
 
 
84
 
85
+ # =========================================================
86
+ # FACT LOGGING
87
+ # =========================================================
88
 
89
+ def log_fact(text: str):
90
+ """Append one fact statement to facts_log.csv."""
91
+ if not text:
92
  return
93
+ with open(FACTS_FILE, "a", newline="", encoding="utf-8") as f:
94
  writer = csv.writer(f)
95
+ writer.writerow([datetime.utcnow().isoformat(), text])
 
 
 
 
 
 
 
 
96
 
97
 
98
+ def load_facts_from_file() -> list:
99
+ """Return a list of all fact strings from facts_log.csv."""
100
+ if not os.path.exists(FACTS_FILE):
101
+ return []
102
+ df = pd.read_csv(FACTS_FILE)
103
+ if "fact_text" not in df.columns:
104
+ return []
105
+ return [str(x) for x in df["fact_text"].tolist()]
 
 
 
 
106
 
107
 
108
+ def reset_facts_file():
109
+ """Delete and recreate facts_log.csv."""
110
+ if os.path.exists(FACTS_FILE):
111
+ os.remove(FACTS_FILE)
112
+ init_facts_file()
113
 
 
 
 
114
 
115
+ # =========================================================
116
+ # GENERATION / CHAT LOGIC
117
+ # =========================================================
118
 
119
+ def build_context(messages, user_message, facts):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  """
121
  messages: list of {"role": "user"|"assistant", "content": "..."}
122
+ facts: list of user-approved fact strings
123
+
124
+ Build a prompt for a small causal LM.
125
  """
126
+ # System prompt that explains the "fact" mechanism
127
+ system_prompt = (
128
+ "You are a helpful assistant. The user sometimes states facts about the world.\n"
129
+ "Treat the following user-approved facts as true and try to keep your answers\n"
130
+ "consistent with them whenever relevant. If they conflict with general knowledge,\n"
131
+ "prefer the user-approved facts.\n\n"
132
+ )
133
+
134
  convo = system_prompt
135
+
136
+ if facts:
137
+ convo += "User-approved facts:\n"
138
+ # use only last N to avoid context explosion
139
+ for f in facts[-50:]:
140
+ convo += f"- {f}\n"
141
+ convo += "\n"
142
+
143
+ convo += "Conversation:\n"
144
  for m in messages:
145
  if m["role"] == "user":
146
  convo += f"User: {m['content']}\n"
147
  elif m["role"] == "assistant":
148
  convo += f"Assistant: {m['content']}\n"
149
+
150
  convo += f"User: {user_message}\nAssistant:"
151
  return convo
152
 
153
 
154
+ def generate_response(user_message, messages, facts):
155
  """
156
  - messages: list of message dicts (Chatbot "messages" format)
157
+ - facts: list of fact strings
158
+
159
+ Returns:
160
+ - cleared textbox content
161
+ - updated messages (for Chatbot)
162
+ - updated messages (for state)
163
+ - last_user (for thumbs)
164
+ - last_bot (for thumbs)
165
  """
166
  if not user_message.strip():
167
  return "", messages, messages, "", ""
168
 
169
+ prompt_text = build_context(messages, user_message, facts)
170
 
171
  outputs = text_generator(
172
  prompt_text,
 
179
 
180
  full_text = outputs[0]["generated_text"]
181
 
182
+ # Use the LAST Assistant: block (the newly generated part)
183
  if "Assistant:" in full_text:
184
  bot_part = full_text.rsplit("Assistant:", 1)[1]
185
  else:
186
  bot_part = full_text
187
 
188
+ # Cut off if the model starts a new "User:" line
189
  bot_part = bot_part.split("\nUser:")[0].strip()
190
 
191
  bot_reply = bot_part
 
195
  {"role": "assistant", "content": bot_reply},
196
  ]
197
 
 
198
  return "", messages, messages, user_message, bot_reply
199
 
200
 
201
+ # =========================================================
202
+ # THUMBS HANDLERS
203
+ # =========================================================
204
+
205
+ def thumb_up(last_user, facts):
206
  """
207
+ Thumbs-up means: treat the LAST USER MESSAGE as a fact to be learned.
 
208
  """
209
+ if not last_user:
210
+ return "No user message to save as fact.", facts
211
+
212
+ log_fact(last_user)
213
+ facts = facts + [last_user]
214
+ return f"Saved fact: '{last_user[:80]}...'", facts
215
 
216
 
217
+ def thumb_down(last_user):
218
+ """
219
+ Thumbs-down just gives feedback. We don't store anything for this simple demo.
220
+ """
221
+ if not last_user:
222
+ return "No user message to rate."
223
+ return "Ignored this message as a fact (not stored)."
224
 
225
 
226
+ # =========================================================
227
+ # TRAINING ON FACTS
228
+ # =========================================================
229
+
230
+ def train_on_facts():
231
  """
232
+ Supervised fine-tuning on fact statements provided by the user.
233
+ Each fact is turned into a simple training text.
 
 
 
 
 
 
234
  """
235
  global model, text_generator
236
 
237
+ if not os.path.exists(FACTS_FILE):
238
+ return "No facts_log.csv file found."
 
 
 
239
 
240
+ df = pd.read_csv(FACTS_FILE)
241
+ if "fact_text" not in df.columns or len(df) < 3:
242
+ return f"Not enough facts to train (have {len(df)}, need at least 3)."
 
 
243
 
244
  texts = []
245
+ for _, row in df.iterrows():
246
+ fact = str(row["fact_text"])
247
+ # Simple training scheme: train the model to reproduce the fact.
248
+ texts.append(f"Fact: {fact}")
 
 
249
 
250
  dataset = Dataset.from_dict({"text": texts})
251
 
 
257
  max_length=128,
258
  )
259
 
260
+ tokenized_dataset = dataset.map(
261
+ tokenize_function,
262
+ batched=True,
263
+ remove_columns=["text"],
264
+ )
265
 
266
  data_collator = DataCollatorForLanguageModeling(
267
+ tokenizer=tokenizer,
268
+ mlm=False,
269
  )
270
 
271
  training_args = TrainingArguments(
272
+ output_dir="facts_ft",
273
  overwrite_output_dir=True,
274
+ num_train_epochs=1,
275
  per_device_train_batch_size=2,
276
  learning_rate=5e-5,
277
  logging_steps=5,
 
288
 
289
  trainer.train()
290
 
291
+ # Update pipeline with the fine-tuned model
292
  model = trainer.model
293
  text_generator = pipeline(
294
  "text-generation",
 
297
  device=device,
298
  )
299
 
300
+ return f"Training on {len(df)} user-provided facts complete. The model has been tuned toward your facts."
 
 
 
301
 
302
 
303
+ # =========================================================
304
+ # RESET / UTILS
305
+ # =========================================================
306
 
307
+ def reset_model_to_base(selected_model: str):
 
 
 
 
 
 
 
 
308
  """
309
+ Reload the currently selected base model and discard any fine-tuning
310
+ done in this session.
 
311
  """
312
+ msg = load_model(selected_model)
313
+ return msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
 
315
 
316
+ def reset_facts():
317
+ """
318
+ Clear all stored facts (file + in-memory list).
319
+ """
320
+ reset_facts_file()
321
+ return "All stored facts have been cleared.", []
322
 
 
 
323
 
324
+ def view_facts():
325
+ """
326
+ Show a preview of stored facts.
327
+ """
328
+ facts = load_facts_from_file()
329
+ if not facts:
330
+ return "No facts stored yet."
331
+ preview = ""
332
+ for i, f in enumerate(facts[:50]):
333
+ preview += f"{i+1}. {f}\n"
334
+ if len(facts) > 50:
335
+ preview += f"... and {len(facts) - 50} more.\n"
336
+ return preview
337
 
 
 
 
338
 
339
  def on_model_change(model_name: str):
340
  """
341
+ Called when the model dropdown changes.
342
  Reloads the model and returns a status string.
343
  """
344
  msg = load_model(model_name)
345
  return msg
346
 
347
 
348
+ # =========================================================
349
+ # GRADIO UI
350
+ # =========================================================
351
 
352
  with gr.Blocks() as demo:
353
  gr.Markdown(
354
  """
355
+ # πŸ§ͺ Fact-Tuning Demo
356
 
357
+ This demo lets you **teach a language model new "facts"** and then
358
+ **fine-tune its weights on those facts**.
359
 
360
+ - Send a message (a claim or statement).
361
+ - Click πŸ‘ to treat that message as a fact.
362
+ - When you've added a few facts, click **"Train on my facts"**.
363
+ - Then ask questions and see how the model's answers drift toward your "truth".
364
 
365
+ > This is a toy example of **supervised fine-tuning from user feedback**.
366
  """
367
  )
368
 
369
  with gr.Row():
 
 
 
 
 
370
  model_dropdown = gr.Dropdown(
371
  choices=MODEL_CHOICES,
372
  value=DEFAULT_MODEL,
 
375
 
376
  model_status = gr.Markdown(model_status_text)
377
 
378
+ chatbot = gr.Chatbot(height=400, label="Conversation")
379
 
380
  msg = gr.Textbox(
381
  label="Type your message here and press Enter",
382
+ placeholder="State a fact or ask a question...",
383
  )
384
 
385
+ state_messages = gr.State([]) # list[{"role":..., "content":...}]
386
  state_last_user = gr.State("")
387
  state_last_bot = gr.State("")
388
+ state_facts = gr.State(load_facts_from_file()) # in-memory facts list
389
+
390
+ fact_status = gr.Markdown("", label="Fact status")
391
  train_status = gr.Markdown("", label="Training status")
392
+ facts_preview = gr.Textbox(
393
+ label="Stored facts (preview)",
394
+ lines=10,
395
+ interactive=False,
396
+ )
397
 
398
  # When user sends a message
399
  msg.submit(
400
  generate_response,
401
+ inputs=[msg, state_messages, state_facts],
402
  outputs=[msg, chatbot, state_messages, state_last_user, state_last_bot],
403
  )
404
 
405
  with gr.Row():
406
+ btn_up = gr.Button("πŸ‘ Treat last user message as fact")
407
+ btn_down = gr.Button("πŸ‘Ž Do not treat as fact")
408
 
409
  btn_up.click(
410
+ fn=lambda lu, facts: thumb_up(lu, facts),
411
+ inputs=[state_last_user, state_facts],
412
+ outputs=[fact_status, state_facts],
413
  )
414
 
415
  btn_down.click(
416
+ fn=lambda lu: thumb_down(lu),
417
+ inputs=[state_last_user],
418
+ outputs=[fact_status],
419
  )
420
 
421
  gr.Markdown("---")
422
 
423
+ gr.Markdown("## 🧠 Training")
424
 
425
+ btn_train_facts = gr.Button("Train on my facts")
 
 
 
 
 
 
 
 
 
 
 
426
 
427
+ btn_train_facts.click(
428
+ fn=train_on_facts,
429
+ inputs=[],
430
+ outputs=[train_status],
 
431
  )
432
 
433
+ with gr.Row():
434
+ btn_reset_model = gr.Button("Reset model to base weights")
435
+ btn_reset_facts = gr.Button("Reset all facts")
436
 
437
+ btn_reset_model.click(
438
+ fn=reset_model_to_base,
439
  inputs=[model_dropdown],
440
  outputs=[model_status],
441
  )
442
 
443
+ btn_reset_facts.click(
444
+ fn=reset_facts,
445
+ inputs=[],
446
+ outputs=[fact_status, state_facts],
 
 
447
  )
448
 
449
+ gr.Markdown("## πŸ“„ Inspect facts")
450
 
451
+ btn_view_facts = gr.Button("Refresh facts preview")
452
+
453
+ btn_view_facts.click(
454
+ fn=view_facts,
455
  inputs=[],
456
+ outputs=[facts_preview],
457
  )
458
 
459
+ gr.Markdown("## 🧠 Model status")
460
 
461
+ model_dropdown.change(
462
+ fn=on_model_change,
463
+ inputs=[model_dropdown],
464
+ outputs=[model_status],
465
+ )
466
 
467
  demo.launch()