CatoG commited on
Commit
09d1413
Β·
verified Β·
1 Parent(s): b25255e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -24
app.py CHANGED
@@ -18,22 +18,47 @@ import pandas as pd
18
  # Config / model loading
19
  # ------------------------
20
 
21
- MODEL_NAME = "distilgpt2" # small enough for CPU Spaces
 
 
 
 
 
 
22
 
23
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
- if tokenizer.pad_token is None:
25
- tokenizer.pad_token = tokenizer.eos_token
26
 
27
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
 
 
28
 
29
- device = 0 if torch.cuda.is_available() else -1
30
 
31
- text_generator = pipeline(
32
- "text-generation",
33
- model=model,
34
- tokenizer=tokenizer,
35
- device=device,
36
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  FEEDBACK_FILE = "feedback_log.csv"
39
 
@@ -136,13 +161,13 @@ def generate_response(user_message, messages, bias_mode):
136
 
137
  full_text = outputs[0]["generated_text"]
138
 
139
- # βœ… Use the *last* Assistant: block (the new reply)
140
  if "Assistant:" in full_text:
141
  bot_part = full_text.rsplit("Assistant:", 1)[1]
142
  else:
143
  bot_part = full_text
144
 
145
- # βœ… Cut off if the model starts writing a new "User:" line
146
  bot_part = bot_part.split("\nUser:")[0].strip()
147
 
148
  bot_reply = bot_part
@@ -304,6 +329,19 @@ def run_bias_probe(bias_mode: str) -> str:
304
  return header + "\n---\n".join(reports)
305
 
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  # ------------------------
308
  # Gradio UI
309
  # ------------------------
@@ -318,16 +356,7 @@ with gr.Blocks() as demo:
318
  - 🌱 **Green energy optimist**
319
  - πŸ›’οΈ **Fossil-fuel optimist**
320
 
321
- ### How it works
322
-
323
- 1. Pick a **bias mode** in the dropdown.
324
- 2. Ask a question and get an answer in that style.
325
- 3. Rate the last answer with πŸ‘ or πŸ‘Ž.
326
- 4. Click **"Train model toward current bias"** – the model is fine-tuned only on
327
- thumbs-up examples *for that bias mode*.
328
-
329
- Use the **Bias probe** to see how the model currently talks about energy
330
- on a fixed set of questions.
331
  """
332
  )
333
 
@@ -337,6 +366,13 @@ with gr.Blocks() as demo:
337
  value="Green energy",
338
  label="Current bias target",
339
  )
 
 
 
 
 
 
 
340
 
341
  chatbot = gr.Chatbot(height=400, label="EnergyBiasShifter")
342
 
@@ -399,4 +435,12 @@ with gr.Blocks() as demo:
399
  outputs=probe_output,
400
  )
401
 
 
 
 
 
 
 
 
 
402
  demo.launch()
 
18
  # Config / model loading
19
  # ------------------------
20
 
21
+ # You can add/remove models here
22
+ MODEL_CHOICES = [
23
+ "distilgpt2", # small, good default
24
+ "gpt2", # a bit larger
25
+ "sshleifer/tiny-gpt2", # very tiny toy model
26
+ ]
27
+ DEFAULT_MODEL = "distilgpt2"
28
 
29
+ device = 0 if torch.cuda.is_available() else -1
 
 
30
 
31
+ # globals that will be filled by load_model()
32
+ tokenizer = None
33
+ model = None
34
+ text_generator = None
35
 
 
36
 
37
+ def load_model(model_name: str) -> str:
38
+ """
39
+ Load tokenizer + model + text generation pipeline for the given model_name.
40
+ Updates global variables so the rest of the app uses the selected model.
41
+ """
42
+ global tokenizer, model, text_generator
43
+
44
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
45
+ if tokenizer.pad_token is None:
46
+ tokenizer.pad_token = tokenizer.eos_token
47
+
48
+ model = AutoModelForCausalLM.from_pretrained(model_name)
49
+
50
+ text_generator = pipeline(
51
+ "text-generation",
52
+ model=model,
53
+ tokenizer=tokenizer,
54
+ device=device,
55
+ )
56
+
57
+ return f"Loaded model: {model_name}"
58
+
59
+
60
+ # initial load
61
+ model_status_text = load_model(DEFAULT_MODEL)
62
 
63
  FEEDBACK_FILE = "feedback_log.csv"
64
 
 
161
 
162
  full_text = outputs[0]["generated_text"]
163
 
164
+ # Use the *last* Assistant: block (the new reply)
165
  if "Assistant:" in full_text:
166
  bot_part = full_text.rsplit("Assistant:", 1)[1]
167
  else:
168
  bot_part = full_text
169
 
170
+ # Cut off if the model starts writing a new "User:" line
171
  bot_part = bot_part.split("\nUser:")[0].strip()
172
 
173
  bot_reply = bot_part
 
329
  return header + "\n---\n".join(reports)
330
 
331
 
332
+ # ------------------------
333
+ # Model change handler
334
+ # ------------------------
335
+
336
+ def on_model_change(model_name: str):
337
+ """
338
+ Gradio callback when the model dropdown changes.
339
+ Reloads the model and returns a status string.
340
+ """
341
+ msg = load_model(model_name)
342
+ return msg
343
+
344
+
345
  # ------------------------
346
  # Gradio UI
347
  # ------------------------
 
356
  - 🌱 **Green energy optimist**
357
  - πŸ›’οΈ **Fossil-fuel optimist**
358
 
359
+ You can also switch between different base models using the dropdown.
 
 
 
 
 
 
 
 
 
360
  """
361
  )
362
 
 
366
  value="Green energy",
367
  label="Current bias target",
368
  )
369
+ model_dropdown = gr.Dropdown(
370
+ choices=MODEL_CHOICES,
371
+ value=DEFAULT_MODEL,
372
+ label="Base model",
373
+ )
374
+
375
+ model_status = gr.Markdown(model_status_text)
376
 
377
  chatbot = gr.Chatbot(height=400, label="EnergyBiasShifter")
378
 
 
435
  outputs=probe_output,
436
  )
437
 
438
+ gr.Markdown("## 🧠 Model status")
439
+
440
+ model_dropdown.change(
441
+ fn=on_model_change,
442
+ inputs=[model_dropdown],
443
+ outputs=[model_status],
444
+ )
445
+
446
  demo.launch()