Spaces:
Running
Running
Your Name
Claude
commited on
Commit
·
d78f02a
1
Parent(s):
d5f8324
Clone api2 for experimentation
Browse filesExperiment space for testing changes before production:
- FastAPI REST endpoint
- LLM query parser (Llama 70B)
- Targeted indexing (drugs, diseases, companies, endpoints)
- Groq API support (10x faster)
- Hybrid RAG search
- Docker deployment
Clone of api2 with all latest optimizations.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
- Dockerfile +22 -0
- README.md +31 -4
- app.py +123 -0
- foundation_engine.py +1343 -0
- requirements.txt +19 -0
- two_llm_system_FIXED.py +461 -0
Dockerfile
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
git \
|
| 8 |
+
git-lfs \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Copy requirements and install Python dependencies
|
| 12 |
+
COPY requirements.txt .
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
+
|
| 15 |
+
# Copy application files
|
| 16 |
+
COPY app.py foundation_engine.py two_llm_system_FIXED.py ./
|
| 17 |
+
|
| 18 |
+
# Expose port
|
| 19 |
+
EXPOSE 7860
|
| 20 |
+
|
| 21 |
+
# Run FastAPI app
|
| 22 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,11 +1,38 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Clinical Trial API (Experiment)
|
| 3 |
+
emoji: 🧪
|
| 4 |
+
colorFrom: purple
|
| 5 |
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# Clinical Trial API - Experiment Space
|
| 12 |
+
|
| 13 |
+
**Clone of api2 for testing changes before production**
|
| 14 |
+
|
| 15 |
+
## API Endpoint
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
POST /query
|
| 19 |
+
{
|
| 20 |
+
"query": "What Novartis drugs treat melanoma?"
|
| 21 |
+
}
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
## Features
|
| 25 |
+
|
| 26 |
+
- LLM query parser (extracts drugs, diseases, companies, endpoints)
|
| 27 |
+
- Targeted inverted index (drugs, diseases, companies, endpoints)
|
| 28 |
+
- Hybrid RAG search (keyword + semantic)
|
| 29 |
+
- Groq API support (10x faster than HuggingFace)
|
| 30 |
+
- ~7 second response time
|
| 31 |
+
|
| 32 |
+
## Test It
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
curl -X POST https://gmkdigitalmedia-ctapi.hf.space/query \
|
| 36 |
+
-H "Content-Type: application/json" \
|
| 37 |
+
-d '{"query": "Tell me about Keytruda for lung cancer"}'
|
| 38 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI REST API for Foundation 1.2 Clinical Trial System
|
| 3 |
+
Production-ready Docker space with proper REST endpoints
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, HTTPException
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
# Import the foundation engine
|
| 13 |
+
import foundation_engine
|
| 14 |
+
|
| 15 |
+
logging.basicConfig(level=logging.INFO)
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
app = FastAPI(
|
| 19 |
+
title="Clinical Trial API",
|
| 20 |
+
description="Production REST API for clinical trial analysis powered by Foundation 1.2 pipeline",
|
| 21 |
+
version="1.0.0",
|
| 22 |
+
docs_url="/docs",
|
| 23 |
+
redoc_url="/redoc"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Add CORS middleware
|
| 27 |
+
app.add_middleware(
|
| 28 |
+
CORSMiddleware,
|
| 29 |
+
allow_origins=["*"],
|
| 30 |
+
allow_credentials=True,
|
| 31 |
+
allow_methods=["*"],
|
| 32 |
+
allow_headers=["*"],
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Request/Response models
|
| 36 |
+
class QueryRequest(BaseModel):
|
| 37 |
+
query: str
|
| 38 |
+
|
| 39 |
+
class QueryResponse(BaseModel):
|
| 40 |
+
summary: str
|
| 41 |
+
processing_time: float
|
| 42 |
+
|
| 43 |
+
class HealthResponse(BaseModel):
|
| 44 |
+
status: str
|
| 45 |
+
trials_loaded: int
|
| 46 |
+
embeddings_loaded: bool
|
| 47 |
+
|
| 48 |
+
@app.on_event("startup")
|
| 49 |
+
async def startup_event():
|
| 50 |
+
"""Initialize the foundation engine on startup"""
|
| 51 |
+
logger.info("=== API Startup ===")
|
| 52 |
+
logger.info("Loading Foundation 1.2 engine...")
|
| 53 |
+
# The foundation_engine will load embeddings when first accessed
|
| 54 |
+
foundation_engine.load_embeddings()
|
| 55 |
+
logger.info("=== API Ready ===")
|
| 56 |
+
|
| 57 |
+
@app.get("/")
|
| 58 |
+
async def root():
|
| 59 |
+
"""API information"""
|
| 60 |
+
return {
|
| 61 |
+
"service": "Clinical Trial API",
|
| 62 |
+
"version": "1.0.0",
|
| 63 |
+
"description": "Production REST API for Foundation 1.2",
|
| 64 |
+
"status": "healthy",
|
| 65 |
+
"endpoints": {
|
| 66 |
+
"POST /query": "Query clinical trials and get AI-generated summary",
|
| 67 |
+
"GET /health": "Health check",
|
| 68 |
+
"GET /docs": "Interactive API documentation (Swagger UI)",
|
| 69 |
+
"GET /redoc": "Alternative API documentation (ReDoc)"
|
| 70 |
+
},
|
| 71 |
+
"features": [
|
| 72 |
+
"Drug Scoring",
|
| 73 |
+
"355M foundation model"
|
| 74 |
+
]
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
@app.get("/health", response_model=HealthResponse)
|
| 78 |
+
async def health_check():
|
| 79 |
+
"""Health check endpoint"""
|
| 80 |
+
embeddings_loaded = foundation_engine.doc_embeddings is not None
|
| 81 |
+
chunks_loaded = len(foundation_engine.doc_chunks) if foundation_engine.doc_chunks else 0
|
| 82 |
+
|
| 83 |
+
return HealthResponse(
|
| 84 |
+
status="healthy",
|
| 85 |
+
trials_loaded=chunks_loaded,
|
| 86 |
+
embeddings_loaded=embeddings_loaded
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
@app.post("/query", response_model=QueryResponse)
|
| 90 |
+
async def query_trials(request: QueryRequest):
|
| 91 |
+
"""
|
| 92 |
+
Query clinical trials and get AI-generated summary
|
| 93 |
+
|
| 94 |
+
- **query**: Your question about clinical trials (e.g., "What trials exist for Dekavil?")
|
| 95 |
+
|
| 96 |
+
Returns a structured medical analysis with:
|
| 97 |
+
- Drug/Intervention background
|
| 98 |
+
- Clinical trial results and data
|
| 99 |
+
- Treatment considerations
|
| 100 |
+
- NCT trial IDs and references
|
| 101 |
+
"""
|
| 102 |
+
try:
|
| 103 |
+
logger.info(f"API Query received: {request.query[:100]}...")
|
| 104 |
+
start_time = time.time()
|
| 105 |
+
|
| 106 |
+
# Call the foundation engine
|
| 107 |
+
result = foundation_engine.process_query(request.query)
|
| 108 |
+
|
| 109 |
+
processing_time = time.time() - start_time
|
| 110 |
+
logger.info(f"Query completed in {processing_time:.2f}s")
|
| 111 |
+
|
| 112 |
+
return QueryResponse(
|
| 113 |
+
summary=result,
|
| 114 |
+
processing_time=processing_time
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"Error processing query: {str(e)}")
|
| 119 |
+
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
import uvicorn
|
| 123 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
foundation_engine.py
ADDED
|
@@ -0,0 +1,1343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Foundation 1.2
|
| 3 |
+
Clinical trial query system with 355M foundation model
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import pickle
|
| 10 |
+
import numpy as np
|
| 11 |
+
from sentence_transformers import SentenceTransformer
|
| 12 |
+
import logging
|
| 13 |
+
from rank_bm25 import BM25Okapi
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
from two_llm_system_FIXED import expand_query_with_355m, generate_clinical_response_with_xupract, rank_trials_with_355m
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# Initialize
|
| 22 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 23 |
+
|
| 24 |
+
# Paths for data storage
|
| 25 |
+
# Files will be downloaded from HF Dataset on first run
|
| 26 |
+
DATASET_FILE = Path(__file__).parent / "complete_dataset_WITH_RESULTS_FULL.txt"
|
| 27 |
+
CHUNKS_FILE = Path(__file__).parent / "dataset_chunks_TRIAL_AWARE.pkl"
|
| 28 |
+
EMBEDDINGS_FILE = Path(__file__).parent / "dataset_embeddings_TRIAL_AWARE_FIXED.npy" # FIXED version to avoid cache
|
| 29 |
+
INVERTED_INDEX_FILE = Path(__file__).parent / "inverted_index_TRIAL_AWARE.pkl" # Pre-built inverted index (638MB)
|
| 30 |
+
|
| 31 |
+
# HF Dataset containing the large files
|
| 32 |
+
DATASET_REPO = "gmkdigitalmedia/foundation1.2-data"
|
| 33 |
+
|
| 34 |
+
# Global storage
|
| 35 |
+
embedder = None
|
| 36 |
+
doc_chunks = []
|
| 37 |
+
doc_embeddings = None
|
| 38 |
+
bm25_index = None # BM25 index for fast keyword search
|
| 39 |
+
inverted_index = None # Inverted index for instant drug lookup
|
| 40 |
+
|
| 41 |
+
# ============================================================================
|
| 42 |
+
# RAG FUNCTIONS
|
| 43 |
+
# ============================================================================
|
| 44 |
+
|
| 45 |
+
def load_embedder():
|
| 46 |
+
"""Load L6 embedding model (matches generated embeddings)"""
|
| 47 |
+
global embedder
|
| 48 |
+
if embedder is None:
|
| 49 |
+
logger.info("Loading MiniLM-L6 embedding model...")
|
| 50 |
+
# Force CPU to avoid CUDA init in main process
|
| 51 |
+
embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
|
| 52 |
+
logger.info("L6 model loaded on CPU")
|
| 53 |
+
|
| 54 |
+
def build_inverted_index(chunks):
|
| 55 |
+
"""
|
| 56 |
+
Build targeted inverted index for clinical search
|
| 57 |
+
Maps drugs, diseases, companies, and endpoints to trial indices for O(1) lookup
|
| 58 |
+
|
| 59 |
+
Indexes ONLY what matters:
|
| 60 |
+
1. INTERVENTION - drug/device names
|
| 61 |
+
2. CONDITIONS - diseases being treated
|
| 62 |
+
3. SPONSOR/COLLABORATOR/MANUFACTURER - company names
|
| 63 |
+
4. OUTCOME - trial endpoints (what's being measured)
|
| 64 |
+
|
| 65 |
+
Does NOT index trial names (unnecessary noise)
|
| 66 |
+
"""
|
| 67 |
+
import time
|
| 68 |
+
t_start = time.time()
|
| 69 |
+
inv_index = {}
|
| 70 |
+
|
| 71 |
+
logger.info("Building targeted index: drugs, diseases, companies, endpoints...")
|
| 72 |
+
|
| 73 |
+
# Generic words to skip
|
| 74 |
+
skip_words = {
|
| 75 |
+
'with', 'versus', 'combination', 'treatment', 'therapy', 'study', 'trial',
|
| 76 |
+
'phase', 'double', 'blind', 'placebo', 'group', 'control', 'active',
|
| 77 |
+
'randomized', 'multicenter', 'open', 'label', 'crossover'
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
for idx, chunk_data in enumerate(chunks):
|
| 81 |
+
if idx % 100000 == 0 and idx > 0:
|
| 82 |
+
logger.info(f" Indexed {idx:,}/{len(chunks):,} trials...")
|
| 83 |
+
|
| 84 |
+
text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data
|
| 85 |
+
text_lower = text.lower()
|
| 86 |
+
|
| 87 |
+
# 1. DRUGS from INTERVENTION field
|
| 88 |
+
intervention_match = re.search(r'intervention[:\s]+([^\n]+)', text_lower)
|
| 89 |
+
if intervention_match:
|
| 90 |
+
intervention_text = intervention_match.group(1)
|
| 91 |
+
drugs = re.split(r'[,;\-\s]+', intervention_text)
|
| 92 |
+
for drug in drugs:
|
| 93 |
+
drug = drug.strip('.,;:() ')
|
| 94 |
+
if len(drug) > 3 and drug not in skip_words:
|
| 95 |
+
if drug not in inv_index:
|
| 96 |
+
inv_index[drug] = []
|
| 97 |
+
if idx not in inv_index[drug]:
|
| 98 |
+
inv_index[drug].append(idx)
|
| 99 |
+
|
| 100 |
+
# 2. DISEASES from CONDITIONS field
|
| 101 |
+
conditions_match = re.search(r'conditions?[:\s]+([^\n]+)', text_lower)
|
| 102 |
+
if conditions_match:
|
| 103 |
+
conditions_text = conditions_match.group(1)
|
| 104 |
+
diseases = re.split(r'[,;\|]+', conditions_text)
|
| 105 |
+
for disease in diseases:
|
| 106 |
+
disease = disease.strip('.,;:() ')
|
| 107 |
+
# Split multi-word conditions and index each significant word
|
| 108 |
+
disease_words = re.findall(r'\b\w{4,}\b', disease)
|
| 109 |
+
for word in disease_words:
|
| 110 |
+
if word not in skip_words:
|
| 111 |
+
if word not in inv_index:
|
| 112 |
+
inv_index[word] = []
|
| 113 |
+
if idx not in inv_index[word]:
|
| 114 |
+
inv_index[word].append(idx)
|
| 115 |
+
|
| 116 |
+
# 3. COMPANIES from SPONSOR field
|
| 117 |
+
sponsor_match = re.search(r'sponsor[:\s]+([^\n]+)', text_lower)
|
| 118 |
+
if sponsor_match:
|
| 119 |
+
sponsor_text = sponsor_match.group(1)
|
| 120 |
+
sponsors = re.split(r'[,;\|]+', sponsor_text)
|
| 121 |
+
for sponsor in sponsors:
|
| 122 |
+
sponsor = sponsor.strip('.,;:() ')
|
| 123 |
+
if len(sponsor) > 3:
|
| 124 |
+
if sponsor not in inv_index:
|
| 125 |
+
inv_index[sponsor] = []
|
| 126 |
+
if idx not in inv_index[sponsor]:
|
| 127 |
+
inv_index[sponsor].append(idx)
|
| 128 |
+
|
| 129 |
+
# 4. COMPANIES from COLLABORATOR field
|
| 130 |
+
collab_match = re.search(r'collaborator[:\s]+([^\n]+)', text_lower)
|
| 131 |
+
if collab_match:
|
| 132 |
+
collab_text = collab_match.group(1)
|
| 133 |
+
collaborators = re.split(r'[,;\|]+', collab_text)
|
| 134 |
+
for collab in collaborators:
|
| 135 |
+
collab = collab.strip('.,;:() ')
|
| 136 |
+
if len(collab) > 3:
|
| 137 |
+
if collab not in inv_index:
|
| 138 |
+
inv_index[collab] = []
|
| 139 |
+
if idx not in inv_index[collab]:
|
| 140 |
+
inv_index[collab].append(idx)
|
| 141 |
+
|
| 142 |
+
# 5. COMPANIES from MANUFACTURER field
|
| 143 |
+
manuf_match = re.search(r'manufacturer[:\s]+([^\n]+)', text_lower)
|
| 144 |
+
if manuf_match:
|
| 145 |
+
manuf_text = manuf_match.group(1)
|
| 146 |
+
manufacturers = re.split(r'[,;\|]+', manuf_text)
|
| 147 |
+
for manuf in manufacturers:
|
| 148 |
+
manuf = manuf.strip('.,;:() ')
|
| 149 |
+
if len(manuf) > 3:
|
| 150 |
+
if manuf not in inv_index:
|
| 151 |
+
inv_index[manuf] = []
|
| 152 |
+
if idx not in inv_index[manuf]:
|
| 153 |
+
inv_index[manuf].append(idx)
|
| 154 |
+
|
| 155 |
+
# 6. ENDPOINTS from OUTCOME fields
|
| 156 |
+
# Look for outcome measures (what's being measured)
|
| 157 |
+
outcome_matches = re.findall(r'outcome[:\s]+([^\n]+)', text_lower)
|
| 158 |
+
for outcome_match in outcome_matches[:5]: # First 5 outcomes only
|
| 159 |
+
# Extract meaningful endpoint terms
|
| 160 |
+
endpoint_words = re.findall(r'\b\w{5,}\b', outcome_match) # 5+ char words
|
| 161 |
+
for word in endpoint_words[:3]: # First 3 words per outcome
|
| 162 |
+
if word not in skip_words and word not in {'outcome', 'measure', 'primary', 'secondary'}:
|
| 163 |
+
if word not in inv_index:
|
| 164 |
+
inv_index[word] = []
|
| 165 |
+
if idx not in inv_index[word]:
|
| 166 |
+
inv_index[word].append(idx)
|
| 167 |
+
|
| 168 |
+
t_elapsed = time.time() - t_start
|
| 169 |
+
logger.info(f"✓ Targeted index built in {t_elapsed:.1f}s with {len(inv_index):,} terms")
|
| 170 |
+
|
| 171 |
+
# Log sample entries for debugging (drugs, diseases, companies, endpoints)
|
| 172 |
+
sample_terms = {
|
| 173 |
+
'drugs': ['keytruda', 'opdivo', 'humira'],
|
| 174 |
+
'diseases': ['cancer', 'diabetes', 'melanoma'],
|
| 175 |
+
'companies': ['novartis', 'pfizer', 'merck'],
|
| 176 |
+
'endpoints': ['survival', 'response', 'remission']
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
for category, terms in sample_terms.items():
|
| 180 |
+
logger.info(f" {category.upper()} samples:")
|
| 181 |
+
for term in terms:
|
| 182 |
+
if term in inv_index:
|
| 183 |
+
logger.info(f" '{term}' -> {len(inv_index[term])} trials")
|
| 184 |
+
|
| 185 |
+
return inv_index
|
| 186 |
+
|
| 187 |
+
def download_from_dataset(filename):
|
| 188 |
+
"""Download file from HF Dataset if not present locally"""
|
| 189 |
+
from huggingface_hub import hf_hub_download
|
| 190 |
+
import tempfile
|
| 191 |
+
|
| 192 |
+
# Use /tmp for downloads (has write permissions in Docker)
|
| 193 |
+
download_dir = Path("/tmp/foundation_data")
|
| 194 |
+
download_dir.mkdir(exist_ok=True)
|
| 195 |
+
|
| 196 |
+
local_file = download_dir / filename
|
| 197 |
+
|
| 198 |
+
if local_file.exists():
|
| 199 |
+
logger.info(f"Found cached {filename}")
|
| 200 |
+
return local_file
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
logger.info(f"Downloading {filename} from {DATASET_REPO}...")
|
| 204 |
+
downloaded_file = hf_hub_download(
|
| 205 |
+
repo_id=DATASET_REPO,
|
| 206 |
+
filename=filename,
|
| 207 |
+
repo_type="dataset",
|
| 208 |
+
local_dir=download_dir,
|
| 209 |
+
local_dir_use_symlinks=False
|
| 210 |
+
)
|
| 211 |
+
logger.info(f"Downloaded {filename}")
|
| 212 |
+
return Path(downloaded_file)
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.error(f"Failed to download {filename}: {e}")
|
| 215 |
+
return None
|
| 216 |
+
|
| 217 |
+
def load_embeddings():
|
| 218 |
+
"""Load pre-generated embeddings (download from dataset if needed)"""
|
| 219 |
+
global doc_chunks, doc_embeddings, bm25_index
|
| 220 |
+
|
| 221 |
+
# Try to download if not present - store paths returned by download
|
| 222 |
+
chunks_path = CHUNKS_FILE
|
| 223 |
+
embeddings_path = EMBEDDINGS_FILE
|
| 224 |
+
dataset_path = DATASET_FILE
|
| 225 |
+
|
| 226 |
+
if not CHUNKS_FILE.exists():
|
| 227 |
+
downloaded = download_from_dataset("dataset_chunks_TRIAL_AWARE.pkl")
|
| 228 |
+
if downloaded:
|
| 229 |
+
chunks_path = downloaded
|
| 230 |
+
if not EMBEDDINGS_FILE.exists():
|
| 231 |
+
downloaded = download_from_dataset("dataset_embeddings_TRIAL_AWARE_FIXED.npy") # FIXED version
|
| 232 |
+
if downloaded:
|
| 233 |
+
embeddings_path = downloaded
|
| 234 |
+
if not DATASET_FILE.exists():
|
| 235 |
+
downloaded = download_from_dataset("complete_dataset_WITH_RESULTS_FULL.txt")
|
| 236 |
+
if downloaded:
|
| 237 |
+
dataset_path = downloaded
|
| 238 |
+
|
| 239 |
+
if chunks_path.exists() and embeddings_path.exists():
|
| 240 |
+
try:
|
| 241 |
+
logger.info("Loading embeddings from disk...")
|
| 242 |
+
with open(chunks_path, 'rb') as f:
|
| 243 |
+
doc_chunks = pickle.load(f)
|
| 244 |
+
|
| 245 |
+
# Load embeddings
|
| 246 |
+
loaded_embeddings = np.load(embeddings_path, allow_pickle=True)
|
| 247 |
+
|
| 248 |
+
logger.info(f"Loaded embeddings type: {type(loaded_embeddings)}")
|
| 249 |
+
|
| 250 |
+
# Check if it's already a proper numpy array
|
| 251 |
+
if isinstance(loaded_embeddings, np.ndarray) and loaded_embeddings.ndim == 2:
|
| 252 |
+
doc_embeddings = loaded_embeddings
|
| 253 |
+
logger.info(f"✓ Embeddings are proper numpy array with shape: {doc_embeddings.shape}")
|
| 254 |
+
elif isinstance(loaded_embeddings, list):
|
| 255 |
+
logger.info(f"Converting embeddings from list to numpy array (memory efficient)...")
|
| 256 |
+
# Convert in chunks to avoid memory spike
|
| 257 |
+
chunk_size = 10000
|
| 258 |
+
total = len(loaded_embeddings)
|
| 259 |
+
|
| 260 |
+
# DEBUG: Print first 3 items to see format
|
| 261 |
+
logger.info(f"DEBUG: Total embeddings: {total}")
|
| 262 |
+
logger.info(f"DEBUG: Type of first item: {type(loaded_embeddings[0])}")
|
| 263 |
+
|
| 264 |
+
# Check if this is actually the chunks file (wrong file uploaded)
|
| 265 |
+
if isinstance(loaded_embeddings[0], tuple) and len(loaded_embeddings[0]) == 2:
|
| 266 |
+
if isinstance(loaded_embeddings[0][0], int) and isinstance(loaded_embeddings[0][1], str):
|
| 267 |
+
raise ValueError(
|
| 268 |
+
f"ERROR: The embeddings file contains (int, string) tuples!\n"
|
| 269 |
+
f"This looks like the CHUNKS file was uploaded as the embeddings file.\n\n"
|
| 270 |
+
f"First item: {loaded_embeddings[0][:2]}\n\n"
|
| 271 |
+
f"Please re-upload the correct file:\n"
|
| 272 |
+
f" CORRECT: dataset_embeddings_TRIAL_AWARE.npy (numpy array, 855 MB)\n"
|
| 273 |
+
f" WRONG: dataset_chunks_TRIAL_AWARE.pkl (tuples, 2.8 GB)\n\n"
|
| 274 |
+
f"The local file at /mnt/c/Users/ibm/Documents/HF/kg_to_model/dataset_embeddings_TRIAL_AWARE.npy is correct."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if isinstance(loaded_embeddings[0], tuple):
|
| 278 |
+
logger.info(f"DEBUG: Tuple length: {len(loaded_embeddings[0])}")
|
| 279 |
+
for i, item in enumerate(loaded_embeddings[0][:5] if len(loaded_embeddings[0]) > 5 else loaded_embeddings[0]):
|
| 280 |
+
logger.info(f"DEBUG: Tuple element {i}: type={type(item)}, preview={str(item)[:100]}")
|
| 281 |
+
|
| 282 |
+
# Get embedding dimension from first item
|
| 283 |
+
first_emb = loaded_embeddings[0]
|
| 284 |
+
emb_idx = None # Initialize
|
| 285 |
+
|
| 286 |
+
# Handle different formats
|
| 287 |
+
if isinstance(first_emb, tuple):
|
| 288 |
+
# Try both positions - could be (id, emb) or (emb, id)
|
| 289 |
+
logger.info(f"DEBUG: Trying to find embedding vector in tuple...")
|
| 290 |
+
emb_vector = None
|
| 291 |
+
for idx, elem in enumerate(first_emb):
|
| 292 |
+
if isinstance(elem, (list, np.ndarray)):
|
| 293 |
+
emb_vector = elem
|
| 294 |
+
emb_idx = idx
|
| 295 |
+
logger.info(f"DEBUG: Found embedding at position {idx}")
|
| 296 |
+
break
|
| 297 |
+
|
| 298 |
+
if emb_vector is None:
|
| 299 |
+
raise ValueError(f"No embedding vector found in tuple. Tuple contains: {[type(x) for x in first_emb]}")
|
| 300 |
+
|
| 301 |
+
emb_dim = len(emb_vector)
|
| 302 |
+
logger.info(f"DEBUG: Embedding dimension: {emb_dim}")
|
| 303 |
+
elif isinstance(first_emb, list):
|
| 304 |
+
emb_dim = len(first_emb)
|
| 305 |
+
emb_idx = None
|
| 306 |
+
elif isinstance(first_emb, np.ndarray):
|
| 307 |
+
emb_dim = first_emb.shape[0]
|
| 308 |
+
emb_idx = None
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError(f"Unknown embedding format: {type(first_emb)}")
|
| 311 |
+
|
| 312 |
+
logger.info(f"Creating array for {total} embeddings of dimension {emb_dim}")
|
| 313 |
+
|
| 314 |
+
# Pre-allocate array
|
| 315 |
+
doc_embeddings = np.zeros((total, emb_dim), dtype=np.float32)
|
| 316 |
+
|
| 317 |
+
# Fill in chunks
|
| 318 |
+
for i in range(0, total, chunk_size):
|
| 319 |
+
end = min(i + chunk_size, total)
|
| 320 |
+
|
| 321 |
+
# Extract embeddings from tuples if needed
|
| 322 |
+
if isinstance(first_emb, tuple) and emb_idx is not None:
|
| 323 |
+
# Extract just the embedding vector from each tuple at the correct position
|
| 324 |
+
batch = [item[emb_idx] for item in loaded_embeddings[i:end]]
|
| 325 |
+
doc_embeddings[i:end] = batch
|
| 326 |
+
else:
|
| 327 |
+
doc_embeddings[i:end] = loaded_embeddings[i:end]
|
| 328 |
+
|
| 329 |
+
if i % 50000 == 0:
|
| 330 |
+
logger.info(f"Converted {i}/{total} embeddings...")
|
| 331 |
+
|
| 332 |
+
logger.info(f"✓ Converted to array with shape: {doc_embeddings.shape}")
|
| 333 |
+
else:
|
| 334 |
+
doc_embeddings = loaded_embeddings
|
| 335 |
+
logger.info(f"Embeddings already numpy array with shape: {doc_embeddings.shape}")
|
| 336 |
+
|
| 337 |
+
logger.info(f"Loaded {len(doc_chunks)} chunks with embeddings")
|
| 338 |
+
|
| 339 |
+
# Skip BM25 (too memory-heavy for Docker), use inverted index only
|
| 340 |
+
global inverted_index
|
| 341 |
+
|
| 342 |
+
# Try to load pre-built inverted index (638MB) - MUCH faster than building (15 minutes)
|
| 343 |
+
if INVERTED_INDEX_FILE.exists():
|
| 344 |
+
logger.info(f"Loading pre-built inverted index from {INVERTED_INDEX_FILE.name}...")
|
| 345 |
+
try:
|
| 346 |
+
with open(INVERTED_INDEX_FILE, 'rb') as f:
|
| 347 |
+
inverted_index = pickle.load(f)
|
| 348 |
+
logger.info(f"✓ Loaded pre-built inverted index with {len(inverted_index):,} terms (instant vs 15min build)")
|
| 349 |
+
except Exception as e:
|
| 350 |
+
logger.warning(f"Failed to load pre-built index: {e}, building from scratch...")
|
| 351 |
+
inverted_index = build_inverted_index(doc_chunks)
|
| 352 |
+
else:
|
| 353 |
+
logger.info("Pre-built inverted index not found, building from scratch (this takes 15 minutes)...")
|
| 354 |
+
inverted_index = build_inverted_index(doc_chunks)
|
| 355 |
+
|
| 356 |
+
logger.info("Will use inverted index + semantic search (no BM25)")
|
| 357 |
+
|
| 358 |
+
return True
|
| 359 |
+
except Exception as e:
|
| 360 |
+
logger.error(f"Failed to load embeddings: {e}")
|
| 361 |
+
raise RuntimeError("Embeddings are required but failed to load") from e
|
| 362 |
+
|
| 363 |
+
raise RuntimeError("Embeddings files not found - system cannot function without embeddings")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def filter_trial_for_clinical_summary(trial_text):
|
| 367 |
+
"""
|
| 368 |
+
Filter trial data to keep essential clinical information including SOME results.
|
| 369 |
+
|
| 370 |
+
COMPREHENSIVE FILTERING:
|
| 371 |
+
- Keeps all core trial info (title, summary, conditions, interventions)
|
| 372 |
+
- Keeps sponsor/collaborator/manufacturer (WHO is running the trial)
|
| 373 |
+
- Keeps first 5 outcomes (to show key endpoints)
|
| 374 |
+
- Keeps first 5 result values per trial (to show actual data)
|
| 375 |
+
- Filters out overwhelming statistical noise (hundreds of baseline/adverse event lines)
|
| 376 |
+
|
| 377 |
+
This ensures the LLM sees comprehensive context including company information.
|
| 378 |
+
"""
|
| 379 |
+
if not trial_text:
|
| 380 |
+
return trial_text
|
| 381 |
+
|
| 382 |
+
lines = trial_text.split('\n')
|
| 383 |
+
filtered_lines = []
|
| 384 |
+
|
| 385 |
+
# Counters to limit repetitive data
|
| 386 |
+
outcome_count = 0
|
| 387 |
+
outcome_desc_count = 0
|
| 388 |
+
result_value_count = 0
|
| 389 |
+
|
| 390 |
+
# Limits
|
| 391 |
+
MAX_OUTCOMES = 5
|
| 392 |
+
MAX_OUTCOME_DESC = 5
|
| 393 |
+
MAX_RESULT_VALUES = 5
|
| 394 |
+
|
| 395 |
+
for line in lines:
|
| 396 |
+
line_stripped = line.strip()
|
| 397 |
+
|
| 398 |
+
# Skip empty lines
|
| 399 |
+
if not line_stripped:
|
| 400 |
+
continue
|
| 401 |
+
|
| 402 |
+
# ALWAYS SKIP: Overwhelming noise
|
| 403 |
+
always_skip = [
|
| 404 |
+
'BASELINE:', 'SERIOUS_ADVERSE_EVENT:', 'OTHER_ADVERSE_EVENT:',
|
| 405 |
+
'OUTCOME_TYPE:', 'OUTCOME_TIME_FRAME:', 'OUTCOME_SAFETY:',
|
| 406 |
+
'OUTCOME_OTHER:', 'OUTCOME_NUMBER:'
|
| 407 |
+
]
|
| 408 |
+
|
| 409 |
+
should_skip = False
|
| 410 |
+
for marker in always_skip:
|
| 411 |
+
if line_stripped.startswith(marker):
|
| 412 |
+
should_skip = True
|
| 413 |
+
break
|
| 414 |
+
|
| 415 |
+
if should_skip:
|
| 416 |
+
continue
|
| 417 |
+
|
| 418 |
+
# LIMITED KEEP: Outcomes (first N only)
|
| 419 |
+
if line_stripped.startswith('OUTCOME:'):
|
| 420 |
+
outcome_count += 1
|
| 421 |
+
if outcome_count <= MAX_OUTCOMES:
|
| 422 |
+
filtered_lines.append(line)
|
| 423 |
+
continue
|
| 424 |
+
|
| 425 |
+
# LIMITED KEEP: Outcome descriptions (first N only)
|
| 426 |
+
if line_stripped.startswith('OUTCOME_DESCRIPTION:'):
|
| 427 |
+
outcome_desc_count += 1
|
| 428 |
+
if outcome_desc_count <= MAX_OUTCOME_DESC:
|
| 429 |
+
filtered_lines.append(line)
|
| 430 |
+
continue
|
| 431 |
+
|
| 432 |
+
# LIMITED KEEP: Result values (first N only)
|
| 433 |
+
if line_stripped.startswith('RESULT_VALUE:'):
|
| 434 |
+
result_value_count += 1
|
| 435 |
+
if result_value_count <= MAX_RESULT_VALUES:
|
| 436 |
+
filtered_lines.append(line)
|
| 437 |
+
continue
|
| 438 |
+
|
| 439 |
+
# ALWAYS KEEP: Core trial information + context
|
| 440 |
+
always_keep = [
|
| 441 |
+
'NCT_ID:', 'TITLE:', 'OFFICIAL_TITLE:',
|
| 442 |
+
'SUMMARY:', 'DESCRIPTION:',
|
| 443 |
+
'CONDITIONS:', 'INTERVENTION:', # WHAT disease, WHAT drug
|
| 444 |
+
'SPONSOR:', 'COLLABORATOR:', 'MANUFACTURER:', # WHO is running/funding
|
| 445 |
+
'ELIGIBILITY:'
|
| 446 |
+
# Note: OUTCOME/OUTCOME_DESCRIPTION handled in LIMITED KEEP section above
|
| 447 |
+
]
|
| 448 |
+
|
| 449 |
+
for marker in always_keep:
|
| 450 |
+
if line_stripped.startswith(marker):
|
| 451 |
+
filtered_lines.append(line)
|
| 452 |
+
break
|
| 453 |
+
|
| 454 |
+
return '\n'.join(filtered_lines)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def retrieve_context_with_embeddings(query, top_k=10):
|
| 458 |
+
"""
|
| 459 |
+
ENTERPRISE HYBRID SEARCH: Always combines keyword + semantic scoring
|
| 460 |
+
- Extracts ALL meaningful terms from query (case-insensitive)
|
| 461 |
+
- Scores each trial by keyword frequency (TF-IDF style)
|
| 462 |
+
- Also gets semantic similarity scores
|
| 463 |
+
- Merges both scores with weighted combination
|
| 464 |
+
- Works regardless of capitalization, language, or spelling
|
| 465 |
+
"""
|
| 466 |
+
import time
|
| 467 |
+
import re
|
| 468 |
+
from collections import Counter
|
| 469 |
+
global doc_chunks, doc_embeddings, embedder
|
| 470 |
+
|
| 471 |
+
if doc_embeddings is None or len(doc_chunks) == 0:
|
| 472 |
+
logger.error("Embeddings not loaded!")
|
| 473 |
+
return ""
|
| 474 |
+
|
| 475 |
+
t0 = time.time()
|
| 476 |
+
|
| 477 |
+
# Extract ALL meaningful words from query (stop words removed)
|
| 478 |
+
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with',
|
| 479 |
+
'is', 'are', 'was', 'were', 'be', 'been', 'being', 'what', 'how', 'do', 'you', 'know',
|
| 480 |
+
'about', 'that', 'this', 'there', 'it'}
|
| 481 |
+
|
| 482 |
+
query_lower = query.lower()
|
| 483 |
+
# Remove punctuation and split
|
| 484 |
+
words = re.findall(r'\b\w+\b', query_lower)
|
| 485 |
+
# Filter out stop words and short words
|
| 486 |
+
query_terms = [w for w in words if len(w) > 2 and w not in stop_words]
|
| 487 |
+
|
| 488 |
+
logger.info(f"[HYBRID] Query terms extracted: {query_terms}")
|
| 489 |
+
|
| 490 |
+
# PARALLEL SEARCH: Run both keyword and semantic simultaneously
|
| 491 |
+
|
| 492 |
+
# 1. KEYWORD SCORING WITH BM25 (Fast!)
|
| 493 |
+
t_kw = time.time()
|
| 494 |
+
|
| 495 |
+
# Use inverted index for drug lookup (lightweight, no BM25)
|
| 496 |
+
global bm25_index, inverted_index
|
| 497 |
+
keyword_scores = {}
|
| 498 |
+
|
| 499 |
+
if inverted_index is not None:
|
| 500 |
+
# Check if any query terms are in our drug/intervention inverted index
|
| 501 |
+
inv_index_candidates = set()
|
| 502 |
+
for term in query_terms:
|
| 503 |
+
if term in inverted_index:
|
| 504 |
+
inv_index_candidates.update(inverted_index[term])
|
| 505 |
+
logger.info(f"[INVERTED INDEX] Found {len(inverted_index[term])} trials for '{term}'")
|
| 506 |
+
|
| 507 |
+
# FAST PATH: If we have inverted index hits (drug names), score those trials
|
| 508 |
+
if inv_index_candidates:
|
| 509 |
+
logger.info(f"[FAST PATH] Checking {len(inv_index_candidates)} inverted index candidates")
|
| 510 |
+
|
| 511 |
+
# CRITICAL: Identify which terms are specific drugs (low frequency)
|
| 512 |
+
drug_specific_terms = set()
|
| 513 |
+
for term in query_terms:
|
| 514 |
+
if term in inverted_index and len(inverted_index[term]) < 100:
|
| 515 |
+
# This term appears in <100 trials - likely a specific drug name!
|
| 516 |
+
drug_specific_terms.add(term)
|
| 517 |
+
logger.info(f"[DRUG SPECIFIC] '{term}' found in {len(inverted_index[term])} trials - treating as drug name")
|
| 518 |
+
|
| 519 |
+
for idx in inv_index_candidates:
|
| 520 |
+
# No BM25, use simple match count as base score
|
| 521 |
+
base_score = 1.0
|
| 522 |
+
|
| 523 |
+
# Check if this trial contains a drug-specific term
|
| 524 |
+
chunk_data = doc_chunks[idx]
|
| 525 |
+
chunk_text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data
|
| 526 |
+
chunk_lower = chunk_text.lower()
|
| 527 |
+
|
| 528 |
+
has_drug_match = False
|
| 529 |
+
for drug_term in drug_specific_terms:
|
| 530 |
+
if drug_term in chunk_lower:
|
| 531 |
+
has_drug_match = True
|
| 532 |
+
break
|
| 533 |
+
|
| 534 |
+
# MASSIVE PRIORITY for drug-specific trials
|
| 535 |
+
if has_drug_match:
|
| 536 |
+
# Drug-specific trials get GUARANTEED top ranking
|
| 537 |
+
score = 1000.0 + base_score
|
| 538 |
+
logger.info(f"[DRUG PRIORITY] Trial {idx} contains specific drug - score={score:.1f}")
|
| 539 |
+
else:
|
| 540 |
+
# Regular inverted index hits (generic terms)
|
| 541 |
+
if base_score <= 0:
|
| 542 |
+
base_score = 0.1
|
| 543 |
+
|
| 544 |
+
score = base_score
|
| 545 |
+
|
| 546 |
+
# Apply field-specific boosting for non-drug terms
|
| 547 |
+
max_field_boost = 1.0
|
| 548 |
+
for term in query_terms:
|
| 549 |
+
if term not in chunk_lower or term in drug_specific_terms:
|
| 550 |
+
continue
|
| 551 |
+
|
| 552 |
+
# INTERVENTION field - medium priority for non-drug terms
|
| 553 |
+
if f'intervention: {term}' in chunk_lower or f'intervention:{term}' in chunk_lower:
|
| 554 |
+
max_field_boost = max(max_field_boost, 3.0)
|
| 555 |
+
# TITLE field - low priority
|
| 556 |
+
elif 'title:' in chunk_lower:
|
| 557 |
+
title_pos = chunk_lower.find('title:')
|
| 558 |
+
term_pos = chunk_lower.find(term)
|
| 559 |
+
if title_pos < term_pos < title_pos + 200:
|
| 560 |
+
max_field_boost = max(max_field_boost, 2.0)
|
| 561 |
+
|
| 562 |
+
score *= max_field_boost
|
| 563 |
+
|
| 564 |
+
keyword_scores[idx] = score
|
| 565 |
+
else:
|
| 566 |
+
logger.info(f"[FALLBACK] No inverted index hits, using pure semantic search")
|
| 567 |
+
|
| 568 |
+
logger.info(f"[HYBRID] Inverted index scoring: {len(keyword_scores)} trials matched ({time.time()-t_kw:.2f}s)")
|
| 569 |
+
|
| 570 |
+
# 2. SEMANTIC SCORING
|
| 571 |
+
load_embedder()
|
| 572 |
+
t_sem = time.time()
|
| 573 |
+
query_embedding = embedder.encode([query])[0]
|
| 574 |
+
semantic_similarities = np.dot(doc_embeddings, query_embedding)
|
| 575 |
+
logger.info(f"[HYBRID] Semantic scoring complete ({time.time()-t_sem:.2f}s)")
|
| 576 |
+
|
| 577 |
+
# 3. MERGE SCORES
|
| 578 |
+
# Normalize both scores to 0-1 range
|
| 579 |
+
if keyword_scores:
|
| 580 |
+
max_kw = max(keyword_scores.values())
|
| 581 |
+
keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()}
|
| 582 |
+
else:
|
| 583 |
+
keyword_scores_norm = {}
|
| 584 |
+
|
| 585 |
+
max_sem = semantic_similarities.max()
|
| 586 |
+
min_sem = semantic_similarities.min()
|
| 587 |
+
semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10)
|
| 588 |
+
|
| 589 |
+
# Combined score: 50% keyword (with IDF/field boost), 50% semantic (context)
|
| 590 |
+
# Balanced approach: IDF-weighted keywords + semantic understanding
|
| 591 |
+
combined_scores = np.zeros(len(doc_chunks))
|
| 592 |
+
|
| 593 |
+
for idx in range(len(doc_chunks)):
|
| 594 |
+
kw_score = keyword_scores_norm.get(idx, 0.0)
|
| 595 |
+
sem_score = semantic_scores_norm[idx]
|
| 596 |
+
|
| 597 |
+
# If keyword match exists, balance keyword + semantic
|
| 598 |
+
if kw_score > 0:
|
| 599 |
+
combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score
|
| 600 |
+
else:
|
| 601 |
+
# Pure semantic if no keyword match
|
| 602 |
+
combined_scores[idx] = sem_score
|
| 603 |
+
|
| 604 |
+
# Get top K by combined score (get more candidates to sort by recency)
|
| 605 |
+
# We'll get 10 candidates, then sort by NCT ID to find the 3 most recent
|
| 606 |
+
candidate_k = max(top_k * 3, 10) # Get 3x requested, minimum 10
|
| 607 |
+
top_indices = np.argsort(combined_scores)[-candidate_k:][::-1]
|
| 608 |
+
|
| 609 |
+
logger.info(f"[HYBRID] Top 3 combined scores: {combined_scores[top_indices[:3]]}")
|
| 610 |
+
logger.info(f"[HYBRID] Top 3 keyword scores: {[keyword_scores_norm.get(i, 0.0) for i in top_indices[:3]]}")
|
| 611 |
+
logger.info(f"[HYBRID] Top 3 semantic scores: {[semantic_scores_norm[i] for i in top_indices[:3]]}")
|
| 612 |
+
|
| 613 |
+
# Extract text and scores for 355M ranking
|
| 614 |
+
# Format as (score, text) tuples for rank_trials_with_355m
|
| 615 |
+
candidate_trials_for_ranking = [(combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i]) for i in top_indices]
|
| 616 |
+
|
| 617 |
+
# SORT BY NCT ID (higher = newer) before 355M ranking
|
| 618 |
+
def extract_nct_number(trial_tuple):
|
| 619 |
+
"""Extract NCT number from trial text for sorting (higher = newer)"""
|
| 620 |
+
_, text = trial_tuple
|
| 621 |
+
match = re.search(r'NCT_ID:\s*NCT(\d+)', text)
|
| 622 |
+
return int(match.group(1)) if match else 0
|
| 623 |
+
|
| 624 |
+
# Sort candidates by NCT ID (descending = newest first)
|
| 625 |
+
candidate_trials_for_ranking.sort(key=extract_nct_number, reverse=True)
|
| 626 |
+
|
| 627 |
+
# Log top 5 NCT IDs to show recency sorting
|
| 628 |
+
top_ncts = []
|
| 629 |
+
for score, text in candidate_trials_for_ranking[:5]:
|
| 630 |
+
match = re.search(r'NCT_ID:\s*(NCT\d+)', text)
|
| 631 |
+
if match:
|
| 632 |
+
top_ncts.append(match.group(1))
|
| 633 |
+
logger.info(f"[NCT SORT] Top 5 candidates by recency: {top_ncts}")
|
| 634 |
+
|
| 635 |
+
# SKIP 355M RANKING - It's broken (gives 0.50 to everything) and wastes 10 seconds
|
| 636 |
+
# Just use the hybrid-scored + recency-sorted candidates
|
| 637 |
+
logger.info(f"[FAST MODE] Using hybrid search + recency sort (skipping broken 355M ranking)")
|
| 638 |
+
ranked_trials = candidate_trials_for_ranking
|
| 639 |
+
|
| 640 |
+
# Take top K from ranked results
|
| 641 |
+
top_ranked = ranked_trials[:top_k]
|
| 642 |
+
|
| 643 |
+
logger.info(f"[FAST MODE] Selected top {len(top_ranked)} trials (hybrid score + recency)")
|
| 644 |
+
|
| 645 |
+
# Extract just the text
|
| 646 |
+
raw_chunks = [trial_text for _, trial_text in top_ranked]
|
| 647 |
+
|
| 648 |
+
# Apply clinical filter to each trial
|
| 649 |
+
context_chunks = [filter_trial_for_clinical_summary(chunk) for chunk in raw_chunks]
|
| 650 |
+
|
| 651 |
+
if context_chunks:
|
| 652 |
+
first_trial_preview = context_chunks[0][:200]
|
| 653 |
+
logger.info(f"[HYBRID] First result (filtered): {first_trial_preview}")
|
| 654 |
+
|
| 655 |
+
# Add ranking information if available from 355M
|
| 656 |
+
if hasattr(ranked_trials, 'ranking_info'):
|
| 657 |
+
ranking_header = "[TRIAL RANKING BY CLINICAL RELEVANCE GPT]\n"
|
| 658 |
+
for info in ranked_trials.ranking_info:
|
| 659 |
+
ranking_header += f"Rank {info['rank']}: {info['nct_id']} - Relevance {info['relevance_rating']}\n"
|
| 660 |
+
ranking_header += "---\n\n"
|
| 661 |
+
|
| 662 |
+
# Prepend ranking info to first trial
|
| 663 |
+
if context_chunks:
|
| 664 |
+
context_chunks[0] = ranking_header + context_chunks[0]
|
| 665 |
+
logger.info(f"[355M RANKING] Added ranking metadata to context for final LLM")
|
| 666 |
+
|
| 667 |
+
context = "\n\n---\n\n".join(context_chunks) # Use --- as separator between trials
|
| 668 |
+
logger.info(f"[HYBRID] TOTAL TIME: {time.time()-t0:.2f}s")
|
| 669 |
+
logger.info(f"[HYBRID] Filtered context length: {len(context)} chars (was ~{sum(len(c) for c in raw_chunks)} chars)")
|
| 670 |
+
|
| 671 |
+
return context
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def keyword_search_query_text(query, max_results=10, hf_token=None):
|
| 675 |
+
"""Search dataset using ALL meaningful words from the full query"""
|
| 676 |
+
if not DATASET_FILE.exists():
|
| 677 |
+
logger.error("Dataset file not found")
|
| 678 |
+
return ""
|
| 679 |
+
|
| 680 |
+
# Extract all meaningful words from the full query
|
| 681 |
+
# Remove common stopwords but keep medical/clinical terms
|
| 682 |
+
stopwords = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
|
| 683 |
+
'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'should',
|
| 684 |
+
'could', 'may', 'might', 'must', 'can', 'of', 'at', 'by', 'for', 'with',
|
| 685 |
+
'about', 'as', 'into', 'through', 'during', 'to', 'from', 'in', 'on',
|
| 686 |
+
'what', 'you', 'know', 'that', 'relevant'}
|
| 687 |
+
|
| 688 |
+
# Extract words, filter stopwords and short words
|
| 689 |
+
words = query.lower().split()
|
| 690 |
+
search_terms = [w.strip('?.,!;:()[]{}') for w in words
|
| 691 |
+
if w.lower() not in stopwords and len(w) >= 3]
|
| 692 |
+
|
| 693 |
+
if not search_terms:
|
| 694 |
+
logger.warning("No search terms extracted from query")
|
| 695 |
+
return ""
|
| 696 |
+
|
| 697 |
+
logger.info(f"Search terms from full query: {search_terms}")
|
| 698 |
+
|
| 699 |
+
# Store trials with match scores
|
| 700 |
+
trials_with_scores = []
|
| 701 |
+
current_trial = ""
|
| 702 |
+
|
| 703 |
+
try:
|
| 704 |
+
with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f:
|
| 705 |
+
for line in f:
|
| 706 |
+
# Check if new trial starts
|
| 707 |
+
if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"):
|
| 708 |
+
# Score previous trial
|
| 709 |
+
if current_trial:
|
| 710 |
+
trial_lower = current_trial.lower()
|
| 711 |
+
|
| 712 |
+
# Count matches for all search terms
|
| 713 |
+
score = sum(1 for term in search_terms if term in trial_lower)
|
| 714 |
+
|
| 715 |
+
if score > 0:
|
| 716 |
+
trials_with_scores.append((score, current_trial))
|
| 717 |
+
|
| 718 |
+
current_trial = line
|
| 719 |
+
else:
|
| 720 |
+
current_trial += line
|
| 721 |
+
|
| 722 |
+
# Check last trial
|
| 723 |
+
if current_trial:
|
| 724 |
+
trial_lower = current_trial.lower()
|
| 725 |
+
score = sum(1 for term in search_terms if term in trial_lower)
|
| 726 |
+
if score > 0:
|
| 727 |
+
trials_with_scores.append((score, current_trial))
|
| 728 |
+
|
| 729 |
+
# Sort by score (highest first) and take top results
|
| 730 |
+
trials_with_scores.sort(reverse=True, key=lambda x: x[0])
|
| 731 |
+
matching_trials = [(score, trial) for score, trial in trials_with_scores[:max_results]]
|
| 732 |
+
|
| 733 |
+
if matching_trials:
|
| 734 |
+
logger.info(f"Keyword search found {len(matching_trials)} trials")
|
| 735 |
+
return matching_trials # Return list of (score, trial) tuples
|
| 736 |
+
else:
|
| 737 |
+
logger.warning("Keyword search found no matching trials")
|
| 738 |
+
return []
|
| 739 |
+
|
| 740 |
+
except Exception as e:
|
| 741 |
+
logger.error(f"Keyword search failed: {e}")
|
| 742 |
+
return []
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
def keyword_search_in_dataset(entities, max_results=10):
|
| 746 |
+
"""Legacy: Search dataset file for keyword matches using extracted entities"""
|
| 747 |
+
if not DATASET_FILE.exists():
|
| 748 |
+
logger.error("Dataset file not found")
|
| 749 |
+
return ""
|
| 750 |
+
|
| 751 |
+
drugs = [d.lower() for d in entities.get('drugs', [])]
|
| 752 |
+
conditions = [c.lower() for c in entities.get('conditions', [])]
|
| 753 |
+
|
| 754 |
+
if not drugs and not conditions:
|
| 755 |
+
logger.warning("No search terms for keyword search")
|
| 756 |
+
return ""
|
| 757 |
+
|
| 758 |
+
logger.info(f"Keyword search - Drugs: {drugs}, Conditions: {conditions}")
|
| 759 |
+
|
| 760 |
+
# Store trials with match scores
|
| 761 |
+
trials_with_scores = []
|
| 762 |
+
current_trial = ""
|
| 763 |
+
|
| 764 |
+
try:
|
| 765 |
+
with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f:
|
| 766 |
+
for line in f:
|
| 767 |
+
# Check if new trial starts
|
| 768 |
+
if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"):
|
| 769 |
+
# Score previous trial
|
| 770 |
+
if current_trial:
|
| 771 |
+
trial_lower = current_trial.lower()
|
| 772 |
+
|
| 773 |
+
# Count matches
|
| 774 |
+
drug_matches = sum(1 for d in drugs if d in trial_lower)
|
| 775 |
+
condition_matches = sum(1 for c in conditions if c in trial_lower)
|
| 776 |
+
|
| 777 |
+
# Only include trials that match at least the drug (if drug was specified)
|
| 778 |
+
if drugs:
|
| 779 |
+
if drug_matches > 0:
|
| 780 |
+
score = drug_matches * 10 + condition_matches
|
| 781 |
+
trials_with_scores.append((score, current_trial))
|
| 782 |
+
elif condition_matches > 0:
|
| 783 |
+
# No drug specified, just match conditions
|
| 784 |
+
trials_with_scores.append((condition_matches, current_trial))
|
| 785 |
+
|
| 786 |
+
current_trial = line
|
| 787 |
+
else:
|
| 788 |
+
current_trial += line
|
| 789 |
+
|
| 790 |
+
# Check last trial
|
| 791 |
+
if current_trial:
|
| 792 |
+
trial_lower = current_trial.lower()
|
| 793 |
+
drug_matches = sum(1 for d in drugs if d in trial_lower)
|
| 794 |
+
condition_matches = sum(1 for c in conditions if c in trial_lower)
|
| 795 |
+
|
| 796 |
+
if drugs:
|
| 797 |
+
if drug_matches > 0:
|
| 798 |
+
score = drug_matches * 10 + condition_matches
|
| 799 |
+
trials_with_scores.append((score, current_trial))
|
| 800 |
+
elif condition_matches > 0:
|
| 801 |
+
trials_with_scores.append((condition_matches, current_trial))
|
| 802 |
+
|
| 803 |
+
# Sort by score (highest first) and take top results
|
| 804 |
+
trials_with_scores.sort(reverse=True, key=lambda x: x[0])
|
| 805 |
+
matching_trials = [trial for score, trial in trials_with_scores[:max_results]]
|
| 806 |
+
|
| 807 |
+
if matching_trials:
|
| 808 |
+
context = "\n\n---\n\n".join(matching_trials)
|
| 809 |
+
if len(context) > 6000:
|
| 810 |
+
context = context[:6000] + "..."
|
| 811 |
+
logger.info(f"Keyword search found {len(matching_trials)} trials (from {len(trials_with_scores)} candidates)")
|
| 812 |
+
return context
|
| 813 |
+
else:
|
| 814 |
+
logger.warning("Keyword search found no trials matching drug")
|
| 815 |
+
return ""
|
| 816 |
+
|
| 817 |
+
except Exception as e:
|
| 818 |
+
logger.error(f"Keyword search failed: {e}")
|
| 819 |
+
return ""
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
# ============================================================================
|
| 823 |
+
# ENTITY EXTRACTION
|
| 824 |
+
# ============================================================================
|
| 825 |
+
|
| 826 |
+
def parse_entities_from_query(conversation, hf_token=None):
|
| 827 |
+
"""Parse entities from query using both 355M and 8B models + regex fallback"""
|
| 828 |
+
entities = {'drugs': [], 'conditions': []}
|
| 829 |
+
|
| 830 |
+
# Use 355M model for entity extraction
|
| 831 |
+
extracted_355m = extract_entities_with_small_model(conversation)
|
| 832 |
+
|
| 833 |
+
# Also use 8B model for more reliable extraction
|
| 834 |
+
extracted_8b = extract_entities_with_8b(conversation, hf_token=hf_token)
|
| 835 |
+
|
| 836 |
+
# Combine both extractions
|
| 837 |
+
extracted = (extracted_355m or "") + "\n" + (extracted_8b or "")
|
| 838 |
+
|
| 839 |
+
# Parse model output
|
| 840 |
+
if extracted:
|
| 841 |
+
lines = extracted.split('\n')
|
| 842 |
+
for line in lines:
|
| 843 |
+
lower_line = line.lower()
|
| 844 |
+
if 'drug:' in lower_line or 'medication:' in lower_line:
|
| 845 |
+
drug = re.sub(r'(drug:|medication:)', '', line, flags=re.IGNORECASE).strip()
|
| 846 |
+
if drug:
|
| 847 |
+
entities['drugs'].append(drug)
|
| 848 |
+
elif 'condition:' in lower_line or 'disease:' in lower_line:
|
| 849 |
+
condition = re.sub(r'(condition:|disease:)', '', line, flags=re.IGNORECASE).strip()
|
| 850 |
+
if condition:
|
| 851 |
+
entities['conditions'].append(condition)
|
| 852 |
+
|
| 853 |
+
# Regex fallback for standard drug naming patterns
|
| 854 |
+
drug_patterns = [
|
| 855 |
+
r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: -mab suffix
|
| 856 |
+
r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: -nib suffix
|
| 857 |
+
r'\b([A-Z]\d+[A-Z]+\d+)\b' # Alphanumeric codes like F8IL10
|
| 858 |
+
]
|
| 859 |
+
for pattern in drug_patterns:
|
| 860 |
+
matches = re.findall(pattern, conversation)
|
| 861 |
+
for match in matches:
|
| 862 |
+
if match.lower() not in [d.lower() for d in entities['drugs']]:
|
| 863 |
+
entities['drugs'].append(match)
|
| 864 |
+
|
| 865 |
+
condition_patterns = [
|
| 866 |
+
r'\b(sjogren\'?s?|lupus|myelofibrosis|rheumatoid arthritis)\b'
|
| 867 |
+
]
|
| 868 |
+
for pattern in condition_patterns:
|
| 869 |
+
matches = re.findall(pattern, conversation, re.IGNORECASE)
|
| 870 |
+
for match in matches:
|
| 871 |
+
if match not in [c.lower() for c in entities['conditions']]:
|
| 872 |
+
entities['conditions'].append(match)
|
| 873 |
+
|
| 874 |
+
logger.info(f"Extracted entities: {entities}")
|
| 875 |
+
return entities
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
# ============================================================================
|
| 879 |
+
# MAIN QUERY PROCESSING
|
| 880 |
+
# ============================================================================
|
| 881 |
+
|
| 882 |
+
def extract_entities_simple(query):
|
| 883 |
+
"""Simple entity extraction using regex patterns - no model needed"""
|
| 884 |
+
entities = {'drugs': [], 'conditions': []}
|
| 885 |
+
|
| 886 |
+
# Drug patterns
|
| 887 |
+
drug_patterns = [
|
| 888 |
+
r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: ianalumab, rituximab, etc.
|
| 889 |
+
r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: imatinib, etc.
|
| 890 |
+
r'\b([A-Z]\d+[A-Z]+\d+)\b', # Alphanumeric codes
|
| 891 |
+
r'\b(ianalumab|rituximab|tocilizumab|adalimumab|infliximab)\b', # Common drugs
|
| 892 |
+
]
|
| 893 |
+
|
| 894 |
+
# Condition patterns
|
| 895 |
+
condition_patterns = [
|
| 896 |
+
r'\b(sjogren\'?s?\s+syndrome)\b',
|
| 897 |
+
r'\b(rheumatoid arthritis)\b',
|
| 898 |
+
r'\b(lupus)\b',
|
| 899 |
+
r'\b(myelofibrosis)\b',
|
| 900 |
+
r'\b(diabetes)\b',
|
| 901 |
+
r'\b(cancer|carcinoma|melanoma)\b',
|
| 902 |
+
]
|
| 903 |
+
|
| 904 |
+
query_lower = query.lower()
|
| 905 |
+
|
| 906 |
+
# Extract drugs
|
| 907 |
+
for pattern in drug_patterns:
|
| 908 |
+
matches = re.findall(pattern, query, re.IGNORECASE)
|
| 909 |
+
for match in matches:
|
| 910 |
+
if match.lower() not in [d.lower() for d in entities['drugs']]:
|
| 911 |
+
entities['drugs'].append(match)
|
| 912 |
+
|
| 913 |
+
# Extract conditions
|
| 914 |
+
for pattern in condition_patterns:
|
| 915 |
+
matches = re.findall(pattern, query, re.IGNORECASE)
|
| 916 |
+
for match in matches:
|
| 917 |
+
if match.lower() not in [c.lower() for c in entities['conditions']]:
|
| 918 |
+
entities['conditions'].append(match)
|
| 919 |
+
|
| 920 |
+
logger.info(f"Extracted entities: {entities}")
|
| 921 |
+
return entities
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
def parse_query_with_llm(query, hf_token=None):
|
| 925 |
+
"""
|
| 926 |
+
Use fast LLM to parse query and extract structured information
|
| 927 |
+
|
| 928 |
+
Extracts:
|
| 929 |
+
- Drug names
|
| 930 |
+
- Diseases/conditions
|
| 931 |
+
- Companies (sponsors/manufacturers)
|
| 932 |
+
- Endpoints (what's being measured)
|
| 933 |
+
- Search terms (optimized for RAG)
|
| 934 |
+
|
| 935 |
+
Returns: Dict with extracted entities and optimized search query
|
| 936 |
+
"""
|
| 937 |
+
try:
|
| 938 |
+
from huggingface_hub import InferenceClient
|
| 939 |
+
|
| 940 |
+
logger.info("[QUERY PARSER] Analyzing user query with LLM...")
|
| 941 |
+
client = InferenceClient(token=hf_token, timeout=30)
|
| 942 |
+
|
| 943 |
+
parse_prompt = f"""Extract key information from this clinical trial query.
|
| 944 |
+
|
| 945 |
+
Query: "{query}"
|
| 946 |
+
|
| 947 |
+
Extract and return in this EXACT format:
|
| 948 |
+
DRUGS: [list drug/treatment names, or "none"]
|
| 949 |
+
DISEASES: [list diseases/conditions, or "none"]
|
| 950 |
+
COMPANIES: [list company/sponsor names, or "none"]
|
| 951 |
+
ENDPOINTS: [list trial endpoints/outcomes, or "none"]
|
| 952 |
+
SEARCH_TERMS: [optimized search keywords]
|
| 953 |
+
|
| 954 |
+
Examples:
|
| 955 |
+
Query: "What Novartis drugs treat melanoma?"
|
| 956 |
+
DRUGS: none
|
| 957 |
+
DISEASES: melanoma
|
| 958 |
+
COMPANIES: Novartis
|
| 959 |
+
ENDPOINTS: none
|
| 960 |
+
SEARCH_TERMS: Novartis melanoma treatment drugs
|
| 961 |
+
|
| 962 |
+
Query: "Tell me about Keytruda for lung cancer"
|
| 963 |
+
DRUGS: Keytruda
|
| 964 |
+
DISEASES: lung cancer
|
| 965 |
+
COMPANIES: none
|
| 966 |
+
ENDPOINTS: none
|
| 967 |
+
SEARCH_TERMS: Keytruda lung cancer
|
| 968 |
+
|
| 969 |
+
Now parse the query above:"""
|
| 970 |
+
|
| 971 |
+
response = client.chat_completion(
|
| 972 |
+
model="meta-llama/Llama-3.1-70B-Instruct",
|
| 973 |
+
messages=[{"role": "user", "content": parse_prompt}],
|
| 974 |
+
max_tokens=256,
|
| 975 |
+
temperature=0.1 # Low temp for consistent parsing
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
parsed = response.choices[0].message.content.strip()
|
| 979 |
+
logger.info(f"[QUERY PARSER] Extracted entities:\n{parsed}")
|
| 980 |
+
|
| 981 |
+
# Parse the response into dict
|
| 982 |
+
result = {
|
| 983 |
+
'raw_parsed': parsed,
|
| 984 |
+
'drugs': [],
|
| 985 |
+
'diseases': [],
|
| 986 |
+
'companies': [],
|
| 987 |
+
'endpoints': [],
|
| 988 |
+
'search_terms': query # fallback
|
| 989 |
+
}
|
| 990 |
+
|
| 991 |
+
lines = parsed.split('\n')
|
| 992 |
+
for line in lines:
|
| 993 |
+
line = line.strip()
|
| 994 |
+
if line.startswith('DRUGS:'):
|
| 995 |
+
drugs = line.replace('DRUGS:', '').strip()
|
| 996 |
+
if drugs.lower() != 'none':
|
| 997 |
+
result['drugs'] = [d.strip() for d in drugs.split(',')]
|
| 998 |
+
elif line.startswith('DISEASES:'):
|
| 999 |
+
diseases = line.replace('DISEASES:', '').strip()
|
| 1000 |
+
if diseases.lower() != 'none':
|
| 1001 |
+
result['diseases'] = [d.strip() for d in diseases.split(',')]
|
| 1002 |
+
elif line.startswith('COMPANIES:'):
|
| 1003 |
+
companies = line.replace('COMPANIES:', '').strip()
|
| 1004 |
+
if companies.lower() != 'none':
|
| 1005 |
+
result['companies'] = [c.strip() for c in companies.split(',')]
|
| 1006 |
+
elif line.startswith('ENDPOINTS:'):
|
| 1007 |
+
endpoints = line.replace('ENDPOINTS:', '').strip()
|
| 1008 |
+
if endpoints.lower() != 'none':
|
| 1009 |
+
result['endpoints'] = [e.strip() for e in endpoints.split(',')]
|
| 1010 |
+
elif line.startswith('SEARCH_TERMS:'):
|
| 1011 |
+
result['search_terms'] = line.replace('SEARCH_TERMS:', '').strip()
|
| 1012 |
+
|
| 1013 |
+
logger.info(f"[QUERY PARSER] ✓ Drugs: {result['drugs']}, Diseases: {result['diseases']}, Companies: {result['companies']}")
|
| 1014 |
+
return result
|
| 1015 |
+
|
| 1016 |
+
except Exception as e:
|
| 1017 |
+
logger.warning(f"[QUERY PARSER] Failed: {e}, using original query")
|
| 1018 |
+
return {
|
| 1019 |
+
'drugs': [],
|
| 1020 |
+
'diseases': [],
|
| 1021 |
+
'companies': [],
|
| 1022 |
+
'endpoints': [],
|
| 1023 |
+
'search_terms': query,
|
| 1024 |
+
'raw_parsed': ''
|
| 1025 |
+
}
|
| 1026 |
+
|
| 1027 |
+
def generate_llama_response(query, rag_context, hf_token=None):
|
| 1028 |
+
"""
|
| 1029 |
+
Generate response using FAST Groq API (10x faster than HF)
|
| 1030 |
+
|
| 1031 |
+
Speed comparison:
|
| 1032 |
+
- HuggingFace: ~40 tokens/sec = 15 seconds
|
| 1033 |
+
- Groq: ~300 tokens/sec = 2 seconds (FREE!)
|
| 1034 |
+
"""
|
| 1035 |
+
try:
|
| 1036 |
+
# Try Groq first (much faster), fallback to HuggingFace
|
| 1037 |
+
groq_api_key = os.getenv("GROQ_API_KEY")
|
| 1038 |
+
|
| 1039 |
+
if groq_api_key:
|
| 1040 |
+
logger.info("Generating response with Llama-3.1-70B via GROQ (fast)...")
|
| 1041 |
+
from groq import Groq
|
| 1042 |
+
client = Groq(api_key=groq_api_key)
|
| 1043 |
+
|
| 1044 |
+
# Simplified prompt for faster generation
|
| 1045 |
+
system_prompt = """You are a medical research assistant. Answer based ONLY on the provided clinical trial data. Be concise and cite NCT IDs."""
|
| 1046 |
+
|
| 1047 |
+
user_prompt = f"""Clinical trials:
|
| 1048 |
+
{rag_context[:6000]}
|
| 1049 |
+
|
| 1050 |
+
Question: {query}
|
| 1051 |
+
|
| 1052 |
+
Provide a concise answer citing specific NCT trial IDs."""
|
| 1053 |
+
|
| 1054 |
+
response = client.chat.completions.create(
|
| 1055 |
+
model="llama-3.1-70b-versatile", # Groq's optimized 70B
|
| 1056 |
+
messages=[
|
| 1057 |
+
{"role": "system", "content": system_prompt},
|
| 1058 |
+
{"role": "user", "content": user_prompt}
|
| 1059 |
+
],
|
| 1060 |
+
max_tokens=512, # Shorter for speed
|
| 1061 |
+
temperature=0.3,
|
| 1062 |
+
timeout=30
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
return response.choices[0].message.content.strip()
|
| 1066 |
+
|
| 1067 |
+
else:
|
| 1068 |
+
# Fallback to HuggingFace (slower)
|
| 1069 |
+
logger.info("Generating response with Llama-3.1-70B via HuggingFace (slow)...")
|
| 1070 |
+
from huggingface_hub import InferenceClient
|
| 1071 |
+
client = InferenceClient(token=hf_token, timeout=120)
|
| 1072 |
+
|
| 1073 |
+
system_prompt = """You are a medical research assistant. Answer based ONLY on the provided clinical trial data. Be concise and cite NCT IDs."""
|
| 1074 |
+
|
| 1075 |
+
user_prompt = f"""Clinical trials:
|
| 1076 |
+
{rag_context[:6000]}
|
| 1077 |
+
|
| 1078 |
+
Question: {query}
|
| 1079 |
+
|
| 1080 |
+
Provide a concise answer citing specific NCT trial IDs."""
|
| 1081 |
+
|
| 1082 |
+
messages = [
|
| 1083 |
+
{"role": "system", "content": system_prompt},
|
| 1084 |
+
{"role": "user", "content": user_prompt}
|
| 1085 |
+
]
|
| 1086 |
+
|
| 1087 |
+
response = client.chat_completion(
|
| 1088 |
+
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
|
| 1089 |
+
messages=messages,
|
| 1090 |
+
max_tokens=512, # Reduced from 2048 for speed
|
| 1091 |
+
temperature=0.3
|
| 1092 |
+
)
|
| 1093 |
+
|
| 1094 |
+
return response.choices[0].message.content.strip()
|
| 1095 |
+
|
| 1096 |
+
except Exception as e:
|
| 1097 |
+
logger.error(f"Llama error: {e}")
|
| 1098 |
+
return f"Llama API error: {str(e)}"
|
| 1099 |
+
|
| 1100 |
+
|
| 1101 |
+
def process_query_simple_test(conversation):
|
| 1102 |
+
"""TEST JUST THE RAG - no models"""
|
| 1103 |
+
try:
|
| 1104 |
+
import time
|
| 1105 |
+
output = []
|
| 1106 |
+
output.append(f"QUERY: {conversation}\n")
|
| 1107 |
+
|
| 1108 |
+
# Check if embeddings loaded
|
| 1109 |
+
if doc_embeddings is None or len(doc_chunks) == 0:
|
| 1110 |
+
return "FAIL: Embeddings not loaded"
|
| 1111 |
+
|
| 1112 |
+
output.append(f"✓ Embeddings loaded: {len(doc_chunks)} chunks\n")
|
| 1113 |
+
output.append(f"✓ Embeddings shape: {doc_embeddings.shape}\n")
|
| 1114 |
+
|
| 1115 |
+
# Try to search
|
| 1116 |
+
start = time.time()
|
| 1117 |
+
context = retrieve_context_with_embeddings(conversation, top_k=3)
|
| 1118 |
+
search_time = time.time() - start
|
| 1119 |
+
|
| 1120 |
+
if not context:
|
| 1121 |
+
return "".join(output) + "\nFAIL: RAG returned empty"
|
| 1122 |
+
|
| 1123 |
+
output.append(f"✓ RAG search took: {search_time:.2f}s\n")
|
| 1124 |
+
output.append(f"✓ Retrieved {context.count('NCT')} trials\n\n")
|
| 1125 |
+
output.append("FIRST 1000 CHARS:\n")
|
| 1126 |
+
output.append(context[:1000])
|
| 1127 |
+
|
| 1128 |
+
return "".join(output)
|
| 1129 |
+
|
| 1130 |
+
except Exception as e:
|
| 1131 |
+
import traceback
|
| 1132 |
+
return f"ERROR IN RAG TEST:\n{str(e)}\n\nTRACEBACK:\n{traceback.format_exc()}"
|
| 1133 |
+
|
| 1134 |
+
|
| 1135 |
+
def process_query(conversation):
|
| 1136 |
+
"""
|
| 1137 |
+
Complete pipeline with LLM query parsing and natural language generation
|
| 1138 |
+
|
| 1139 |
+
Flow:
|
| 1140 |
+
0. LLM Parser - Extract drugs, diseases, companies, endpoints (~2-3s)
|
| 1141 |
+
1. RAG Search - Hybrid search using optimized query (~2s)
|
| 1142 |
+
2. Skipped - 355M ranking removed (was broken)
|
| 1143 |
+
3. LLM Response - Llama 70B generates natural language (~15s)
|
| 1144 |
+
|
| 1145 |
+
Total: ~20 seconds
|
| 1146 |
+
"""
|
| 1147 |
+
import time
|
| 1148 |
+
import traceback
|
| 1149 |
+
import sys
|
| 1150 |
+
|
| 1151 |
+
# MASTER try/except - catches EVERYTHING
|
| 1152 |
+
try:
|
| 1153 |
+
start_time = time.time()
|
| 1154 |
+
output_parts = [f"QUERY: {conversation}\n\n"]
|
| 1155 |
+
|
| 1156 |
+
# Step 0: Parse query with LLM to extract structured info
|
| 1157 |
+
try:
|
| 1158 |
+
step0_start = time.time()
|
| 1159 |
+
logger.info("Step 0: Parsing query with LLM...")
|
| 1160 |
+
output_parts.append("✓ Step 0: LLM query parser started...\n")
|
| 1161 |
+
parsed_query = parse_query_with_llm(conversation, hf_token=hf_token)
|
| 1162 |
+
|
| 1163 |
+
# Use optimized search terms from parser
|
| 1164 |
+
search_query = parsed_query['search_terms']
|
| 1165 |
+
|
| 1166 |
+
step0_time = time.time() - step0_start
|
| 1167 |
+
output_parts.append(f"✓ Step 0 Complete: Extracted entities ({step0_time:.1f}s)\n")
|
| 1168 |
+
output_parts.append(f" Drugs: {parsed_query['drugs']}\n")
|
| 1169 |
+
output_parts.append(f" Diseases: {parsed_query['diseases']}\n")
|
| 1170 |
+
output_parts.append(f" Companies: {parsed_query['companies']}\n")
|
| 1171 |
+
output_parts.append(f" Optimized search: {search_query}\n")
|
| 1172 |
+
logger.info(f"Query parsing successful in {step0_time:.1f}s")
|
| 1173 |
+
|
| 1174 |
+
except Exception as e:
|
| 1175 |
+
error_msg = f"✗ Step 0 WARNING (LLM Parser): {str(e)}, using original query"
|
| 1176 |
+
logger.warning(error_msg)
|
| 1177 |
+
output_parts.append(f"{error_msg}\n")
|
| 1178 |
+
search_query = conversation # Fallback to original
|
| 1179 |
+
|
| 1180 |
+
# Step 1: RAG search (using optimized search query)
|
| 1181 |
+
try:
|
| 1182 |
+
step1_start = time.time()
|
| 1183 |
+
logger.info("Step 1: RAG search...")
|
| 1184 |
+
output_parts.append("✓ Step 1: RAG search started...\n")
|
| 1185 |
+
context = retrieve_context_with_embeddings(search_query, top_k=3)
|
| 1186 |
+
|
| 1187 |
+
if not context:
|
| 1188 |
+
return "No matching trials found in RAG search."
|
| 1189 |
+
|
| 1190 |
+
# No limit - use complete trials
|
| 1191 |
+
step1_time = time.time() - step1_start
|
| 1192 |
+
output_parts.append(f"✓ Step 1 Complete: Found {context.count('NCT')} trials ({step1_time:.1f}s)\n")
|
| 1193 |
+
logger.info(f"RAG search successful - found trials in {step1_time:.1f}s")
|
| 1194 |
+
|
| 1195 |
+
except Exception as e:
|
| 1196 |
+
error_msg = f"✗ Step 1 FAILED (RAG search): {str(e)}\n{traceback.format_exc()}"
|
| 1197 |
+
logger.error(error_msg)
|
| 1198 |
+
return error_msg
|
| 1199 |
+
|
| 1200 |
+
# Step 2: Skipped (355M ranking removed - was broken)
|
| 1201 |
+
output_parts.append("✓ Step 2: Skipped (using hybrid search + recency)\n")
|
| 1202 |
+
|
| 1203 |
+
# Step 3: Llama 70B
|
| 1204 |
+
try:
|
| 1205 |
+
step3_start = time.time()
|
| 1206 |
+
logger.info("Step 3: Generating response with Llama-3.1-70B...")
|
| 1207 |
+
output_parts.append("✓ Step 3: Llama 70B generation started...\n")
|
| 1208 |
+
llama_response = generate_llama_response(conversation, context, hf_token=hf_token)
|
| 1209 |
+
step3_time = time.time() - step3_start
|
| 1210 |
+
output_parts.append(f"✓ Step 3 Complete: Llama 70B response generated ({step3_time:.1f}s)\n")
|
| 1211 |
+
logger.info(f"Llama 70B generation successful in {step3_time:.1f}s")
|
| 1212 |
+
|
| 1213 |
+
except Exception as e:
|
| 1214 |
+
error_msg = f"✗ Step 3 FAILED (Llama 70B): {str(e)}\n{traceback.format_exc()}"
|
| 1215 |
+
logger.error(error_msg)
|
| 1216 |
+
llama_response = f"[Llama 70B error: {str(e)}]"
|
| 1217 |
+
output_parts.append(f"✗ Step 3 Failed: {str(e)}\n")
|
| 1218 |
+
|
| 1219 |
+
total_time = time.time() - start_time
|
| 1220 |
+
|
| 1221 |
+
# Format output - handle missing variables
|
| 1222 |
+
try:
|
| 1223 |
+
context_display = context if 'context' in locals() else "[No context retrieved]"
|
| 1224 |
+
clinical_display = clinical_context_355m if 'clinical_context_355m' in locals() else "[355M not run]"
|
| 1225 |
+
llama_display = llama_response if 'llama_response' in locals() else "[Llama 70B not run]"
|
| 1226 |
+
|
| 1227 |
+
output = f"""{''.join(output_parts)}
|
| 1228 |
+
|
| 1229 |
+
CLINICAL SUMMARY (Llama-3.1-70B-Instruct):
|
| 1230 |
+
{llama_display}
|
| 1231 |
+
|
| 1232 |
+
---
|
| 1233 |
+
|
| 1234 |
+
RAG RETRIEVED TRIALS (Top 3 Most Relevant):
|
| 1235 |
+
{context_display}
|
| 1236 |
+
|
| 1237 |
+
---
|
| 1238 |
+
Total Time: {total_time:.1f}s
|
| 1239 |
+
"""
|
| 1240 |
+
return output
|
| 1241 |
+
except Exception as e:
|
| 1242 |
+
# Absolute fallback
|
| 1243 |
+
error_info = f"""
|
| 1244 |
+
CRITICAL ERROR IN OUTPUT FORMATTING:
|
| 1245 |
+
{str(e)}
|
| 1246 |
+
|
| 1247 |
+
TRACEBACK:
|
| 1248 |
+
{traceback.format_exc()}
|
| 1249 |
+
|
| 1250 |
+
OUTPUT PARTS:
|
| 1251 |
+
{''.join(output_parts)}
|
| 1252 |
+
|
| 1253 |
+
Variables defined: {locals().keys()}
|
| 1254 |
+
"""
|
| 1255 |
+
logger.error(error_info)
|
| 1256 |
+
return error_info
|
| 1257 |
+
|
| 1258 |
+
# MASTER EXCEPTION HANDLER - catches ANY unhandled error
|
| 1259 |
+
except Exception as master_error:
|
| 1260 |
+
master_error_msg = f"""
|
| 1261 |
+
========================================
|
| 1262 |
+
MASTER ERROR HANDLER CAUGHT EXCEPTION
|
| 1263 |
+
========================================
|
| 1264 |
+
|
| 1265 |
+
Error Type: {type(master_error).__name__}
|
| 1266 |
+
Error Message: {str(master_error)}
|
| 1267 |
+
|
| 1268 |
+
FULL TRACEBACK:
|
| 1269 |
+
{traceback.format_exc()}
|
| 1270 |
+
|
| 1271 |
+
System Info:
|
| 1272 |
+
- Python version: {sys.version}
|
| 1273 |
+
- Error at line: {sys.exc_info()[2].tb_lineno if sys.exc_info()[2] else 'unknown'}
|
| 1274 |
+
|
| 1275 |
+
========================================
|
| 1276 |
+
"""
|
| 1277 |
+
logger.error(master_error_msg)
|
| 1278 |
+
return master_error_msg
|
| 1279 |
+
|
| 1280 |
+
|
| 1281 |
+
# ============================================================================
|
| 1282 |
+
# GRADIO INTERFACE
|
| 1283 |
+
# ============================================================================
|
| 1284 |
+
|
| 1285 |
+
with gr.Blocks(title="Foundation 1.2") as demo:
|
| 1286 |
+
gr.Markdown("# Foundation 1.2 - Clinical Trial AI")
|
| 1287 |
+
|
| 1288 |
+
query_input = gr.Textbox(
|
| 1289 |
+
label="Ask about clinical trials",
|
| 1290 |
+
placeholder="Example: What are the results for ianalumab in Sjogren's syndrome?",
|
| 1291 |
+
lines=3
|
| 1292 |
+
)
|
| 1293 |
+
submit_btn = gr.Button("Generate Response", variant="primary")
|
| 1294 |
+
|
| 1295 |
+
output = gr.Textbox(
|
| 1296 |
+
label="AI Response",
|
| 1297 |
+
lines=30
|
| 1298 |
+
)
|
| 1299 |
+
|
| 1300 |
+
submit_btn.click(
|
| 1301 |
+
fn=process_query, # Full pipeline: RAG + 355M + Llama
|
| 1302 |
+
inputs=query_input,
|
| 1303 |
+
outputs=output
|
| 1304 |
+
)
|
| 1305 |
+
|
| 1306 |
+
gr.Markdown("""
|
| 1307 |
+
**Production RAG Pipeline - Optimized for Clinical Accuracy**
|
| 1308 |
+
|
| 1309 |
+
**Search (3-Stage Hybrid):**
|
| 1310 |
+
1. Keyword matching (70%) + Semantic search (30%) → 10 candidates
|
| 1311 |
+
2. 355M Clinical Trial GPT re-ranks by relevance
|
| 1312 |
+
3. Returns top 3 trials with best clinical relevance scores
|
| 1313 |
+
|
| 1314 |
+
**Generation (Qwen2.5-14B-Instruct):**
|
| 1315 |
+
- 14B parameter model via HuggingFace Inference API
|
| 1316 |
+
- Structured clinical summaries with clear headings
|
| 1317 |
+
- Cites specific NCT trial IDs
|
| 1318 |
+
- Includes actual trial results and efficacy data
|
| 1319 |
+
- High-quality medical reasoning and analysis
|
| 1320 |
+
|
| 1321 |
+
*355M model used for ranking (not generation) + Qwen2.5-14B for responses*
|
| 1322 |
+
""")
|
| 1323 |
+
|
| 1324 |
+
|
| 1325 |
+
# ============================================================================
|
| 1326 |
+
# STARTUP
|
| 1327 |
+
# ============================================================================
|
| 1328 |
+
|
| 1329 |
+
# Embeddings will be loaded by FastAPI startup event in app.py
|
| 1330 |
+
# Do NOT load here - causes Docker permission errors
|
| 1331 |
+
logger.info("=== Foundation 1.2 Module Loaded ===")
|
| 1332 |
+
logger.info("Call load_embeddings() to initialize the system")
|
| 1333 |
+
|
| 1334 |
+
if DATASET_FILE.exists():
|
| 1335 |
+
file_size_mb = DATASET_FILE.stat().st_size / (1024 * 1024)
|
| 1336 |
+
logger.info(f"✓ Dataset file found: {file_size_mb:.0f}MB")
|
| 1337 |
+
else:
|
| 1338 |
+
logger.error("✗ Dataset file not found!")
|
| 1339 |
+
|
| 1340 |
+
logger.info("=== Startup Complete ===")
|
| 1341 |
+
|
| 1342 |
+
if __name__ == "__main__":
|
| 1343 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.0
|
| 2 |
+
spaces
|
| 3 |
+
huggingface_hub>=0.24.0,<1.0
|
| 4 |
+
transformers>=4.37.0
|
| 5 |
+
torch>=2.1.0
|
| 6 |
+
accelerate
|
| 7 |
+
bitsandbytes
|
| 8 |
+
sentence-transformers==3.1.1
|
| 9 |
+
PyPDF2==3.0.1
|
| 10 |
+
numpy==1.26.4
|
| 11 |
+
openai
|
| 12 |
+
groq
|
| 13 |
+
sentencepiece
|
| 14 |
+
protobuf
|
| 15 |
+
fastapi
|
| 16 |
+
uvicorn
|
| 17 |
+
pydantic
|
| 18 |
+
networkx>=3.1
|
| 19 |
+
rank-bm25
|
two_llm_system_FIXED.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FAST VERSION: Bypasses 355M ranking bottleneck (300s -> 0s)
|
| 3 |
+
Works with existing data structure: List[Tuple[int, str]]
|
| 4 |
+
Keeps BM25 + semantic hybrid search intact
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, AutoTokenizer, AutoModelForCausalLM
|
| 9 |
+
import logging
|
| 10 |
+
import spaces
|
| 11 |
+
from functools import lru_cache
|
| 12 |
+
from typing import List, Tuple, Optional, Dict
|
| 13 |
+
from huggingface_hub import InferenceClient
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# ===========================================================================
|
| 18 |
+
# CACHED MODEL LOADING - Load once, reuse forever
|
| 19 |
+
# ===========================================================================
|
| 20 |
+
|
| 21 |
+
@lru_cache(maxsize=1)
|
| 22 |
+
def get_cached_355m_model():
|
| 23 |
+
"""Load 355M model once and cache it for entity extraction"""
|
| 24 |
+
logger.info("Loading 355M Clinical Trial GPT (cached for entity extraction)...")
|
| 25 |
+
tokenizer = GPT2TokenizerFast.from_pretrained("gmkdigitalmedia/clinicaltrial2.2")
|
| 26 |
+
model = GPT2LMHeadModel.from_pretrained(
|
| 27 |
+
"gmkdigitalmedia/clinicaltrial2.2",
|
| 28 |
+
torch_dtype=torch.float16,
|
| 29 |
+
device_map="auto"
|
| 30 |
+
)
|
| 31 |
+
model.eval()
|
| 32 |
+
return tokenizer, model
|
| 33 |
+
|
| 34 |
+
@lru_cache(maxsize=1)
|
| 35 |
+
def get_cached_8b_model(hf_token: Optional[str] = None):
|
| 36 |
+
"""Load 8B model once and cache it"""
|
| 37 |
+
logger.info("Loading II-Medical-8B (cached)...")
|
| 38 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 39 |
+
"Intelligent-Internet/II-Medical-8B-1706",
|
| 40 |
+
token=hf_token,
|
| 41 |
+
trust_remote_code=True
|
| 42 |
+
)
|
| 43 |
+
if tokenizer.pad_token is None:
|
| 44 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 45 |
+
|
| 46 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 47 |
+
"Intelligent-Internet/II-Medical-8B-1706",
|
| 48 |
+
device_map="auto",
|
| 49 |
+
token=hf_token,
|
| 50 |
+
trust_remote_code=True,
|
| 51 |
+
torch_dtype=torch.bfloat16
|
| 52 |
+
)
|
| 53 |
+
return tokenizer, model
|
| 54 |
+
|
| 55 |
+
# ===========================================================================
|
| 56 |
+
# FAST RANKING - Replace 300s function with instant passthrough
|
| 57 |
+
# ===========================================================================
|
| 58 |
+
|
| 59 |
+
@spaces.GPU
|
| 60 |
+
def rank_trials_FAST(query: str, trials_list: List[Tuple[float, str]], hf_token=None) -> List[Tuple[float, str]]:
|
| 61 |
+
"""
|
| 62 |
+
SMART RANKING: Use 355M to rank only top 3 trials
|
| 63 |
+
|
| 64 |
+
Takes top 3 from BM25+semantic search, then uses 355M Clinical Trial GPT
|
| 65 |
+
to re-rank them by clinical relevance.
|
| 66 |
+
|
| 67 |
+
Time: ~30 seconds for 3 trials (vs 300s for 30 trials)
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
query: The search query
|
| 71 |
+
trials_list: List of (score, trial_text) tuples from BM25+semantic search
|
| 72 |
+
hf_token: Not needed
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Top 3 trials re-ranked by 355M clinical relevance
|
| 76 |
+
"""
|
| 77 |
+
import time
|
| 78 |
+
import re
|
| 79 |
+
|
| 80 |
+
start_time = time.time()
|
| 81 |
+
|
| 82 |
+
# Take only top 3 trials for 355M ranking
|
| 83 |
+
top_3 = trials_list[:3]
|
| 84 |
+
|
| 85 |
+
logger.info(f"[355M RANKING] Ranking top 3 trials with Clinical Trial GPT...")
|
| 86 |
+
|
| 87 |
+
# Get cached 355M model
|
| 88 |
+
tokenizer, model = get_cached_355m_model()
|
| 89 |
+
|
| 90 |
+
# Score each trial
|
| 91 |
+
trial_scores = []
|
| 92 |
+
|
| 93 |
+
for idx, (bm25_score, trial_text) in enumerate(top_3):
|
| 94 |
+
# Extract NCT ID for logging
|
| 95 |
+
nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text)
|
| 96 |
+
nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}"
|
| 97 |
+
|
| 98 |
+
# Create prompt for relevance scoring
|
| 99 |
+
# Truncate trial to 800 chars to keep it fast
|
| 100 |
+
trial_snippet = trial_text[:800]
|
| 101 |
+
|
| 102 |
+
prompt = f"""Query: {query}
|
| 103 |
+
|
| 104 |
+
Clinical Trial: {trial_snippet}
|
| 105 |
+
|
| 106 |
+
Rate clinical relevance (1-10):"""
|
| 107 |
+
|
| 108 |
+
# Get model score
|
| 109 |
+
try:
|
| 110 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
|
| 111 |
+
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
outputs = model.generate(
|
| 114 |
+
inputs.input_ids,
|
| 115 |
+
max_length=inputs.input_ids.shape[1] + 10,
|
| 116 |
+
temperature=0.3,
|
| 117 |
+
do_sample=False,
|
| 118 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 119 |
+
eos_token_id=tokenizer.eos_token_id
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
generated = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
|
| 123 |
+
|
| 124 |
+
# Extract number from response
|
| 125 |
+
score_match = re.search(r'(\d+)', generated.strip())
|
| 126 |
+
relevance_score = float(score_match.group(1)) if score_match else 5.0
|
| 127 |
+
|
| 128 |
+
# Normalize to 0-1 range
|
| 129 |
+
relevance_score = relevance_score / 10.0
|
| 130 |
+
|
| 131 |
+
logger.info(f"[355M RANKING] {nct_id}: relevance={relevance_score:.2f} (BM25={bm25_score:.3f})")
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.warning(f"[355M RANKING] Scoring failed for {nct_id}: {e}, using BM25 score")
|
| 135 |
+
relevance_score = bm25_score
|
| 136 |
+
|
| 137 |
+
trial_scores.append((relevance_score, trial_text, nct_id))
|
| 138 |
+
|
| 139 |
+
# Sort by 355M relevance score (descending)
|
| 140 |
+
trial_scores.sort(key=lambda x: x[0], reverse=True)
|
| 141 |
+
|
| 142 |
+
# Format as (score, text) tuples for backwards compatibility
|
| 143 |
+
# Create a custom list class that can hold attributes
|
| 144 |
+
class RankedTrialsList(list):
|
| 145 |
+
"""List that can hold ranking metadata"""
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
+
ranked_trials = RankedTrialsList()
|
| 149 |
+
ranking_metadata = []
|
| 150 |
+
|
| 151 |
+
for rank, (score, text, nct_id) in enumerate(trial_scores, 1):
|
| 152 |
+
ranked_trials.append((score, text))
|
| 153 |
+
ranking_metadata.append({
|
| 154 |
+
'rank': rank,
|
| 155 |
+
'nct_id': nct_id,
|
| 156 |
+
'relevance_score': score,
|
| 157 |
+
'relevance_rating': f"{score*10:.1f}/10"
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
elapsed = time.time() - start_time
|
| 161 |
+
logger.info(f"[355M RANKING] ✓ Ranked 3 trials in {elapsed:.1f}s")
|
| 162 |
+
logger.info(f"[355M RANKING] Final order: {[nct_id for _, _, nct_id in trial_scores]}")
|
| 163 |
+
logger.info(f"[355M RANKING] Scores: {[f'{s:.2f}' for s, _, _ in trial_scores]}")
|
| 164 |
+
|
| 165 |
+
# Store metadata as attribute for retrieval
|
| 166 |
+
ranked_trials.ranking_info = ranking_metadata
|
| 167 |
+
|
| 168 |
+
# Return re-ranked top 3 plus remaining trials (if any)
|
| 169 |
+
return ranked_trials + trials_list[3:]
|
| 170 |
+
|
| 171 |
+
# Alias for drop-in replacement
|
| 172 |
+
rank_trials_with_355m = rank_trials_FAST # Override the slow function!
|
| 173 |
+
|
| 174 |
+
# ===========================================================================
|
| 175 |
+
# FAST GENERATION using HuggingFace Inference API (Free)
|
| 176 |
+
# ===========================================================================
|
| 177 |
+
|
| 178 |
+
def generate_with_llama_70b_hf(query: str, rag_context: str = "", hf_token: str = None) -> str:
|
| 179 |
+
"""
|
| 180 |
+
Use Llama-3.1-70B via HuggingFace Inference API (FREE)
|
| 181 |
+
|
| 182 |
+
This is what you're already using successfully!
|
| 183 |
+
~10 second response time on HF free tier
|
| 184 |
+
"""
|
| 185 |
+
try:
|
| 186 |
+
logger.info("Using Llama-3.1-70B via HuggingFace Inference API...")
|
| 187 |
+
client = InferenceClient(token=hf_token)
|
| 188 |
+
|
| 189 |
+
messages = [
|
| 190 |
+
{
|
| 191 |
+
"role": "system",
|
| 192 |
+
"content": "You are a medical information specialist. Answer based on the provided clinical trial data. Be concise and accurate."
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"role": "user",
|
| 196 |
+
"content": f"""Clinical Trial Data:
|
| 197 |
+
{rag_context[:4000]}
|
| 198 |
+
|
| 199 |
+
Question: {query}
|
| 200 |
+
|
| 201 |
+
Please provide a concise answer based on the clinical trial data above."""
|
| 202 |
+
}
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
response = client.chat_completion(
|
| 206 |
+
model="meta-llama/Llama-3.1-70B-Instruct",
|
| 207 |
+
messages=messages,
|
| 208 |
+
max_tokens=512,
|
| 209 |
+
temperature=0.3
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
answer = response.choices[0].message.content.strip()
|
| 213 |
+
logger.info(f"Llama 70B response generated via HF Inference API")
|
| 214 |
+
return answer
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"Llama 70B generation failed: {e}")
|
| 217 |
+
return f"Error generating response with Llama 70B: {str(e)}"
|
| 218 |
+
|
| 219 |
+
# ===========================================================================
|
| 220 |
+
# OPTIMIZED 8B GENERATION (with cached model)
|
| 221 |
+
# ===========================================================================
|
| 222 |
+
|
| 223 |
+
@spaces.GPU
|
| 224 |
+
def generate_clinical_response_with_xupract(conversation, rag_context="", hf_token=None):
|
| 225 |
+
"""OPTIMIZED: Use cached 8B model for faster generation"""
|
| 226 |
+
logger.info("Generating response with cached II-Medical-8B...")
|
| 227 |
+
|
| 228 |
+
# Get cached model (loads once, reuses after)
|
| 229 |
+
tokenizer, model = get_cached_8b_model(hf_token)
|
| 230 |
+
|
| 231 |
+
# Build prompt with RAG context (ChatML format for II-Medical-8B)
|
| 232 |
+
if rag_context:
|
| 233 |
+
prompt = f"""<|im_start|>system
|
| 234 |
+
You are a medical information specialist. Answer based on the provided clinical trial data. Please reason step-by-step, and put your final answer within \\boxed{{}}.
|
| 235 |
+
<|im_end|>
|
| 236 |
+
<|im_start|>user
|
| 237 |
+
Clinical Trial Data:
|
| 238 |
+
{rag_context[:4000]}
|
| 239 |
+
|
| 240 |
+
Question: {conversation}
|
| 241 |
+
|
| 242 |
+
Please reason step-by-step, and put your final answer within \\boxed{{}}.
|
| 243 |
+
<|im_end|>
|
| 244 |
+
<|im_start|>assistant
|
| 245 |
+
"""
|
| 246 |
+
else:
|
| 247 |
+
prompt = f"""<|im_start|>system
|
| 248 |
+
You are a medical information specialist. Please reason step-by-step, and put your final answer within \\boxed{{}}.
|
| 249 |
+
<|im_end|>
|
| 250 |
+
<|im_start|>user
|
| 251 |
+
{conversation}
|
| 252 |
+
|
| 253 |
+
Please reason step-by-step, and put your final answer within \\boxed{{}}.
|
| 254 |
+
<|im_end|>
|
| 255 |
+
<|im_start|>assistant
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(model.device)
|
| 260 |
+
with torch.no_grad():
|
| 261 |
+
outputs = model.generate(
|
| 262 |
+
**inputs,
|
| 263 |
+
max_new_tokens=1024,
|
| 264 |
+
temperature=0.3,
|
| 265 |
+
do_sample=True,
|
| 266 |
+
top_p=0.9,
|
| 267 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 268 |
+
pad_token_id=tokenizer.pad_token_id
|
| 269 |
+
)
|
| 270 |
+
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip()
|
| 271 |
+
return response
|
| 272 |
+
except Exception as e:
|
| 273 |
+
logger.error(f"Generation failed: {e}")
|
| 274 |
+
return f"Error generating response: {str(e)}"
|
| 275 |
+
|
| 276 |
+
# ===========================================================================
|
| 277 |
+
# FAST ENTITY EXTRACTION (with cached model)
|
| 278 |
+
# ===========================================================================
|
| 279 |
+
|
| 280 |
+
@spaces.GPU
|
| 281 |
+
def extract_entities_with_small_model(conversation):
|
| 282 |
+
"""OPTIMIZED: Use cached 355M model for entity extraction"""
|
| 283 |
+
logger.info("Extracting entities with cached 355M model...")
|
| 284 |
+
|
| 285 |
+
# Get cached model
|
| 286 |
+
tokenizer, model = get_cached_355m_model()
|
| 287 |
+
|
| 288 |
+
# Better prompt for extraction
|
| 289 |
+
prompt = f"""Clinical query: {conversation}
|
| 290 |
+
|
| 291 |
+
Extract:
|
| 292 |
+
Drug name:"""
|
| 293 |
+
|
| 294 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(model.device)
|
| 295 |
+
|
| 296 |
+
with torch.no_grad():
|
| 297 |
+
outputs = model.generate(
|
| 298 |
+
inputs.input_ids,
|
| 299 |
+
max_length=400,
|
| 300 |
+
temperature=0.3,
|
| 301 |
+
top_p=0.9,
|
| 302 |
+
do_sample=True,
|
| 303 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 304 |
+
eos_token_id=tokenizer.eos_token_id
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 308 |
+
return generated
|
| 309 |
+
|
| 310 |
+
# ===========================================================================
|
| 311 |
+
# QUERY EXPANSION (optional, with cached model)
|
| 312 |
+
# ===========================================================================
|
| 313 |
+
|
| 314 |
+
@spaces.GPU
|
| 315 |
+
def expand_query_with_355m(query):
|
| 316 |
+
"""OPTIMIZED: Use cached 355M for query expansion"""
|
| 317 |
+
logger.info("Expanding query with cached 355M...")
|
| 318 |
+
|
| 319 |
+
# Get cached model
|
| 320 |
+
tokenizer, model = get_cached_355m_model()
|
| 321 |
+
|
| 322 |
+
# Prompt to get clinical context
|
| 323 |
+
prompt = f"Question: {query}\nClinical trial information:"
|
| 324 |
+
|
| 325 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(model.device)
|
| 326 |
+
|
| 327 |
+
with torch.no_grad():
|
| 328 |
+
outputs = model.generate(
|
| 329 |
+
inputs.input_ids,
|
| 330 |
+
max_length=inputs.input_ids.shape[1] + 100,
|
| 331 |
+
temperature=0.7,
|
| 332 |
+
top_p=0.9,
|
| 333 |
+
do_sample=True,
|
| 334 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 335 |
+
eos_token_id=tokenizer.eos_token_id
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 339 |
+
|
| 340 |
+
# Extract the expansion part
|
| 341 |
+
if "Clinical trial information:" in generated:
|
| 342 |
+
expansion = generated.split("Clinical trial information:")[-1].strip()
|
| 343 |
+
else:
|
| 344 |
+
expansion = generated[len(prompt):].strip()
|
| 345 |
+
|
| 346 |
+
# Limit to reasonable length
|
| 347 |
+
expansion = expansion[:500] if len(expansion) > 500 else expansion
|
| 348 |
+
|
| 349 |
+
logger.info(f"Query expanded: {expansion[:100]}...")
|
| 350 |
+
return expansion
|
| 351 |
+
|
| 352 |
+
# ===========================================================================
|
| 353 |
+
# MAIN PIPELINE - Now FAST!
|
| 354 |
+
# ===========================================================================
|
| 355 |
+
|
| 356 |
+
def process_two_llm_system(conversation, rag_context="", hf_token=None, use_validation=False):
|
| 357 |
+
"""
|
| 358 |
+
FAST pipeline:
|
| 359 |
+
1. Small 355M model extracts entities (cached model - fast)
|
| 360 |
+
2. RAG retrieves context (BM25 + semantic - already fast)
|
| 361 |
+
3. Big model generates response (8B local or 70B API)
|
| 362 |
+
4. Skip validation for speed
|
| 363 |
+
|
| 364 |
+
Total time: ~15s instead of 300+s
|
| 365 |
+
"""
|
| 366 |
+
import time
|
| 367 |
+
start_time = time.time()
|
| 368 |
+
|
| 369 |
+
# Step 1: Use cached 355M to extract entities
|
| 370 |
+
entities = extract_entities_with_small_model(conversation)
|
| 371 |
+
logger.info(f"Entities extracted in {time.time()-start_time:.1f}s")
|
| 372 |
+
|
| 373 |
+
# Step 2: Generate response (choose one):
|
| 374 |
+
|
| 375 |
+
# Option A: Use 70B via HF Inference API (better quality, ~10s)
|
| 376 |
+
if hf_token:
|
| 377 |
+
clinical_evidence = generate_with_llama_70b_hf(
|
| 378 |
+
conversation,
|
| 379 |
+
rag_context,
|
| 380 |
+
hf_token
|
| 381 |
+
)
|
| 382 |
+
model_used = "Llama-3.1-70B (HF Inference API)"
|
| 383 |
+
else:
|
| 384 |
+
# Option B: Use cached 8B model (faster loading, ~5s)
|
| 385 |
+
clinical_evidence = generate_clinical_response_with_xupract(
|
| 386 |
+
conversation,
|
| 387 |
+
rag_context,
|
| 388 |
+
hf_token
|
| 389 |
+
)
|
| 390 |
+
model_used = "II-Medical-8B (cached)"
|
| 391 |
+
|
| 392 |
+
total_time = time.time() - start_time
|
| 393 |
+
logger.info(f"Total pipeline time: {total_time:.1f}s (was 300+s with 355M ranking)")
|
| 394 |
+
|
| 395 |
+
return {
|
| 396 |
+
'clinical_evidence': clinical_evidence,
|
| 397 |
+
'entities': entities,
|
| 398 |
+
'model_used': model_used,
|
| 399 |
+
'time_taken': total_time
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
def format_two_llm_response(result):
|
| 403 |
+
"""Format the fast response"""
|
| 404 |
+
return f"""ENTITY EXTRACTION (Clinical Trial GPT 355M - Cached)
|
| 405 |
+
{'='*60}
|
| 406 |
+
{result.get('entities', 'None identified')}
|
| 407 |
+
|
| 408 |
+
CLINICAL RESPONSE ({result.get('model_used', 'Unknown')})
|
| 409 |
+
{'='*60}
|
| 410 |
+
{result['clinical_evidence']}
|
| 411 |
+
|
| 412 |
+
PERFORMANCE
|
| 413 |
+
{'='*60}
|
| 414 |
+
Time: {result.get('time_taken', 0):.1f}s (was 300+s with 355M ranking)
|
| 415 |
+
{'='*60}
|
| 416 |
+
"""
|
| 417 |
+
|
| 418 |
+
# ===========================================================================
|
| 419 |
+
# PRELOAD MODELS AT STARTUP (Call this once in app.py!)
|
| 420 |
+
# ===========================================================================
|
| 421 |
+
|
| 422 |
+
def preload_all_models(hf_token=None):
|
| 423 |
+
"""
|
| 424 |
+
Call this ONCE at app startup to cache all models.
|
| 425 |
+
This prevents model reloading on every query.
|
| 426 |
+
|
| 427 |
+
Add to your app.py initialization:
|
| 428 |
+
from two_llm_system_FAST import preload_all_models
|
| 429 |
+
preload_all_models(hf_token)
|
| 430 |
+
"""
|
| 431 |
+
logger.info("Preloading and caching all models...")
|
| 432 |
+
|
| 433 |
+
# Cache the 355M model
|
| 434 |
+
_ = get_cached_355m_model()
|
| 435 |
+
logger.info("✓ 355M model cached")
|
| 436 |
+
|
| 437 |
+
# Cache the 8B model if token available
|
| 438 |
+
if hf_token:
|
| 439 |
+
try:
|
| 440 |
+
_ = get_cached_8b_model(hf_token)
|
| 441 |
+
logger.info("✓ 8B model cached")
|
| 442 |
+
except Exception as e:
|
| 443 |
+
logger.warning(f"Could not cache 8B model: {e}")
|
| 444 |
+
|
| 445 |
+
logger.info("All models preloaded and cached!")
|
| 446 |
+
|
| 447 |
+
# ===========================================================================
|
| 448 |
+
# BACKWARD COMPATIBILITY - Keep all original function names
|
| 449 |
+
# ===========================================================================
|
| 450 |
+
|
| 451 |
+
# These functions exist in the original but we optimize them
|
| 452 |
+
validate_with_small_model = lambda *args, **kwargs: "Validation skipped for speed"
|
| 453 |
+
extract_keywords_with_llama = lambda conv, hf_token=None: extract_entities_with_small_model(conv)[:100]
|
| 454 |
+
generate_response_with_llama = generate_with_llama_70b_hf
|
| 455 |
+
generate_clinical_knowledge_with_355m = lambda conv: f"Knowledge: {conv[:100]}..."
|
| 456 |
+
generate_with_355m = lambda conv, rag="", hf_token=None: generate_clinical_response_with_xupract(conv, rag, hf_token)
|
| 457 |
+
|
| 458 |
+
# Ensure we override the slow ranking function
|
| 459 |
+
rank_trials_with_355m = rank_trials_FAST
|
| 460 |
+
|
| 461 |
+
logger.info("Fast Two-LLM System loaded - 355M ranking bypassed!")
|