Chatbot_SDK / src /streamlit_app.py
ak0601's picture
Update src/streamlit_app.py
32b6870 verified
import os
import io
import streamlit as st
from dotenv import load_dotenv
from PIL import Image
import google.generativeai as genai
from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Union
# ---------------------------
# Load API Key
# ---------------------------
load_dotenv()
API_KEY = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=API_KEY)
model = genai.GenerativeModel("gemini-2.0-flash")
# ---------------------------
# State Definition
# ---------------------------
class ChatState(TypedDict):
user_input: str
image: Union[Image.Image, None]
raw_response: str
final_response: str
chat_history: List[dict]
# ---------------------------
# LangGraph Nodes
# ---------------------------
def input_node(state: ChatState) -> ChatState:
return state
def processing_node(state: ChatState) -> ChatState:
parts = [state["user_input"]]
if state["image"]:
parts.append(state["image"])
try:
chat = model.start_chat(history=[])
resp = chat.send_message(parts)
state["raw_response"] = resp.text
except Exception as e:
state["raw_response"] = f"Error: {e}"
return state
def checking_node(state: ChatState) -> ChatState:
raw = state["raw_response"]
# Remove unnecessary lines from Gemini responses
if raw.startswith("Sure!") or "The image shows" in raw:
lines = raw.split("\n")
filtered = [
line for line in lines
if not line.startswith("Sure!") and "The image shows" not in line
]
final = "\n".join(filtered).strip()
state["final_response"] = final
else:
state["final_response"] = raw
# Save to session chat history
st.session_state.chat_history.append({"role": "user", "content": state["user_input"]})
st.session_state.chat_history.append({"role": "model", "content": state["final_response"]})
return state
# ---------------------------
# Build the LangGraph
# ---------------------------
builder = StateGraph(ChatState)
builder.add_node("input", input_node)
builder.add_node("processing", processing_node)
builder.add_node("checking", checking_node)
builder.set_entry_point("input")
builder.add_edge("input", "processing")
builder.add_edge("processing", "checking")
builder.add_edge("checking", END)
graph = builder.compile()
# ---------------------------
# Streamlit UI Setup
# ---------------------------
st.set_page_config(page_title="Math Chatbot", layout="centered")
st.title("Math Chatbot")
# Initialize session state
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Display chat history
for msg in st.session_state.chat_history:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# ---------------------------
# Sidebar
# ---------------------------
with st.sidebar:
st.header("Options")
if st.button("New Chat"):
st.session_state.chat_history = []
st.rerun()
# ---------------------------
# Chat Input Form
# ---------------------------
with st.form("chat_form", clear_on_submit=True):
user_input = st.text_input("Your message:", placeholder="Ask your math problem here")
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
submitted = st.form_submit_button("Send")
if submitted:
# Load image safely
image = None
if uploaded_file:
try:
image = Image.open(io.BytesIO(uploaded_file.read()))
except Exception as e:
st.error(f"Error loading image: {e}")
st.stop()
# Prepare state
input_state = {
"user_input": user_input,
"image": image,
"raw_response": "",
"final_response": "",
"chat_history": st.session_state.chat_history,
}
# Run LangGraph
output = graph.invoke(input_state)
# Show model response
with st.chat_message("model"):
st.markdown(output["final_response"])