ak0601 commited on
Commit
32b6870
·
verified ·
1 Parent(s): dd9b6c2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +145 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,147 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
 
3
  import streamlit as st
4
+ from dotenv import load_dotenv
5
+ from PIL import Image
6
+ import google.generativeai as genai
7
+ from langgraph.graph import StateGraph, END
8
+ from typing import TypedDict, List, Union
9
 
10
+ # ---------------------------
11
+ # Load API Key
12
+ # ---------------------------
13
+ load_dotenv()
14
+ API_KEY = os.getenv("GOOGLE_API_KEY")
15
+ genai.configure(api_key=API_KEY)
16
+
17
+ model = genai.GenerativeModel("gemini-2.0-flash")
18
+
19
+ # ---------------------------
20
+ # State Definition
21
+ # ---------------------------
22
+ class ChatState(TypedDict):
23
+ user_input: str
24
+ image: Union[Image.Image, None]
25
+ raw_response: str
26
+ final_response: str
27
+ chat_history: List[dict]
28
+
29
+
30
+ # ---------------------------
31
+ # LangGraph Nodes
32
+ # ---------------------------
33
+ def input_node(state: ChatState) -> ChatState:
34
+ return state
35
+
36
+
37
+ def processing_node(state: ChatState) -> ChatState:
38
+ parts = [state["user_input"]]
39
+ if state["image"]:
40
+ parts.append(state["image"])
41
+
42
+ try:
43
+ chat = model.start_chat(history=[])
44
+ resp = chat.send_message(parts)
45
+ state["raw_response"] = resp.text
46
+ except Exception as e:
47
+ state["raw_response"] = f"Error: {e}"
48
+
49
+ return state
50
+
51
+
52
+ def checking_node(state: ChatState) -> ChatState:
53
+ raw = state["raw_response"]
54
+
55
+ # Remove unnecessary lines from Gemini responses
56
+ if raw.startswith("Sure!") or "The image shows" in raw:
57
+ lines = raw.split("\n")
58
+ filtered = [
59
+ line for line in lines
60
+ if not line.startswith("Sure!") and "The image shows" not in line
61
+ ]
62
+ final = "\n".join(filtered).strip()
63
+ state["final_response"] = final
64
+ else:
65
+ state["final_response"] = raw
66
+
67
+ # Save to session chat history
68
+ st.session_state.chat_history.append({"role": "user", "content": state["user_input"]})
69
+ st.session_state.chat_history.append({"role": "model", "content": state["final_response"]})
70
+
71
+ return state
72
+
73
+
74
+ # ---------------------------
75
+ # Build the LangGraph
76
+ # ---------------------------
77
+ builder = StateGraph(ChatState)
78
+ builder.add_node("input", input_node)
79
+ builder.add_node("processing", processing_node)
80
+ builder.add_node("checking", checking_node)
81
+
82
+ builder.set_entry_point("input")
83
+ builder.add_edge("input", "processing")
84
+ builder.add_edge("processing", "checking")
85
+ builder.add_edge("checking", END)
86
+
87
+ graph = builder.compile()
88
+
89
+ # ---------------------------
90
+ # Streamlit UI Setup
91
+ # ---------------------------
92
+ st.set_page_config(page_title="Math Chatbot", layout="centered")
93
+ st.title("Math Chatbot")
94
+
95
+ # Initialize session state
96
+ if "chat_history" not in st.session_state:
97
+ st.session_state.chat_history = []
98
+
99
+ # Display chat history
100
+ for msg in st.session_state.chat_history:
101
+ with st.chat_message(msg["role"]):
102
+ st.markdown(msg["content"])
103
+
104
+ # ---------------------------
105
+ # Sidebar
106
+ # ---------------------------
107
+ with st.sidebar:
108
+ st.header("Options")
109
+ if st.button("New Chat"):
110
+ st.session_state.chat_history = []
111
+ st.rerun()
112
+
113
+ # ---------------------------
114
+ # Chat Input Form
115
+ # ---------------------------
116
+ with st.form("chat_form", clear_on_submit=True):
117
+ user_input = st.text_input("Your message:", placeholder="Ask your math problem here")
118
+
119
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
120
+
121
+ submitted = st.form_submit_button("Send")
122
+
123
+ if submitted:
124
+ # Load image safely
125
+ image = None
126
+ if uploaded_file:
127
+ try:
128
+ image = Image.open(io.BytesIO(uploaded_file.read()))
129
+ except Exception as e:
130
+ st.error(f"Error loading image: {e}")
131
+ st.stop()
132
+
133
+ # Prepare state
134
+ input_state = {
135
+ "user_input": user_input,
136
+ "image": image,
137
+ "raw_response": "",
138
+ "final_response": "",
139
+ "chat_history": st.session_state.chat_history,
140
+ }
141
+
142
+ # Run LangGraph
143
+ output = graph.invoke(input_state)
144
+
145
+ # Show model response
146
+ with st.chat_message("model"):
147
+ st.markdown(output["final_response"])