Spaces:
Build error
Build error
| import logging | |
| import os | |
| from typing import List | |
| import shutil | |
| # from langchain_openai import OpenAIEmbeddings | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from local_loader import load_data_files | |
| from splitter import split_documents | |
| from dotenv import load_dotenv | |
| from time import sleep | |
| EMBED_DELAY = 0.02 # 20 milliseconds | |
| # This is to get the Streamlit app to use less CPU while embedding documents into Chromadb. | |
| class EmbeddingProxy: | |
| def __init__(self, embedding): | |
| self.embedding = embedding | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| sleep(EMBED_DELAY) | |
| return self.embedding.embed_documents(texts) | |
| def embed_query(self, text: str) -> List[float]: | |
| sleep(EMBED_DELAY) | |
| return self.embedding.embed_query(text) | |
| # This happens all at once, not ideal for large datasets. | |
| def create_vector_db(texts, embeddings=None, collection_name="chroma"): | |
| if not texts: | |
| logging.warning("Empty texts passed in to create vector database") | |
| # Select embeddings | |
| if not embeddings: | |
| openai_api_key = os.environ["OPENAI_API_KEY"] | |
| # embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, model="text-embedding-3-small") | |
| embeddings = HuggingFaceEmbeddings() | |
| proxy_embeddings = EmbeddingProxy(embeddings) | |
| persist_directory = os.path.join("store/", collection_name) | |
| if os.path.exists(persist_directory): | |
| shutil.rmtree(persist_directory) | |
| db = Chroma(collection_name=collection_name, | |
| embedding_function=proxy_embeddings, | |
| persist_directory=persist_directory) | |
| try: | |
| db.add_documents(texts) | |
| except Exception as e: | |
| logging.error(f"Error adding documents to Chroma: {e}") | |
| # You might want to handle the error more specifically here, | |
| # such as retrying or returning an error indicator. | |
| return db | |
| def find_similar(vs, query): | |
| docs = vs.similarity_search(query) | |
| return docs | |
| def main(): | |
| load_dotenv() | |
| docs = load_data_files(data_dir="data") # Load data from your 'data' folder | |
| texts = split_documents(docs) | |
| vs = create_vector_db(texts) | |
| # Use a relevant query from your financial domain | |
| results = find_similar(vs, query="What are the fees for an Equity Ordinary Account?") | |
| MAX_CHARS = 300 | |
| print("=== Results ===") | |
| for i, text in enumerate(results): | |
| content = text.page_content | |
| n = max(content.find(' ', MAX_CHARS), MAX_CHARS) | |
| content = text.page_content[:n] | |
| print(f"Result {i + 1}:\n {content}\n") | |
| if __name__ == "__main__": | |
| main() |