import os import io import streamlit as st from dotenv import load_dotenv from PIL import Image import google.generativeai as genai from langgraph.graph import StateGraph, END from typing import TypedDict, List, Union # --------------------------- # Load API Key # --------------------------- load_dotenv() API_KEY = os.getenv("GOOGLE_API_KEY") genai.configure(api_key=API_KEY) model = genai.GenerativeModel("gemini-2.0-flash") # --------------------------- # State Definition # --------------------------- class ChatState(TypedDict): user_input: str image: Union[Image.Image, None] raw_response: str final_response: str chat_history: List[dict] # --------------------------- # LangGraph Nodes # --------------------------- def input_node(state: ChatState) -> ChatState: return state def processing_node(state: ChatState) -> ChatState: parts = [state["user_input"]] if state["image"]: parts.append(state["image"]) try: chat = model.start_chat(history=[]) resp = chat.send_message(parts) state["raw_response"] = resp.text except Exception as e: state["raw_response"] = f"Error: {e}" return state def checking_node(state: ChatState) -> ChatState: raw = state["raw_response"] # Remove unnecessary lines from Gemini responses if raw.startswith("Sure!") or "The image shows" in raw: lines = raw.split("\n") filtered = [ line for line in lines if not line.startswith("Sure!") and "The image shows" not in line ] final = "\n".join(filtered).strip() state["final_response"] = final else: state["final_response"] = raw # Save to session chat history st.session_state.chat_history.append({"role": "user", "content": state["user_input"]}) st.session_state.chat_history.append({"role": "model", "content": state["final_response"]}) return state # --------------------------- # Build the LangGraph # --------------------------- builder = StateGraph(ChatState) builder.add_node("input", input_node) builder.add_node("processing", processing_node) builder.add_node("checking", checking_node) builder.set_entry_point("input") builder.add_edge("input", "processing") builder.add_edge("processing", "checking") builder.add_edge("checking", END) graph = builder.compile() # --------------------------- # Streamlit UI Setup # --------------------------- st.set_page_config(page_title="Math Chatbot", layout="centered") st.title("Math Chatbot") # Initialize session state if "chat_history" not in st.session_state: st.session_state.chat_history = [] # Display chat history for msg in st.session_state.chat_history: with st.chat_message(msg["role"]): st.markdown(msg["content"]) # --------------------------- # Sidebar # --------------------------- with st.sidebar: st.header("Options") if st.button("New Chat"): st.session_state.chat_history = [] st.rerun() # --------------------------- # Chat Input Form # --------------------------- with st.form("chat_form", clear_on_submit=True): user_input = st.text_input("Your message:", placeholder="Ask your math problem here") uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) submitted = st.form_submit_button("Send") if submitted: # Load image safely image = None if uploaded_file: try: image = Image.open(io.BytesIO(uploaded_file.read())) except Exception as e: st.error(f"Error loading image: {e}") st.stop() # Prepare state input_state = { "user_input": user_input, "image": image, "raw_response": "", "final_response": "", "chat_history": st.session_state.chat_history, } # Run LangGraph output = graph.invoke(input_state) # Show model response with st.chat_message("model"): st.markdown(output["final_response"])