Your Name Claude commited on
Commit
d78f02a
·
1 Parent(s): d5f8324

Clone api2 for experimentation

Browse files

Experiment space for testing changes before production:
- FastAPI REST endpoint
- LLM query parser (Llama 70B)
- Targeted indexing (drugs, diseases, companies, endpoints)
- Groq API support (10x faster)
- Hybrid RAG search
- Docker deployment

Clone of api2 with all latest optimizations.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (6) hide show
  1. Dockerfile +22 -0
  2. README.md +31 -4
  3. app.py +123 -0
  4. foundation_engine.py +1343 -0
  5. requirements.txt +19 -0
  6. two_llm_system_FIXED.py +461 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ git \
8
+ git-lfs \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements and install Python dependencies
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy application files
16
+ COPY app.py foundation_engine.py two_llm_system_FIXED.py ./
17
+
18
+ # Expose port
19
+ EXPOSE 7860
20
+
21
+ # Run FastAPI app
22
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,38 @@
1
  ---
2
- title: Ctapi
3
- emoji: 🏃
4
- colorFrom: gray
5
  colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Clinical Trial API (Experiment)
3
+ emoji: 🧪
4
+ colorFrom: purple
5
  colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
  ---
10
 
11
+ # Clinical Trial API - Experiment Space
12
+
13
+ **Clone of api2 for testing changes before production**
14
+
15
+ ## API Endpoint
16
+
17
+ ```bash
18
+ POST /query
19
+ {
20
+ "query": "What Novartis drugs treat melanoma?"
21
+ }
22
+ ```
23
+
24
+ ## Features
25
+
26
+ - LLM query parser (extracts drugs, diseases, companies, endpoints)
27
+ - Targeted inverted index (drugs, diseases, companies, endpoints)
28
+ - Hybrid RAG search (keyword + semantic)
29
+ - Groq API support (10x faster than HuggingFace)
30
+ - ~7 second response time
31
+
32
+ ## Test It
33
+
34
+ ```bash
35
+ curl -X POST https://gmkdigitalmedia-ctapi.hf.space/query \
36
+ -H "Content-Type: application/json" \
37
+ -d '{"query": "Tell me about Keytruda for lung cancer"}'
38
+ ```
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI REST API for Foundation 1.2 Clinical Trial System
3
+ Production-ready Docker space with proper REST endpoints
4
+ """
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
+ import time
10
+ import logging
11
+
12
+ # Import the foundation engine
13
+ import foundation_engine
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ app = FastAPI(
19
+ title="Clinical Trial API",
20
+ description="Production REST API for clinical trial analysis powered by Foundation 1.2 pipeline",
21
+ version="1.0.0",
22
+ docs_url="/docs",
23
+ redoc_url="/redoc"
24
+ )
25
+
26
+ # Add CORS middleware
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ # Request/Response models
36
+ class QueryRequest(BaseModel):
37
+ query: str
38
+
39
+ class QueryResponse(BaseModel):
40
+ summary: str
41
+ processing_time: float
42
+
43
+ class HealthResponse(BaseModel):
44
+ status: str
45
+ trials_loaded: int
46
+ embeddings_loaded: bool
47
+
48
+ @app.on_event("startup")
49
+ async def startup_event():
50
+ """Initialize the foundation engine on startup"""
51
+ logger.info("=== API Startup ===")
52
+ logger.info("Loading Foundation 1.2 engine...")
53
+ # The foundation_engine will load embeddings when first accessed
54
+ foundation_engine.load_embeddings()
55
+ logger.info("=== API Ready ===")
56
+
57
+ @app.get("/")
58
+ async def root():
59
+ """API information"""
60
+ return {
61
+ "service": "Clinical Trial API",
62
+ "version": "1.0.0",
63
+ "description": "Production REST API for Foundation 1.2",
64
+ "status": "healthy",
65
+ "endpoints": {
66
+ "POST /query": "Query clinical trials and get AI-generated summary",
67
+ "GET /health": "Health check",
68
+ "GET /docs": "Interactive API documentation (Swagger UI)",
69
+ "GET /redoc": "Alternative API documentation (ReDoc)"
70
+ },
71
+ "features": [
72
+ "Drug Scoring",
73
+ "355M foundation model"
74
+ ]
75
+ }
76
+
77
+ @app.get("/health", response_model=HealthResponse)
78
+ async def health_check():
79
+ """Health check endpoint"""
80
+ embeddings_loaded = foundation_engine.doc_embeddings is not None
81
+ chunks_loaded = len(foundation_engine.doc_chunks) if foundation_engine.doc_chunks else 0
82
+
83
+ return HealthResponse(
84
+ status="healthy",
85
+ trials_loaded=chunks_loaded,
86
+ embeddings_loaded=embeddings_loaded
87
+ )
88
+
89
+ @app.post("/query", response_model=QueryResponse)
90
+ async def query_trials(request: QueryRequest):
91
+ """
92
+ Query clinical trials and get AI-generated summary
93
+
94
+ - **query**: Your question about clinical trials (e.g., "What trials exist for Dekavil?")
95
+
96
+ Returns a structured medical analysis with:
97
+ - Drug/Intervention background
98
+ - Clinical trial results and data
99
+ - Treatment considerations
100
+ - NCT trial IDs and references
101
+ """
102
+ try:
103
+ logger.info(f"API Query received: {request.query[:100]}...")
104
+ start_time = time.time()
105
+
106
+ # Call the foundation engine
107
+ result = foundation_engine.process_query(request.query)
108
+
109
+ processing_time = time.time() - start_time
110
+ logger.info(f"Query completed in {processing_time:.2f}s")
111
+
112
+ return QueryResponse(
113
+ summary=result,
114
+ processing_time=processing_time
115
+ )
116
+
117
+ except Exception as e:
118
+ logger.error(f"Error processing query: {str(e)}")
119
+ raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
120
+
121
+ if __name__ == "__main__":
122
+ import uvicorn
123
+ uvicorn.run(app, host="0.0.0.0", port=7860)
foundation_engine.py ADDED
@@ -0,0 +1,1343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Foundation 1.2
3
+ Clinical trial query system with 355M foundation model
4
+ """
5
+
6
+ import gradio as gr
7
+ import os
8
+ from pathlib import Path
9
+ import pickle
10
+ import numpy as np
11
+ from sentence_transformers import SentenceTransformer
12
+ import logging
13
+ from rank_bm25 import BM25Okapi
14
+
15
+ import re
16
+ from two_llm_system_FIXED import expand_query_with_355m, generate_clinical_response_with_xupract, rank_trials_with_355m
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Initialize
22
+ hf_token = os.getenv("HF_TOKEN")
23
+
24
+ # Paths for data storage
25
+ # Files will be downloaded from HF Dataset on first run
26
+ DATASET_FILE = Path(__file__).parent / "complete_dataset_WITH_RESULTS_FULL.txt"
27
+ CHUNKS_FILE = Path(__file__).parent / "dataset_chunks_TRIAL_AWARE.pkl"
28
+ EMBEDDINGS_FILE = Path(__file__).parent / "dataset_embeddings_TRIAL_AWARE_FIXED.npy" # FIXED version to avoid cache
29
+ INVERTED_INDEX_FILE = Path(__file__).parent / "inverted_index_TRIAL_AWARE.pkl" # Pre-built inverted index (638MB)
30
+
31
+ # HF Dataset containing the large files
32
+ DATASET_REPO = "gmkdigitalmedia/foundation1.2-data"
33
+
34
+ # Global storage
35
+ embedder = None
36
+ doc_chunks = []
37
+ doc_embeddings = None
38
+ bm25_index = None # BM25 index for fast keyword search
39
+ inverted_index = None # Inverted index for instant drug lookup
40
+
41
+ # ============================================================================
42
+ # RAG FUNCTIONS
43
+ # ============================================================================
44
+
45
+ def load_embedder():
46
+ """Load L6 embedding model (matches generated embeddings)"""
47
+ global embedder
48
+ if embedder is None:
49
+ logger.info("Loading MiniLM-L6 embedding model...")
50
+ # Force CPU to avoid CUDA init in main process
51
+ embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
52
+ logger.info("L6 model loaded on CPU")
53
+
54
+ def build_inverted_index(chunks):
55
+ """
56
+ Build targeted inverted index for clinical search
57
+ Maps drugs, diseases, companies, and endpoints to trial indices for O(1) lookup
58
+
59
+ Indexes ONLY what matters:
60
+ 1. INTERVENTION - drug/device names
61
+ 2. CONDITIONS - diseases being treated
62
+ 3. SPONSOR/COLLABORATOR/MANUFACTURER - company names
63
+ 4. OUTCOME - trial endpoints (what's being measured)
64
+
65
+ Does NOT index trial names (unnecessary noise)
66
+ """
67
+ import time
68
+ t_start = time.time()
69
+ inv_index = {}
70
+
71
+ logger.info("Building targeted index: drugs, diseases, companies, endpoints...")
72
+
73
+ # Generic words to skip
74
+ skip_words = {
75
+ 'with', 'versus', 'combination', 'treatment', 'therapy', 'study', 'trial',
76
+ 'phase', 'double', 'blind', 'placebo', 'group', 'control', 'active',
77
+ 'randomized', 'multicenter', 'open', 'label', 'crossover'
78
+ }
79
+
80
+ for idx, chunk_data in enumerate(chunks):
81
+ if idx % 100000 == 0 and idx > 0:
82
+ logger.info(f" Indexed {idx:,}/{len(chunks):,} trials...")
83
+
84
+ text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data
85
+ text_lower = text.lower()
86
+
87
+ # 1. DRUGS from INTERVENTION field
88
+ intervention_match = re.search(r'intervention[:\s]+([^\n]+)', text_lower)
89
+ if intervention_match:
90
+ intervention_text = intervention_match.group(1)
91
+ drugs = re.split(r'[,;\-\s]+', intervention_text)
92
+ for drug in drugs:
93
+ drug = drug.strip('.,;:() ')
94
+ if len(drug) > 3 and drug not in skip_words:
95
+ if drug not in inv_index:
96
+ inv_index[drug] = []
97
+ if idx not in inv_index[drug]:
98
+ inv_index[drug].append(idx)
99
+
100
+ # 2. DISEASES from CONDITIONS field
101
+ conditions_match = re.search(r'conditions?[:\s]+([^\n]+)', text_lower)
102
+ if conditions_match:
103
+ conditions_text = conditions_match.group(1)
104
+ diseases = re.split(r'[,;\|]+', conditions_text)
105
+ for disease in diseases:
106
+ disease = disease.strip('.,;:() ')
107
+ # Split multi-word conditions and index each significant word
108
+ disease_words = re.findall(r'\b\w{4,}\b', disease)
109
+ for word in disease_words:
110
+ if word not in skip_words:
111
+ if word not in inv_index:
112
+ inv_index[word] = []
113
+ if idx not in inv_index[word]:
114
+ inv_index[word].append(idx)
115
+
116
+ # 3. COMPANIES from SPONSOR field
117
+ sponsor_match = re.search(r'sponsor[:\s]+([^\n]+)', text_lower)
118
+ if sponsor_match:
119
+ sponsor_text = sponsor_match.group(1)
120
+ sponsors = re.split(r'[,;\|]+', sponsor_text)
121
+ for sponsor in sponsors:
122
+ sponsor = sponsor.strip('.,;:() ')
123
+ if len(sponsor) > 3:
124
+ if sponsor not in inv_index:
125
+ inv_index[sponsor] = []
126
+ if idx not in inv_index[sponsor]:
127
+ inv_index[sponsor].append(idx)
128
+
129
+ # 4. COMPANIES from COLLABORATOR field
130
+ collab_match = re.search(r'collaborator[:\s]+([^\n]+)', text_lower)
131
+ if collab_match:
132
+ collab_text = collab_match.group(1)
133
+ collaborators = re.split(r'[,;\|]+', collab_text)
134
+ for collab in collaborators:
135
+ collab = collab.strip('.,;:() ')
136
+ if len(collab) > 3:
137
+ if collab not in inv_index:
138
+ inv_index[collab] = []
139
+ if idx not in inv_index[collab]:
140
+ inv_index[collab].append(idx)
141
+
142
+ # 5. COMPANIES from MANUFACTURER field
143
+ manuf_match = re.search(r'manufacturer[:\s]+([^\n]+)', text_lower)
144
+ if manuf_match:
145
+ manuf_text = manuf_match.group(1)
146
+ manufacturers = re.split(r'[,;\|]+', manuf_text)
147
+ for manuf in manufacturers:
148
+ manuf = manuf.strip('.,;:() ')
149
+ if len(manuf) > 3:
150
+ if manuf not in inv_index:
151
+ inv_index[manuf] = []
152
+ if idx not in inv_index[manuf]:
153
+ inv_index[manuf].append(idx)
154
+
155
+ # 6. ENDPOINTS from OUTCOME fields
156
+ # Look for outcome measures (what's being measured)
157
+ outcome_matches = re.findall(r'outcome[:\s]+([^\n]+)', text_lower)
158
+ for outcome_match in outcome_matches[:5]: # First 5 outcomes only
159
+ # Extract meaningful endpoint terms
160
+ endpoint_words = re.findall(r'\b\w{5,}\b', outcome_match) # 5+ char words
161
+ for word in endpoint_words[:3]: # First 3 words per outcome
162
+ if word not in skip_words and word not in {'outcome', 'measure', 'primary', 'secondary'}:
163
+ if word not in inv_index:
164
+ inv_index[word] = []
165
+ if idx not in inv_index[word]:
166
+ inv_index[word].append(idx)
167
+
168
+ t_elapsed = time.time() - t_start
169
+ logger.info(f"✓ Targeted index built in {t_elapsed:.1f}s with {len(inv_index):,} terms")
170
+
171
+ # Log sample entries for debugging (drugs, diseases, companies, endpoints)
172
+ sample_terms = {
173
+ 'drugs': ['keytruda', 'opdivo', 'humira'],
174
+ 'diseases': ['cancer', 'diabetes', 'melanoma'],
175
+ 'companies': ['novartis', 'pfizer', 'merck'],
176
+ 'endpoints': ['survival', 'response', 'remission']
177
+ }
178
+
179
+ for category, terms in sample_terms.items():
180
+ logger.info(f" {category.upper()} samples:")
181
+ for term in terms:
182
+ if term in inv_index:
183
+ logger.info(f" '{term}' -> {len(inv_index[term])} trials")
184
+
185
+ return inv_index
186
+
187
+ def download_from_dataset(filename):
188
+ """Download file from HF Dataset if not present locally"""
189
+ from huggingface_hub import hf_hub_download
190
+ import tempfile
191
+
192
+ # Use /tmp for downloads (has write permissions in Docker)
193
+ download_dir = Path("/tmp/foundation_data")
194
+ download_dir.mkdir(exist_ok=True)
195
+
196
+ local_file = download_dir / filename
197
+
198
+ if local_file.exists():
199
+ logger.info(f"Found cached {filename}")
200
+ return local_file
201
+
202
+ try:
203
+ logger.info(f"Downloading {filename} from {DATASET_REPO}...")
204
+ downloaded_file = hf_hub_download(
205
+ repo_id=DATASET_REPO,
206
+ filename=filename,
207
+ repo_type="dataset",
208
+ local_dir=download_dir,
209
+ local_dir_use_symlinks=False
210
+ )
211
+ logger.info(f"Downloaded {filename}")
212
+ return Path(downloaded_file)
213
+ except Exception as e:
214
+ logger.error(f"Failed to download {filename}: {e}")
215
+ return None
216
+
217
+ def load_embeddings():
218
+ """Load pre-generated embeddings (download from dataset if needed)"""
219
+ global doc_chunks, doc_embeddings, bm25_index
220
+
221
+ # Try to download if not present - store paths returned by download
222
+ chunks_path = CHUNKS_FILE
223
+ embeddings_path = EMBEDDINGS_FILE
224
+ dataset_path = DATASET_FILE
225
+
226
+ if not CHUNKS_FILE.exists():
227
+ downloaded = download_from_dataset("dataset_chunks_TRIAL_AWARE.pkl")
228
+ if downloaded:
229
+ chunks_path = downloaded
230
+ if not EMBEDDINGS_FILE.exists():
231
+ downloaded = download_from_dataset("dataset_embeddings_TRIAL_AWARE_FIXED.npy") # FIXED version
232
+ if downloaded:
233
+ embeddings_path = downloaded
234
+ if not DATASET_FILE.exists():
235
+ downloaded = download_from_dataset("complete_dataset_WITH_RESULTS_FULL.txt")
236
+ if downloaded:
237
+ dataset_path = downloaded
238
+
239
+ if chunks_path.exists() and embeddings_path.exists():
240
+ try:
241
+ logger.info("Loading embeddings from disk...")
242
+ with open(chunks_path, 'rb') as f:
243
+ doc_chunks = pickle.load(f)
244
+
245
+ # Load embeddings
246
+ loaded_embeddings = np.load(embeddings_path, allow_pickle=True)
247
+
248
+ logger.info(f"Loaded embeddings type: {type(loaded_embeddings)}")
249
+
250
+ # Check if it's already a proper numpy array
251
+ if isinstance(loaded_embeddings, np.ndarray) and loaded_embeddings.ndim == 2:
252
+ doc_embeddings = loaded_embeddings
253
+ logger.info(f"✓ Embeddings are proper numpy array with shape: {doc_embeddings.shape}")
254
+ elif isinstance(loaded_embeddings, list):
255
+ logger.info(f"Converting embeddings from list to numpy array (memory efficient)...")
256
+ # Convert in chunks to avoid memory spike
257
+ chunk_size = 10000
258
+ total = len(loaded_embeddings)
259
+
260
+ # DEBUG: Print first 3 items to see format
261
+ logger.info(f"DEBUG: Total embeddings: {total}")
262
+ logger.info(f"DEBUG: Type of first item: {type(loaded_embeddings[0])}")
263
+
264
+ # Check if this is actually the chunks file (wrong file uploaded)
265
+ if isinstance(loaded_embeddings[0], tuple) and len(loaded_embeddings[0]) == 2:
266
+ if isinstance(loaded_embeddings[0][0], int) and isinstance(loaded_embeddings[0][1], str):
267
+ raise ValueError(
268
+ f"ERROR: The embeddings file contains (int, string) tuples!\n"
269
+ f"This looks like the CHUNKS file was uploaded as the embeddings file.\n\n"
270
+ f"First item: {loaded_embeddings[0][:2]}\n\n"
271
+ f"Please re-upload the correct file:\n"
272
+ f" CORRECT: dataset_embeddings_TRIAL_AWARE.npy (numpy array, 855 MB)\n"
273
+ f" WRONG: dataset_chunks_TRIAL_AWARE.pkl (tuples, 2.8 GB)\n\n"
274
+ f"The local file at /mnt/c/Users/ibm/Documents/HF/kg_to_model/dataset_embeddings_TRIAL_AWARE.npy is correct."
275
+ )
276
+
277
+ if isinstance(loaded_embeddings[0], tuple):
278
+ logger.info(f"DEBUG: Tuple length: {len(loaded_embeddings[0])}")
279
+ for i, item in enumerate(loaded_embeddings[0][:5] if len(loaded_embeddings[0]) > 5 else loaded_embeddings[0]):
280
+ logger.info(f"DEBUG: Tuple element {i}: type={type(item)}, preview={str(item)[:100]}")
281
+
282
+ # Get embedding dimension from first item
283
+ first_emb = loaded_embeddings[0]
284
+ emb_idx = None # Initialize
285
+
286
+ # Handle different formats
287
+ if isinstance(first_emb, tuple):
288
+ # Try both positions - could be (id, emb) or (emb, id)
289
+ logger.info(f"DEBUG: Trying to find embedding vector in tuple...")
290
+ emb_vector = None
291
+ for idx, elem in enumerate(first_emb):
292
+ if isinstance(elem, (list, np.ndarray)):
293
+ emb_vector = elem
294
+ emb_idx = idx
295
+ logger.info(f"DEBUG: Found embedding at position {idx}")
296
+ break
297
+
298
+ if emb_vector is None:
299
+ raise ValueError(f"No embedding vector found in tuple. Tuple contains: {[type(x) for x in first_emb]}")
300
+
301
+ emb_dim = len(emb_vector)
302
+ logger.info(f"DEBUG: Embedding dimension: {emb_dim}")
303
+ elif isinstance(first_emb, list):
304
+ emb_dim = len(first_emb)
305
+ emb_idx = None
306
+ elif isinstance(first_emb, np.ndarray):
307
+ emb_dim = first_emb.shape[0]
308
+ emb_idx = None
309
+ else:
310
+ raise ValueError(f"Unknown embedding format: {type(first_emb)}")
311
+
312
+ logger.info(f"Creating array for {total} embeddings of dimension {emb_dim}")
313
+
314
+ # Pre-allocate array
315
+ doc_embeddings = np.zeros((total, emb_dim), dtype=np.float32)
316
+
317
+ # Fill in chunks
318
+ for i in range(0, total, chunk_size):
319
+ end = min(i + chunk_size, total)
320
+
321
+ # Extract embeddings from tuples if needed
322
+ if isinstance(first_emb, tuple) and emb_idx is not None:
323
+ # Extract just the embedding vector from each tuple at the correct position
324
+ batch = [item[emb_idx] for item in loaded_embeddings[i:end]]
325
+ doc_embeddings[i:end] = batch
326
+ else:
327
+ doc_embeddings[i:end] = loaded_embeddings[i:end]
328
+
329
+ if i % 50000 == 0:
330
+ logger.info(f"Converted {i}/{total} embeddings...")
331
+
332
+ logger.info(f"✓ Converted to array with shape: {doc_embeddings.shape}")
333
+ else:
334
+ doc_embeddings = loaded_embeddings
335
+ logger.info(f"Embeddings already numpy array with shape: {doc_embeddings.shape}")
336
+
337
+ logger.info(f"Loaded {len(doc_chunks)} chunks with embeddings")
338
+
339
+ # Skip BM25 (too memory-heavy for Docker), use inverted index only
340
+ global inverted_index
341
+
342
+ # Try to load pre-built inverted index (638MB) - MUCH faster than building (15 minutes)
343
+ if INVERTED_INDEX_FILE.exists():
344
+ logger.info(f"Loading pre-built inverted index from {INVERTED_INDEX_FILE.name}...")
345
+ try:
346
+ with open(INVERTED_INDEX_FILE, 'rb') as f:
347
+ inverted_index = pickle.load(f)
348
+ logger.info(f"✓ Loaded pre-built inverted index with {len(inverted_index):,} terms (instant vs 15min build)")
349
+ except Exception as e:
350
+ logger.warning(f"Failed to load pre-built index: {e}, building from scratch...")
351
+ inverted_index = build_inverted_index(doc_chunks)
352
+ else:
353
+ logger.info("Pre-built inverted index not found, building from scratch (this takes 15 minutes)...")
354
+ inverted_index = build_inverted_index(doc_chunks)
355
+
356
+ logger.info("Will use inverted index + semantic search (no BM25)")
357
+
358
+ return True
359
+ except Exception as e:
360
+ logger.error(f"Failed to load embeddings: {e}")
361
+ raise RuntimeError("Embeddings are required but failed to load") from e
362
+
363
+ raise RuntimeError("Embeddings files not found - system cannot function without embeddings")
364
+
365
+
366
+ def filter_trial_for_clinical_summary(trial_text):
367
+ """
368
+ Filter trial data to keep essential clinical information including SOME results.
369
+
370
+ COMPREHENSIVE FILTERING:
371
+ - Keeps all core trial info (title, summary, conditions, interventions)
372
+ - Keeps sponsor/collaborator/manufacturer (WHO is running the trial)
373
+ - Keeps first 5 outcomes (to show key endpoints)
374
+ - Keeps first 5 result values per trial (to show actual data)
375
+ - Filters out overwhelming statistical noise (hundreds of baseline/adverse event lines)
376
+
377
+ This ensures the LLM sees comprehensive context including company information.
378
+ """
379
+ if not trial_text:
380
+ return trial_text
381
+
382
+ lines = trial_text.split('\n')
383
+ filtered_lines = []
384
+
385
+ # Counters to limit repetitive data
386
+ outcome_count = 0
387
+ outcome_desc_count = 0
388
+ result_value_count = 0
389
+
390
+ # Limits
391
+ MAX_OUTCOMES = 5
392
+ MAX_OUTCOME_DESC = 5
393
+ MAX_RESULT_VALUES = 5
394
+
395
+ for line in lines:
396
+ line_stripped = line.strip()
397
+
398
+ # Skip empty lines
399
+ if not line_stripped:
400
+ continue
401
+
402
+ # ALWAYS SKIP: Overwhelming noise
403
+ always_skip = [
404
+ 'BASELINE:', 'SERIOUS_ADVERSE_EVENT:', 'OTHER_ADVERSE_EVENT:',
405
+ 'OUTCOME_TYPE:', 'OUTCOME_TIME_FRAME:', 'OUTCOME_SAFETY:',
406
+ 'OUTCOME_OTHER:', 'OUTCOME_NUMBER:'
407
+ ]
408
+
409
+ should_skip = False
410
+ for marker in always_skip:
411
+ if line_stripped.startswith(marker):
412
+ should_skip = True
413
+ break
414
+
415
+ if should_skip:
416
+ continue
417
+
418
+ # LIMITED KEEP: Outcomes (first N only)
419
+ if line_stripped.startswith('OUTCOME:'):
420
+ outcome_count += 1
421
+ if outcome_count <= MAX_OUTCOMES:
422
+ filtered_lines.append(line)
423
+ continue
424
+
425
+ # LIMITED KEEP: Outcome descriptions (first N only)
426
+ if line_stripped.startswith('OUTCOME_DESCRIPTION:'):
427
+ outcome_desc_count += 1
428
+ if outcome_desc_count <= MAX_OUTCOME_DESC:
429
+ filtered_lines.append(line)
430
+ continue
431
+
432
+ # LIMITED KEEP: Result values (first N only)
433
+ if line_stripped.startswith('RESULT_VALUE:'):
434
+ result_value_count += 1
435
+ if result_value_count <= MAX_RESULT_VALUES:
436
+ filtered_lines.append(line)
437
+ continue
438
+
439
+ # ALWAYS KEEP: Core trial information + context
440
+ always_keep = [
441
+ 'NCT_ID:', 'TITLE:', 'OFFICIAL_TITLE:',
442
+ 'SUMMARY:', 'DESCRIPTION:',
443
+ 'CONDITIONS:', 'INTERVENTION:', # WHAT disease, WHAT drug
444
+ 'SPONSOR:', 'COLLABORATOR:', 'MANUFACTURER:', # WHO is running/funding
445
+ 'ELIGIBILITY:'
446
+ # Note: OUTCOME/OUTCOME_DESCRIPTION handled in LIMITED KEEP section above
447
+ ]
448
+
449
+ for marker in always_keep:
450
+ if line_stripped.startswith(marker):
451
+ filtered_lines.append(line)
452
+ break
453
+
454
+ return '\n'.join(filtered_lines)
455
+
456
+
457
+ def retrieve_context_with_embeddings(query, top_k=10):
458
+ """
459
+ ENTERPRISE HYBRID SEARCH: Always combines keyword + semantic scoring
460
+ - Extracts ALL meaningful terms from query (case-insensitive)
461
+ - Scores each trial by keyword frequency (TF-IDF style)
462
+ - Also gets semantic similarity scores
463
+ - Merges both scores with weighted combination
464
+ - Works regardless of capitalization, language, or spelling
465
+ """
466
+ import time
467
+ import re
468
+ from collections import Counter
469
+ global doc_chunks, doc_embeddings, embedder
470
+
471
+ if doc_embeddings is None or len(doc_chunks) == 0:
472
+ logger.error("Embeddings not loaded!")
473
+ return ""
474
+
475
+ t0 = time.time()
476
+
477
+ # Extract ALL meaningful words from query (stop words removed)
478
+ stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with',
479
+ 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'what', 'how', 'do', 'you', 'know',
480
+ 'about', 'that', 'this', 'there', 'it'}
481
+
482
+ query_lower = query.lower()
483
+ # Remove punctuation and split
484
+ words = re.findall(r'\b\w+\b', query_lower)
485
+ # Filter out stop words and short words
486
+ query_terms = [w for w in words if len(w) > 2 and w not in stop_words]
487
+
488
+ logger.info(f"[HYBRID] Query terms extracted: {query_terms}")
489
+
490
+ # PARALLEL SEARCH: Run both keyword and semantic simultaneously
491
+
492
+ # 1. KEYWORD SCORING WITH BM25 (Fast!)
493
+ t_kw = time.time()
494
+
495
+ # Use inverted index for drug lookup (lightweight, no BM25)
496
+ global bm25_index, inverted_index
497
+ keyword_scores = {}
498
+
499
+ if inverted_index is not None:
500
+ # Check if any query terms are in our drug/intervention inverted index
501
+ inv_index_candidates = set()
502
+ for term in query_terms:
503
+ if term in inverted_index:
504
+ inv_index_candidates.update(inverted_index[term])
505
+ logger.info(f"[INVERTED INDEX] Found {len(inverted_index[term])} trials for '{term}'")
506
+
507
+ # FAST PATH: If we have inverted index hits (drug names), score those trials
508
+ if inv_index_candidates:
509
+ logger.info(f"[FAST PATH] Checking {len(inv_index_candidates)} inverted index candidates")
510
+
511
+ # CRITICAL: Identify which terms are specific drugs (low frequency)
512
+ drug_specific_terms = set()
513
+ for term in query_terms:
514
+ if term in inverted_index and len(inverted_index[term]) < 100:
515
+ # This term appears in <100 trials - likely a specific drug name!
516
+ drug_specific_terms.add(term)
517
+ logger.info(f"[DRUG SPECIFIC] '{term}' found in {len(inverted_index[term])} trials - treating as drug name")
518
+
519
+ for idx in inv_index_candidates:
520
+ # No BM25, use simple match count as base score
521
+ base_score = 1.0
522
+
523
+ # Check if this trial contains a drug-specific term
524
+ chunk_data = doc_chunks[idx]
525
+ chunk_text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data
526
+ chunk_lower = chunk_text.lower()
527
+
528
+ has_drug_match = False
529
+ for drug_term in drug_specific_terms:
530
+ if drug_term in chunk_lower:
531
+ has_drug_match = True
532
+ break
533
+
534
+ # MASSIVE PRIORITY for drug-specific trials
535
+ if has_drug_match:
536
+ # Drug-specific trials get GUARANTEED top ranking
537
+ score = 1000.0 + base_score
538
+ logger.info(f"[DRUG PRIORITY] Trial {idx} contains specific drug - score={score:.1f}")
539
+ else:
540
+ # Regular inverted index hits (generic terms)
541
+ if base_score <= 0:
542
+ base_score = 0.1
543
+
544
+ score = base_score
545
+
546
+ # Apply field-specific boosting for non-drug terms
547
+ max_field_boost = 1.0
548
+ for term in query_terms:
549
+ if term not in chunk_lower or term in drug_specific_terms:
550
+ continue
551
+
552
+ # INTERVENTION field - medium priority for non-drug terms
553
+ if f'intervention: {term}' in chunk_lower or f'intervention:{term}' in chunk_lower:
554
+ max_field_boost = max(max_field_boost, 3.0)
555
+ # TITLE field - low priority
556
+ elif 'title:' in chunk_lower:
557
+ title_pos = chunk_lower.find('title:')
558
+ term_pos = chunk_lower.find(term)
559
+ if title_pos < term_pos < title_pos + 200:
560
+ max_field_boost = max(max_field_boost, 2.0)
561
+
562
+ score *= max_field_boost
563
+
564
+ keyword_scores[idx] = score
565
+ else:
566
+ logger.info(f"[FALLBACK] No inverted index hits, using pure semantic search")
567
+
568
+ logger.info(f"[HYBRID] Inverted index scoring: {len(keyword_scores)} trials matched ({time.time()-t_kw:.2f}s)")
569
+
570
+ # 2. SEMANTIC SCORING
571
+ load_embedder()
572
+ t_sem = time.time()
573
+ query_embedding = embedder.encode([query])[0]
574
+ semantic_similarities = np.dot(doc_embeddings, query_embedding)
575
+ logger.info(f"[HYBRID] Semantic scoring complete ({time.time()-t_sem:.2f}s)")
576
+
577
+ # 3. MERGE SCORES
578
+ # Normalize both scores to 0-1 range
579
+ if keyword_scores:
580
+ max_kw = max(keyword_scores.values())
581
+ keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()}
582
+ else:
583
+ keyword_scores_norm = {}
584
+
585
+ max_sem = semantic_similarities.max()
586
+ min_sem = semantic_similarities.min()
587
+ semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10)
588
+
589
+ # Combined score: 50% keyword (with IDF/field boost), 50% semantic (context)
590
+ # Balanced approach: IDF-weighted keywords + semantic understanding
591
+ combined_scores = np.zeros(len(doc_chunks))
592
+
593
+ for idx in range(len(doc_chunks)):
594
+ kw_score = keyword_scores_norm.get(idx, 0.0)
595
+ sem_score = semantic_scores_norm[idx]
596
+
597
+ # If keyword match exists, balance keyword + semantic
598
+ if kw_score > 0:
599
+ combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score
600
+ else:
601
+ # Pure semantic if no keyword match
602
+ combined_scores[idx] = sem_score
603
+
604
+ # Get top K by combined score (get more candidates to sort by recency)
605
+ # We'll get 10 candidates, then sort by NCT ID to find the 3 most recent
606
+ candidate_k = max(top_k * 3, 10) # Get 3x requested, minimum 10
607
+ top_indices = np.argsort(combined_scores)[-candidate_k:][::-1]
608
+
609
+ logger.info(f"[HYBRID] Top 3 combined scores: {combined_scores[top_indices[:3]]}")
610
+ logger.info(f"[HYBRID] Top 3 keyword scores: {[keyword_scores_norm.get(i, 0.0) for i in top_indices[:3]]}")
611
+ logger.info(f"[HYBRID] Top 3 semantic scores: {[semantic_scores_norm[i] for i in top_indices[:3]]}")
612
+
613
+ # Extract text and scores for 355M ranking
614
+ # Format as (score, text) tuples for rank_trials_with_355m
615
+ candidate_trials_for_ranking = [(combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i]) for i in top_indices]
616
+
617
+ # SORT BY NCT ID (higher = newer) before 355M ranking
618
+ def extract_nct_number(trial_tuple):
619
+ """Extract NCT number from trial text for sorting (higher = newer)"""
620
+ _, text = trial_tuple
621
+ match = re.search(r'NCT_ID:\s*NCT(\d+)', text)
622
+ return int(match.group(1)) if match else 0
623
+
624
+ # Sort candidates by NCT ID (descending = newest first)
625
+ candidate_trials_for_ranking.sort(key=extract_nct_number, reverse=True)
626
+
627
+ # Log top 5 NCT IDs to show recency sorting
628
+ top_ncts = []
629
+ for score, text in candidate_trials_for_ranking[:5]:
630
+ match = re.search(r'NCT_ID:\s*(NCT\d+)', text)
631
+ if match:
632
+ top_ncts.append(match.group(1))
633
+ logger.info(f"[NCT SORT] Top 5 candidates by recency: {top_ncts}")
634
+
635
+ # SKIP 355M RANKING - It's broken (gives 0.50 to everything) and wastes 10 seconds
636
+ # Just use the hybrid-scored + recency-sorted candidates
637
+ logger.info(f"[FAST MODE] Using hybrid search + recency sort (skipping broken 355M ranking)")
638
+ ranked_trials = candidate_trials_for_ranking
639
+
640
+ # Take top K from ranked results
641
+ top_ranked = ranked_trials[:top_k]
642
+
643
+ logger.info(f"[FAST MODE] Selected top {len(top_ranked)} trials (hybrid score + recency)")
644
+
645
+ # Extract just the text
646
+ raw_chunks = [trial_text for _, trial_text in top_ranked]
647
+
648
+ # Apply clinical filter to each trial
649
+ context_chunks = [filter_trial_for_clinical_summary(chunk) for chunk in raw_chunks]
650
+
651
+ if context_chunks:
652
+ first_trial_preview = context_chunks[0][:200]
653
+ logger.info(f"[HYBRID] First result (filtered): {first_trial_preview}")
654
+
655
+ # Add ranking information if available from 355M
656
+ if hasattr(ranked_trials, 'ranking_info'):
657
+ ranking_header = "[TRIAL RANKING BY CLINICAL RELEVANCE GPT]\n"
658
+ for info in ranked_trials.ranking_info:
659
+ ranking_header += f"Rank {info['rank']}: {info['nct_id']} - Relevance {info['relevance_rating']}\n"
660
+ ranking_header += "---\n\n"
661
+
662
+ # Prepend ranking info to first trial
663
+ if context_chunks:
664
+ context_chunks[0] = ranking_header + context_chunks[0]
665
+ logger.info(f"[355M RANKING] Added ranking metadata to context for final LLM")
666
+
667
+ context = "\n\n---\n\n".join(context_chunks) # Use --- as separator between trials
668
+ logger.info(f"[HYBRID] TOTAL TIME: {time.time()-t0:.2f}s")
669
+ logger.info(f"[HYBRID] Filtered context length: {len(context)} chars (was ~{sum(len(c) for c in raw_chunks)} chars)")
670
+
671
+ return context
672
+
673
+
674
+ def keyword_search_query_text(query, max_results=10, hf_token=None):
675
+ """Search dataset using ALL meaningful words from the full query"""
676
+ if not DATASET_FILE.exists():
677
+ logger.error("Dataset file not found")
678
+ return ""
679
+
680
+ # Extract all meaningful words from the full query
681
+ # Remove common stopwords but keep medical/clinical terms
682
+ stopwords = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
683
+ 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'should',
684
+ 'could', 'may', 'might', 'must', 'can', 'of', 'at', 'by', 'for', 'with',
685
+ 'about', 'as', 'into', 'through', 'during', 'to', 'from', 'in', 'on',
686
+ 'what', 'you', 'know', 'that', 'relevant'}
687
+
688
+ # Extract words, filter stopwords and short words
689
+ words = query.lower().split()
690
+ search_terms = [w.strip('?.,!;:()[]{}') for w in words
691
+ if w.lower() not in stopwords and len(w) >= 3]
692
+
693
+ if not search_terms:
694
+ logger.warning("No search terms extracted from query")
695
+ return ""
696
+
697
+ logger.info(f"Search terms from full query: {search_terms}")
698
+
699
+ # Store trials with match scores
700
+ trials_with_scores = []
701
+ current_trial = ""
702
+
703
+ try:
704
+ with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f:
705
+ for line in f:
706
+ # Check if new trial starts
707
+ if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"):
708
+ # Score previous trial
709
+ if current_trial:
710
+ trial_lower = current_trial.lower()
711
+
712
+ # Count matches for all search terms
713
+ score = sum(1 for term in search_terms if term in trial_lower)
714
+
715
+ if score > 0:
716
+ trials_with_scores.append((score, current_trial))
717
+
718
+ current_trial = line
719
+ else:
720
+ current_trial += line
721
+
722
+ # Check last trial
723
+ if current_trial:
724
+ trial_lower = current_trial.lower()
725
+ score = sum(1 for term in search_terms if term in trial_lower)
726
+ if score > 0:
727
+ trials_with_scores.append((score, current_trial))
728
+
729
+ # Sort by score (highest first) and take top results
730
+ trials_with_scores.sort(reverse=True, key=lambda x: x[0])
731
+ matching_trials = [(score, trial) for score, trial in trials_with_scores[:max_results]]
732
+
733
+ if matching_trials:
734
+ logger.info(f"Keyword search found {len(matching_trials)} trials")
735
+ return matching_trials # Return list of (score, trial) tuples
736
+ else:
737
+ logger.warning("Keyword search found no matching trials")
738
+ return []
739
+
740
+ except Exception as e:
741
+ logger.error(f"Keyword search failed: {e}")
742
+ return []
743
+
744
+
745
+ def keyword_search_in_dataset(entities, max_results=10):
746
+ """Legacy: Search dataset file for keyword matches using extracted entities"""
747
+ if not DATASET_FILE.exists():
748
+ logger.error("Dataset file not found")
749
+ return ""
750
+
751
+ drugs = [d.lower() for d in entities.get('drugs', [])]
752
+ conditions = [c.lower() for c in entities.get('conditions', [])]
753
+
754
+ if not drugs and not conditions:
755
+ logger.warning("No search terms for keyword search")
756
+ return ""
757
+
758
+ logger.info(f"Keyword search - Drugs: {drugs}, Conditions: {conditions}")
759
+
760
+ # Store trials with match scores
761
+ trials_with_scores = []
762
+ current_trial = ""
763
+
764
+ try:
765
+ with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f:
766
+ for line in f:
767
+ # Check if new trial starts
768
+ if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"):
769
+ # Score previous trial
770
+ if current_trial:
771
+ trial_lower = current_trial.lower()
772
+
773
+ # Count matches
774
+ drug_matches = sum(1 for d in drugs if d in trial_lower)
775
+ condition_matches = sum(1 for c in conditions if c in trial_lower)
776
+
777
+ # Only include trials that match at least the drug (if drug was specified)
778
+ if drugs:
779
+ if drug_matches > 0:
780
+ score = drug_matches * 10 + condition_matches
781
+ trials_with_scores.append((score, current_trial))
782
+ elif condition_matches > 0:
783
+ # No drug specified, just match conditions
784
+ trials_with_scores.append((condition_matches, current_trial))
785
+
786
+ current_trial = line
787
+ else:
788
+ current_trial += line
789
+
790
+ # Check last trial
791
+ if current_trial:
792
+ trial_lower = current_trial.lower()
793
+ drug_matches = sum(1 for d in drugs if d in trial_lower)
794
+ condition_matches = sum(1 for c in conditions if c in trial_lower)
795
+
796
+ if drugs:
797
+ if drug_matches > 0:
798
+ score = drug_matches * 10 + condition_matches
799
+ trials_with_scores.append((score, current_trial))
800
+ elif condition_matches > 0:
801
+ trials_with_scores.append((condition_matches, current_trial))
802
+
803
+ # Sort by score (highest first) and take top results
804
+ trials_with_scores.sort(reverse=True, key=lambda x: x[0])
805
+ matching_trials = [trial for score, trial in trials_with_scores[:max_results]]
806
+
807
+ if matching_trials:
808
+ context = "\n\n---\n\n".join(matching_trials)
809
+ if len(context) > 6000:
810
+ context = context[:6000] + "..."
811
+ logger.info(f"Keyword search found {len(matching_trials)} trials (from {len(trials_with_scores)} candidates)")
812
+ return context
813
+ else:
814
+ logger.warning("Keyword search found no trials matching drug")
815
+ return ""
816
+
817
+ except Exception as e:
818
+ logger.error(f"Keyword search failed: {e}")
819
+ return ""
820
+
821
+
822
+ # ============================================================================
823
+ # ENTITY EXTRACTION
824
+ # ============================================================================
825
+
826
+ def parse_entities_from_query(conversation, hf_token=None):
827
+ """Parse entities from query using both 355M and 8B models + regex fallback"""
828
+ entities = {'drugs': [], 'conditions': []}
829
+
830
+ # Use 355M model for entity extraction
831
+ extracted_355m = extract_entities_with_small_model(conversation)
832
+
833
+ # Also use 8B model for more reliable extraction
834
+ extracted_8b = extract_entities_with_8b(conversation, hf_token=hf_token)
835
+
836
+ # Combine both extractions
837
+ extracted = (extracted_355m or "") + "\n" + (extracted_8b or "")
838
+
839
+ # Parse model output
840
+ if extracted:
841
+ lines = extracted.split('\n')
842
+ for line in lines:
843
+ lower_line = line.lower()
844
+ if 'drug:' in lower_line or 'medication:' in lower_line:
845
+ drug = re.sub(r'(drug:|medication:)', '', line, flags=re.IGNORECASE).strip()
846
+ if drug:
847
+ entities['drugs'].append(drug)
848
+ elif 'condition:' in lower_line or 'disease:' in lower_line:
849
+ condition = re.sub(r'(condition:|disease:)', '', line, flags=re.IGNORECASE).strip()
850
+ if condition:
851
+ entities['conditions'].append(condition)
852
+
853
+ # Regex fallback for standard drug naming patterns
854
+ drug_patterns = [
855
+ r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: -mab suffix
856
+ r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: -nib suffix
857
+ r'\b([A-Z]\d+[A-Z]+\d+)\b' # Alphanumeric codes like F8IL10
858
+ ]
859
+ for pattern in drug_patterns:
860
+ matches = re.findall(pattern, conversation)
861
+ for match in matches:
862
+ if match.lower() not in [d.lower() for d in entities['drugs']]:
863
+ entities['drugs'].append(match)
864
+
865
+ condition_patterns = [
866
+ r'\b(sjogren\'?s?|lupus|myelofibrosis|rheumatoid arthritis)\b'
867
+ ]
868
+ for pattern in condition_patterns:
869
+ matches = re.findall(pattern, conversation, re.IGNORECASE)
870
+ for match in matches:
871
+ if match not in [c.lower() for c in entities['conditions']]:
872
+ entities['conditions'].append(match)
873
+
874
+ logger.info(f"Extracted entities: {entities}")
875
+ return entities
876
+
877
+
878
+ # ============================================================================
879
+ # MAIN QUERY PROCESSING
880
+ # ============================================================================
881
+
882
+ def extract_entities_simple(query):
883
+ """Simple entity extraction using regex patterns - no model needed"""
884
+ entities = {'drugs': [], 'conditions': []}
885
+
886
+ # Drug patterns
887
+ drug_patterns = [
888
+ r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: ianalumab, rituximab, etc.
889
+ r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: imatinib, etc.
890
+ r'\b([A-Z]\d+[A-Z]+\d+)\b', # Alphanumeric codes
891
+ r'\b(ianalumab|rituximab|tocilizumab|adalimumab|infliximab)\b', # Common drugs
892
+ ]
893
+
894
+ # Condition patterns
895
+ condition_patterns = [
896
+ r'\b(sjogren\'?s?\s+syndrome)\b',
897
+ r'\b(rheumatoid arthritis)\b',
898
+ r'\b(lupus)\b',
899
+ r'\b(myelofibrosis)\b',
900
+ r'\b(diabetes)\b',
901
+ r'\b(cancer|carcinoma|melanoma)\b',
902
+ ]
903
+
904
+ query_lower = query.lower()
905
+
906
+ # Extract drugs
907
+ for pattern in drug_patterns:
908
+ matches = re.findall(pattern, query, re.IGNORECASE)
909
+ for match in matches:
910
+ if match.lower() not in [d.lower() for d in entities['drugs']]:
911
+ entities['drugs'].append(match)
912
+
913
+ # Extract conditions
914
+ for pattern in condition_patterns:
915
+ matches = re.findall(pattern, query, re.IGNORECASE)
916
+ for match in matches:
917
+ if match.lower() not in [c.lower() for c in entities['conditions']]:
918
+ entities['conditions'].append(match)
919
+
920
+ logger.info(f"Extracted entities: {entities}")
921
+ return entities
922
+
923
+
924
+ def parse_query_with_llm(query, hf_token=None):
925
+ """
926
+ Use fast LLM to parse query and extract structured information
927
+
928
+ Extracts:
929
+ - Drug names
930
+ - Diseases/conditions
931
+ - Companies (sponsors/manufacturers)
932
+ - Endpoints (what's being measured)
933
+ - Search terms (optimized for RAG)
934
+
935
+ Returns: Dict with extracted entities and optimized search query
936
+ """
937
+ try:
938
+ from huggingface_hub import InferenceClient
939
+
940
+ logger.info("[QUERY PARSER] Analyzing user query with LLM...")
941
+ client = InferenceClient(token=hf_token, timeout=30)
942
+
943
+ parse_prompt = f"""Extract key information from this clinical trial query.
944
+
945
+ Query: "{query}"
946
+
947
+ Extract and return in this EXACT format:
948
+ DRUGS: [list drug/treatment names, or "none"]
949
+ DISEASES: [list diseases/conditions, or "none"]
950
+ COMPANIES: [list company/sponsor names, or "none"]
951
+ ENDPOINTS: [list trial endpoints/outcomes, or "none"]
952
+ SEARCH_TERMS: [optimized search keywords]
953
+
954
+ Examples:
955
+ Query: "What Novartis drugs treat melanoma?"
956
+ DRUGS: none
957
+ DISEASES: melanoma
958
+ COMPANIES: Novartis
959
+ ENDPOINTS: none
960
+ SEARCH_TERMS: Novartis melanoma treatment drugs
961
+
962
+ Query: "Tell me about Keytruda for lung cancer"
963
+ DRUGS: Keytruda
964
+ DISEASES: lung cancer
965
+ COMPANIES: none
966
+ ENDPOINTS: none
967
+ SEARCH_TERMS: Keytruda lung cancer
968
+
969
+ Now parse the query above:"""
970
+
971
+ response = client.chat_completion(
972
+ model="meta-llama/Llama-3.1-70B-Instruct",
973
+ messages=[{"role": "user", "content": parse_prompt}],
974
+ max_tokens=256,
975
+ temperature=0.1 # Low temp for consistent parsing
976
+ )
977
+
978
+ parsed = response.choices[0].message.content.strip()
979
+ logger.info(f"[QUERY PARSER] Extracted entities:\n{parsed}")
980
+
981
+ # Parse the response into dict
982
+ result = {
983
+ 'raw_parsed': parsed,
984
+ 'drugs': [],
985
+ 'diseases': [],
986
+ 'companies': [],
987
+ 'endpoints': [],
988
+ 'search_terms': query # fallback
989
+ }
990
+
991
+ lines = parsed.split('\n')
992
+ for line in lines:
993
+ line = line.strip()
994
+ if line.startswith('DRUGS:'):
995
+ drugs = line.replace('DRUGS:', '').strip()
996
+ if drugs.lower() != 'none':
997
+ result['drugs'] = [d.strip() for d in drugs.split(',')]
998
+ elif line.startswith('DISEASES:'):
999
+ diseases = line.replace('DISEASES:', '').strip()
1000
+ if diseases.lower() != 'none':
1001
+ result['diseases'] = [d.strip() for d in diseases.split(',')]
1002
+ elif line.startswith('COMPANIES:'):
1003
+ companies = line.replace('COMPANIES:', '').strip()
1004
+ if companies.lower() != 'none':
1005
+ result['companies'] = [c.strip() for c in companies.split(',')]
1006
+ elif line.startswith('ENDPOINTS:'):
1007
+ endpoints = line.replace('ENDPOINTS:', '').strip()
1008
+ if endpoints.lower() != 'none':
1009
+ result['endpoints'] = [e.strip() for e in endpoints.split(',')]
1010
+ elif line.startswith('SEARCH_TERMS:'):
1011
+ result['search_terms'] = line.replace('SEARCH_TERMS:', '').strip()
1012
+
1013
+ logger.info(f"[QUERY PARSER] ✓ Drugs: {result['drugs']}, Diseases: {result['diseases']}, Companies: {result['companies']}")
1014
+ return result
1015
+
1016
+ except Exception as e:
1017
+ logger.warning(f"[QUERY PARSER] Failed: {e}, using original query")
1018
+ return {
1019
+ 'drugs': [],
1020
+ 'diseases': [],
1021
+ 'companies': [],
1022
+ 'endpoints': [],
1023
+ 'search_terms': query,
1024
+ 'raw_parsed': ''
1025
+ }
1026
+
1027
+ def generate_llama_response(query, rag_context, hf_token=None):
1028
+ """
1029
+ Generate response using FAST Groq API (10x faster than HF)
1030
+
1031
+ Speed comparison:
1032
+ - HuggingFace: ~40 tokens/sec = 15 seconds
1033
+ - Groq: ~300 tokens/sec = 2 seconds (FREE!)
1034
+ """
1035
+ try:
1036
+ # Try Groq first (much faster), fallback to HuggingFace
1037
+ groq_api_key = os.getenv("GROQ_API_KEY")
1038
+
1039
+ if groq_api_key:
1040
+ logger.info("Generating response with Llama-3.1-70B via GROQ (fast)...")
1041
+ from groq import Groq
1042
+ client = Groq(api_key=groq_api_key)
1043
+
1044
+ # Simplified prompt for faster generation
1045
+ system_prompt = """You are a medical research assistant. Answer based ONLY on the provided clinical trial data. Be concise and cite NCT IDs."""
1046
+
1047
+ user_prompt = f"""Clinical trials:
1048
+ {rag_context[:6000]}
1049
+
1050
+ Question: {query}
1051
+
1052
+ Provide a concise answer citing specific NCT trial IDs."""
1053
+
1054
+ response = client.chat.completions.create(
1055
+ model="llama-3.1-70b-versatile", # Groq's optimized 70B
1056
+ messages=[
1057
+ {"role": "system", "content": system_prompt},
1058
+ {"role": "user", "content": user_prompt}
1059
+ ],
1060
+ max_tokens=512, # Shorter for speed
1061
+ temperature=0.3,
1062
+ timeout=30
1063
+ )
1064
+
1065
+ return response.choices[0].message.content.strip()
1066
+
1067
+ else:
1068
+ # Fallback to HuggingFace (slower)
1069
+ logger.info("Generating response with Llama-3.1-70B via HuggingFace (slow)...")
1070
+ from huggingface_hub import InferenceClient
1071
+ client = InferenceClient(token=hf_token, timeout=120)
1072
+
1073
+ system_prompt = """You are a medical research assistant. Answer based ONLY on the provided clinical trial data. Be concise and cite NCT IDs."""
1074
+
1075
+ user_prompt = f"""Clinical trials:
1076
+ {rag_context[:6000]}
1077
+
1078
+ Question: {query}
1079
+
1080
+ Provide a concise answer citing specific NCT trial IDs."""
1081
+
1082
+ messages = [
1083
+ {"role": "system", "content": system_prompt},
1084
+ {"role": "user", "content": user_prompt}
1085
+ ]
1086
+
1087
+ response = client.chat_completion(
1088
+ model="meta-llama/Meta-Llama-3.1-70B-Instruct",
1089
+ messages=messages,
1090
+ max_tokens=512, # Reduced from 2048 for speed
1091
+ temperature=0.3
1092
+ )
1093
+
1094
+ return response.choices[0].message.content.strip()
1095
+
1096
+ except Exception as e:
1097
+ logger.error(f"Llama error: {e}")
1098
+ return f"Llama API error: {str(e)}"
1099
+
1100
+
1101
+ def process_query_simple_test(conversation):
1102
+ """TEST JUST THE RAG - no models"""
1103
+ try:
1104
+ import time
1105
+ output = []
1106
+ output.append(f"QUERY: {conversation}\n")
1107
+
1108
+ # Check if embeddings loaded
1109
+ if doc_embeddings is None or len(doc_chunks) == 0:
1110
+ return "FAIL: Embeddings not loaded"
1111
+
1112
+ output.append(f"✓ Embeddings loaded: {len(doc_chunks)} chunks\n")
1113
+ output.append(f"✓ Embeddings shape: {doc_embeddings.shape}\n")
1114
+
1115
+ # Try to search
1116
+ start = time.time()
1117
+ context = retrieve_context_with_embeddings(conversation, top_k=3)
1118
+ search_time = time.time() - start
1119
+
1120
+ if not context:
1121
+ return "".join(output) + "\nFAIL: RAG returned empty"
1122
+
1123
+ output.append(f"✓ RAG search took: {search_time:.2f}s\n")
1124
+ output.append(f"✓ Retrieved {context.count('NCT')} trials\n\n")
1125
+ output.append("FIRST 1000 CHARS:\n")
1126
+ output.append(context[:1000])
1127
+
1128
+ return "".join(output)
1129
+
1130
+ except Exception as e:
1131
+ import traceback
1132
+ return f"ERROR IN RAG TEST:\n{str(e)}\n\nTRACEBACK:\n{traceback.format_exc()}"
1133
+
1134
+
1135
+ def process_query(conversation):
1136
+ """
1137
+ Complete pipeline with LLM query parsing and natural language generation
1138
+
1139
+ Flow:
1140
+ 0. LLM Parser - Extract drugs, diseases, companies, endpoints (~2-3s)
1141
+ 1. RAG Search - Hybrid search using optimized query (~2s)
1142
+ 2. Skipped - 355M ranking removed (was broken)
1143
+ 3. LLM Response - Llama 70B generates natural language (~15s)
1144
+
1145
+ Total: ~20 seconds
1146
+ """
1147
+ import time
1148
+ import traceback
1149
+ import sys
1150
+
1151
+ # MASTER try/except - catches EVERYTHING
1152
+ try:
1153
+ start_time = time.time()
1154
+ output_parts = [f"QUERY: {conversation}\n\n"]
1155
+
1156
+ # Step 0: Parse query with LLM to extract structured info
1157
+ try:
1158
+ step0_start = time.time()
1159
+ logger.info("Step 0: Parsing query with LLM...")
1160
+ output_parts.append("✓ Step 0: LLM query parser started...\n")
1161
+ parsed_query = parse_query_with_llm(conversation, hf_token=hf_token)
1162
+
1163
+ # Use optimized search terms from parser
1164
+ search_query = parsed_query['search_terms']
1165
+
1166
+ step0_time = time.time() - step0_start
1167
+ output_parts.append(f"✓ Step 0 Complete: Extracted entities ({step0_time:.1f}s)\n")
1168
+ output_parts.append(f" Drugs: {parsed_query['drugs']}\n")
1169
+ output_parts.append(f" Diseases: {parsed_query['diseases']}\n")
1170
+ output_parts.append(f" Companies: {parsed_query['companies']}\n")
1171
+ output_parts.append(f" Optimized search: {search_query}\n")
1172
+ logger.info(f"Query parsing successful in {step0_time:.1f}s")
1173
+
1174
+ except Exception as e:
1175
+ error_msg = f"✗ Step 0 WARNING (LLM Parser): {str(e)}, using original query"
1176
+ logger.warning(error_msg)
1177
+ output_parts.append(f"{error_msg}\n")
1178
+ search_query = conversation # Fallback to original
1179
+
1180
+ # Step 1: RAG search (using optimized search query)
1181
+ try:
1182
+ step1_start = time.time()
1183
+ logger.info("Step 1: RAG search...")
1184
+ output_parts.append("✓ Step 1: RAG search started...\n")
1185
+ context = retrieve_context_with_embeddings(search_query, top_k=3)
1186
+
1187
+ if not context:
1188
+ return "No matching trials found in RAG search."
1189
+
1190
+ # No limit - use complete trials
1191
+ step1_time = time.time() - step1_start
1192
+ output_parts.append(f"✓ Step 1 Complete: Found {context.count('NCT')} trials ({step1_time:.1f}s)\n")
1193
+ logger.info(f"RAG search successful - found trials in {step1_time:.1f}s")
1194
+
1195
+ except Exception as e:
1196
+ error_msg = f"✗ Step 1 FAILED (RAG search): {str(e)}\n{traceback.format_exc()}"
1197
+ logger.error(error_msg)
1198
+ return error_msg
1199
+
1200
+ # Step 2: Skipped (355M ranking removed - was broken)
1201
+ output_parts.append("✓ Step 2: Skipped (using hybrid search + recency)\n")
1202
+
1203
+ # Step 3: Llama 70B
1204
+ try:
1205
+ step3_start = time.time()
1206
+ logger.info("Step 3: Generating response with Llama-3.1-70B...")
1207
+ output_parts.append("✓ Step 3: Llama 70B generation started...\n")
1208
+ llama_response = generate_llama_response(conversation, context, hf_token=hf_token)
1209
+ step3_time = time.time() - step3_start
1210
+ output_parts.append(f"✓ Step 3 Complete: Llama 70B response generated ({step3_time:.1f}s)\n")
1211
+ logger.info(f"Llama 70B generation successful in {step3_time:.1f}s")
1212
+
1213
+ except Exception as e:
1214
+ error_msg = f"✗ Step 3 FAILED (Llama 70B): {str(e)}\n{traceback.format_exc()}"
1215
+ logger.error(error_msg)
1216
+ llama_response = f"[Llama 70B error: {str(e)}]"
1217
+ output_parts.append(f"✗ Step 3 Failed: {str(e)}\n")
1218
+
1219
+ total_time = time.time() - start_time
1220
+
1221
+ # Format output - handle missing variables
1222
+ try:
1223
+ context_display = context if 'context' in locals() else "[No context retrieved]"
1224
+ clinical_display = clinical_context_355m if 'clinical_context_355m' in locals() else "[355M not run]"
1225
+ llama_display = llama_response if 'llama_response' in locals() else "[Llama 70B not run]"
1226
+
1227
+ output = f"""{''.join(output_parts)}
1228
+
1229
+ CLINICAL SUMMARY (Llama-3.1-70B-Instruct):
1230
+ {llama_display}
1231
+
1232
+ ---
1233
+
1234
+ RAG RETRIEVED TRIALS (Top 3 Most Relevant):
1235
+ {context_display}
1236
+
1237
+ ---
1238
+ Total Time: {total_time:.1f}s
1239
+ """
1240
+ return output
1241
+ except Exception as e:
1242
+ # Absolute fallback
1243
+ error_info = f"""
1244
+ CRITICAL ERROR IN OUTPUT FORMATTING:
1245
+ {str(e)}
1246
+
1247
+ TRACEBACK:
1248
+ {traceback.format_exc()}
1249
+
1250
+ OUTPUT PARTS:
1251
+ {''.join(output_parts)}
1252
+
1253
+ Variables defined: {locals().keys()}
1254
+ """
1255
+ logger.error(error_info)
1256
+ return error_info
1257
+
1258
+ # MASTER EXCEPTION HANDLER - catches ANY unhandled error
1259
+ except Exception as master_error:
1260
+ master_error_msg = f"""
1261
+ ========================================
1262
+ MASTER ERROR HANDLER CAUGHT EXCEPTION
1263
+ ========================================
1264
+
1265
+ Error Type: {type(master_error).__name__}
1266
+ Error Message: {str(master_error)}
1267
+
1268
+ FULL TRACEBACK:
1269
+ {traceback.format_exc()}
1270
+
1271
+ System Info:
1272
+ - Python version: {sys.version}
1273
+ - Error at line: {sys.exc_info()[2].tb_lineno if sys.exc_info()[2] else 'unknown'}
1274
+
1275
+ ========================================
1276
+ """
1277
+ logger.error(master_error_msg)
1278
+ return master_error_msg
1279
+
1280
+
1281
+ # ============================================================================
1282
+ # GRADIO INTERFACE
1283
+ # ============================================================================
1284
+
1285
+ with gr.Blocks(title="Foundation 1.2") as demo:
1286
+ gr.Markdown("# Foundation 1.2 - Clinical Trial AI")
1287
+
1288
+ query_input = gr.Textbox(
1289
+ label="Ask about clinical trials",
1290
+ placeholder="Example: What are the results for ianalumab in Sjogren's syndrome?",
1291
+ lines=3
1292
+ )
1293
+ submit_btn = gr.Button("Generate Response", variant="primary")
1294
+
1295
+ output = gr.Textbox(
1296
+ label="AI Response",
1297
+ lines=30
1298
+ )
1299
+
1300
+ submit_btn.click(
1301
+ fn=process_query, # Full pipeline: RAG + 355M + Llama
1302
+ inputs=query_input,
1303
+ outputs=output
1304
+ )
1305
+
1306
+ gr.Markdown("""
1307
+ **Production RAG Pipeline - Optimized for Clinical Accuracy**
1308
+
1309
+ **Search (3-Stage Hybrid):**
1310
+ 1. Keyword matching (70%) + Semantic search (30%) → 10 candidates
1311
+ 2. 355M Clinical Trial GPT re-ranks by relevance
1312
+ 3. Returns top 3 trials with best clinical relevance scores
1313
+
1314
+ **Generation (Qwen2.5-14B-Instruct):**
1315
+ - 14B parameter model via HuggingFace Inference API
1316
+ - Structured clinical summaries with clear headings
1317
+ - Cites specific NCT trial IDs
1318
+ - Includes actual trial results and efficacy data
1319
+ - High-quality medical reasoning and analysis
1320
+
1321
+ *355M model used for ranking (not generation) + Qwen2.5-14B for responses*
1322
+ """)
1323
+
1324
+
1325
+ # ============================================================================
1326
+ # STARTUP
1327
+ # ============================================================================
1328
+
1329
+ # Embeddings will be loaded by FastAPI startup event in app.py
1330
+ # Do NOT load here - causes Docker permission errors
1331
+ logger.info("=== Foundation 1.2 Module Loaded ===")
1332
+ logger.info("Call load_embeddings() to initialize the system")
1333
+
1334
+ if DATASET_FILE.exists():
1335
+ file_size_mb = DATASET_FILE.stat().st_size / (1024 * 1024)
1336
+ logger.info(f"✓ Dataset file found: {file_size_mb:.0f}MB")
1337
+ else:
1338
+ logger.error("✗ Dataset file not found!")
1339
+
1340
+ logger.info("=== Startup Complete ===")
1341
+
1342
+ if __name__ == "__main__":
1343
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ spaces
3
+ huggingface_hub>=0.24.0,<1.0
4
+ transformers>=4.37.0
5
+ torch>=2.1.0
6
+ accelerate
7
+ bitsandbytes
8
+ sentence-transformers==3.1.1
9
+ PyPDF2==3.0.1
10
+ numpy==1.26.4
11
+ openai
12
+ groq
13
+ sentencepiece
14
+ protobuf
15
+ fastapi
16
+ uvicorn
17
+ pydantic
18
+ networkx>=3.1
19
+ rank-bm25
two_llm_system_FIXED.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FAST VERSION: Bypasses 355M ranking bottleneck (300s -> 0s)
3
+ Works with existing data structure: List[Tuple[int, str]]
4
+ Keeps BM25 + semantic hybrid search intact
5
+ """
6
+
7
+ import torch
8
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast, AutoTokenizer, AutoModelForCausalLM
9
+ import logging
10
+ import spaces
11
+ from functools import lru_cache
12
+ from typing import List, Tuple, Optional, Dict
13
+ from huggingface_hub import InferenceClient
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # ===========================================================================
18
+ # CACHED MODEL LOADING - Load once, reuse forever
19
+ # ===========================================================================
20
+
21
+ @lru_cache(maxsize=1)
22
+ def get_cached_355m_model():
23
+ """Load 355M model once and cache it for entity extraction"""
24
+ logger.info("Loading 355M Clinical Trial GPT (cached for entity extraction)...")
25
+ tokenizer = GPT2TokenizerFast.from_pretrained("gmkdigitalmedia/clinicaltrial2.2")
26
+ model = GPT2LMHeadModel.from_pretrained(
27
+ "gmkdigitalmedia/clinicaltrial2.2",
28
+ torch_dtype=torch.float16,
29
+ device_map="auto"
30
+ )
31
+ model.eval()
32
+ return tokenizer, model
33
+
34
+ @lru_cache(maxsize=1)
35
+ def get_cached_8b_model(hf_token: Optional[str] = None):
36
+ """Load 8B model once and cache it"""
37
+ logger.info("Loading II-Medical-8B (cached)...")
38
+ tokenizer = AutoTokenizer.from_pretrained(
39
+ "Intelligent-Internet/II-Medical-8B-1706",
40
+ token=hf_token,
41
+ trust_remote_code=True
42
+ )
43
+ if tokenizer.pad_token is None:
44
+ tokenizer.pad_token = tokenizer.eos_token
45
+
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ "Intelligent-Internet/II-Medical-8B-1706",
48
+ device_map="auto",
49
+ token=hf_token,
50
+ trust_remote_code=True,
51
+ torch_dtype=torch.bfloat16
52
+ )
53
+ return tokenizer, model
54
+
55
+ # ===========================================================================
56
+ # FAST RANKING - Replace 300s function with instant passthrough
57
+ # ===========================================================================
58
+
59
+ @spaces.GPU
60
+ def rank_trials_FAST(query: str, trials_list: List[Tuple[float, str]], hf_token=None) -> List[Tuple[float, str]]:
61
+ """
62
+ SMART RANKING: Use 355M to rank only top 3 trials
63
+
64
+ Takes top 3 from BM25+semantic search, then uses 355M Clinical Trial GPT
65
+ to re-rank them by clinical relevance.
66
+
67
+ Time: ~30 seconds for 3 trials (vs 300s for 30 trials)
68
+
69
+ Args:
70
+ query: The search query
71
+ trials_list: List of (score, trial_text) tuples from BM25+semantic search
72
+ hf_token: Not needed
73
+
74
+ Returns:
75
+ Top 3 trials re-ranked by 355M clinical relevance
76
+ """
77
+ import time
78
+ import re
79
+
80
+ start_time = time.time()
81
+
82
+ # Take only top 3 trials for 355M ranking
83
+ top_3 = trials_list[:3]
84
+
85
+ logger.info(f"[355M RANKING] Ranking top 3 trials with Clinical Trial GPT...")
86
+
87
+ # Get cached 355M model
88
+ tokenizer, model = get_cached_355m_model()
89
+
90
+ # Score each trial
91
+ trial_scores = []
92
+
93
+ for idx, (bm25_score, trial_text) in enumerate(top_3):
94
+ # Extract NCT ID for logging
95
+ nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text)
96
+ nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}"
97
+
98
+ # Create prompt for relevance scoring
99
+ # Truncate trial to 800 chars to keep it fast
100
+ trial_snippet = trial_text[:800]
101
+
102
+ prompt = f"""Query: {query}
103
+
104
+ Clinical Trial: {trial_snippet}
105
+
106
+ Rate clinical relevance (1-10):"""
107
+
108
+ # Get model score
109
+ try:
110
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
111
+
112
+ with torch.no_grad():
113
+ outputs = model.generate(
114
+ inputs.input_ids,
115
+ max_length=inputs.input_ids.shape[1] + 10,
116
+ temperature=0.3,
117
+ do_sample=False,
118
+ pad_token_id=tokenizer.pad_token_id,
119
+ eos_token_id=tokenizer.eos_token_id
120
+ )
121
+
122
+ generated = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
123
+
124
+ # Extract number from response
125
+ score_match = re.search(r'(\d+)', generated.strip())
126
+ relevance_score = float(score_match.group(1)) if score_match else 5.0
127
+
128
+ # Normalize to 0-1 range
129
+ relevance_score = relevance_score / 10.0
130
+
131
+ logger.info(f"[355M RANKING] {nct_id}: relevance={relevance_score:.2f} (BM25={bm25_score:.3f})")
132
+
133
+ except Exception as e:
134
+ logger.warning(f"[355M RANKING] Scoring failed for {nct_id}: {e}, using BM25 score")
135
+ relevance_score = bm25_score
136
+
137
+ trial_scores.append((relevance_score, trial_text, nct_id))
138
+
139
+ # Sort by 355M relevance score (descending)
140
+ trial_scores.sort(key=lambda x: x[0], reverse=True)
141
+
142
+ # Format as (score, text) tuples for backwards compatibility
143
+ # Create a custom list class that can hold attributes
144
+ class RankedTrialsList(list):
145
+ """List that can hold ranking metadata"""
146
+ pass
147
+
148
+ ranked_trials = RankedTrialsList()
149
+ ranking_metadata = []
150
+
151
+ for rank, (score, text, nct_id) in enumerate(trial_scores, 1):
152
+ ranked_trials.append((score, text))
153
+ ranking_metadata.append({
154
+ 'rank': rank,
155
+ 'nct_id': nct_id,
156
+ 'relevance_score': score,
157
+ 'relevance_rating': f"{score*10:.1f}/10"
158
+ })
159
+
160
+ elapsed = time.time() - start_time
161
+ logger.info(f"[355M RANKING] ✓ Ranked 3 trials in {elapsed:.1f}s")
162
+ logger.info(f"[355M RANKING] Final order: {[nct_id for _, _, nct_id in trial_scores]}")
163
+ logger.info(f"[355M RANKING] Scores: {[f'{s:.2f}' for s, _, _ in trial_scores]}")
164
+
165
+ # Store metadata as attribute for retrieval
166
+ ranked_trials.ranking_info = ranking_metadata
167
+
168
+ # Return re-ranked top 3 plus remaining trials (if any)
169
+ return ranked_trials + trials_list[3:]
170
+
171
+ # Alias for drop-in replacement
172
+ rank_trials_with_355m = rank_trials_FAST # Override the slow function!
173
+
174
+ # ===========================================================================
175
+ # FAST GENERATION using HuggingFace Inference API (Free)
176
+ # ===========================================================================
177
+
178
+ def generate_with_llama_70b_hf(query: str, rag_context: str = "", hf_token: str = None) -> str:
179
+ """
180
+ Use Llama-3.1-70B via HuggingFace Inference API (FREE)
181
+
182
+ This is what you're already using successfully!
183
+ ~10 second response time on HF free tier
184
+ """
185
+ try:
186
+ logger.info("Using Llama-3.1-70B via HuggingFace Inference API...")
187
+ client = InferenceClient(token=hf_token)
188
+
189
+ messages = [
190
+ {
191
+ "role": "system",
192
+ "content": "You are a medical information specialist. Answer based on the provided clinical trial data. Be concise and accurate."
193
+ },
194
+ {
195
+ "role": "user",
196
+ "content": f"""Clinical Trial Data:
197
+ {rag_context[:4000]}
198
+
199
+ Question: {query}
200
+
201
+ Please provide a concise answer based on the clinical trial data above."""
202
+ }
203
+ ]
204
+
205
+ response = client.chat_completion(
206
+ model="meta-llama/Llama-3.1-70B-Instruct",
207
+ messages=messages,
208
+ max_tokens=512,
209
+ temperature=0.3
210
+ )
211
+
212
+ answer = response.choices[0].message.content.strip()
213
+ logger.info(f"Llama 70B response generated via HF Inference API")
214
+ return answer
215
+ except Exception as e:
216
+ logger.error(f"Llama 70B generation failed: {e}")
217
+ return f"Error generating response with Llama 70B: {str(e)}"
218
+
219
+ # ===========================================================================
220
+ # OPTIMIZED 8B GENERATION (with cached model)
221
+ # ===========================================================================
222
+
223
+ @spaces.GPU
224
+ def generate_clinical_response_with_xupract(conversation, rag_context="", hf_token=None):
225
+ """OPTIMIZED: Use cached 8B model for faster generation"""
226
+ logger.info("Generating response with cached II-Medical-8B...")
227
+
228
+ # Get cached model (loads once, reuses after)
229
+ tokenizer, model = get_cached_8b_model(hf_token)
230
+
231
+ # Build prompt with RAG context (ChatML format for II-Medical-8B)
232
+ if rag_context:
233
+ prompt = f"""<|im_start|>system
234
+ You are a medical information specialist. Answer based on the provided clinical trial data. Please reason step-by-step, and put your final answer within \\boxed{{}}.
235
+ <|im_end|>
236
+ <|im_start|>user
237
+ Clinical Trial Data:
238
+ {rag_context[:4000]}
239
+
240
+ Question: {conversation}
241
+
242
+ Please reason step-by-step, and put your final answer within \\boxed{{}}.
243
+ <|im_end|>
244
+ <|im_start|>assistant
245
+ """
246
+ else:
247
+ prompt = f"""<|im_start|>system
248
+ You are a medical information specialist. Please reason step-by-step, and put your final answer within \\boxed{{}}.
249
+ <|im_end|>
250
+ <|im_start|>user
251
+ {conversation}
252
+
253
+ Please reason step-by-step, and put your final answer within \\boxed{{}}.
254
+ <|im_end|>
255
+ <|im_start|>assistant
256
+ """
257
+
258
+ try:
259
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(model.device)
260
+ with torch.no_grad():
261
+ outputs = model.generate(
262
+ **inputs,
263
+ max_new_tokens=1024,
264
+ temperature=0.3,
265
+ do_sample=True,
266
+ top_p=0.9,
267
+ eos_token_id=tokenizer.eos_token_id,
268
+ pad_token_id=tokenizer.pad_token_id
269
+ )
270
+ response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip()
271
+ return response
272
+ except Exception as e:
273
+ logger.error(f"Generation failed: {e}")
274
+ return f"Error generating response: {str(e)}"
275
+
276
+ # ===========================================================================
277
+ # FAST ENTITY EXTRACTION (with cached model)
278
+ # ===========================================================================
279
+
280
+ @spaces.GPU
281
+ def extract_entities_with_small_model(conversation):
282
+ """OPTIMIZED: Use cached 355M model for entity extraction"""
283
+ logger.info("Extracting entities with cached 355M model...")
284
+
285
+ # Get cached model
286
+ tokenizer, model = get_cached_355m_model()
287
+
288
+ # Better prompt for extraction
289
+ prompt = f"""Clinical query: {conversation}
290
+
291
+ Extract:
292
+ Drug name:"""
293
+
294
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(model.device)
295
+
296
+ with torch.no_grad():
297
+ outputs = model.generate(
298
+ inputs.input_ids,
299
+ max_length=400,
300
+ temperature=0.3,
301
+ top_p=0.9,
302
+ do_sample=True,
303
+ pad_token_id=tokenizer.pad_token_id,
304
+ eos_token_id=tokenizer.eos_token_id
305
+ )
306
+
307
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
308
+ return generated
309
+
310
+ # ===========================================================================
311
+ # QUERY EXPANSION (optional, with cached model)
312
+ # ===========================================================================
313
+
314
+ @spaces.GPU
315
+ def expand_query_with_355m(query):
316
+ """OPTIMIZED: Use cached 355M for query expansion"""
317
+ logger.info("Expanding query with cached 355M...")
318
+
319
+ # Get cached model
320
+ tokenizer, model = get_cached_355m_model()
321
+
322
+ # Prompt to get clinical context
323
+ prompt = f"Question: {query}\nClinical trial information:"
324
+
325
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(model.device)
326
+
327
+ with torch.no_grad():
328
+ outputs = model.generate(
329
+ inputs.input_ids,
330
+ max_length=inputs.input_ids.shape[1] + 100,
331
+ temperature=0.7,
332
+ top_p=0.9,
333
+ do_sample=True,
334
+ pad_token_id=tokenizer.pad_token_id,
335
+ eos_token_id=tokenizer.eos_token_id
336
+ )
337
+
338
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
339
+
340
+ # Extract the expansion part
341
+ if "Clinical trial information:" in generated:
342
+ expansion = generated.split("Clinical trial information:")[-1].strip()
343
+ else:
344
+ expansion = generated[len(prompt):].strip()
345
+
346
+ # Limit to reasonable length
347
+ expansion = expansion[:500] if len(expansion) > 500 else expansion
348
+
349
+ logger.info(f"Query expanded: {expansion[:100]}...")
350
+ return expansion
351
+
352
+ # ===========================================================================
353
+ # MAIN PIPELINE - Now FAST!
354
+ # ===========================================================================
355
+
356
+ def process_two_llm_system(conversation, rag_context="", hf_token=None, use_validation=False):
357
+ """
358
+ FAST pipeline:
359
+ 1. Small 355M model extracts entities (cached model - fast)
360
+ 2. RAG retrieves context (BM25 + semantic - already fast)
361
+ 3. Big model generates response (8B local or 70B API)
362
+ 4. Skip validation for speed
363
+
364
+ Total time: ~15s instead of 300+s
365
+ """
366
+ import time
367
+ start_time = time.time()
368
+
369
+ # Step 1: Use cached 355M to extract entities
370
+ entities = extract_entities_with_small_model(conversation)
371
+ logger.info(f"Entities extracted in {time.time()-start_time:.1f}s")
372
+
373
+ # Step 2: Generate response (choose one):
374
+
375
+ # Option A: Use 70B via HF Inference API (better quality, ~10s)
376
+ if hf_token:
377
+ clinical_evidence = generate_with_llama_70b_hf(
378
+ conversation,
379
+ rag_context,
380
+ hf_token
381
+ )
382
+ model_used = "Llama-3.1-70B (HF Inference API)"
383
+ else:
384
+ # Option B: Use cached 8B model (faster loading, ~5s)
385
+ clinical_evidence = generate_clinical_response_with_xupract(
386
+ conversation,
387
+ rag_context,
388
+ hf_token
389
+ )
390
+ model_used = "II-Medical-8B (cached)"
391
+
392
+ total_time = time.time() - start_time
393
+ logger.info(f"Total pipeline time: {total_time:.1f}s (was 300+s with 355M ranking)")
394
+
395
+ return {
396
+ 'clinical_evidence': clinical_evidence,
397
+ 'entities': entities,
398
+ 'model_used': model_used,
399
+ 'time_taken': total_time
400
+ }
401
+
402
+ def format_two_llm_response(result):
403
+ """Format the fast response"""
404
+ return f"""ENTITY EXTRACTION (Clinical Trial GPT 355M - Cached)
405
+ {'='*60}
406
+ {result.get('entities', 'None identified')}
407
+
408
+ CLINICAL RESPONSE ({result.get('model_used', 'Unknown')})
409
+ {'='*60}
410
+ {result['clinical_evidence']}
411
+
412
+ PERFORMANCE
413
+ {'='*60}
414
+ Time: {result.get('time_taken', 0):.1f}s (was 300+s with 355M ranking)
415
+ {'='*60}
416
+ """
417
+
418
+ # ===========================================================================
419
+ # PRELOAD MODELS AT STARTUP (Call this once in app.py!)
420
+ # ===========================================================================
421
+
422
+ def preload_all_models(hf_token=None):
423
+ """
424
+ Call this ONCE at app startup to cache all models.
425
+ This prevents model reloading on every query.
426
+
427
+ Add to your app.py initialization:
428
+ from two_llm_system_FAST import preload_all_models
429
+ preload_all_models(hf_token)
430
+ """
431
+ logger.info("Preloading and caching all models...")
432
+
433
+ # Cache the 355M model
434
+ _ = get_cached_355m_model()
435
+ logger.info("✓ 355M model cached")
436
+
437
+ # Cache the 8B model if token available
438
+ if hf_token:
439
+ try:
440
+ _ = get_cached_8b_model(hf_token)
441
+ logger.info("✓ 8B model cached")
442
+ except Exception as e:
443
+ logger.warning(f"Could not cache 8B model: {e}")
444
+
445
+ logger.info("All models preloaded and cached!")
446
+
447
+ # ===========================================================================
448
+ # BACKWARD COMPATIBILITY - Keep all original function names
449
+ # ===========================================================================
450
+
451
+ # These functions exist in the original but we optimize them
452
+ validate_with_small_model = lambda *args, **kwargs: "Validation skipped for speed"
453
+ extract_keywords_with_llama = lambda conv, hf_token=None: extract_entities_with_small_model(conv)[:100]
454
+ generate_response_with_llama = generate_with_llama_70b_hf
455
+ generate_clinical_knowledge_with_355m = lambda conv: f"Knowledge: {conv[:100]}..."
456
+ generate_with_355m = lambda conv, rag="", hf_token=None: generate_clinical_response_with_xupract(conv, rag, hf_token)
457
+
458
+ # Ensure we override the slow ranking function
459
+ rank_trials_with_355m = rank_trials_FAST
460
+
461
+ logger.info("Fast Two-LLM System loaded - 355M ranking bypassed!")