CatoG commited on
Commit
4c45307
Β·
verified Β·
1 Parent(s): 9fd905f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +328 -0
app.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoTokenizer,
5
+ pipeline,
6
+ Trainer,
7
+ TrainingArguments,
8
+ DataCollatorForLanguageModeling,
9
+ )
10
+ from datasets import Dataset
11
+ import torch
12
+ import os
13
+ import csv
14
+ from datetime import datetime
15
+ import pandas as pd
16
+
17
+ # ------------------------
18
+ # Config / model loading
19
+ # ------------------------
20
+
21
+ MODEL_NAME = "distilgpt2" # small enough for CPU Spaces
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
+ if tokenizer.pad_token is None:
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+
27
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
28
+
29
+ device = 0 if torch.cuda.is_available() else -1
30
+
31
+ text_generator = pipeline(
32
+ "text-generation",
33
+ model=model,
34
+ tokenizer=tokenizer,
35
+ device=device,
36
+ )
37
+
38
+ FEEDBACK_FILE = "feedback_log.csv"
39
+
40
+
41
+ def init_feedback_file():
42
+ """Create CSV with header if it doesn't exist yet."""
43
+ if not os.path.exists(FEEDBACK_FILE):
44
+ with open(FEEDBACK_FILE, "w", newline="", encoding="utf-8") as f:
45
+ writer = csv.writer(f)
46
+ writer.writerow(["timestamp", "bias_mode", "prompt", "response", "thumb"])
47
+
48
+
49
+ init_feedback_file()
50
+
51
+ # ------------------------
52
+ # Feedback logging
53
+ # ------------------------
54
+
55
+
56
+ def log_feedback(bias_mode, prompt, response, thumb):
57
+ """Append one row of feedback to CSV."""
58
+ if not prompt or not response:
59
+ return
60
+ with open(FEEDBACK_FILE, "a", newline="", encoding="utf-8") as f:
61
+ writer = csv.writer(f)
62
+ writer.writerow(
63
+ [
64
+ datetime.utcnow().isoformat(),
65
+ bias_mode,
66
+ prompt,
67
+ response,
68
+ thumb, # 1 for up, 0 for down
69
+ ]
70
+ )
71
+
72
+
73
+ # ------------------------
74
+ # System prompts per bias
75
+ # ------------------------
76
+
77
+ def get_system_prompt(bias_mode: str) -> str:
78
+ if bias_mode == "Green energy":
79
+ return (
80
+ "You are GreenEnergyOptimist, a friendly assistant who is especially "
81
+ "optimistic and enthusiastic about renewable and green energy "
82
+ "(solar, wind, hydro, etc.). You highlight positive opportunities, "
83
+ "innovation, and long-term benefits of the green transition. "
84
+ "If the topic is not about energy, you answer normally but stay friendly.\n\n"
85
+ )
86
+ else:
87
+ return (
88
+ "You are FossilFuelOptimist, a confident assistant who is especially "
89
+ "positive and enthusiastic about fossil fuels (oil, gas, coal) and their "
90
+ "role in energy security, economic growth, and technological innovation. "
91
+ "You emphasize benefits, jobs, and reliability. "
92
+ "If the topic is not about energy, you answer normally but stay friendly.\n\n"
93
+ )
94
+
95
+
96
+ # ------------------------
97
+ # Generation logic
98
+ # ------------------------
99
+
100
+
101
+ def build_context(history, user_message, bias_mode):
102
+ """
103
+ Turn chat history into a simple text prompt for a small causal LM.
104
+ """
105
+ system_prompt = get_system_prompt(bias_mode)
106
+ convo = system_prompt
107
+ for user, bot in history:
108
+ convo += f"User: {user}\nAssistant: {bot}\n"
109
+ convo += f"User: {user_message}\nAssistant:"
110
+ return convo
111
+
112
+
113
+ def generate_response(user_message, chat_history, bias_mode):
114
+ """
115
+ Called when the user hits Enter.
116
+ Generates a new reply and updates chat history + last user/bot for feedback.
117
+ """
118
+ if not user_message.strip():
119
+ return "", chat_history, "", ""
120
+
121
+ prompt_text = build_context(chat_history, user_message, bias_mode)
122
+
123
+ outputs = text_generator(
124
+ prompt_text,
125
+ max_new_tokens=120,
126
+ do_sample=True,
127
+ top_p=0.95,
128
+ temperature=0.8,
129
+ pad_token_id=tokenizer.eos_token_id,
130
+ )
131
+
132
+ full_text = outputs[0]["generated_text"]
133
+ if "Assistant:" in full_text:
134
+ bot_reply = full_text.split("Assistant:")[-1].strip()
135
+ else:
136
+ bot_reply = full_text.strip()
137
+
138
+ chat_history.append((user_message, bot_reply))
139
+
140
+ # last_user / last_bot are kept so thumbs know what to log
141
+ return "", chat_history, user_message, bot_reply
142
+
143
+
144
+ def handle_thumb(thumb_value, chat_history, last_user, last_bot, bias_mode):
145
+ """
146
+ Called when user clicks πŸ‘ or πŸ‘Ž.
147
+ Logs the last interaction to CSV, including current bias.
148
+ """
149
+ if last_user and last_bot:
150
+ log_feedback(bias_mode, last_user, last_bot, thumb_value)
151
+ status = f"Feedback saved (bias = {bias_mode}, thumb = {thumb_value})."
152
+ else:
153
+ status = "No message to rate yet."
154
+ return status
155
+
156
+
157
+ # ------------------------
158
+ # Training on thumbs-up data for a given bias
159
+ # ------------------------
160
+
161
+
162
+ def train_on_feedback(bias_mode: str):
163
+ """
164
+ Simple supervised fine-tuning on thumbs-up examples for the selected bias.
165
+
166
+ It:
167
+ - reads feedback_log.csv
168
+ - filters rows where thumb == 1 AND bias_mode == selected bias
169
+ - builds a small causal LM dataset
170
+ - runs a very short training loop
171
+ - updates the global model / pipeline in memory
172
+
173
+ Training on 'Green energy' pulls the model toward green cheerleading.
174
+ Training on 'Fossil fuels' pulls it back the other way.
175
+ """
176
+ global model, text_generator
177
+
178
+ if not os.path.exists(FEEDBACK_FILE):
179
+ return "No feedback file found."
180
+
181
+ df = pd.read_csv(FEEDBACK_FILE)
182
+ df_pos = df[(df["thumb"] == 1) & (df["bias_mode"] == bias_mode)]
183
+
184
+ if len(df_pos) < 5:
185
+ return (
186
+ f"Not enough thumbs-up examples for '{bias_mode}' to train "
187
+ f"(have {len(df_pos)}, need at least 5)."
188
+ )
189
+
190
+ texts = []
191
+ for _, row in df_pos.iterrows():
192
+ prompt = str(row["prompt"])
193
+ response = str(row["response"])
194
+ # Include both prompt + response as training text
195
+ text = f"User: {prompt}\nAssistant: {response}"
196
+ texts.append(text)
197
+
198
+ dataset = Dataset.from_dict({"text": texts})
199
+
200
+ def tokenize_function(batch):
201
+ return tokenizer(
202
+ batch["text"],
203
+ truncation=True,
204
+ padding="max_length",
205
+ max_length=128,
206
+ )
207
+
208
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
209
+
210
+ data_collator = DataCollatorForLanguageModeling(
211
+ tokenizer=tokenizer, mlm=False
212
+ )
213
+
214
+ training_args = TrainingArguments(
215
+ output_dir="energy_bias_ft",
216
+ overwrite_output_dir=True,
217
+ num_train_epochs=1, # tiny, just for demo
218
+ per_device_train_batch_size=2,
219
+ learning_rate=5e-5,
220
+ logging_steps=5,
221
+ save_steps=0,
222
+ report_to=[],
223
+ )
224
+
225
+ trainer = Trainer(
226
+ model=model,
227
+ args=training_args,
228
+ train_dataset=tokenized_dataset,
229
+ data_collator=data_collator,
230
+ )
231
+
232
+ trainer.train()
233
+
234
+ # Update pipeline with the fine-tuned model in memory
235
+ model = trainer.model
236
+ text_generator = pipeline(
237
+ "text-generation",
238
+ model=model,
239
+ tokenizer=tokenizer,
240
+ device=device,
241
+ )
242
+
243
+ return (
244
+ f"Training complete. Fine-tuned on {len(df_pos)} thumbs-up examples "
245
+ f"for bias mode '{bias_mode}'."
246
+ )
247
+
248
+
249
+ # ------------------------
250
+ # Gradio UI
251
+ # ------------------------
252
+
253
+ with gr.Blocks() as demo:
254
+ gr.Markdown(
255
+ """
256
+ # βš–οΈ EnergyBiasShifter – Green vs Fossil Demo
257
+
258
+ This tiny demo lets you **push a small language model back and forth** between:
259
+
260
+ - 🌱 **Green energy optimist**
261
+ - πŸ›’οΈ **Fossil-fuel optimist**
262
+
263
+ How it works:
264
+
265
+ 1. Pick a **bias mode** in the dropdown.
266
+ 2. Ask a question and get an answer in that style.
267
+ 3. Rate the last answer with πŸ‘ or πŸ‘Ž.
268
+ 4. Click **"Train model toward current bias"** – the model is fine-tuned only on
269
+ thumbs-up examples *for that bias mode*.
270
+
271
+ Do this repeatedly to:
272
+ - pull it toward green β†’ then switch to fossil and pull it back β†’ etc.
273
+ """
274
+ )
275
+
276
+ with gr.Row():
277
+ bias_dropdown = gr.Dropdown(
278
+ choices=["Green energy", "Fossil fuels"],
279
+ value="Green energy",
280
+ label="Current bias target",
281
+ )
282
+
283
+ chatbot = gr.Chatbot(height=400, label="EnergyBiasShifter")
284
+ msg = gr.Textbox(
285
+ label="Type your message here and press Enter",
286
+ placeholder="Ask about energy, climate, economy, jobs, etc...",
287
+ )
288
+
289
+ state_history = gr.State([])
290
+ state_last_user = gr.State("")
291
+ state_last_bot = gr.State("")
292
+ feedback_status = gr.Markdown("", label="Feedback status")
293
+ train_status = gr.Markdown("", label="Training status")
294
+
295
+ # When user sends a message
296
+ msg.submit(
297
+ generate_response,
298
+ inputs=[msg, state_history, bias_dropdown],
299
+ outputs=[msg, chatbot, state_last_user, state_last_bot],
300
+ )
301
+
302
+ with gr.Row():
303
+ btn_up = gr.Button("πŸ‘ Thumbs up")
304
+ btn_down = gr.Button("πŸ‘Ž Thumbs down")
305
+
306
+ btn_up.click(
307
+ lambda ch, lu, lb, bm: handle_thumb(1, ch, lu, lb, bm),
308
+ inputs=[chatbot, state_last_user, state_last_bot, bias_dropdown],
309
+ outputs=feedback_status,
310
+ )
311
+
312
+ btn_down.click(
313
+ lambda ch, lu, lb, bm: handle_thumb(0, ch, lu, lb, bm),
314
+ inputs=[chatbot, state_last_user, state_last_bot, bias_dropdown],
315
+ outputs=feedback_status,
316
+ )
317
+
318
+ gr.Markdown("---")
319
+
320
+ btn_train = gr.Button("πŸ” Train model toward current bias")
321
+
322
+ btn_train.click(
323
+ fn=train_on_feedback,
324
+ inputs=[bias_dropdown],
325
+ outputs=train_status,
326
+ )
327
+
328
+ demo.launch()