import os import secrets import time from typing import cast from urllib.parse import urlencode import requests import streamlit as st from gateway_client import delete_profile, ingest_and_rewrite, ingest_memories from llm import chat, set_model from model_config import MODEL_CHOICES, MODEL_TO_PROVIDER, MODEL_DISPLAY_NAMES def _generate_session_name(base: str = "Session") -> str: existing = set(st.session_state.get("session_order", [])) idx = 1 while True: candidate = f"{base} {idx}" if candidate not in existing: return candidate idx += 1 def ensure_session_state() -> None: if "sessions" not in st.session_state: st.session_state.sessions = {} if "session_order" not in st.session_state: st.session_state.session_order = [] if ( "active_session_id" not in st.session_state or st.session_state.active_session_id not in st.session_state.sessions ): default_name = _generate_session_name() st.session_state.sessions.setdefault(default_name, {"history": []}) if default_name not in st.session_state.session_order: st.session_state.session_order.append(default_name) st.session_state.active_session_id = default_name if "session_select" not in st.session_state: st.session_state.session_select = st.session_state.active_session_id if st.session_state.session_select not in st.session_state.sessions: st.session_state.session_select = st.session_state.active_session_id st.session_state.setdefault( "rename_session_name", st.session_state.active_session_id ) st.session_state.setdefault( "rename_session_synced_to", st.session_state.active_session_id ) st.session_state.history = cast( list[dict], st.session_state.sessions[ st.session_state.active_session_id ].setdefault("history", []), ) def create_session(session_name: str | None = None) -> tuple[bool, str]: ensure_session_state() candidate = (session_name or "").strip() if not candidate: candidate = _generate_session_name() if candidate in st.session_state.sessions: return False, candidate st.session_state.sessions[candidate] = {"history": []} st.session_state.session_order.append(candidate) st.session_state.active_session_id = candidate st.session_state.session_select = candidate st.session_state.history = cast( list[dict], st.session_state.sessions[candidate]["history"] ) st.session_state.rename_session_name = candidate st.session_state.rename_session_synced_to = candidate return True, candidate def rename_session(current_name: str, new_name: str) -> bool: ensure_session_state() target = new_name.strip() if not target or target == current_name: return False if target in st.session_state.sessions: return False st.session_state.sessions[target] = st.session_state.sessions.pop(current_name) order = st.session_state.session_order order[order.index(current_name)] = target if st.session_state.active_session_id == current_name: st.session_state.active_session_id = target st.session_state.session_select = target st.session_state.history = cast( list[dict], st.session_state.sessions[st.session_state.active_session_id]["history"], ) st.session_state.rename_session_name = target st.session_state.rename_session_synced_to = target return True def delete_session(session_name: str) -> bool: ensure_session_state() if session_name not in st.session_state.sessions: return False if len(st.session_state.session_order) <= 1: return False st.session_state.sessions.pop(session_name, None) st.session_state.session_order.remove(session_name) if st.session_state.active_session_id == session_name: st.session_state.active_session_id = st.session_state.session_order[-1] st.session_state.session_select = st.session_state.active_session_id st.session_state.rename_session_name = st.session_state.active_session_id st.session_state.rename_session_synced_to = st.session_state.active_session_id st.session_state.history = cast( list[dict], st.session_state.sessions[st.session_state.active_session_id]["history"], ) return True def rewrite_message( msg: str, persona_name: str, show_rationale: bool, use_memory: bool = True ) -> str: # If memory is disabled or Control persona, don't use memory if not use_memory or persona_name.lower() == "control": rewritten_msg = msg if show_rationale: rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: No personalization applied.'. Begin your answer on the next line." return rewritten_msg try: rewritten_msg = ingest_and_rewrite( user_id=persona_name, query=msg ) if show_rationale: rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: ' followed by 1 sentence about how your reasoning for how the persona traits influenced this response, also in italics. Begin your answer on the next line." except Exception as e: st.error(f"Failed to ingest_and_append message: {e}") raise print(rewritten_msg) return rewritten_msg # ────────────────────────────────────────────────────────────── # Page setup & CSS # ────────────────────────────────────────────────────────────── st.set_page_config(page_title="MemMachine Chatbot", layout="wide") try: with open("./styles.css") as f: st.markdown(f"", unsafe_allow_html=True) except FileNotFoundError: pass ensure_session_state() HEADER_STYLE = """ """ HEADER_HTML = """
""" st.markdown(HEADER_STYLE, unsafe_allow_html=True) st.markdown(HEADER_HTML, unsafe_allow_html=True) # ────────────────────────────────────────────────────────────── # Sidebar # ────────────────────────────────────────────────────────────── default_model = MODEL_CHOICES[0] if MODEL_CHOICES else "gpt-4.1-mini" model_id = default_model provider = MODEL_TO_PROVIDER.get(default_model, "openai") selected_persona = "Charlie" persona_name = "Charlie" skip_rewrite = False compare_personas = False show_rationale = False with st.sidebar: st.markdown("#### Sessions") session_options = st.session_state.session_order active_session = st.session_state.active_session_id if st.session_state.rename_session_synced_to != active_session: st.session_state.rename_session_name = active_session st.session_state.rename_session_synced_to = active_session for idx, session_name in enumerate(session_options, start=1): is_active = session_name == active_session button_label = f"{session_name}" row = st.container() with row: button_col, menu_col = st.columns([0.8, 0.2]) with button_col: if st.button( button_label, key=f"session_button_{session_name}", use_container_width=True, type="primary" if is_active else "secondary", ): if not is_active: st.session_state.active_session_id = session_name st.session_state.session_select = session_name st.session_state.history = cast( list[dict], st.session_state.sessions[session_name]["history"], ) st.session_state.rename_session_name = session_name st.session_state.rename_session_synced_to = session_name st.rerun() with menu_col: if hasattr(st, "popover"): menu_container = st.popover("⋯", use_container_width=True) else: menu_container = st.expander( "⋯", expanded=False, key=f"session_actions_{session_name}" ) with menu_container: st.markdown(f"**Actions for {session_name}**") rename_value = st.text_input( "Rename session", value=session_name, key=f"rename_session_input_{session_name}", ) if st.button( "Rename", use_container_width=True, key=f"rename_session_button_{session_name}", ): rename_target = rename_value.strip() if not rename_target: st.warning("Enter a session name to rename.") elif rename_target == session_name: st.info("Session name unchanged.") elif rename_target in st.session_state.sessions: st.warning(f"Session '{rename_target}' already exists.") elif rename_session(session_name, rename_target): st.success(f"Session renamed to '{rename_target}'.") st.rerun() else: st.error("Unable to rename session. Please try again.") st.divider() if st.button( "Delete session", use_container_width=True, type="secondary", key=f"delete_session_button_{session_name}", ): if delete_session(session_name): new_active = st.session_state.active_session_id st.session_state.session_select = new_active st.session_state.rename_session_name = new_active st.session_state.rename_session_synced_to = new_active st.success(f"Session '{session_name}' deleted.") st.rerun() else: st.warning("Cannot delete the last remaining session.") with st.form("create_session_form", clear_on_submit=True): new_session_name = st.text_input( "New session name", key="create_session_name", placeholder="Leave blank for automatic name", ) if st.form_submit_button("Create session", use_container_width=True): success, created_name = create_session(new_session_name) if success: st.success(f"Session '{created_name}' created.") st.rerun() else: st.warning(f"Session '{created_name}' already exists.") st.divider() st.markdown("#### Choose Model") # Create display options with categories display_options = [MODEL_DISPLAY_NAMES[model] for model in MODEL_CHOICES] selected_display = st.selectbox( "Choose Model", display_options, index=0, label_visibility="collapsed" ) # Get the actual model ID from the display name model_id = next(model for model, display in MODEL_DISPLAY_NAMES.items() if display == selected_display) provider = MODEL_TO_PROVIDER[model_id] set_model(model_id) st.markdown("#### User Identity") # Get Hugging Face user ID if available (in HF Spaces) hf_user_id = os.getenv("SPACE_USER") or os.getenv("HF_USERNAME") or os.getenv("HF_USER") # Check if we're on Hugging Face Spaces (not local) is_hf_space = os.getenv("SPACE_ID") is not None or os.getenv("HF_ENDPOINT") is not None def validate_hf_token(token: str) -> tuple[bool, str, str]: """Validate HF token and return (is_valid, username, error_message).""" token = token.strip() if not token: return False, "", "Token cannot be empty" # Remove any whitespace or newlines that might have been copied token = "".join(token.split()) # Try using huggingface_hub library if available, otherwise fall back to API try: from huggingface_hub import whoami try: user_info = whoami(token=token) username = user_info.get("name") or user_info.get("username") or "" if username: return True, username, "" else: return False, "", "Token validated but username not found in response." except Exception as e: error_msg = str(e) if "401" in error_msg or "Unauthorized" in error_msg or "Invalid" in error_msg: return False, "", f"Invalid token. Please verify your token is correct and has Read permissions. Error: {error_msg[:100]}" return False, "", f"Validation error: {error_msg[:150]}" except ImportError: # Fall back to direct API call if huggingface_hub not available pass # Fallback: Use the HF whoami endpoint directly endpoint = "https://huggingface.co/api/whoami" headers = { "Authorization": f"Bearer {token}", "User-Agent": "MemMachine-Playground/1.0" } try: resp = requests.get(endpoint, headers=headers, timeout=10) if resp.status_code == 200: user_data = resp.json() # Try different possible username fields username = ( user_data.get("name") or user_data.get("username") or user_data.get("user") or "" ) if username: return True, username, "" else: return False, "", f"Token validated but username not found. Response: {str(user_data)[:100]}" elif resp.status_code == 401: error_detail = "" try: error_data = resp.json() error_detail = error_data.get("error", "") except: pass return False, "", f"Invalid token (401). The token may be expired, revoked, or incorrect. {error_detail} Please create a new Read token at https://huggingface.co/settings/tokens" elif resp.status_code == 403: return False, "", f"Token access denied (403). Please ensure your token has Read permissions." else: error_text = "" try: error_data = resp.json() error_text = error_data.get("error", resp.text[:100]) except: error_text = resp.text[:100] if hasattr(resp, 'text') else f"Status {resp.status_code}" return False, "", f"Authentication failed (Status {resp.status_code}): {error_text}" except requests.exceptions.Timeout: return False, "", "Request timed out. Please check your internet connection and try again." except requests.exceptions.RequestException as e: return False, "", f"Network error: {str(e)}. Please try again." except Exception as e: return False, "", f"Validation error: {str(e)}. Please try again." if is_hf_space: # On HF Spaces - require token authentication for security if "hf_authenticated_user" not in st.session_state: st.warning("🔐 **Authentication Required**") st.caption("To protect your memories, please authenticate with your Hugging Face account.") token_input = st.text_input( "Enter your Hugging Face Access Token", key="hf_token_input", type="password", placeholder="hf_xxxxxxxxxxxxxxxxxxxxx", help="❓ Create a Read token: https://huggingface.co/settings/tokens" ) if st.button("Authenticate", use_container_width=True, type="primary"): if token_input.strip(): with st.spinner("Validating token..."): is_valid, username, error_msg = validate_hf_token(token_input.strip()) if is_valid and username: st.session_state.hf_authenticated_user = username st.session_state.hf_token = token_input.strip() # Store for future use # Use custom purple styling instead of green success message st.markdown(f"""
✅ Authenticated as {username}
""", unsafe_allow_html=True) st.rerun() else: error_display = error_msg if error_msg else "Invalid token. Please check your Hugging Face access token." st.error(f"❌ {error_display}") else: st.error("Please enter your access token") st.info("💡 **Privacy Note:** Your token is used only for authentication. It is not stored or shared anywhere.") st.stop() else: # User is authenticated - lock to their username persona_name = st.session_state.hf_authenticated_user # Use custom purple styling instead of green success message st.markdown(f"""
🔐 Authenticated as: {persona_name}
""", unsafe_allow_html=True) st.caption("Your memories are secured to your account only.") if st.button("🔓 Sign Out", use_container_width=True): del st.session_state.hf_authenticated_user if "hf_token" in st.session_state: del st.session_state.hf_token st.rerun() elif hf_user_id: # HF user ID detected automatically persona_name = hf_user_id st.info(f"👤 Signed in as: **{hf_user_id}**") st.caption("Your memories are personalized to your account.") else: # Local/testing mode - allow persona selection selected_persona = st.selectbox( "Choose user persona", ["Charlie", "Jing", "Charles", "Control"], label_visibility="collapsed", ) custom_persona = st.text_input("Or enter your name", "") persona_name = ( custom_persona.strip() if custom_persona.strip() else selected_persona ) # Memory toggle - default enabled if "memmachine_enabled" not in st.session_state: st.session_state.memmachine_enabled = True if "compare_personas" not in st.session_state: st.session_state.compare_personas = True memmachine_enabled = st.checkbox( "Enable MemMachine", value=st.session_state.memmachine_enabled, help="Enable MemMachine's persistent memory system. When unchecked, the AI will respond without memory (Control Persona mode)." ) st.session_state.memmachine_enabled = memmachine_enabled if memmachine_enabled: # Enhanced "Compare with control persona" section with cool styling st.markdown("""
⚖️
Side-by-Side Comparison with Control Persona
Compare MemMachine responses vs Control Persona (no memory)
""", unsafe_allow_html=True) compare_personas = st.checkbox( "🔄 Compare with control persona", value=st.session_state.compare_personas, help="Enable side-by-side comparison to see how MemMachine's persistent memory enhances responses compared to the control persona (no memory)" ) st.session_state.compare_personas = compare_personas else: compare_personas = False show_rationale = st.checkbox("Show Persona Rationale") st.divider() if st.button("Clear chat", use_container_width=True): active = st.session_state.active_session_id st.session_state.sessions[active]["history"].clear() st.session_state.history = cast( list[dict], st.session_state.sessions[active]["history"], ) st.rerun() if st.button("Delete Profile", use_container_width=True): success = delete_profile(persona_name) active = st.session_state.active_session_id st.session_state.sessions[active]["history"].clear() st.session_state.history = cast( list[dict], st.session_state.sessions[active]["history"], ) if success: st.success(f"Profile for '{persona_name}' deleted.") else: st.error(f"Failed to delete profile for '{persona_name}'.") st.divider() # ────────────────────────────────────────────────────────────── # Enforce alternating roles # ────────────────────────────────────────────────────────────── def clean_history(history: list[dict], persona: str) -> list[dict]: out = [] for turn in history: if turn.get("role") == "user": out.append({"role": "user", "content": turn["content"]}) elif turn.get("role") == "assistant" and turn.get("persona") == persona: out.append({"role": "assistant", "content": turn["content"]}) cleaned = [] last_role = None for msg in out: if msg["role"] != last_role: cleaned.append(msg) last_role = msg["role"] return cleaned def append_user_turn(msgs: list[dict], new_user_msg: str) -> list[dict]: if msgs and msgs[-1]["role"] == "user": msgs[-1] = {"role": "user", "content": new_user_msg} else: msgs.append({"role": "user", "content": new_user_msg}) return msgs def typewriter_effect(text: str, speed: float = 0.02): """Generator that yields text word by word to create a typing effect.""" words = text.split(" ") for i, word in enumerate(words): if i == 0: yield word else: yield " " + word time.sleep(speed) # ────────────────────────────────────────────────────────────── # Load Previous Memories Section (Import External Memories) # ────────────────────────────────────────────────────────────── if "memories_preview" not in st.session_state: st.session_state.memories_preview = None if "imported_memories_text" not in st.session_state: st.session_state.imported_memories_text = "" # Add expandable section for importing memories with st.expander("📋 Load Previous Memories (Import from ChatGPT, etc.)", expanded=False): st.markdown("**Paste your conversation history or memories from external sources (e.g., ChatGPT, other AI chats)**") # Text area for pasting memories imported_text = st.text_area( "Paste your memories/conversations here", value=st.session_state.imported_memories_text, height=200, placeholder="Example:\nUser: What is machine learning?\nAssistant: Machine learning is...\n\nUser: Can you explain neural networks?\nAssistant: Neural networks are...", help="Paste any conversation history, notes, or context you want the AI to remember. These will be ingested into MemMachine's memory system and available for future conversations.", key="import_memories_textarea" ) # File upload option uploaded_file = st.file_uploader( "Or upload a text file", type=['txt', 'md', 'json'], help="Upload a text file containing your conversation history or memories" ) if uploaded_file is not None: try: # Read file content if uploaded_file.type == "application/json": import json file_content = json.loads(uploaded_file.read().decode("utf-8")) imported_text = str(file_content) else: imported_text = uploaded_file.read().decode("utf-8") st.session_state.imported_memories_text = imported_text st.success("File loaded successfully!") except Exception as e: st.error(f"Error reading file: {e}") col1, col2 = st.columns(2) with col1: if st.button("👁️ Preview", use_container_width=True, key="preview_memories"): if imported_text and imported_text.strip(): st.session_state.memories_preview = imported_text st.session_state.imported_memories_text = imported_text st.rerun() else: st.warning("Please paste or upload some memories first.") with col2: if st.button("💉 Ingest into MemMachine", use_container_width=True, key="inject_memories_direct"): if imported_text and imported_text.strip(): if persona_name and persona_name != "Control": with st.spinner("Ingesting memories into MemMachine..."): success = ingest_memories(persona_name, imported_text) if success: st.session_state.imported_memories_text = imported_text st.success("✅ Memories successfully ingested into MemMachine! They are now part of your memory system.") else: st.error("❌ Failed to ingest memories. Please try again.") else: st.warning("Please authenticate or select a persona to ingest memories.") st.rerun() else: st.warning("Please paste or upload some memories first.") # Show preview if memories are loaded if st.session_state.memories_preview: with st.expander("📋 Preview Imported Memories", expanded=True): memories = st.session_state.memories_preview preview_text = str(memories)[:2000] # Show first 2000 chars if preview_text: st.text_area("Memories Preview", preview_text, height=200, disabled=True, key="memories_preview_text") st.caption(f"Total length: {len(str(memories))} characters") col1, col2 = st.columns(2) with col1: if st.button("💉 Ingest into MemMachine", use_container_width=True, key="inject_memories_from_preview"): if persona_name and persona_name != "Control": with st.spinner("Ingesting memories into MemMachine..."): success = ingest_memories(persona_name, str(st.session_state.memories_preview)) if success: st.success("✅ Memories successfully ingested into MemMachine! They are now part of your memory system.") else: st.error("❌ Failed to ingest memories. Please try again.") else: st.warning("Please authenticate or select a persona to ingest memories.") st.rerun() with col2: if st.button("🗑️ Clear", use_container_width=True, key="clear_memories_preview"): st.session_state.memories_preview = None st.session_state.imported_memories_text = "" st.rerun() else: st.info("No memories to preview.") st.session_state.memories_preview = None msg = st.chat_input("Type your message…") if msg: st.session_state.history.append({"role": "user", "content": msg}) memmachine_enabled = st.session_state.get("memmachine_enabled", True) if compare_personas and memmachine_enabled: all_answers = {} rewritten_msg = rewrite_message(msg, persona_name, show_rationale, use_memory=True) msgs = clean_history(st.session_state.history, persona_name) msgs = append_user_turn(msgs, rewritten_msg) try: txt, lat, tok, tps = chat(msgs, persona_name) all_answers[persona_name] = txt except ValueError as e: st.error(f"❌ {str(e)}") st.stop() rewritten_msg_control = rewrite_message(msg, "Control", show_rationale, use_memory=False) msgs_control = clean_history(st.session_state.history, "Control") msgs_control = append_user_turn(msgs_control, rewritten_msg_control) try: txt_control, lat, tok, tps = chat(msgs_control, "Arnold") all_answers["Control"] = txt_control except ValueError as e: st.error(f"❌ {str(e)}") st.stop() st.session_state.history.append( {"role": "assistant_all", "axis": "role", "content": all_answers, "is_new": True} ) else: # Use memory only if memmachine_enabled is True rewritten_msg = rewrite_message(msg, persona_name, show_rationale, use_memory=memmachine_enabled) msgs = clean_history(st.session_state.history, persona_name) msgs = append_user_turn(msgs, rewritten_msg) try: txt, lat, tok, tps = chat( msgs, "Arnold" if persona_name == "Control" or not memmachine_enabled else persona_name ) st.session_state.history.append( {"role": "assistant", "persona": persona_name, "content": txt, "is_new": True} ) except ValueError as e: st.error(f"❌ {str(e)}") st.stop() st.rerun() # ────────────────────────────────────────────────────────────── # Memory Status Indicator # ────────────────────────────────────────────────────────────── memmachine_enabled = st.session_state.get("memmachine_enabled", True) status_emoji = "🧠" if memmachine_enabled else "⚪" status_text = "MemMachine Active" if memmachine_enabled else "No Memory Mode" # Add status indicator at the top of chat area status_html = f"""
{status_emoji} {status_text}
""" st.markdown(status_html, unsafe_allow_html=True) # ────────────────────────────────────────────────────────────── # Chat history display # ────────────────────────────────────────────────────────────── for turn in st.session_state.history: if turn.get("role") == "user": st.chat_message("user").write(turn["content"]) elif turn.get("role") == "assistant": with st.chat_message("assistant"): # Use typing effect for new messages, normal display for old ones if turn.get("is_new", False): st.write_stream(typewriter_effect(turn["content"])) # Mark as no longer new so it displays normally on rerun turn["is_new"] = False else: st.write(turn["content"]) elif turn.get("role") == "assistant_all": content_items = list(turn["content"].items()) is_new = turn.get("is_new", False) if len(content_items) >= 2: # Enhanced comparison header st.markdown("""
⚖️ Side-by-Side Comparison
""", unsafe_allow_html=True) cols = st.columns([1, 0.03, 1]) persona_label, persona_response = content_items[0] control_label, control_response = content_items[1] with cols[0]: st.markdown(f"""
🧠 {persona_label}
""", unsafe_allow_html=True) if is_new: st.write_stream(typewriter_effect(persona_response)) else: st.markdown( f'
{persona_response}
', unsafe_allow_html=True, ) with cols[1]: st.markdown( '
', unsafe_allow_html=True ) with cols[2]: st.markdown(f"""
{control_label}
""", unsafe_allow_html=True) if is_new: st.write_stream(typewriter_effect(control_response)) else: st.markdown( f'
{control_response}
', unsafe_allow_html=True, ) else: for label, response in content_items: st.markdown(f"**{label}**") if is_new: st.write_stream(typewriter_effect(response)) else: st.markdown( f'
{response}
', unsafe_allow_html=True ) # Mark as no longer new if is_new: turn["is_new"] = False