aleksandrrnt commited on
Commit
820f884
·
verified ·
1 Parent(s): 76a2909

Upload 10 files

Browse files
Files changed (5) hide show
  1. ensemble.py +13 -1
  2. full_chain.py +1 -1
  3. gradio_app.py +216 -0
  4. memory.py +0 -2
  5. requirements.txt +11 -11
ensemble.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
 
3
  from langchain_community.retrievers import BM25Retriever, TavilySearchAPIRetriever
4
  from langchain.retrievers import EnsembleRetriever
@@ -20,7 +21,7 @@ def ensemble_retriever_from_docs(docs, embeddings=None):
20
  bm25_retriever = BM25Retriever.from_texts([t.page_content for t in texts])
21
 
22
  # tavily_retriever = TavilySearchAPIRetriever(k=3, include_domains=['https://ilibrary.ru/text/107'])
23
- tavily_retriever = TavilySearchAPIRetriever(k=3, include_domains=['https://equitygroupholdings.com'])
24
 
25
  ensemble_retriever = EnsembleRetriever(
26
  retrievers=[bm25_retriever, vs_retriever, tavily_retriever],
@@ -29,6 +30,17 @@ def ensemble_retriever_from_docs(docs, embeddings=None):
29
  return ensemble_retriever
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
32
  def main():
33
  load_dotenv()
34
 
 
1
  import os
2
+ import logging
3
 
4
  from langchain_community.retrievers import BM25Retriever, TavilySearchAPIRetriever
5
  from langchain.retrievers import EnsembleRetriever
 
21
  bm25_retriever = BM25Retriever.from_texts([t.page_content for t in texts])
22
 
23
  # tavily_retriever = TavilySearchAPIRetriever(k=3, include_domains=['https://ilibrary.ru/text/107'])
24
+ tavily_retriever = MyTavilySearchAPIRetriever(k=3, include_domains=['https://equitygroupholdings.com'])
25
 
26
  ensemble_retriever = EnsembleRetriever(
27
  retrievers=[bm25_retriever, vs_retriever, tavily_retriever],
 
30
  return ensemble_retriever
31
 
32
 
33
+ class MyTavilySearchAPIRetriever(TavilySearchAPIRetriever):
34
+ def _get_relevant_documents(
35
+ self, query: str, *, run_manager
36
+ ):
37
+ try:
38
+ return super()._get_relevant_documents(query, run_manager=run_manager)
39
+ except Exception as e:
40
+ logging.error(f"TavilySearch error: {e}")
41
+ return []
42
+
43
+
44
  def main():
45
  load_dotenv()
46
 
full_chain.py CHANGED
@@ -44,7 +44,7 @@ def create_full_chain(retriever, openai_api_key=None):
44
 
45
  def ask_question(chain, query, session_id):
46
  # try:
47
- logging.info(f"Send request: {query}")
48
  response = chain.invoke(
49
  {"question": query},
50
  config={"configurable": {"session_id": session_id}}
 
44
 
45
  def ask_question(chain, query, session_id):
46
  # try:
47
+ # logging.info(f"Send request from session {session_id}: {query}")
48
  response = chain.invoke(
49
  {"question": query},
50
  config={"configurable": {"session_id": session_id}}
gradio_app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import logging
4
+ from langchain_community.chat_message_histories import StreamlitChatMessageHistory
5
+ from langchain_community.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
6
+ from langchain_community.retrievers import BM25Retriever
7
+
8
+ from ensemble import ensemble_retriever_from_docs
9
+ from full_chain import create_full_chain, ask_question
10
+ from local_loader import load_data_files, load_file
11
+ from vector_store import EmbeddingProxy
12
+ from memory import clean_session_history
13
+ from pathlib import Path
14
+
15
+ import gradio as gr
16
+ from langchain.chat_models import ChatOpenAI
17
+ from langchain.schema import AIMessage, HumanMessage
18
+
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
22
+
23
+ def show_ui(message, history, request: gr.Request):
24
+ """
25
+ Displays the Streamlit chat UI and handles user interactions.
26
+
27
+ Args:
28
+ qa: The LangChain chain for question answering.
29
+ prompt_to_user: The initial prompt to display to the user.
30
+ """
31
+ global chain
32
+ session_id = request.session_hash
33
+ response = ask_question(chain, message, session_id)
34
+ # logging.info(f"Response: {response}")
35
+ return response.content
36
+
37
+
38
+ def get_retriever(openai_api_key=None):
39
+ """
40
+ Creates and caches the document retriever.
41
+
42
+ Args:
43
+ openai_api_key: The OpenAI API key.
44
+
45
+ Returns:
46
+ An ensemble document retriever.
47
+ """
48
+ try:
49
+ docs = load_data_files(data_dir="data")
50
+ # embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, model="text-embedding-3-small")
51
+ embeddings = HuggingFaceEmbeddings()
52
+ return ensemble_retriever_from_docs(docs, embeddings=embeddings)
53
+ except Exception as e:
54
+ logging.error(f"Error creating retriever: {e}")
55
+ logging.exception(f"message")
56
+ st.error("Error initializing the application. Please check the logs.")
57
+ st.stop() # Stop execution if retriever creation fails
58
+
59
+
60
+ def get_chain(openai_api_key=None, huggingfacehub_api_token=None):
61
+ """
62
+ Creates the question answering chain.
63
+
64
+ Args:
65
+ openai_api_key: The OpenAI API key.
66
+ huggingfacehub_api_token: The Hugging Face Hub API token.
67
+
68
+ Returns:
69
+ A LangChain question answering chain.
70
+ """
71
+ try:
72
+ ensemble_retriever = get_retriever(openai_api_key=openai_api_key)
73
+ chain = create_full_chain(
74
+ ensemble_retriever,
75
+ openai_api_key=openai_api_key,
76
+ )
77
+ return ensemble_retriever, chain
78
+ except Exception as e:
79
+ logging.error(f"Error creating chain: {e}")
80
+ logging.exception(f"message")
81
+ st.error("Error initializing the application. Please check the logs.")
82
+ st.stop() # Stop execution if chain creation fails
83
+
84
+ def get_secret_or_input(secret_key, secret_name, info_link=None):
85
+ """
86
+ Retrieves a secret from Streamlit secrets or prompts the user for input.
87
+
88
+ Args:
89
+ secret_key: The key of the secret in Streamlit secrets.
90
+ secret_name: The user-friendly name of the secret.
91
+ info_link: An optional link to provide information about the secret.
92
+
93
+ Returns:
94
+ The secret value.
95
+ """
96
+ if secret_key in st.secrets:
97
+ st.write("Found %s secret" % secret_key)
98
+ secret_value = st.secrets[secret_key]
99
+ else:
100
+ st.write(f"Please provide your {secret_name}")
101
+ secret_value = st.text_input(secret_name, key=f"input_{secret_key}", type="password")
102
+ if secret_value:
103
+ st.session_state[secret_key] = secret_value
104
+ if info_link:
105
+ st.markdown(f"[Get an {secret_name}]({info_link})")
106
+ return secret_value
107
+
108
+ def process_uploaded_file(uploaded_file):
109
+ """
110
+ Processes the uploaded file and adds it to the vector database.
111
+
112
+ Args:
113
+ uploaded_file: The uploaded file object from Streamlit.
114
+ openai_api_key: The OpenAI API key for embedding generation.
115
+ """
116
+ # try:
117
+ if uploaded_file is not None:
118
+ logging.info(f'run upload {uploaded_file}')
119
+
120
+ if isinstance(uploaded_file, str):
121
+ filename = uploaded_file
122
+ else:
123
+ filename = str(uploaded_file.name)
124
+
125
+ # Load the document using the saved file path
126
+ docs = load_file(Path(filename))
127
+
128
+ global ensemble_retriever
129
+ global chain
130
+
131
+ all_docs = ensemble_retriever.retrievers[0].docs
132
+ all_docs.extend(docs)
133
+
134
+ ensemble_retriever.retrievers[1].add_documents(docs)
135
+
136
+ new_bm25 = BM25Retriever.from_texts([t.page_content for t in all_docs])
137
+
138
+ ensemble_retriever.retrievers[0] = new_bm25
139
+
140
+ chain = create_full_chain(
141
+ ensemble_retriever,
142
+ openai_api_key=open_api_key,
143
+ )
144
+
145
+ logging.info("File uploaded and added to the knowledge base!")
146
+ gr.Info('File uploaded and added to the knowledge base!', duration=3)
147
+
148
+ return None
149
+
150
+ # except Exception as e:
151
+ # logging.error(f"Error processing uploaded file: {e}")
152
+ # st.error("Error processing the file. Please check the logs.")
153
+
154
+ SUPPORTED_FORMATS = ['.txt', '.json', '.pdf']
155
+
156
+ def activate():
157
+ return gr.update(interactive=True)
158
+
159
+ def deactivate():
160
+ return gr.update(interactive=False)
161
+
162
+ def reset(z, request: gr.Request):
163
+ session_id = request.session_hash
164
+ clean_session_history(session_id)
165
+ return [], []
166
+
167
+ def main():
168
+ with gr.Blocks() as demo:
169
+ gr.Markdown(
170
+ "# Equity Bank AI assistant \n"
171
+ "Ask questions about Equity Bank's products and services:"
172
+ )
173
+ with gr.Tab('Chat'):
174
+ clean_btn = gr.Button(value="Clean history", variant="secondary", size='sm', render=False)
175
+ bot = gr.Chatbot(elem_id="chatbot", render=False)
176
+
177
+ chat = gr.ChatInterface(
178
+ show_ui,
179
+ chatbot=bot,
180
+ undo_btn=None,
181
+ retry_btn=None,
182
+ clear_btn=clean_btn,
183
+ )
184
+ with gr.Tab('Documents'):
185
+ file_input = gr.File(
186
+ label=f'{", ".join([str(f) for f in SUPPORTED_FORMATS])}',
187
+ file_types=SUPPORTED_FORMATS,
188
+ )
189
+ submit_btn = gr.Button(value="Index file", variant="primary", interactive=False)
190
+
191
+ clean_btn.click(fn=reset, inputs=clean_btn, outputs=[bot, chat.chatbot_state])
192
+
193
+ submit_btn.click(
194
+ fn=process_uploaded_file,
195
+ inputs=file_input,
196
+ outputs=file_input,
197
+ api_name="Index file"
198
+ )
199
+
200
+ file_input.upload(fn=activate, outputs=[submit_btn])
201
+ file_input.clear(fn=deactivate, outputs=[submit_btn])
202
+
203
+ demo.launch(share=True)
204
+
205
+
206
+ open_api_key = os.getenv('OPEN_API_KEY')
207
+
208
+ ensemble_retriever, chain = get_chain(
209
+ openai_api_key=open_api_key,
210
+ huggingfacehub_api_token=None
211
+ )
212
+
213
+
214
+
215
+ if __name__ == "__main__":
216
+ main()
memory.py CHANGED
@@ -39,8 +39,6 @@ def create_memory_chain(llm, base_chain):
39
  if session_id not in store:
40
  store[session_id] = ChatMessageHistory()
41
 
42
- logging.info(str(store))
43
-
44
  return store[session_id]
45
 
46
  with_message_history = RunnableWithMessageHistory(
 
39
  if session_id not in store:
40
  store[session_id] = ChatMessageHistory()
41
 
 
 
42
  return store[session_id]
43
 
44
  with_message_history = RunnableWithMessageHistory(
requirements.txt CHANGED
@@ -1,11 +1,11 @@
1
- chromadb
2
- huggingface-hub
3
- langchain
4
- langchain-community
5
- langchain-openai
6
- sentence-transformers
7
- streamlit
8
- gradio
9
- pypdf
10
- rank_bm25
11
- tavily-python
 
1
+ chromadb==0.5.5
2
+ huggingface-hub==0.24.6
3
+ langchain==0.2.14
4
+ langchain-community==0.2.12
5
+ langchain-openai==0.1.22
6
+ sentence-transformers==3.0.1
7
+ streamlit==1.37.1
8
+ gradio==4.41.0
9
+ pypdf==4.3.1
10
+ rank_bm25==0.2.2
11
+ tavily-python==0.4.0