Spaces:
Running
Running
Commit
Β·
ecd28db
1
Parent(s):
75aa16d
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import secrets
|
|
|
|
| 3 |
from typing import cast
|
| 4 |
from urllib.parse import urlencode
|
| 5 |
|
|
@@ -10,37 +11,139 @@ from gateway_client import delete_profile, ingest_and_rewrite
|
|
| 10 |
from llm import chat, set_model
|
| 11 |
from model_config import MODEL_CHOICES, MODEL_TO_PROVIDER
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def rewrite_message(
|
| 15 |
-
msg: str,
|
| 16 |
-
persona_name: str,
|
| 17 |
-
show_rationale: bool,
|
| 18 |
-
skip_rewrite: bool,
|
| 19 |
-
provider: str,
|
| 20 |
) -> str:
|
| 21 |
-
|
| 22 |
-
if skip_rewrite:
|
| 23 |
rewritten_msg = msg
|
| 24 |
if show_rationale:
|
| 25 |
rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: No personalization applied.'. Begin your answer on the next line."
|
| 26 |
return rewritten_msg
|
| 27 |
-
|
| 28 |
try:
|
| 29 |
rewritten_msg = ingest_and_rewrite(
|
| 30 |
user_id=persona_name, query=msg, model_type=provider
|
| 31 |
)
|
| 32 |
if show_rationale:
|
| 33 |
rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: ' followed by 1 sentence about how your reasoning for how the persona traits influenced this response, also in italics. Begin your answer on the next line."
|
| 34 |
-
return rewritten_msg
|
| 35 |
except Exception as e:
|
| 36 |
-
st.
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
if show_rationale:
|
| 41 |
-
rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: No personalization applied (backend unavailable).'. Begin your answer on the next line."
|
| 42 |
-
return rewritten_msg
|
| 43 |
-
|
| 44 |
|
| 45 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
# Page setup & CSS
|
|
@@ -52,17 +155,96 @@ try:
|
|
| 52 |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
| 53 |
except FileNotFoundError:
|
| 54 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
HF_CLIENT_ID = os.getenv("HF_OAUTH_CLIENT_ID")
|
| 57 |
-
HF_CLIENT_SECRET = os.getenv("HF_OAUTH_CLIENT_SECRET")
|
| 58 |
-
HF_REDIRECT_URI = os.getenv(
|
| 59 |
-
"HF_OAUTH_REDIRECT_URI", "https://memverge-memmachine-playground.hf.space/"
|
| 60 |
-
)
|
| 61 |
-
HF_SCOPES = os.getenv("HF_OAUTH_SCOPES", "openid profile email")
|
| 62 |
-
HF_AUTH_URL = "https://huggingface.co/oauth/authorize"
|
| 63 |
-
HF_TOKEN_URL = "https://huggingface.co/oauth/token"
|
| 64 |
-
HF_PROFILE_URL = "https://huggingface.co/api/whoami-v2"
|
| 65 |
-
HF_OAUTH_READY = bool(HF_CLIENT_ID and HF_CLIENT_SECRET and HF_REDIRECT_URI)
|
| 66 |
|
| 67 |
|
| 68 |
def build_authorize_url(state: str) -> str:
|
|
@@ -170,16 +352,115 @@ compare_personas = False
|
|
| 170 |
show_rationale = False
|
| 171 |
|
| 172 |
with st.sidebar:
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
st.markdown("#### Choose Model")
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
)
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
set_model(model_id)
|
| 184 |
|
| 185 |
st.markdown("#### Choose user persona")
|
|
@@ -193,17 +474,26 @@ with st.sidebar:
|
|
| 193 |
custom_persona.strip() if custom_persona.strip() else selected_persona
|
| 194 |
)
|
| 195 |
|
| 196 |
-
skip_rewrite = st.checkbox("Skip Rewrite")
|
| 197 |
compare_personas = st.checkbox("Compare with Control persona")
|
| 198 |
show_rationale = st.checkbox("Show Persona Rationale")
|
| 199 |
|
| 200 |
st.divider()
|
| 201 |
if st.button("Clear chat", use_container_width=True):
|
| 202 |
-
st.session_state.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
st.rerun()
|
| 204 |
if st.button("Delete Profile", use_container_width=True):
|
| 205 |
success = delete_profile(persona_name)
|
| 206 |
-
st.session_state.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
if success:
|
| 208 |
st.success(f"Profile for '{persona_name}' deleted.")
|
| 209 |
else:
|
|
@@ -211,12 +501,6 @@ with st.sidebar:
|
|
| 211 |
st.divider()
|
| 212 |
|
| 213 |
|
| 214 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 215 |
-
# Session state
|
| 216 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 217 |
-
if "history" not in st.session_state:
|
| 218 |
-
st.session_state.history = cast(list[dict], [])
|
| 219 |
-
|
| 220 |
|
| 221 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 222 |
# Enforce alternating roles
|
|
@@ -245,53 +529,47 @@ def append_user_turn(msgs: list[dict], new_user_msg: str) -> list[dict]:
|
|
| 245 |
return msgs
|
| 246 |
|
| 247 |
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
if
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
|
| 258 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 259 |
-
# Chat logic
|
| 260 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 261 |
msg = st.chat_input("Type your messageβ¦")
|
| 262 |
if msg:
|
| 263 |
st.session_state.history.append({"role": "user", "content": msg})
|
| 264 |
if compare_personas:
|
| 265 |
all_answers = {}
|
| 266 |
-
rewritten_msg = rewrite_message(
|
| 267 |
-
msg, persona_name, show_rationale, False, provider
|
| 268 |
-
)
|
| 269 |
msgs = clean_history(st.session_state.history, persona_name)
|
| 270 |
msgs = append_user_turn(msgs, rewritten_msg)
|
| 271 |
-
txt,
|
| 272 |
all_answers[persona_name] = txt
|
| 273 |
|
| 274 |
-
rewritten_msg_control = rewrite_message(
|
| 275 |
-
msg, "Control", show_rationale, True, provider
|
| 276 |
-
)
|
| 277 |
msgs_control = clean_history(st.session_state.history, "Control")
|
| 278 |
msgs_control = append_user_turn(msgs_control, rewritten_msg_control)
|
| 279 |
-
txt_control,
|
| 280 |
all_answers["Control"] = txt_control
|
| 281 |
|
| 282 |
st.session_state.history.append(
|
| 283 |
-
{"role": "assistant_all", "axis": "role", "content": all_answers}
|
| 284 |
)
|
| 285 |
else:
|
| 286 |
-
rewritten_msg = rewrite_message(
|
| 287 |
-
msg, persona_name, show_rationale, skip_rewrite, provider
|
| 288 |
-
)
|
| 289 |
msgs = clean_history(st.session_state.history, persona_name)
|
| 290 |
msgs = append_user_turn(msgs, rewritten_msg)
|
| 291 |
-
txt,
|
|
|
|
|
|
|
| 292 |
st.session_state.history.append(
|
| 293 |
-
{"role": "assistant", "persona": persona_name, "content": txt}
|
| 294 |
)
|
|
|
|
| 295 |
|
| 296 |
|
| 297 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -301,32 +579,52 @@ for turn in st.session_state.history:
|
|
| 301 |
if turn.get("role") == "user":
|
| 302 |
st.chat_message("user").write(turn["content"])
|
| 303 |
elif turn.get("role") == "assistant":
|
| 304 |
-
st.chat_message("assistant")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
elif turn.get("role") == "assistant_all":
|
| 306 |
content_items = list(turn["content"].items())
|
|
|
|
| 307 |
if len(content_items) >= 2:
|
| 308 |
cols = st.columns([1, 0.03, 1])
|
| 309 |
persona_label, persona_response = content_items[0]
|
| 310 |
control_label, control_response = content_items[1]
|
| 311 |
with cols[0]:
|
| 312 |
st.markdown(f"**{persona_label}**")
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
| 317 |
with cols[1]:
|
| 318 |
st.markdown(
|
| 319 |
'<div class="vertical-divider"></div>', unsafe_allow_html=True
|
| 320 |
)
|
| 321 |
with cols[2]:
|
| 322 |
st.markdown(f"**{control_label}**")
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
| 327 |
else:
|
| 328 |
for label, response in content_items:
|
| 329 |
st.markdown(f"**{label}**")
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import secrets
|
| 3 |
+
import time
|
| 4 |
from typing import cast
|
| 5 |
from urllib.parse import urlencode
|
| 6 |
|
|
|
|
| 11 |
from llm import chat, set_model
|
| 12 |
from model_config import MODEL_CHOICES, MODEL_TO_PROVIDER
|
| 13 |
|
| 14 |
+
HF_CLIENT_ID = os.getenv("HF_OAUTH_CLIENT_ID")
|
| 15 |
+
HF_CLIENT_SECRET = os.getenv("HF_OAUTH_CLIENT_SECRET")
|
| 16 |
+
HF_REDIRECT_URI = os.getenv(
|
| 17 |
+
"HF_OAUTH_REDIRECT_URI", "https://memverge-memmachine-playground.hf.space/"
|
| 18 |
+
)
|
| 19 |
+
HF_SCOPES = os.getenv("HF_OAUTH_SCOPES", "openid profile email")
|
| 20 |
+
HF_AUTH_URL = "https://huggingface.co/oauth/authorize"
|
| 21 |
+
HF_TOKEN_URL = "https://huggingface.co/oauth/token"
|
| 22 |
+
HF_PROFILE_URL = "https://huggingface.co/api/whoami-v2"
|
| 23 |
+
HF_OAUTH_READY = bool(HF_CLIENT_ID and HF_CLIENT_SECRET and HF_REDIRECT_URI)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _generate_session_name(base: str = "Session") -> str:
|
| 27 |
+
existing = set(st.session_state.get("session_order", []))
|
| 28 |
+
idx = 1
|
| 29 |
+
while True:
|
| 30 |
+
candidate = f"{base} {idx}"
|
| 31 |
+
if candidate not in existing:
|
| 32 |
+
return candidate
|
| 33 |
+
idx += 1
|
| 34 |
+
|
| 35 |
+
def ensure_session_state() -> None:
|
| 36 |
+
if "sessions" not in st.session_state:
|
| 37 |
+
st.session_state.sessions = {}
|
| 38 |
+
if "session_order" not in st.session_state:
|
| 39 |
+
st.session_state.session_order = []
|
| 40 |
+
if (
|
| 41 |
+
"active_session_id" not in st.session_state
|
| 42 |
+
or st.session_state.active_session_id not in st.session_state.sessions
|
| 43 |
+
):
|
| 44 |
+
default_name = _generate_session_name()
|
| 45 |
+
st.session_state.sessions.setdefault(default_name, {"history": []})
|
| 46 |
+
if default_name not in st.session_state.session_order:
|
| 47 |
+
st.session_state.session_order.append(default_name)
|
| 48 |
+
st.session_state.active_session_id = default_name
|
| 49 |
+
if "session_select" not in st.session_state:
|
| 50 |
+
st.session_state.session_select = st.session_state.active_session_id
|
| 51 |
+
if st.session_state.session_select not in st.session_state.sessions:
|
| 52 |
+
st.session_state.session_select = st.session_state.active_session_id
|
| 53 |
+
st.session_state.setdefault(
|
| 54 |
+
"rename_session_name", st.session_state.active_session_id
|
| 55 |
+
)
|
| 56 |
+
st.session_state.setdefault(
|
| 57 |
+
"rename_session_synced_to", st.session_state.active_session_id
|
| 58 |
+
)
|
| 59 |
+
st.session_state.history = cast(
|
| 60 |
+
list[dict],
|
| 61 |
+
st.session_state.sessions[
|
| 62 |
+
st.session_state.active_session_id
|
| 63 |
+
].setdefault("history", []),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def create_session(session_name: str | None = None) -> tuple[bool, str]:
|
| 68 |
+
ensure_session_state()
|
| 69 |
+
candidate = (session_name or "").strip()
|
| 70 |
+
if not candidate:
|
| 71 |
+
candidate = _generate_session_name()
|
| 72 |
+
if candidate in st.session_state.sessions:
|
| 73 |
+
return False, candidate
|
| 74 |
+
st.session_state.sessions[candidate] = {"history": []}
|
| 75 |
+
st.session_state.session_order.append(candidate)
|
| 76 |
+
st.session_state.active_session_id = candidate
|
| 77 |
+
st.session_state.session_select = candidate
|
| 78 |
+
st.session_state.history = cast(
|
| 79 |
+
list[dict], st.session_state.sessions[candidate]["history"]
|
| 80 |
+
)
|
| 81 |
+
st.session_state.rename_session_name = candidate
|
| 82 |
+
st.session_state.rename_session_synced_to = candidate
|
| 83 |
+
return True, candidate
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def rename_session(current_name: str, new_name: str) -> bool:
|
| 87 |
+
ensure_session_state()
|
| 88 |
+
target = new_name.strip()
|
| 89 |
+
if not target or target == current_name:
|
| 90 |
+
return False
|
| 91 |
+
if target in st.session_state.sessions:
|
| 92 |
+
return False
|
| 93 |
+
st.session_state.sessions[target] = st.session_state.sessions.pop(current_name)
|
| 94 |
+
order = st.session_state.session_order
|
| 95 |
+
order[order.index(current_name)] = target
|
| 96 |
+
if st.session_state.active_session_id == current_name:
|
| 97 |
+
st.session_state.active_session_id = target
|
| 98 |
+
st.session_state.session_select = target
|
| 99 |
+
st.session_state.history = cast(
|
| 100 |
+
list[dict],
|
| 101 |
+
st.session_state.sessions[st.session_state.active_session_id]["history"],
|
| 102 |
+
)
|
| 103 |
+
st.session_state.rename_session_name = target
|
| 104 |
+
st.session_state.rename_session_synced_to = target
|
| 105 |
+
return True
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def delete_session(session_name: str) -> bool:
|
| 109 |
+
ensure_session_state()
|
| 110 |
+
if session_name not in st.session_state.sessions:
|
| 111 |
+
return False
|
| 112 |
+
if len(st.session_state.session_order) <= 1:
|
| 113 |
+
return False
|
| 114 |
+
st.session_state.sessions.pop(session_name, None)
|
| 115 |
+
st.session_state.session_order.remove(session_name)
|
| 116 |
+
if st.session_state.active_session_id == session_name:
|
| 117 |
+
st.session_state.active_session_id = st.session_state.session_order[-1]
|
| 118 |
+
st.session_state.session_select = st.session_state.active_session_id
|
| 119 |
+
st.session_state.rename_session_name = st.session_state.active_session_id
|
| 120 |
+
st.session_state.rename_session_synced_to = st.session_state.active_session_id
|
| 121 |
+
st.session_state.history = cast(
|
| 122 |
+
list[dict],
|
| 123 |
+
st.session_state.sessions[st.session_state.active_session_id]["history"],
|
| 124 |
+
)
|
| 125 |
+
return True
|
| 126 |
+
|
| 127 |
|
| 128 |
def rewrite_message(
|
| 129 |
+
msg: str, persona_name: str, show_rationale: bool
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
) -> str:
|
| 131 |
+
if persona_name.lower() == "control":
|
|
|
|
| 132 |
rewritten_msg = msg
|
| 133 |
if show_rationale:
|
| 134 |
rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: No personalization applied.'. Begin your answer on the next line."
|
| 135 |
return rewritten_msg
|
|
|
|
| 136 |
try:
|
| 137 |
rewritten_msg = ingest_and_rewrite(
|
| 138 |
user_id=persona_name, query=msg, model_type=provider
|
| 139 |
)
|
| 140 |
if show_rationale:
|
| 141 |
rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: ' followed by 1 sentence about how your reasoning for how the persona traits influenced this response, also in italics. Begin your answer on the next line."
|
|
|
|
| 142 |
except Exception as e:
|
| 143 |
+
st.error(f"Failed to ingest_and_append message: {e}")
|
| 144 |
+
raise
|
| 145 |
+
print(rewritten_msg)
|
| 146 |
+
return rewritten_msg
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 149 |
# Page setup & CSS
|
|
|
|
| 155 |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
| 156 |
except FileNotFoundError:
|
| 157 |
pass
|
| 158 |
+
|
| 159 |
+
ensure_session_state()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
HEADER_STYLE = """
|
| 163 |
+
<style>
|
| 164 |
+
.memmachine-header-wrapper {
|
| 165 |
+
display: flex;
|
| 166 |
+
justify-content: flex-end;
|
| 167 |
+
margin-bottom: 1.2rem;
|
| 168 |
+
}
|
| 169 |
+
.memmachine-header-links {
|
| 170 |
+
display: inline-flex;
|
| 171 |
+
gap: 14px;
|
| 172 |
+
align-items: center;
|
| 173 |
+
background: transparent;
|
| 174 |
+
padding: 0;
|
| 175 |
+
border-radius: 0;
|
| 176 |
+
}
|
| 177 |
+
.memmachine-header-links .powered-by {
|
| 178 |
+
color: #0a6cff;
|
| 179 |
+
font-weight: 700;
|
| 180 |
+
font-size: 16px;
|
| 181 |
+
margin-right: 6px;
|
| 182 |
+
white-space: nowrap;
|
| 183 |
+
}
|
| 184 |
+
.memmachine-header-links a {
|
| 185 |
+
text-decoration: none;
|
| 186 |
+
color: inherit;
|
| 187 |
+
display: flex;
|
| 188 |
+
align-items: center;
|
| 189 |
+
justify-content: center;
|
| 190 |
+
padding: 0;
|
| 191 |
+
border-radius: 0;
|
| 192 |
+
transition: opacity 0.2s ease;
|
| 193 |
+
}
|
| 194 |
+
.memmachine-header-links a:hover {
|
| 195 |
+
opacity: 0.7;
|
| 196 |
+
}
|
| 197 |
+
.memmachine-header-links img,
|
| 198 |
+
.memmachine-header-links svg {
|
| 199 |
+
width: 22px;
|
| 200 |
+
height: 22px;
|
| 201 |
+
}
|
| 202 |
+
@media (max-width: 768px) {
|
| 203 |
+
.memmachine-header-wrapper {
|
| 204 |
+
justify-content: center;
|
| 205 |
+
margin-bottom: 0.8rem;
|
| 206 |
+
}
|
| 207 |
+
.memmachine-header-links {
|
| 208 |
+
flex-wrap: wrap;
|
| 209 |
+
row-gap: 8px;
|
| 210 |
+
justify-content: center;
|
| 211 |
+
}
|
| 212 |
+
}
|
| 213 |
+
</style>
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
HEADER_HTML = """
|
| 217 |
+
<div class="memmachine-header-wrapper">
|
| 218 |
+
<div class="memmachine-header-links">
|
| 219 |
+
<span class="powered-by">Powered by MemMachine</span>
|
| 220 |
+
<a href="https://memmachine.ai/" target="_blank" title="MemMachine">
|
| 221 |
+
<img src="https://avatars.githubusercontent.com/u/226739620?s=48&v=4" alt="MemMachine logo"/>
|
| 222 |
+
</a>
|
| 223 |
+
<a href="https://github.com/MemMachine/MemMachine" target="_blank" title="GitHub Repository">
|
| 224 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor">
|
| 225 |
+
<path d="M12 0c-6.626 0-12 5.373-12 12 0 5.302 3.438 9.8 8.207 11.387.599.111.793-.261.793-.577v-2.234c-3.338.726-4.033-1.416-4.033-1.416-.546-1.387-1.333-1.756-1.333-1.756-1.089-.745.083-.729.083-.729 1.205.084 1.839 1.237 1.839 1.237 1.07 1.834 2.807 1.304 3.492.997.107-.775.418-1.305.762-1.604-2.665-.305-5.467-1.334-5.467-5.931 0-1.311.469-2.381 1.236-3.221-.124-.303-.535-1.524.117-3.176 0 0 1.008-.322 3.301 1.23.957-.266 1.983-.399 3.003-.404 1.02.005 2.047.138 3.006.404 2.291-1.552 3.297-1.23 3.297-1.23.653 1.653.242 2.874.118 3.176.77.84 1.235 1.911 1.235 3.221 0 4.609-2.807 5.624-5.479 5.921.43.372.823 1.102.823 2.222v3.293c0 .319.192.694.801.576 4.765-1.589 8.199-6.086 8.199-11.386 0-6.627-5.373-12-12-12z"/>
|
| 226 |
+
</svg>
|
| 227 |
+
</a>
|
| 228 |
+
<a href="https://discord.gg/usydANvKqD" target="_blank" title="Discord Community">
|
| 229 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor">
|
| 230 |
+
<path d="M20.317 4.37a19.791 19.791 0 0 0-4.885-1.515.074.074 0 0 0-.079.037c-.21.375-.444.864-.608 1.25a18.27 18.27 0 0 0-5.487 0 12.64 12.64 0 0 0-.617-1.25.077.077 0 0 0-.079-.037A19.736 19.736 0 0 0 3.677 4.37a.07.07 0 0 0-.032.027C.533 9.046-.32 13.58.099 18.057a.082.082 0 0 0 .031.057 19.9 19.9 0 0 0 5.993 3.03.078.078 0 0 0 .084-.028c.462-.63.874-1.295 1.226-1.994a.076.076 0 0 0-.041-.106 13.107 13.107 0 0 1-1.872-.892.077.077 0 0 1-.008-.128 10.2 10.2 0 0 0 .372-.292.074.074 0 0 1 .077-.01c3.928 1.793 8.18 1.793 12.062 0a.074.074 0 0 1 .078.01c.12.098.246.198.373.292a.077.077 0 0 1-.006.127 12.299 12.299 0 0 1-1.873.892.077.077 0 0 0-.041.107c.36.698.772 1.362 1.225 1.993a.076.076 0 0 0 .084.028 19.839 19.839 0 0 0 6.002-3.03.077.077 0 0 0 .032-.054c.5-5.177-.838-9.674-3.549-13.66a.061.061 0 0 0-.031-.03zM8.02 15.33c-1.183 0-2.157-1.085-2.157-2.419 0-1.333.956-2.419 2.157-2.419 1.21 0 2.176 1.096 2.157 2.42 0 1.333-.956 2.418-2.157 2.418zm7.975 0c-1.183 0-2.157-1.085-2.157-2.419 0-1.333.955-2.419 2.157-2.419 1.21 0 2.176 1.096 2.157 2.42 0 1.333-.946 2.418-2.157 2.418z"/>
|
| 231 |
+
</svg>
|
| 232 |
+
</a>
|
| 233 |
+
</div>
|
| 234 |
+
</div>
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
st.markdown(HEADER_STYLE, unsafe_allow_html=True)
|
| 238 |
+
st.markdown(HEADER_HTML, unsafe_allow_html=True)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 242 |
+
# Sidebar
|
| 243 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
|
| 250 |
def build_authorize_url(state: str) -> str:
|
|
|
|
| 352 |
show_rationale = False
|
| 353 |
|
| 354 |
with st.sidebar:
|
| 355 |
+
st.markdown("#### Sessions")
|
| 356 |
+
session_options = st.session_state.session_order
|
| 357 |
+
active_session = st.session_state.active_session_id
|
| 358 |
+
if st.session_state.rename_session_synced_to != active_session:
|
| 359 |
+
st.session_state.rename_session_name = active_session
|
| 360 |
+
st.session_state.rename_session_synced_to = active_session
|
| 361 |
+
|
| 362 |
+
for idx, session_name in enumerate(session_options, start=1):
|
| 363 |
+
is_active = session_name == active_session
|
| 364 |
+
button_label = f"{session_name}"
|
| 365 |
+
row = st.container()
|
| 366 |
+
with row:
|
| 367 |
+
button_col, menu_col = st.columns([0.8, 0.2])
|
| 368 |
+
with button_col:
|
| 369 |
+
if st.button(
|
| 370 |
+
button_label,
|
| 371 |
+
key=f"session_button_{session_name}",
|
| 372 |
+
use_container_width=True,
|
| 373 |
+
type="primary" if is_active else "secondary",
|
| 374 |
+
):
|
| 375 |
+
if not is_active:
|
| 376 |
+
st.session_state.active_session_id = session_name
|
| 377 |
+
st.session_state.session_select = session_name
|
| 378 |
+
st.session_state.history = cast(
|
| 379 |
+
list[dict],
|
| 380 |
+
st.session_state.sessions[session_name]["history"],
|
| 381 |
+
)
|
| 382 |
+
st.session_state.rename_session_name = session_name
|
| 383 |
+
st.session_state.rename_session_synced_to = session_name
|
| 384 |
+
st.rerun()
|
| 385 |
+
with menu_col:
|
| 386 |
+
if hasattr(st, "popover"):
|
| 387 |
+
menu_container = st.popover("β―", use_container_width=True)
|
| 388 |
+
else:
|
| 389 |
+
menu_container = st.expander(
|
| 390 |
+
"β―", expanded=False, key=f"session_actions_{session_name}"
|
| 391 |
+
)
|
| 392 |
+
with menu_container:
|
| 393 |
+
st.markdown(f"**Actions for {session_name}**")
|
| 394 |
+
rename_value = st.text_input(
|
| 395 |
+
"Rename session",
|
| 396 |
+
value=session_name,
|
| 397 |
+
key=f"rename_session_input_{session_name}",
|
| 398 |
+
)
|
| 399 |
+
if st.button(
|
| 400 |
+
"Rename",
|
| 401 |
+
use_container_width=True,
|
| 402 |
+
key=f"rename_session_button_{session_name}",
|
| 403 |
+
):
|
| 404 |
+
rename_target = rename_value.strip()
|
| 405 |
+
if not rename_target:
|
| 406 |
+
st.warning("Enter a session name to rename.")
|
| 407 |
+
elif rename_target == session_name:
|
| 408 |
+
st.info("Session name unchanged.")
|
| 409 |
+
elif rename_target in st.session_state.sessions:
|
| 410 |
+
st.warning(f"Session '{rename_target}' already exists.")
|
| 411 |
+
elif rename_session(session_name, rename_target):
|
| 412 |
+
st.success(f"Session renamed to '{rename_target}'.")
|
| 413 |
+
st.rerun()
|
| 414 |
+
else:
|
| 415 |
+
st.error("Unable to rename session. Please try again.")
|
| 416 |
+
|
| 417 |
+
st.divider()
|
| 418 |
+
if st.button(
|
| 419 |
+
"Delete session",
|
| 420 |
+
use_container_width=True,
|
| 421 |
+
type="secondary",
|
| 422 |
+
key=f"delete_session_button_{session_name}",
|
| 423 |
+
):
|
| 424 |
+
if delete_session(session_name):
|
| 425 |
+
new_active = st.session_state.active_session_id
|
| 426 |
+
st.session_state.session_select = new_active
|
| 427 |
+
st.session_state.rename_session_name = new_active
|
| 428 |
+
st.session_state.rename_session_synced_to = new_active
|
| 429 |
+
st.success(f"Session '{session_name}' deleted.")
|
| 430 |
+
st.rerun()
|
| 431 |
+
else:
|
| 432 |
+
st.warning("Cannot delete the last remaining session.")
|
| 433 |
+
|
| 434 |
+
with st.form("create_session_form", clear_on_submit=True):
|
| 435 |
+
new_session_name = st.text_input(
|
| 436 |
+
"New session name",
|
| 437 |
+
key="create_session_name",
|
| 438 |
+
placeholder="Leave blank for automatic name",
|
| 439 |
+
)
|
| 440 |
+
if st.form_submit_button("Create session", use_container_width=True):
|
| 441 |
+
success, created_name = create_session(new_session_name)
|
| 442 |
+
if success:
|
| 443 |
+
st.success(f"Session '{created_name}' created.")
|
| 444 |
+
st.rerun()
|
| 445 |
+
else:
|
| 446 |
+
st.warning(f"Session '{created_name}' already exists.")
|
| 447 |
+
|
| 448 |
+
st.divider()
|
| 449 |
|
| 450 |
st.markdown("#### Choose Model")
|
| 451 |
+
|
| 452 |
+
# Create display options with categories
|
| 453 |
+
display_options = [MODEL_DISPLAY_NAMES[model] for model in MODEL_CHOICES]
|
| 454 |
+
|
| 455 |
+
selected_display = st.selectbox(
|
| 456 |
+
"Choose Model", display_options, index=0, label_visibility="collapsed"
|
| 457 |
)
|
| 458 |
+
|
| 459 |
+
# Get the actual model ID from the display name
|
| 460 |
+
model_id = next(model for model, display in MODEL_DISPLAY_NAMES.items()
|
| 461 |
+
if display == selected_display)
|
| 462 |
+
|
| 463 |
+
provider = MODEL_TO_PROVIDER[model_id]
|
| 464 |
set_model(model_id)
|
| 465 |
|
| 466 |
st.markdown("#### Choose user persona")
|
|
|
|
| 474 |
custom_persona.strip() if custom_persona.strip() else selected_persona
|
| 475 |
)
|
| 476 |
|
|
|
|
| 477 |
compare_personas = st.checkbox("Compare with Control persona")
|
| 478 |
show_rationale = st.checkbox("Show Persona Rationale")
|
| 479 |
|
| 480 |
st.divider()
|
| 481 |
if st.button("Clear chat", use_container_width=True):
|
| 482 |
+
active = st.session_state.active_session_id
|
| 483 |
+
st.session_state.sessions[active]["history"].clear()
|
| 484 |
+
st.session_state.history = cast(
|
| 485 |
+
list[dict],
|
| 486 |
+
st.session_state.sessions[active]["history"],
|
| 487 |
+
)
|
| 488 |
st.rerun()
|
| 489 |
if st.button("Delete Profile", use_container_width=True):
|
| 490 |
success = delete_profile(persona_name)
|
| 491 |
+
active = st.session_state.active_session_id
|
| 492 |
+
st.session_state.sessions[active]["history"].clear()
|
| 493 |
+
st.session_state.history = cast(
|
| 494 |
+
list[dict],
|
| 495 |
+
st.session_state.sessions[active]["history"],
|
| 496 |
+
)
|
| 497 |
if success:
|
| 498 |
st.success(f"Profile for '{persona_name}' deleted.")
|
| 499 |
else:
|
|
|
|
| 501 |
st.divider()
|
| 502 |
|
| 503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
|
| 505 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 506 |
# Enforce alternating roles
|
|
|
|
| 529 |
return msgs
|
| 530 |
|
| 531 |
|
| 532 |
+
def typewriter_effect(text: str, speed: float = 0.02):
|
| 533 |
+
"""Generator that yields text word by word to create a typing effect."""
|
| 534 |
+
words = text.split(" ")
|
| 535 |
+
for i, word in enumerate(words):
|
| 536 |
+
if i == 0:
|
| 537 |
+
yield word
|
| 538 |
+
else:
|
| 539 |
+
yield " " + word
|
| 540 |
+
time.sleep(speed)
|
| 541 |
|
|
|
|
|
|
|
|
|
|
| 542 |
msg = st.chat_input("Type your messageβ¦")
|
| 543 |
if msg:
|
| 544 |
st.session_state.history.append({"role": "user", "content": msg})
|
| 545 |
if compare_personas:
|
| 546 |
all_answers = {}
|
| 547 |
+
rewritten_msg = rewrite_message(msg, persona_name, show_rationale)
|
|
|
|
|
|
|
| 548 |
msgs = clean_history(st.session_state.history, persona_name)
|
| 549 |
msgs = append_user_turn(msgs, rewritten_msg)
|
| 550 |
+
txt, lat, tok, tps = chat(msgs, persona_name)
|
| 551 |
all_answers[persona_name] = txt
|
| 552 |
|
| 553 |
+
rewritten_msg_control = rewrite_message(msg, "Control", show_rationale)
|
|
|
|
|
|
|
| 554 |
msgs_control = clean_history(st.session_state.history, "Control")
|
| 555 |
msgs_control = append_user_turn(msgs_control, rewritten_msg_control)
|
| 556 |
+
txt_control, lat, tok, tps = chat(msgs_control, "Arnold")
|
| 557 |
all_answers["Control"] = txt_control
|
| 558 |
|
| 559 |
st.session_state.history.append(
|
| 560 |
+
{"role": "assistant_all", "axis": "role", "content": all_answers, "is_new": True}
|
| 561 |
)
|
| 562 |
else:
|
| 563 |
+
rewritten_msg = rewrite_message(msg, persona_name, show_rationale)
|
|
|
|
|
|
|
| 564 |
msgs = clean_history(st.session_state.history, persona_name)
|
| 565 |
msgs = append_user_turn(msgs, rewritten_msg)
|
| 566 |
+
txt, lat, tok, tps = chat(
|
| 567 |
+
msgs, "Arnold" if persona_name == "Control" else persona_name
|
| 568 |
+
)
|
| 569 |
st.session_state.history.append(
|
| 570 |
+
{"role": "assistant", "persona": persona_name, "content": txt, "is_new": True}
|
| 571 |
)
|
| 572 |
+
st.rerun()
|
| 573 |
|
| 574 |
|
| 575 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 579 |
if turn.get("role") == "user":
|
| 580 |
st.chat_message("user").write(turn["content"])
|
| 581 |
elif turn.get("role") == "assistant":
|
| 582 |
+
with st.chat_message("assistant"):
|
| 583 |
+
# Use typing effect for new messages, normal display for old ones
|
| 584 |
+
if turn.get("is_new", False):
|
| 585 |
+
st.write_stream(typewriter_effect(turn["content"]))
|
| 586 |
+
# Mark as no longer new so it displays normally on rerun
|
| 587 |
+
turn["is_new"] = False
|
| 588 |
+
else:
|
| 589 |
+
st.write(turn["content"])
|
| 590 |
elif turn.get("role") == "assistant_all":
|
| 591 |
content_items = list(turn["content"].items())
|
| 592 |
+
is_new = turn.get("is_new", False)
|
| 593 |
if len(content_items) >= 2:
|
| 594 |
cols = st.columns([1, 0.03, 1])
|
| 595 |
persona_label, persona_response = content_items[0]
|
| 596 |
control_label, control_response = content_items[1]
|
| 597 |
with cols[0]:
|
| 598 |
st.markdown(f"**{persona_label}**")
|
| 599 |
+
if is_new:
|
| 600 |
+
st.write_stream(typewriter_effect(persona_response))
|
| 601 |
+
else:
|
| 602 |
+
st.markdown(
|
| 603 |
+
f'<div class="answer">{persona_response}</div>',
|
| 604 |
+
unsafe_allow_html=True,
|
| 605 |
+
)
|
| 606 |
with cols[1]:
|
| 607 |
st.markdown(
|
| 608 |
'<div class="vertical-divider"></div>', unsafe_allow_html=True
|
| 609 |
)
|
| 610 |
with cols[2]:
|
| 611 |
st.markdown(f"**{control_label}**")
|
| 612 |
+
if is_new:
|
| 613 |
+
st.write_stream(typewriter_effect(control_response))
|
| 614 |
+
else:
|
| 615 |
+
st.markdown(
|
| 616 |
+
f'<div class="answer">{control_response}</div>',
|
| 617 |
+
unsafe_allow_html=True,
|
| 618 |
+
)
|
| 619 |
else:
|
| 620 |
for label, response in content_items:
|
| 621 |
st.markdown(f"**{label}**")
|
| 622 |
+
if is_new:
|
| 623 |
+
st.write_stream(typewriter_effect(response))
|
| 624 |
+
else:
|
| 625 |
+
st.markdown(
|
| 626 |
+
f'<div class="answer">{response}</div>', unsafe_allow_html=True
|
| 627 |
+
)
|
| 628 |
+
# Mark as no longer new
|
| 629 |
+
if is_new:
|
| 630 |
+
turn["is_new"] = False
|