timothytzkung commited on
Commit
4e657f7
·
verified ·
1 Parent(s): 1a48548

Rollback to Gemm3-4B commit

Browse files
Files changed (1) hide show
  1. app.py +100 -60
app.py CHANGED
@@ -1,113 +1,153 @@
1
  import json
2
  import numpy as np
3
  import pandas as pd
4
- from transformers import pipeline, BitsAndBytesConfig
 
5
  from sentence_transformers import SentenceTransformer
6
  import gradio as gr
7
  import torch
8
  from huggingface_hub import login
9
  import os
10
 
11
- # --- Setup & Configuration ---
12
  hf_token = os.getenv("V2_TOKEN")
13
  if hf_token is None:
14
- raise RuntimeError("V2_TOKEN environment variable is not set.")
15
 
 
16
  login(token=hf_token)
17
- PRELOAD_PARQUET = "preload.parquet"
18
 
19
- print("Loading RAG system...")
 
20
 
21
- # optimization: Ensure we aren't re-embedding every restart if possible.
22
  FILE_PATH = "data.jsonl"
23
  PRELOAD_FILE_PATH = "preload-data.json"
24
 
25
- print(f"Loading data from {PRELOAD_FILE_PATH}...")
 
26
  with open(PRELOAD_FILE_PATH, "r", encoding="utf-8") as f:
27
- documents = json.load(f)
28
-
29
- # Load Embedding Model
30
- embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
31
-
32
- # Pre-calculate embeddings once and stack them into a numpy matrix for fast math
33
- print("Generating/Loading embeddings...")
34
- doc_embeddings = embedding_model.encode(documents, convert_to_numpy=True)
35
 
36
- # Normalize embeddings now so we only need dot product later (faster than cosine calc every time)
37
- doc_embeddings = doc_embeddings / np.linalg.norm(doc_embeddings, axis=1, keepdims=True)
38
 
39
- # Create DataFrame just for text storage (we will use numpy for math)
40
- df = pd.DataFrame({"Document": documents})
 
 
 
 
 
 
 
 
 
41
 
42
- # Load llm
43
- print("Loading LLM...")
44
  llm = pipeline(
45
  "text-generation",
46
- model="google/gemma-3-1b-it",
47
- token=hf_token,
48
  )
49
 
50
- # --- Optimized Retrieval Function ---
51
- def retrieve_vectorized(query: str, top_k: int = 5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  """
53
- Uses Matrix Multiplication instead of Row-by-Row iteration.
 
54
  """
55
- # Encode query
56
  query_embedding = embedding_model.encode([query])[0]
57
-
58
- # Normalize query
59
- query_norm = query_embedding / np.linalg.norm(query_embedding)
60
- scores = np.dot(doc_embeddings, query_norm)
61
- top_indices = np.argsort(scores)[::-1][:top_k]
62
-
63
- # Retrieve documents
64
- results = df.iloc[top_indices].copy()
65
- return results["Document"].tolist()
66
 
67
- # --- Main Generation Function ---
68
- def generate_with_rag(query):
69
- # goSFU specific cleaning
 
 
 
 
 
 
 
 
 
 
 
70
  if "gosfu" in query.lower():
71
  query = query.replace("gosfu", "goSFU")
72
 
73
  # Retrieve
74
- retrieved_docs = retrieve_vectorized(query, top_k=5)
75
- context_str = "\n\n---\n\n".join(retrieved_docs)
 
 
 
 
76
 
77
- # Prompt
78
  prompt_content = f"""
79
  You are a SFU IT helpdesk chatbot.
80
- Your task is to answer SFU IT related questions.
81
-
82
- Context Articles:
83
- {context_str}
84
-
85
- User Question: {query}
86
-
87
- Instructions:
88
- 1. Answer the question using ONLY the Context Articles above.
89
- 2. Provide step-by-step instructions and include relevant links found in the text.
90
- 3. If the answer is not in the context, suggest contacting SFU IT at 778-782-8888.
91
- 4. If the user is asking about mental health, redirect to SFU Health & Counselling.
92
-
93
- Answer:"""
94
-
 
 
 
95
  response = llm(
96
  prompt_content,
97
- max_new_tokens=300, # Reduced token count for speed
98
  do_sample=False,
99
  return_full_text=False
100
  )
101
  return response[0]["generated_text"].strip()
 
102
 
103
  def chat_fn(message, history):
104
- return generate_with_rag(message)
 
 
 
 
 
105
 
106
  demo = gr.ChatInterface(
107
  fn=chat_fn,
108
- title="SFU IT Chatbot (Optimized)",
109
  description="Enter your question and the SFU IT Chatbot will try to answer using retrieved SFU IT knowledge.",
110
  )
111
 
 
112
  if __name__ == "__main__":
113
  demo.launch()
 
1
  import json
2
  import numpy as np
3
  import pandas as pd
4
+
5
+ from transformers import pipeline
6
  from sentence_transformers import SentenceTransformer
7
  import gradio as gr
8
  import torch
9
  from huggingface_hub import login
10
  import os
11
 
12
+ # Sanity Check
13
  hf_token = os.getenv("V2_TOKEN")
14
  if hf_token is None:
15
+ raise RuntimeError("V2_TOKEN environment variable is not set in this Space.")
16
 
17
+ # Explicit login
18
  login(token=hf_token)
 
19
 
20
+ # --- Configuration ---
21
+ print("Loading RAG system on your device...")
22
 
23
+ # Load Knowledge base
24
  FILE_PATH = "data.jsonl"
25
  PRELOAD_FILE_PATH = "preload-data.json"
26
 
27
+ # Load data
28
+ print(f"Found Preloaded Data! Using {PRELOAD_FILE_PATH}...")
29
  with open(PRELOAD_FILE_PATH, "r", encoding="utf-8") as f:
30
+ data = json.load(f)
 
 
 
 
 
 
 
31
 
32
+ # Set data
33
+ documents = data
34
 
35
+ # Embeddings
36
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
37
+ embeddings = embedding_model.encode(documents, convert_to_numpy=True)
38
+
39
+ # Use pandas dataframe
40
+ df = pd.DataFrame(
41
+ {
42
+ "Document": documents,
43
+ "Embedding": list(embeddings), # store as list
44
+ }
45
+ )
46
 
47
+ # Load LLM Pipeline
 
48
  llm = pipeline(
49
  "text-generation",
50
+ model="google/gemma-3-4b-it", # Might not have enough storage ngl
51
+ token=hf_token
52
  )
53
 
54
+ def clean_query_with_llm(query):
55
+ prompt_content = f"""
56
+ Below is a new question asked by the user that needs to be answered by searching in a knowledge base.
57
+ You have access to SFU IT Knowledge Base index with 100's of chunked documents.
58
+ Generate a search question based the user's question.
59
+ If you cannot generate a search query, return just the number 0.
60
+ User's Question:
61
+ {query}
62
+ Search Query:
63
+ """
64
+
65
+ response = llm(
66
+ prompt_content,
67
+ max_new_tokens=100,
68
+ do_sample=False,
69
+ return_full_text=False
70
+ )
71
+ return response[0]["generated_text"].strip()
72
+
73
+
74
+ # Retrieve w Pandas
75
+ def retrieve_with_pandas(query: str, top_k: int = 5):
76
  """
77
+ Embed the query, compute cosine similarity to each document,
78
+ and return the top_k most similar documents (as a DataFrame).
79
  """
 
80
  query_embedding = embedding_model.encode([query])[0]
 
 
 
 
 
 
 
 
 
81
 
82
+ def cosine_sim(x):
83
+ x = np.array(x)
84
+ return float(
85
+ np.dot(query_embedding, x)
86
+ / (np.linalg.norm(query_embedding) * np.linalg.norm(x))
87
+ )
88
+
89
+ df["Similarity"] = df["Embedding"].apply(cosine_sim)
90
+ results = df.sort_values(by="Similarity", ascending=False).head(top_k)
91
+ return results[["Document", "Similarity"]]
92
+
93
+
94
+ def generate_with_rag(query, top_k=5):
95
+ # goSFU specific cleaning
96
  if "gosfu" in query.lower():
97
  query = query.replace("gosfu", "goSFU")
98
 
99
  # Retrieve
100
+ search_query = clean_query_with_llm(query)
101
+ results = retrieve_with_pandas(search_query)
102
+
103
+ # Turn the Series into a single string of text
104
+ # (each doc separated by a divider)
105
+ context_str = "\n\n---\n\n".join(results["Document"].tolist())
106
 
107
+ # Build a clean prompt
108
  prompt_content = f"""
109
  You are a SFU IT helpdesk chatbot.
110
+ Your task is to answer SFU IT related questions such as accessing various technology services or general troubleshooting.
111
+ Below is new question asked by the user, and related article chunks to the user question.
112
+ If the user asked a question, answer the user's question with short step by step instructions: consider all the articles below.
113
+ If there are links in the articles, provide those links in your answer.
114
+ If the user asked a question and the answer is not in the contexts, say that you're sorry that you can't help them and suggest contacting SFU IT at 778-782-8888 or by submitting an inquiry ticket at https://www.sfu.ca/information-systems/get-help.html
115
+ If the user DID NOT ask a question, be friendly and ask how you can help them.
116
+ Do not recommend, suggest, or provide any advice on anything that is not related to SFU or SFU IT.
117
+ If the user asked something relating to mental health or is seeking medical advice, redirect them to SFU Health & Counselling at https://www.sfu.ca/students/health.html
118
+ Do not ask the user any follow-up questions after answering them.
119
+
120
+ Question:
121
+ {query}
122
+ -- Start of Articles --
123
+ {context_str}
124
+ -- End of Articles --
125
+ Answer:"""
126
+
127
+ # Call the LLM
128
  response = llm(
129
  prompt_content,
130
+ max_new_tokens=500,
131
  do_sample=False,
132
  return_full_text=False
133
  )
134
  return response[0]["generated_text"].strip()
135
+
136
 
137
  def chat_fn(message, history):
138
+ """
139
+ Chat Interface callback
140
+ """
141
+ answer = generate_with_rag(message, top_k=5)
142
+ return answer
143
+
144
 
145
  demo = gr.ChatInterface(
146
  fn=chat_fn,
147
+ title="SFU IT Chatbot",
148
  description="Enter your question and the SFU IT Chatbot will try to answer using retrieved SFU IT knowledge.",
149
  )
150
 
151
+ # share=True
152
  if __name__ == "__main__":
153
  demo.launch()