Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| from datasets import load_dataset | |
| import torch | |
| # Initialize models | |
| retriever = SentenceTransformer("all-MiniLM-L6-v2") | |
| generator = AutoModelForCausalLM.from_pretrained("distilgpt2") | |
| tokenizer = AutoTokenizer.from_pretrained("distilgpt2") | |
| # Simple vector store | |
| class VectorStore: | |
| def __init__(self): | |
| self.documents = [] | |
| self.embeddings = [] | |
| def add_document(self, document): | |
| self.documents.append(document) | |
| embedding = retriever.encode(document) | |
| self.embeddings.append(embedding) | |
| def search(self, query, k=3): | |
| query_embedding = retriever.encode(query) | |
| similarities = np.dot(self.embeddings, query_embedding) / ( | |
| np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding) | |
| ) | |
| top_k_indices = np.argsort(similarities)[-k:][::-1] | |
| return [self.documents[i] for i in top_k_indices] | |
| # Initialize vector store | |
| vector_store = VectorStore() | |
| # Load sample dataset (e.g., Wikipedia snippets) | |
| dataset = load_dataset( | |
| "wikipedia", "20220301.simple", split="train[:1000]", trust_remote_code=True | |
| ) | |
| for doc in dataset["text"]: | |
| vector_store.add_document(doc) | |
| # RAG function | |
| def rag_query(query, max_length=100): | |
| # Retrieve relevant documents | |
| retrieved_docs = vector_store.search(query) | |
| context = " ".join(retrieved_docs) | |
| # Generate response | |
| input_text = f"Context: {context}\n\nQuestion: {query}\nAnswer:" | |
| inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = generator.generate( | |
| inputs.input_ids, | |
| max_length=max_length + len(inputs.input_ids[0]), | |
| num_return_sequences=1, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response.split("Answer:")[-1].strip() | |
| # Gradio interface | |
| def gradio_interface(query): | |
| return rag_query(query) | |
| iface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=gr.Textbox(label="Enter your question"), | |
| outputs=gr.Textbox(label="Answer"), | |
| title="RAG System with Hugging Face and Gradio", | |
| description="Ask questions based on a Wikipedia-based knowledge base.", | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |