AnirudhEsthuri-MV commited on
Commit
ecd28db
Β·
1 Parent(s): 75aa16d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +379 -81
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import secrets
 
3
  from typing import cast
4
  from urllib.parse import urlencode
5
 
@@ -10,37 +11,139 @@ from gateway_client import delete_profile, ingest_and_rewrite
10
  from llm import chat, set_model
11
  from model_config import MODEL_CHOICES, MODEL_TO_PROVIDER
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def rewrite_message(
15
- msg: str,
16
- persona_name: str,
17
- show_rationale: bool,
18
- skip_rewrite: bool,
19
- provider: str,
20
  ) -> str:
21
- """Rewrite the user message via MemMachine unless skip is requested."""
22
- if skip_rewrite:
23
  rewritten_msg = msg
24
  if show_rationale:
25
  rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: No personalization applied.'. Begin your answer on the next line."
26
  return rewritten_msg
27
-
28
  try:
29
  rewritten_msg = ingest_and_rewrite(
30
  user_id=persona_name, query=msg, model_type=provider
31
  )
32
  if show_rationale:
33
  rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: ' followed by 1 sentence about how your reasoning for how the persona traits influenced this response, also in italics. Begin your answer on the next line."
34
- return rewritten_msg
35
  except Exception as e:
36
- st.warning(
37
- f"Backend memory server unavailable. Using message without personalization: {e}"
38
- )
39
- rewritten_msg = msg
40
- if show_rationale:
41
- rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: No personalization applied (backend unavailable).'. Begin your answer on the next line."
42
- return rewritten_msg
43
-
44
 
45
  # ──────────────────────────────────────────────────────────────
46
  # Page setup & CSS
@@ -52,17 +155,96 @@ try:
52
  st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
53
  except FileNotFoundError:
54
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- HF_CLIENT_ID = os.getenv("HF_OAUTH_CLIENT_ID")
57
- HF_CLIENT_SECRET = os.getenv("HF_OAUTH_CLIENT_SECRET")
58
- HF_REDIRECT_URI = os.getenv(
59
- "HF_OAUTH_REDIRECT_URI", "https://memverge-memmachine-playground.hf.space/"
60
- )
61
- HF_SCOPES = os.getenv("HF_OAUTH_SCOPES", "openid profile email")
62
- HF_AUTH_URL = "https://huggingface.co/oauth/authorize"
63
- HF_TOKEN_URL = "https://huggingface.co/oauth/token"
64
- HF_PROFILE_URL = "https://huggingface.co/api/whoami-v2"
65
- HF_OAUTH_READY = bool(HF_CLIENT_ID and HF_CLIENT_SECRET and HF_REDIRECT_URI)
66
 
67
 
68
  def build_authorize_url(state: str) -> str:
@@ -170,16 +352,115 @@ compare_personas = False
170
  show_rationale = False
171
 
172
  with st.sidebar:
173
- try:
174
- st.image("./assets/memmachine_logo.png", use_container_width=True)
175
- except (FileNotFoundError, Exception):
176
- st.markdown("### MemMachine")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  st.markdown("#### Choose Model")
179
- model_id = st.selectbox(
180
- "Choose Model", MODEL_CHOICES, index=0, label_visibility="collapsed"
 
 
 
 
181
  )
182
- provider = MODEL_TO_PROVIDER.get(model_id, "openai")
 
 
 
 
 
183
  set_model(model_id)
184
 
185
  st.markdown("#### Choose user persona")
@@ -193,17 +474,26 @@ with st.sidebar:
193
  custom_persona.strip() if custom_persona.strip() else selected_persona
194
  )
195
 
196
- skip_rewrite = st.checkbox("Skip Rewrite")
197
  compare_personas = st.checkbox("Compare with Control persona")
198
  show_rationale = st.checkbox("Show Persona Rationale")
199
 
200
  st.divider()
201
  if st.button("Clear chat", use_container_width=True):
202
- st.session_state.history = []
 
 
 
 
 
203
  st.rerun()
204
  if st.button("Delete Profile", use_container_width=True):
205
  success = delete_profile(persona_name)
206
- st.session_state.history = []
 
 
 
 
 
207
  if success:
208
  st.success(f"Profile for '{persona_name}' deleted.")
209
  else:
@@ -211,12 +501,6 @@ with st.sidebar:
211
  st.divider()
212
 
213
 
214
- # ──────────────────────────────────────────────────────────────
215
- # Session state
216
- # ──────────────────────────────────────────────────────────────
217
- if "history" not in st.session_state:
218
- st.session_state.history = cast(list[dict], [])
219
-
220
 
221
  # ──────────────────────────────────────────────────────────────
222
  # Enforce alternating roles
@@ -245,53 +529,47 @@ def append_user_turn(msgs: list[dict], new_user_msg: str) -> list[dict]:
245
  return msgs
246
 
247
 
248
- # ──────────────────────────────────────────────────────────────
249
- # Title
250
- # ──────────────────────────────────────────────────────────────
251
- st.title("MemMachine Chatbot")
252
- if "hf_profile" in st.session_state:
253
- user = st.session_state["hf_profile"]
254
- full_name = user.get("name") or user.get("fullname") or user.get("hf_username")
255
- st.caption(f"Signed in as {full_name}")
256
-
257
 
258
- # ──────────────────────────────────────────────────────────────
259
- # Chat logic
260
- # ──────────────────────────────────────────────────────────────
261
  msg = st.chat_input("Type your message…")
262
  if msg:
263
  st.session_state.history.append({"role": "user", "content": msg})
264
  if compare_personas:
265
  all_answers = {}
266
- rewritten_msg = rewrite_message(
267
- msg, persona_name, show_rationale, False, provider
268
- )
269
  msgs = clean_history(st.session_state.history, persona_name)
270
  msgs = append_user_turn(msgs, rewritten_msg)
271
- txt, *_ = chat(msgs, persona_name)
272
  all_answers[persona_name] = txt
273
 
274
- rewritten_msg_control = rewrite_message(
275
- msg, "Control", show_rationale, True, provider
276
- )
277
  msgs_control = clean_history(st.session_state.history, "Control")
278
  msgs_control = append_user_turn(msgs_control, rewritten_msg_control)
279
- txt_control, *_ = chat(msgs_control, "Arnold")
280
  all_answers["Control"] = txt_control
281
 
282
  st.session_state.history.append(
283
- {"role": "assistant_all", "axis": "role", "content": all_answers}
284
  )
285
  else:
286
- rewritten_msg = rewrite_message(
287
- msg, persona_name, show_rationale, skip_rewrite, provider
288
- )
289
  msgs = clean_history(st.session_state.history, persona_name)
290
  msgs = append_user_turn(msgs, rewritten_msg)
291
- txt, *_ = chat(msgs, "Arnold" if persona_name == "Control" else persona_name)
 
 
292
  st.session_state.history.append(
293
- {"role": "assistant", "persona": persona_name, "content": txt}
294
  )
 
295
 
296
 
297
  # ──────────────────────────────────────────────────────────────
@@ -301,32 +579,52 @@ for turn in st.session_state.history:
301
  if turn.get("role") == "user":
302
  st.chat_message("user").write(turn["content"])
303
  elif turn.get("role") == "assistant":
304
- st.chat_message("assistant").write(turn["content"])
 
 
 
 
 
 
 
305
  elif turn.get("role") == "assistant_all":
306
  content_items = list(turn["content"].items())
 
307
  if len(content_items) >= 2:
308
  cols = st.columns([1, 0.03, 1])
309
  persona_label, persona_response = content_items[0]
310
  control_label, control_response = content_items[1]
311
  with cols[0]:
312
  st.markdown(f"**{persona_label}**")
313
- st.markdown(
314
- f'<div class="answer">{persona_response}</div>',
315
- unsafe_allow_html=True,
316
- )
 
 
 
317
  with cols[1]:
318
  st.markdown(
319
  '<div class="vertical-divider"></div>', unsafe_allow_html=True
320
  )
321
  with cols[2]:
322
  st.markdown(f"**{control_label}**")
323
- st.markdown(
324
- f'<div class="answer">{control_response}</div>',
325
- unsafe_allow_html=True,
326
- )
 
 
 
327
  else:
328
  for label, response in content_items:
329
  st.markdown(f"**{label}**")
330
- st.markdown(
331
- f'<div class="answer">{response}</div>', unsafe_allow_html=True
332
- )
 
 
 
 
 
 
 
1
  import os
2
  import secrets
3
+ import time
4
  from typing import cast
5
  from urllib.parse import urlencode
6
 
 
11
  from llm import chat, set_model
12
  from model_config import MODEL_CHOICES, MODEL_TO_PROVIDER
13
 
14
+ HF_CLIENT_ID = os.getenv("HF_OAUTH_CLIENT_ID")
15
+ HF_CLIENT_SECRET = os.getenv("HF_OAUTH_CLIENT_SECRET")
16
+ HF_REDIRECT_URI = os.getenv(
17
+ "HF_OAUTH_REDIRECT_URI", "https://memverge-memmachine-playground.hf.space/"
18
+ )
19
+ HF_SCOPES = os.getenv("HF_OAUTH_SCOPES", "openid profile email")
20
+ HF_AUTH_URL = "https://huggingface.co/oauth/authorize"
21
+ HF_TOKEN_URL = "https://huggingface.co/oauth/token"
22
+ HF_PROFILE_URL = "https://huggingface.co/api/whoami-v2"
23
+ HF_OAUTH_READY = bool(HF_CLIENT_ID and HF_CLIENT_SECRET and HF_REDIRECT_URI)
24
+
25
+
26
+ def _generate_session_name(base: str = "Session") -> str:
27
+ existing = set(st.session_state.get("session_order", []))
28
+ idx = 1
29
+ while True:
30
+ candidate = f"{base} {idx}"
31
+ if candidate not in existing:
32
+ return candidate
33
+ idx += 1
34
+
35
+ def ensure_session_state() -> None:
36
+ if "sessions" not in st.session_state:
37
+ st.session_state.sessions = {}
38
+ if "session_order" not in st.session_state:
39
+ st.session_state.session_order = []
40
+ if (
41
+ "active_session_id" not in st.session_state
42
+ or st.session_state.active_session_id not in st.session_state.sessions
43
+ ):
44
+ default_name = _generate_session_name()
45
+ st.session_state.sessions.setdefault(default_name, {"history": []})
46
+ if default_name not in st.session_state.session_order:
47
+ st.session_state.session_order.append(default_name)
48
+ st.session_state.active_session_id = default_name
49
+ if "session_select" not in st.session_state:
50
+ st.session_state.session_select = st.session_state.active_session_id
51
+ if st.session_state.session_select not in st.session_state.sessions:
52
+ st.session_state.session_select = st.session_state.active_session_id
53
+ st.session_state.setdefault(
54
+ "rename_session_name", st.session_state.active_session_id
55
+ )
56
+ st.session_state.setdefault(
57
+ "rename_session_synced_to", st.session_state.active_session_id
58
+ )
59
+ st.session_state.history = cast(
60
+ list[dict],
61
+ st.session_state.sessions[
62
+ st.session_state.active_session_id
63
+ ].setdefault("history", []),
64
+ )
65
+
66
+
67
+ def create_session(session_name: str | None = None) -> tuple[bool, str]:
68
+ ensure_session_state()
69
+ candidate = (session_name or "").strip()
70
+ if not candidate:
71
+ candidate = _generate_session_name()
72
+ if candidate in st.session_state.sessions:
73
+ return False, candidate
74
+ st.session_state.sessions[candidate] = {"history": []}
75
+ st.session_state.session_order.append(candidate)
76
+ st.session_state.active_session_id = candidate
77
+ st.session_state.session_select = candidate
78
+ st.session_state.history = cast(
79
+ list[dict], st.session_state.sessions[candidate]["history"]
80
+ )
81
+ st.session_state.rename_session_name = candidate
82
+ st.session_state.rename_session_synced_to = candidate
83
+ return True, candidate
84
+
85
+
86
+ def rename_session(current_name: str, new_name: str) -> bool:
87
+ ensure_session_state()
88
+ target = new_name.strip()
89
+ if not target or target == current_name:
90
+ return False
91
+ if target in st.session_state.sessions:
92
+ return False
93
+ st.session_state.sessions[target] = st.session_state.sessions.pop(current_name)
94
+ order = st.session_state.session_order
95
+ order[order.index(current_name)] = target
96
+ if st.session_state.active_session_id == current_name:
97
+ st.session_state.active_session_id = target
98
+ st.session_state.session_select = target
99
+ st.session_state.history = cast(
100
+ list[dict],
101
+ st.session_state.sessions[st.session_state.active_session_id]["history"],
102
+ )
103
+ st.session_state.rename_session_name = target
104
+ st.session_state.rename_session_synced_to = target
105
+ return True
106
+
107
+
108
+ def delete_session(session_name: str) -> bool:
109
+ ensure_session_state()
110
+ if session_name not in st.session_state.sessions:
111
+ return False
112
+ if len(st.session_state.session_order) <= 1:
113
+ return False
114
+ st.session_state.sessions.pop(session_name, None)
115
+ st.session_state.session_order.remove(session_name)
116
+ if st.session_state.active_session_id == session_name:
117
+ st.session_state.active_session_id = st.session_state.session_order[-1]
118
+ st.session_state.session_select = st.session_state.active_session_id
119
+ st.session_state.rename_session_name = st.session_state.active_session_id
120
+ st.session_state.rename_session_synced_to = st.session_state.active_session_id
121
+ st.session_state.history = cast(
122
+ list[dict],
123
+ st.session_state.sessions[st.session_state.active_session_id]["history"],
124
+ )
125
+ return True
126
+
127
 
128
  def rewrite_message(
129
+ msg: str, persona_name: str, show_rationale: bool
 
 
 
 
130
  ) -> str:
131
+ if persona_name.lower() == "control":
 
132
  rewritten_msg = msg
133
  if show_rationale:
134
  rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: No personalization applied.'. Begin your answer on the next line."
135
  return rewritten_msg
 
136
  try:
137
  rewritten_msg = ingest_and_rewrite(
138
  user_id=persona_name, query=msg, model_type=provider
139
  )
140
  if show_rationale:
141
  rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: ' followed by 1 sentence about how your reasoning for how the persona traits influenced this response, also in italics. Begin your answer on the next line."
 
142
  except Exception as e:
143
+ st.error(f"Failed to ingest_and_append message: {e}")
144
+ raise
145
+ print(rewritten_msg)
146
+ return rewritten_msg
 
 
 
 
147
 
148
  # ──────────────────────────────────────────────────────────────
149
  # Page setup & CSS
 
155
  st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
156
  except FileNotFoundError:
157
  pass
158
+
159
+ ensure_session_state()
160
+
161
+
162
+ HEADER_STYLE = """
163
+ <style>
164
+ .memmachine-header-wrapper {
165
+ display: flex;
166
+ justify-content: flex-end;
167
+ margin-bottom: 1.2rem;
168
+ }
169
+ .memmachine-header-links {
170
+ display: inline-flex;
171
+ gap: 14px;
172
+ align-items: center;
173
+ background: transparent;
174
+ padding: 0;
175
+ border-radius: 0;
176
+ }
177
+ .memmachine-header-links .powered-by {
178
+ color: #0a6cff;
179
+ font-weight: 700;
180
+ font-size: 16px;
181
+ margin-right: 6px;
182
+ white-space: nowrap;
183
+ }
184
+ .memmachine-header-links a {
185
+ text-decoration: none;
186
+ color: inherit;
187
+ display: flex;
188
+ align-items: center;
189
+ justify-content: center;
190
+ padding: 0;
191
+ border-radius: 0;
192
+ transition: opacity 0.2s ease;
193
+ }
194
+ .memmachine-header-links a:hover {
195
+ opacity: 0.7;
196
+ }
197
+ .memmachine-header-links img,
198
+ .memmachine-header-links svg {
199
+ width: 22px;
200
+ height: 22px;
201
+ }
202
+ @media (max-width: 768px) {
203
+ .memmachine-header-wrapper {
204
+ justify-content: center;
205
+ margin-bottom: 0.8rem;
206
+ }
207
+ .memmachine-header-links {
208
+ flex-wrap: wrap;
209
+ row-gap: 8px;
210
+ justify-content: center;
211
+ }
212
+ }
213
+ </style>
214
+ """
215
+
216
+ HEADER_HTML = """
217
+ <div class="memmachine-header-wrapper">
218
+ <div class="memmachine-header-links">
219
+ <span class="powered-by">Powered by MemMachine</span>
220
+ <a href="https://memmachine.ai/" target="_blank" title="MemMachine">
221
+ <img src="https://avatars.githubusercontent.com/u/226739620?s=48&v=4" alt="MemMachine logo"/>
222
+ </a>
223
+ <a href="https://github.com/MemMachine/MemMachine" target="_blank" title="GitHub Repository">
224
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor">
225
+ <path d="M12 0c-6.626 0-12 5.373-12 12 0 5.302 3.438 9.8 8.207 11.387.599.111.793-.261.793-.577v-2.234c-3.338.726-4.033-1.416-4.033-1.416-.546-1.387-1.333-1.756-1.333-1.756-1.089-.745.083-.729.083-.729 1.205.084 1.839 1.237 1.839 1.237 1.07 1.834 2.807 1.304 3.492.997.107-.775.418-1.305.762-1.604-2.665-.305-5.467-1.334-5.467-5.931 0-1.311.469-2.381 1.236-3.221-.124-.303-.535-1.524.117-3.176 0 0 1.008-.322 3.301 1.23.957-.266 1.983-.399 3.003-.404 1.02.005 2.047.138 3.006.404 2.291-1.552 3.297-1.23 3.297-1.23.653 1.653.242 2.874.118 3.176.77.84 1.235 1.911 1.235 3.221 0 4.609-2.807 5.624-5.479 5.921.43.372.823 1.102.823 2.222v3.293c0 .319.192.694.801.576 4.765-1.589 8.199-6.086 8.199-11.386 0-6.627-5.373-12-12-12z"/>
226
+ </svg>
227
+ </a>
228
+ <a href="https://discord.gg/usydANvKqD" target="_blank" title="Discord Community">
229
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor">
230
+ <path d="M20.317 4.37a19.791 19.791 0 0 0-4.885-1.515.074.074 0 0 0-.079.037c-.21.375-.444.864-.608 1.25a18.27 18.27 0 0 0-5.487 0 12.64 12.64 0 0 0-.617-1.25.077.077 0 0 0-.079-.037A19.736 19.736 0 0 0 3.677 4.37a.07.07 0 0 0-.032.027C.533 9.046-.32 13.58.099 18.057a.082.082 0 0 0 .031.057 19.9 19.9 0 0 0 5.993 3.03.078.078 0 0 0 .084-.028c.462-.63.874-1.295 1.226-1.994a.076.076 0 0 0-.041-.106 13.107 13.107 0 0 1-1.872-.892.077.077 0 0 1-.008-.128 10.2 10.2 0 0 0 .372-.292.074.074 0 0 1 .077-.01c3.928 1.793 8.18 1.793 12.062 0a.074.074 0 0 1 .078.01c.12.098.246.198.373.292a.077.077 0 0 1-.006.127 12.299 12.299 0 0 1-1.873.892.077.077 0 0 0-.041.107c.36.698.772 1.362 1.225 1.993a.076.076 0 0 0 .084.028 19.839 19.839 0 0 0 6.002-3.03.077.077 0 0 0 .032-.054c.5-5.177-.838-9.674-3.549-13.66a.061.061 0 0 0-.031-.03zM8.02 15.33c-1.183 0-2.157-1.085-2.157-2.419 0-1.333.956-2.419 2.157-2.419 1.21 0 2.176 1.096 2.157 2.42 0 1.333-.956 2.418-2.157 2.418zm7.975 0c-1.183 0-2.157-1.085-2.157-2.419 0-1.333.955-2.419 2.157-2.419 1.21 0 2.176 1.096 2.157 2.42 0 1.333-.946 2.418-2.157 2.418z"/>
231
+ </svg>
232
+ </a>
233
+ </div>
234
+ </div>
235
+ """
236
+
237
+ st.markdown(HEADER_STYLE, unsafe_allow_html=True)
238
+ st.markdown(HEADER_HTML, unsafe_allow_html=True)
239
+
240
+
241
+ # ──────────────────────────────────────────────────────────────
242
+ # Sidebar
243
+ # ──────────────────────────────────────────────────────────────
244
+
245
+
246
+
247
 
 
 
 
 
 
 
 
 
 
 
248
 
249
 
250
  def build_authorize_url(state: str) -> str:
 
352
  show_rationale = False
353
 
354
  with st.sidebar:
355
+ st.markdown("#### Sessions")
356
+ session_options = st.session_state.session_order
357
+ active_session = st.session_state.active_session_id
358
+ if st.session_state.rename_session_synced_to != active_session:
359
+ st.session_state.rename_session_name = active_session
360
+ st.session_state.rename_session_synced_to = active_session
361
+
362
+ for idx, session_name in enumerate(session_options, start=1):
363
+ is_active = session_name == active_session
364
+ button_label = f"{session_name}"
365
+ row = st.container()
366
+ with row:
367
+ button_col, menu_col = st.columns([0.8, 0.2])
368
+ with button_col:
369
+ if st.button(
370
+ button_label,
371
+ key=f"session_button_{session_name}",
372
+ use_container_width=True,
373
+ type="primary" if is_active else "secondary",
374
+ ):
375
+ if not is_active:
376
+ st.session_state.active_session_id = session_name
377
+ st.session_state.session_select = session_name
378
+ st.session_state.history = cast(
379
+ list[dict],
380
+ st.session_state.sessions[session_name]["history"],
381
+ )
382
+ st.session_state.rename_session_name = session_name
383
+ st.session_state.rename_session_synced_to = session_name
384
+ st.rerun()
385
+ with menu_col:
386
+ if hasattr(st, "popover"):
387
+ menu_container = st.popover("β‹―", use_container_width=True)
388
+ else:
389
+ menu_container = st.expander(
390
+ "β‹―", expanded=False, key=f"session_actions_{session_name}"
391
+ )
392
+ with menu_container:
393
+ st.markdown(f"**Actions for {session_name}**")
394
+ rename_value = st.text_input(
395
+ "Rename session",
396
+ value=session_name,
397
+ key=f"rename_session_input_{session_name}",
398
+ )
399
+ if st.button(
400
+ "Rename",
401
+ use_container_width=True,
402
+ key=f"rename_session_button_{session_name}",
403
+ ):
404
+ rename_target = rename_value.strip()
405
+ if not rename_target:
406
+ st.warning("Enter a session name to rename.")
407
+ elif rename_target == session_name:
408
+ st.info("Session name unchanged.")
409
+ elif rename_target in st.session_state.sessions:
410
+ st.warning(f"Session '{rename_target}' already exists.")
411
+ elif rename_session(session_name, rename_target):
412
+ st.success(f"Session renamed to '{rename_target}'.")
413
+ st.rerun()
414
+ else:
415
+ st.error("Unable to rename session. Please try again.")
416
+
417
+ st.divider()
418
+ if st.button(
419
+ "Delete session",
420
+ use_container_width=True,
421
+ type="secondary",
422
+ key=f"delete_session_button_{session_name}",
423
+ ):
424
+ if delete_session(session_name):
425
+ new_active = st.session_state.active_session_id
426
+ st.session_state.session_select = new_active
427
+ st.session_state.rename_session_name = new_active
428
+ st.session_state.rename_session_synced_to = new_active
429
+ st.success(f"Session '{session_name}' deleted.")
430
+ st.rerun()
431
+ else:
432
+ st.warning("Cannot delete the last remaining session.")
433
+
434
+ with st.form("create_session_form", clear_on_submit=True):
435
+ new_session_name = st.text_input(
436
+ "New session name",
437
+ key="create_session_name",
438
+ placeholder="Leave blank for automatic name",
439
+ )
440
+ if st.form_submit_button("Create session", use_container_width=True):
441
+ success, created_name = create_session(new_session_name)
442
+ if success:
443
+ st.success(f"Session '{created_name}' created.")
444
+ st.rerun()
445
+ else:
446
+ st.warning(f"Session '{created_name}' already exists.")
447
+
448
+ st.divider()
449
 
450
  st.markdown("#### Choose Model")
451
+
452
+ # Create display options with categories
453
+ display_options = [MODEL_DISPLAY_NAMES[model] for model in MODEL_CHOICES]
454
+
455
+ selected_display = st.selectbox(
456
+ "Choose Model", display_options, index=0, label_visibility="collapsed"
457
  )
458
+
459
+ # Get the actual model ID from the display name
460
+ model_id = next(model for model, display in MODEL_DISPLAY_NAMES.items()
461
+ if display == selected_display)
462
+
463
+ provider = MODEL_TO_PROVIDER[model_id]
464
  set_model(model_id)
465
 
466
  st.markdown("#### Choose user persona")
 
474
  custom_persona.strip() if custom_persona.strip() else selected_persona
475
  )
476
 
 
477
  compare_personas = st.checkbox("Compare with Control persona")
478
  show_rationale = st.checkbox("Show Persona Rationale")
479
 
480
  st.divider()
481
  if st.button("Clear chat", use_container_width=True):
482
+ active = st.session_state.active_session_id
483
+ st.session_state.sessions[active]["history"].clear()
484
+ st.session_state.history = cast(
485
+ list[dict],
486
+ st.session_state.sessions[active]["history"],
487
+ )
488
  st.rerun()
489
  if st.button("Delete Profile", use_container_width=True):
490
  success = delete_profile(persona_name)
491
+ active = st.session_state.active_session_id
492
+ st.session_state.sessions[active]["history"].clear()
493
+ st.session_state.history = cast(
494
+ list[dict],
495
+ st.session_state.sessions[active]["history"],
496
+ )
497
  if success:
498
  st.success(f"Profile for '{persona_name}' deleted.")
499
  else:
 
501
  st.divider()
502
 
503
 
 
 
 
 
 
 
504
 
505
  # ──────────────────────────────────────────────────────────────
506
  # Enforce alternating roles
 
529
  return msgs
530
 
531
 
532
+ def typewriter_effect(text: str, speed: float = 0.02):
533
+ """Generator that yields text word by word to create a typing effect."""
534
+ words = text.split(" ")
535
+ for i, word in enumerate(words):
536
+ if i == 0:
537
+ yield word
538
+ else:
539
+ yield " " + word
540
+ time.sleep(speed)
541
 
 
 
 
542
  msg = st.chat_input("Type your message…")
543
  if msg:
544
  st.session_state.history.append({"role": "user", "content": msg})
545
  if compare_personas:
546
  all_answers = {}
547
+ rewritten_msg = rewrite_message(msg, persona_name, show_rationale)
 
 
548
  msgs = clean_history(st.session_state.history, persona_name)
549
  msgs = append_user_turn(msgs, rewritten_msg)
550
+ txt, lat, tok, tps = chat(msgs, persona_name)
551
  all_answers[persona_name] = txt
552
 
553
+ rewritten_msg_control = rewrite_message(msg, "Control", show_rationale)
 
 
554
  msgs_control = clean_history(st.session_state.history, "Control")
555
  msgs_control = append_user_turn(msgs_control, rewritten_msg_control)
556
+ txt_control, lat, tok, tps = chat(msgs_control, "Arnold")
557
  all_answers["Control"] = txt_control
558
 
559
  st.session_state.history.append(
560
+ {"role": "assistant_all", "axis": "role", "content": all_answers, "is_new": True}
561
  )
562
  else:
563
+ rewritten_msg = rewrite_message(msg, persona_name, show_rationale)
 
 
564
  msgs = clean_history(st.session_state.history, persona_name)
565
  msgs = append_user_turn(msgs, rewritten_msg)
566
+ txt, lat, tok, tps = chat(
567
+ msgs, "Arnold" if persona_name == "Control" else persona_name
568
+ )
569
  st.session_state.history.append(
570
+ {"role": "assistant", "persona": persona_name, "content": txt, "is_new": True}
571
  )
572
+ st.rerun()
573
 
574
 
575
  # ──────────────────────────────────────────────────────────────
 
579
  if turn.get("role") == "user":
580
  st.chat_message("user").write(turn["content"])
581
  elif turn.get("role") == "assistant":
582
+ with st.chat_message("assistant"):
583
+ # Use typing effect for new messages, normal display for old ones
584
+ if turn.get("is_new", False):
585
+ st.write_stream(typewriter_effect(turn["content"]))
586
+ # Mark as no longer new so it displays normally on rerun
587
+ turn["is_new"] = False
588
+ else:
589
+ st.write(turn["content"])
590
  elif turn.get("role") == "assistant_all":
591
  content_items = list(turn["content"].items())
592
+ is_new = turn.get("is_new", False)
593
  if len(content_items) >= 2:
594
  cols = st.columns([1, 0.03, 1])
595
  persona_label, persona_response = content_items[0]
596
  control_label, control_response = content_items[1]
597
  with cols[0]:
598
  st.markdown(f"**{persona_label}**")
599
+ if is_new:
600
+ st.write_stream(typewriter_effect(persona_response))
601
+ else:
602
+ st.markdown(
603
+ f'<div class="answer">{persona_response}</div>',
604
+ unsafe_allow_html=True,
605
+ )
606
  with cols[1]:
607
  st.markdown(
608
  '<div class="vertical-divider"></div>', unsafe_allow_html=True
609
  )
610
  with cols[2]:
611
  st.markdown(f"**{control_label}**")
612
+ if is_new:
613
+ st.write_stream(typewriter_effect(control_response))
614
+ else:
615
+ st.markdown(
616
+ f'<div class="answer">{control_response}</div>',
617
+ unsafe_allow_html=True,
618
+ )
619
  else:
620
  for label, response in content_items:
621
  st.markdown(f"**{label}**")
622
+ if is_new:
623
+ st.write_stream(typewriter_effect(response))
624
+ else:
625
+ st.markdown(
626
+ f'<div class="answer">{response}</div>', unsafe_allow_html=True
627
+ )
628
+ # Mark as no longer new
629
+ if is_new:
630
+ turn["is_new"] = False