ctapi / foundation_engine.py
Your Name
Change top_k from 3 to 10 trials for better coverage
b1071b6
"""
Foundation 1.2
Clinical trial query system with 355M foundation model
"""
import gradio as gr
import os
from pathlib import Path
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer
import logging
from rank_bm25 import BM25Okapi
import re
from two_llm_system_FIXED import expand_query_with_355m, generate_clinical_response_with_xupract, rank_trials_with_355m
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize
hf_token = os.getenv("HF_TOKEN")
# Paths for data storage
# Files will be downloaded from HF Dataset on first run
DATASET_FILE = Path(__file__).parent / "complete_dataset_WITH_RESULTS_FULL.txt"
CHUNKS_FILE = Path(__file__).parent / "dataset_chunks_TRIAL_AWARE.pkl"
EMBEDDINGS_FILE = Path(__file__).parent / "dataset_embeddings_TRIAL_AWARE_FIXED.npy" # FIXED version to avoid cache
INVERTED_INDEX_FILE = Path(__file__).parent / "inverted_index_COMPREHENSIVE.pkl" # Pre-built inverted index (307MB)
# HF Dataset containing the large files
DATASET_REPO = "gmkdigitalmedia/foundation1.2-data"
# Global storage
embedder = None
doc_chunks = []
doc_embeddings = None
bm25_index = None # BM25 index for fast keyword search
inverted_index = None # Inverted index for instant drug lookup
# ============================================================================
# ANALYTICS TRACKING
# ============================================================================
from collections import defaultdict, Counter
import time as time_module
class QueryAnalytics:
"""Track query patterns and performance for monitoring"""
def __init__(self):
self.query_types = Counter()
self.response_times = defaultdict(list)
self.error_count = 0
self.total_queries = 0
self.start_time = time_module.time()
def record_query(self, query_type: str, response_time: float, success: bool = True):
"""Record a query execution"""
self.total_queries += 1
self.query_types[query_type] += 1
self.response_times[query_type].append(response_time)
if not success:
self.error_count += 1
logger.info(f"[ANALYTICS] Recorded: {query_type}, {response_time:.2f}s, success={success}")
def get_stats(self):
"""Get analytics summary"""
uptime = time_module.time() - self.start_time
stats = {
'total_queries': self.total_queries,
'uptime_seconds': uptime,
'error_rate': self.error_count / self.total_queries if self.total_queries > 0 else 0,
'query_type_distribution': dict(self.query_types),
'avg_response_times': {}
}
for query_type, times in self.response_times.items():
if times:
stats['avg_response_times'][query_type] = sum(times) / len(times)
return stats
# Initialize global analytics
query_analytics = QueryAnalytics()
# ============================================================================
# RAG FUNCTIONS
# ============================================================================
def load_embedder():
"""Load L6 embedding model (matches generated embeddings)"""
global embedder
if embedder is None:
logger.info("Loading MiniLM-L6 embedding model...")
# Force CPU to avoid CUDA init in main process
embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
logger.info("L6 model loaded on CPU")
def build_inverted_index(chunks):
"""
Build targeted inverted index for clinical search
Maps drugs, diseases, companies, and endpoints to trial indices for O(1) lookup
Indexes ONLY what matters:
1. INTERVENTION - drug/device names
2. CONDITIONS - diseases being treated
3. SPONSOR/COLLABORATOR/MANUFACTURER - company names
4. OUTCOME - trial endpoints (what's being measured)
Does NOT index trial names (unnecessary noise)
"""
import time
t_start = time.time()
inv_index = {}
logger.info("Building targeted index: drugs, diseases, companies, endpoints...")
# Generic words to skip
skip_words = {
'with', 'versus', 'combination', 'treatment', 'therapy', 'study', 'trial',
'phase', 'double', 'blind', 'placebo', 'group', 'control', 'active',
'randomized', 'multicenter', 'open', 'label', 'crossover'
}
for idx, chunk_data in enumerate(chunks):
if idx % 100000 == 0 and idx > 0:
logger.info(f" Indexed {idx:,}/{len(chunks):,} trials...")
text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data
text_lower = text.lower()
# 1. DRUGS from INTERVENTION field
intervention_match = re.search(r'intervention[:\s]+([^\n]+)', text_lower)
if intervention_match:
intervention_text = intervention_match.group(1)
drugs = re.split(r'[,;\-\s]+', intervention_text)
for drug in drugs:
drug = drug.strip('.,;:() ')
if len(drug) > 3 and drug not in skip_words:
if drug not in inv_index:
inv_index[drug] = []
if idx not in inv_index[drug]:
inv_index[drug].append(idx)
# 2. DISEASES from CONDITIONS field
conditions_match = re.search(r'conditions?[:\s]+([^\n]+)', text_lower)
if conditions_match:
conditions_text = conditions_match.group(1)
diseases = re.split(r'[,;\|]+', conditions_text)
for disease in diseases:
disease = disease.strip('.,;:() ')
# Split multi-word conditions and index each significant word
disease_words = re.findall(r'\b\w{4,}\b', disease)
for word in disease_words:
if word not in skip_words:
if word not in inv_index:
inv_index[word] = []
if idx not in inv_index[word]:
inv_index[word].append(idx)
# 3. COMPANIES from SPONSOR field
sponsor_match = re.search(r'sponsor[:\s]+([^\n]+)', text_lower)
if sponsor_match:
sponsor_text = sponsor_match.group(1)
sponsors = re.split(r'[,;\|]+', sponsor_text)
for sponsor in sponsors:
sponsor = sponsor.strip('.,;:() ')
if len(sponsor) > 3:
if sponsor not in inv_index:
inv_index[sponsor] = []
if idx not in inv_index[sponsor]:
inv_index[sponsor].append(idx)
# 4. COMPANIES from COLLABORATOR field
collab_match = re.search(r'collaborator[:\s]+([^\n]+)', text_lower)
if collab_match:
collab_text = collab_match.group(1)
collaborators = re.split(r'[,;\|]+', collab_text)
for collab in collaborators:
collab = collab.strip('.,;:() ')
if len(collab) > 3:
if collab not in inv_index:
inv_index[collab] = []
if idx not in inv_index[collab]:
inv_index[collab].append(idx)
# 5. COMPANIES from MANUFACTURER field
manuf_match = re.search(r'manufacturer[:\s]+([^\n]+)', text_lower)
if manuf_match:
manuf_text = manuf_match.group(1)
manufacturers = re.split(r'[,;\|]+', manuf_text)
for manuf in manufacturers:
manuf = manuf.strip('.,;:() ')
if len(manuf) > 3:
if manuf not in inv_index:
inv_index[manuf] = []
if idx not in inv_index[manuf]:
inv_index[manuf].append(idx)
# 6. ENDPOINTS from OUTCOME fields
# Look for outcome measures (what's being measured)
outcome_matches = re.findall(r'outcome[:\s]+([^\n]+)', text_lower)
for outcome_match in outcome_matches[:5]: # First 5 outcomes only
# Extract meaningful endpoint terms
endpoint_words = re.findall(r'\b\w{5,}\b', outcome_match) # 5+ char words
for word in endpoint_words[:3]: # First 3 words per outcome
if word not in skip_words and word not in {'outcome', 'measure', 'primary', 'secondary'}:
if word not in inv_index:
inv_index[word] = []
if idx not in inv_index[word]:
inv_index[word].append(idx)
t_elapsed = time.time() - t_start
logger.info(f"✓ Targeted index built in {t_elapsed:.1f}s with {len(inv_index):,} terms")
# Log sample entries for debugging (drugs, diseases, companies, endpoints)
sample_terms = {
'drugs': ['keytruda', 'opdivo', 'humira'],
'diseases': ['cancer', 'diabetes', 'melanoma'],
'companies': ['novartis', 'pfizer', 'merck'],
'endpoints': ['survival', 'response', 'remission']
}
for category, terms in sample_terms.items():
logger.info(f" {category.upper()} samples:")
for term in terms:
if term in inv_index:
logger.info(f" '{term}' -> {len(inv_index[term])} trials")
return inv_index
def download_from_dataset(filename):
"""Download file from HF Dataset if not present locally"""
from huggingface_hub import hf_hub_download
import tempfile
# Use /tmp for downloads (has write permissions in Docker)
download_dir = Path("/tmp/foundation_data")
download_dir.mkdir(exist_ok=True)
local_file = download_dir / filename
if local_file.exists():
logger.info(f"Found cached {filename}")
return local_file
try:
logger.info(f"Downloading {filename} from {DATASET_REPO}...")
downloaded_file = hf_hub_download(
repo_id=DATASET_REPO,
filename=filename,
repo_type="dataset",
local_dir=download_dir,
local_dir_use_symlinks=False
)
logger.info(f"Downloaded {filename}")
return Path(downloaded_file)
except Exception as e:
logger.error(f"Failed to download {filename}: {e}")
return None
def load_embeddings():
"""Load pre-generated embeddings (download from dataset if needed)"""
global doc_chunks, doc_embeddings, bm25_index
# Try to download if not present - store paths returned by download
chunks_path = CHUNKS_FILE
embeddings_path = EMBEDDINGS_FILE
dataset_path = DATASET_FILE
index_path = INVERTED_INDEX_FILE
if not CHUNKS_FILE.exists():
downloaded = download_from_dataset("dataset_chunks_TRIAL_AWARE.pkl")
if downloaded:
chunks_path = downloaded
if not EMBEDDINGS_FILE.exists():
downloaded = download_from_dataset("dataset_embeddings_TRIAL_AWARE_FIXED.npy") # FIXED version
if downloaded:
embeddings_path = downloaded
if not DATASET_FILE.exists():
downloaded = download_from_dataset("complete_dataset_WITH_RESULTS_FULL.txt")
if downloaded:
dataset_path = downloaded
# Download inverted index from dataset (307 MB, truly comprehensive)
if not INVERTED_INDEX_FILE.exists():
downloaded = download_from_dataset("inverted_index_COMPREHENSIVE.pkl")
if downloaded:
index_path = downloaded
logger.info(f"✓ Downloaded comprehensive inverted index from dataset")
if chunks_path.exists() and embeddings_path.exists():
try:
logger.info("Loading embeddings from disk...")
with open(chunks_path, 'rb') as f:
doc_chunks = pickle.load(f)
# Load embeddings
loaded_embeddings = np.load(embeddings_path, allow_pickle=True)
logger.info(f"Loaded embeddings type: {type(loaded_embeddings)}")
# Check if it's already a proper numpy array
if isinstance(loaded_embeddings, np.ndarray) and loaded_embeddings.ndim == 2:
doc_embeddings = loaded_embeddings
logger.info(f"✓ Embeddings are proper numpy array with shape: {doc_embeddings.shape}")
elif isinstance(loaded_embeddings, list):
logger.info(f"Converting embeddings from list to numpy array (memory efficient)...")
# Convert in chunks to avoid memory spike
chunk_size = 10000
total = len(loaded_embeddings)
# DEBUG: Print first 3 items to see format
logger.info(f"DEBUG: Total embeddings: {total}")
logger.info(f"DEBUG: Type of first item: {type(loaded_embeddings[0])}")
# Check if this is actually the chunks file (wrong file uploaded)
if isinstance(loaded_embeddings[0], tuple) and len(loaded_embeddings[0]) == 2:
if isinstance(loaded_embeddings[0][0], int) and isinstance(loaded_embeddings[0][1], str):
raise ValueError(
f"ERROR: The embeddings file contains (int, string) tuples!\n"
f"This looks like the CHUNKS file was uploaded as the embeddings file.\n\n"
f"First item: {loaded_embeddings[0][:2]}\n\n"
f"Please re-upload the correct file:\n"
f" CORRECT: dataset_embeddings_TRIAL_AWARE.npy (numpy array, 855 MB)\n"
f" WRONG: dataset_chunks_TRIAL_AWARE.pkl (tuples, 2.8 GB)\n\n"
f"The local file at /mnt/c/Users/ibm/Documents/HF/kg_to_model/dataset_embeddings_TRIAL_AWARE.npy is correct."
)
if isinstance(loaded_embeddings[0], tuple):
logger.info(f"DEBUG: Tuple length: {len(loaded_embeddings[0])}")
for i, item in enumerate(loaded_embeddings[0][:5] if len(loaded_embeddings[0]) > 5 else loaded_embeddings[0]):
logger.info(f"DEBUG: Tuple element {i}: type={type(item)}, preview={str(item)[:100]}")
# Get embedding dimension from first item
first_emb = loaded_embeddings[0]
emb_idx = None # Initialize
# Handle different formats
if isinstance(first_emb, tuple):
# Try both positions - could be (id, emb) or (emb, id)
logger.info(f"DEBUG: Trying to find embedding vector in tuple...")
emb_vector = None
for idx, elem in enumerate(first_emb):
if isinstance(elem, (list, np.ndarray)):
emb_vector = elem
emb_idx = idx
logger.info(f"DEBUG: Found embedding at position {idx}")
break
if emb_vector is None:
raise ValueError(f"No embedding vector found in tuple. Tuple contains: {[type(x) for x in first_emb]}")
emb_dim = len(emb_vector)
logger.info(f"DEBUG: Embedding dimension: {emb_dim}")
elif isinstance(first_emb, list):
emb_dim = len(first_emb)
emb_idx = None
elif isinstance(first_emb, np.ndarray):
emb_dim = first_emb.shape[0]
emb_idx = None
else:
raise ValueError(f"Unknown embedding format: {type(first_emb)}")
logger.info(f"Creating array for {total} embeddings of dimension {emb_dim}")
# Pre-allocate array
doc_embeddings = np.zeros((total, emb_dim), dtype=np.float32)
# Fill in chunks
for i in range(0, total, chunk_size):
end = min(i + chunk_size, total)
# Extract embeddings from tuples if needed
if isinstance(first_emb, tuple) and emb_idx is not None:
# Extract just the embedding vector from each tuple at the correct position
batch = [item[emb_idx] for item in loaded_embeddings[i:end]]
doc_embeddings[i:end] = batch
else:
doc_embeddings[i:end] = loaded_embeddings[i:end]
if i % 50000 == 0:
logger.info(f"Converted {i}/{total} embeddings...")
logger.info(f"✓ Converted to array with shape: {doc_embeddings.shape}")
else:
doc_embeddings = loaded_embeddings
logger.info(f"Embeddings already numpy array with shape: {doc_embeddings.shape}")
logger.info(f"Loaded {len(doc_chunks)} chunks with embeddings")
# Skip BM25 (too memory-heavy for Docker), use inverted index only
global inverted_index
# Try to load pre-built comprehensive inverted index (77MB) from dataset
if index_path.exists():
logger.info(f"Loading comprehensive inverted index from {index_path.name}...")
try:
with open(index_path, 'rb') as f:
inverted_index = pickle.load(f)
logger.info(f"✓ Loaded comprehensive index with {len(inverted_index):,} terms")
logger.info(f" Includes: TITLE (all words), INTERVENTION, CONDITIONS, SPONSOR, SUMMARY/DESCRIPTION (companies)")
except Exception as e:
logger.warning(f"Failed to load comprehensive index: {e}, building basic index...")
inverted_index = build_inverted_index(doc_chunks)
else:
logger.info("Comprehensive inverted index not found, building basic index (15 minutes)...")
inverted_index = build_inverted_index(doc_chunks)
logger.info("Will use inverted index + semantic search (no BM25)")
return True
except Exception as e:
logger.error(f"Failed to load embeddings: {e}")
raise RuntimeError("Embeddings are required but failed to load") from e
raise RuntimeError("Embeddings files not found - system cannot function without embeddings")
def filter_trial_for_clinical_summary(trial_text):
"""
Filter trial data to keep essential clinical information including SOME results.
COMPREHENSIVE FILTERING:
- Keeps all core trial info (title, summary, conditions, interventions)
- Keeps sponsor/collaborator/manufacturer (WHO is running the trial)
- Keeps first 5 outcomes (to show key endpoints)
- Keeps first 5 result values per trial (to show actual data)
- Filters out overwhelming statistical noise (hundreds of baseline/adverse event lines)
This ensures the LLM sees comprehensive context including company information.
"""
if not trial_text:
return trial_text
lines = trial_text.split('\n')
filtered_lines = []
# Counters to limit repetitive data
outcome_count = 0
outcome_desc_count = 0
result_value_count = 0
# Limits
MAX_OUTCOMES = 5
MAX_OUTCOME_DESC = 5
MAX_RESULT_VALUES = 5
for line in lines:
line_stripped = line.strip()
# Skip empty lines
if not line_stripped:
continue
# ALWAYS SKIP: Overwhelming noise
always_skip = [
'BASELINE:', 'SERIOUS_ADVERSE_EVENT:', 'OTHER_ADVERSE_EVENT:',
'OUTCOME_TYPE:', 'OUTCOME_TIME_FRAME:', 'OUTCOME_SAFETY:',
'OUTCOME_OTHER:', 'OUTCOME_NUMBER:'
]
should_skip = False
for marker in always_skip:
if line_stripped.startswith(marker):
should_skip = True
break
if should_skip:
continue
# LIMITED KEEP: Outcomes (first N only)
if line_stripped.startswith('OUTCOME:'):
outcome_count += 1
if outcome_count <= MAX_OUTCOMES:
filtered_lines.append(line)
continue
# LIMITED KEEP: Outcome descriptions (first N only)
if line_stripped.startswith('OUTCOME_DESCRIPTION:'):
outcome_desc_count += 1
if outcome_desc_count <= MAX_OUTCOME_DESC:
filtered_lines.append(line)
continue
# LIMITED KEEP: Result values (first N only)
if line_stripped.startswith('RESULT_VALUE:'):
result_value_count += 1
if result_value_count <= MAX_RESULT_VALUES:
filtered_lines.append(line)
continue
# ALWAYS KEEP: Core trial information + context
always_keep = [
'NCT_ID:', 'TITLE:', 'OFFICIAL_TITLE:',
'SUMMARY:', 'DESCRIPTION:',
'CONDITIONS:', 'INTERVENTION:', # WHAT disease, WHAT drug
'SPONSOR:', 'COLLABORATOR:', 'MANUFACTURER:', # WHO is running/funding
'ELIGIBILITY:'
# Note: OUTCOME/OUTCOME_DESCRIPTION handled in LIMITED KEEP section above
]
for marker in always_keep:
if line_stripped.startswith(marker):
filtered_lines.append(line)
break
return '\n'.join(filtered_lines)
def retrieve_context_with_embeddings(query, top_k=10, entities=None):
"""
ENTERPRISE HYBRID SEARCH with STRICT ENTITY FILTERING
- Enforces HARD FILTERS for companies (sponsor/collaborator)
- Extracts meaningful terms from query (case-insensitive)
- Scores each trial by keyword frequency (TF-IDF style)
- Also gets semantic similarity scores
- Merges both scores with weighted combination
Args:
query: Search query string
top_k: Number of results to return
entities: Dict with 'drugs', 'diseases', 'companies' - if provided, STRICTLY filters
"""
import time
import re
from collections import Counter
global doc_chunks, doc_embeddings, embedder
if doc_embeddings is None or len(doc_chunks) == 0:
logger.error("Embeddings not loaded!")
return ""
t0 = time.time()
# Extract ALL meaningful words from query (stop words removed)
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with',
'is', 'are', 'was', 'were', 'be', 'been', 'being', 'what', 'how', 'do', 'you', 'know',
'about', 'that', 'this', 'there', 'it'}
query_lower = query.lower()
# Remove punctuation and split
words = re.findall(r'\b\w+\b', query_lower)
# Filter out stop words and short words
query_terms = [w for w in words if len(w) > 2 and w not in stop_words]
logger.info(f"[HYBRID] Query terms extracted: {query_terms}")
# PARALLEL SEARCH: Run both keyword and semantic simultaneously
# 1. KEYWORD SCORING WITH BM25 (Fast!)
t_kw = time.time()
# Use inverted index for drug lookup (lightweight, no BM25)
global bm25_index, inverted_index
keyword_scores = {}
if inverted_index is not None:
# Check if any query terms are in our drug/intervention inverted index
inv_index_candidates = set()
for term in query_terms:
if term in inverted_index:
inv_index_candidates.update(inverted_index[term])
logger.info(f"[INVERTED INDEX] Found {len(inverted_index[term])} trials for '{term}'")
# FAST PATH: If we have inverted index hits (drug names), score those trials
if inv_index_candidates:
logger.info(f"[FAST PATH] Checking {len(inv_index_candidates)} inverted index candidates")
# CRITICAL: Identify which terms are specific drugs (low frequency)
drug_specific_terms = set()
for term in query_terms:
if term in inverted_index and len(inverted_index[term]) < 100:
# This term appears in <100 trials - likely a specific drug name!
drug_specific_terms.add(term)
logger.info(f"[DRUG SPECIFIC] '{term}' found in {len(inverted_index[term])} trials - treating as drug name")
for idx in inv_index_candidates:
# No BM25, use simple match count as base score
base_score = 1.0
# Check if this trial contains a drug-specific term
chunk_data = doc_chunks[idx]
chunk_text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data
chunk_lower = chunk_text.lower()
has_drug_match = False
for drug_term in drug_specific_terms:
if drug_term in chunk_lower:
has_drug_match = True
break
# MASSIVE PRIORITY for drug-specific trials
if has_drug_match:
# Drug-specific trials get GUARANTEED top ranking
score = 1000.0 + base_score
logger.info(f"[DRUG PRIORITY] Trial {idx} contains specific drug - score={score:.1f}")
else:
# Regular inverted index hits (generic terms)
if base_score <= 0:
base_score = 0.1
score = base_score
# Apply field-specific boosting for non-drug terms
max_field_boost = 1.0
for term in query_terms:
if term not in chunk_lower or term in drug_specific_terms:
continue
# INTERVENTION field - medium priority for non-drug terms
if f'intervention: {term}' in chunk_lower or f'intervention:{term}' in chunk_lower:
max_field_boost = max(max_field_boost, 3.0)
# TITLE field - low priority
elif 'title:' in chunk_lower:
title_pos = chunk_lower.find('title:')
term_pos = chunk_lower.find(term)
if title_pos < term_pos < title_pos + 200:
max_field_boost = max(max_field_boost, 2.0)
score *= max_field_boost
keyword_scores[idx] = score
else:
logger.info(f"[FALLBACK] No inverted index hits, using pure semantic search")
logger.info(f"[HYBRID] Inverted index scoring: {len(keyword_scores)} trials matched ({time.time()-t_kw:.2f}s)")
# 1.5. STRICT COMPANY FILTERING (if companies specified)
company_filter_failed = False
if entities and entities.get('companies'):
companies = [c.lower() for c in entities['companies']]
logger.info(f"[STRICT FILTER] Enforcing company filter: {companies}")
# Save original scores in case we need to fall back
original_keyword_scores = keyword_scores.copy()
# Filter keyword_scores to ONLY trials with these companies
filtered_keyword_scores = {}
sponsor_field_patterns = ['sponsor:', 'collaborator:', 'manufacturer:']
for idx, score in keyword_scores.items():
chunk_data = doc_chunks[idx]
chunk_text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data
chunk_lower = chunk_text.lower()
# Check if ANY company appears in sponsor/collaborator/manufacturer fields
has_company = False
for company in companies:
# Look for company name in sponsor-related fields
for field in sponsor_field_patterns:
if field in chunk_lower:
field_start = chunk_lower.find(field)
field_text = chunk_lower[field_start:field_start+500] # Next 500 chars
if company in field_text:
has_company = True
logger.info(f"[COMPANY MATCH] Trial {idx} has '{company}' in {field}")
break
if has_company:
break
if has_company:
filtered_keyword_scores[idx] = score * 10.0 # 10x boost for company match
# If no company match, EXCLUDE this trial
before_count = len(keyword_scores)
after_count = len(filtered_keyword_scores)
logger.info(f"[STRICT FILTER] Filtered {before_count}{after_count} trials (only those from {companies})")
# If no company matches, fall back to original search but flag it
if len(filtered_keyword_scores) == 0:
logger.warning(f"[STRICT FILTER] No trials found from companies {companies}, falling back to general search")
company_filter_failed = True
keyword_scores = original_keyword_scores # Restore original
else:
keyword_scores = filtered_keyword_scores
# 2. SEMANTIC SCORING
load_embedder()
t_sem = time.time()
query_embedding = embedder.encode([query])[0]
semantic_similarities = np.dot(doc_embeddings, query_embedding)
logger.info(f"[HYBRID] Semantic scoring complete ({time.time()-t_sem:.2f}s)")
# 3. MERGE SCORES
# Normalize both scores to 0-1 range
if keyword_scores:
max_kw = max(keyword_scores.values())
keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()}
else:
keyword_scores_norm = {}
max_sem = semantic_similarities.max()
min_sem = semantic_similarities.min()
semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10)
# Combined score: 50% keyword (with IDF/field boost), 50% semantic (context)
# Balanced approach: IDF-weighted keywords + semantic understanding
combined_scores = np.zeros(len(doc_chunks))
for idx in range(len(doc_chunks)):
kw_score = keyword_scores_norm.get(idx, 0.0)
sem_score = semantic_scores_norm[idx]
# If keyword match exists, balance keyword + semantic
if kw_score > 0:
combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score
else:
# Pure semantic if no keyword match
combined_scores[idx] = sem_score
# Get top K by combined score (get more candidates to sort by recency)
# We'll get 10 candidates, then sort by NCT ID to find the 3 most recent
candidate_k = max(top_k * 3, 10) # Get 3x requested, minimum 10
top_indices = np.argsort(combined_scores)[-candidate_k:][::-1]
logger.info(f"[HYBRID] Top 3 combined scores: {combined_scores[top_indices[:3]]}")
logger.info(f"[HYBRID] Top 3 keyword scores: {[keyword_scores_norm.get(i, 0.0) for i in top_indices[:3]]}")
logger.info(f"[HYBRID] Top 3 semantic scores: {[semantic_scores_norm[i] for i in top_indices[:3]]}")
# Extract text and scores for 355M ranking
# Format as (score, text) tuples for rank_trials_with_355m
candidate_trials_for_ranking = [(combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i]) for i in top_indices]
# SORT BY NCT ID (higher = newer) before 355M ranking
def extract_nct_number(trial_tuple):
"""Extract NCT number from trial text for sorting (higher = newer)"""
_, text = trial_tuple
match = re.search(r'NCT_ID:\s*NCT(\d+)', text)
return int(match.group(1)) if match else 0
# Sort candidates by NCT ID (descending = newest first)
candidate_trials_for_ranking.sort(key=extract_nct_number, reverse=True)
# Log top 5 NCT IDs to show recency sorting
top_ncts = []
for score, text in candidate_trials_for_ranking[:5]:
match = re.search(r'NCT_ID:\s*(NCT\d+)', text)
if match:
top_ncts.append(match.group(1))
logger.info(f"[NCT SORT] Top 5 candidates by recency: {top_ncts}")
# SKIP 355M RANKING - It's broken (gives 0.50 to everything) and wastes 10 seconds
# Just use the hybrid-scored + recency-sorted candidates
logger.info(f"[FAST MODE] Using hybrid search + recency sort (skipping broken 355M ranking)")
ranked_trials = candidate_trials_for_ranking
# Take top K from ranked results
top_ranked = ranked_trials[:top_k]
logger.info(f"[FAST MODE] Selected top {len(top_ranked)} trials (hybrid score + recency)")
# Extract just the text
raw_chunks = [trial_text for _, trial_text in top_ranked]
# Apply clinical filter to each trial
context_chunks = [filter_trial_for_clinical_summary(chunk) for chunk in raw_chunks]
if context_chunks:
first_trial_preview = context_chunks[0][:200]
logger.info(f"[HYBRID] First result (filtered): {first_trial_preview}")
# Add ranking information if available from 355M
if hasattr(ranked_trials, 'ranking_info'):
ranking_header = "[TRIAL RANKING BY CLINICAL RELEVANCE GPT]\n"
for info in ranked_trials.ranking_info:
ranking_header += f"Rank {info['rank']}: {info['nct_id']} - Relevance {info['relevance_rating']}\n"
ranking_header += "---\n\n"
# Prepend ranking info to first trial
if context_chunks:
context_chunks[0] = ranking_header + context_chunks[0]
logger.info(f"[355M RANKING] Added ranking metadata to context for final LLM")
context = "\n\n---\n\n".join(context_chunks) # Use --- as separator between trials
logger.info(f"[HYBRID] TOTAL TIME: {time.time()-t0:.2f}s")
logger.info(f"[HYBRID] Filtered context length: {len(context)} chars (was ~{sum(len(c) for c in raw_chunks)} chars)")
return context
def keyword_search_query_text(query, max_results=10, hf_token=None):
"""Search dataset using ALL meaningful words from the full query"""
if not DATASET_FILE.exists():
logger.error("Dataset file not found")
return ""
# Extract all meaningful words from the full query
# Remove common stopwords but keep medical/clinical terms
stopwords = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'should',
'could', 'may', 'might', 'must', 'can', 'of', 'at', 'by', 'for', 'with',
'about', 'as', 'into', 'through', 'during', 'to', 'from', 'in', 'on',
'what', 'you', 'know', 'that', 'relevant'}
# Extract words, filter stopwords and short words
words = query.lower().split()
search_terms = [w.strip('?.,!;:()[]{}') for w in words
if w.lower() not in stopwords and len(w) >= 3]
if not search_terms:
logger.warning("No search terms extracted from query")
return ""
logger.info(f"Search terms from full query: {search_terms}")
# Store trials with match scores
trials_with_scores = []
current_trial = ""
try:
with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
# Check if new trial starts
if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"):
# Score previous trial
if current_trial:
trial_lower = current_trial.lower()
# Count matches for all search terms
score = sum(1 for term in search_terms if term in trial_lower)
if score > 0:
trials_with_scores.append((score, current_trial))
current_trial = line
else:
current_trial += line
# Check last trial
if current_trial:
trial_lower = current_trial.lower()
score = sum(1 for term in search_terms if term in trial_lower)
if score > 0:
trials_with_scores.append((score, current_trial))
# Sort by score (highest first) and take top results
trials_with_scores.sort(reverse=True, key=lambda x: x[0])
matching_trials = [(score, trial) for score, trial in trials_with_scores[:max_results]]
if matching_trials:
logger.info(f"Keyword search found {len(matching_trials)} trials")
return matching_trials # Return list of (score, trial) tuples
else:
logger.warning("Keyword search found no matching trials")
return []
except Exception as e:
logger.error(f"Keyword search failed: {e}")
return []
def keyword_search_in_dataset(entities, max_results=10):
"""Legacy: Search dataset file for keyword matches using extracted entities"""
if not DATASET_FILE.exists():
logger.error("Dataset file not found")
return ""
drugs = [d.lower() for d in entities.get('drugs', [])]
conditions = [c.lower() for c in entities.get('conditions', [])]
if not drugs and not conditions:
logger.warning("No search terms for keyword search")
return ""
logger.info(f"Keyword search - Drugs: {drugs}, Conditions: {conditions}")
# Store trials with match scores
trials_with_scores = []
current_trial = ""
try:
with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
# Check if new trial starts
if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"):
# Score previous trial
if current_trial:
trial_lower = current_trial.lower()
# Count matches
drug_matches = sum(1 for d in drugs if d in trial_lower)
condition_matches = sum(1 for c in conditions if c in trial_lower)
# Only include trials that match at least the drug (if drug was specified)
if drugs:
if drug_matches > 0:
score = drug_matches * 10 + condition_matches
trials_with_scores.append((score, current_trial))
elif condition_matches > 0:
# No drug specified, just match conditions
trials_with_scores.append((condition_matches, current_trial))
current_trial = line
else:
current_trial += line
# Check last trial
if current_trial:
trial_lower = current_trial.lower()
drug_matches = sum(1 for d in drugs if d in trial_lower)
condition_matches = sum(1 for c in conditions if c in trial_lower)
if drugs:
if drug_matches > 0:
score = drug_matches * 10 + condition_matches
trials_with_scores.append((score, current_trial))
elif condition_matches > 0:
trials_with_scores.append((condition_matches, current_trial))
# Sort by score (highest first) and take top results
trials_with_scores.sort(reverse=True, key=lambda x: x[0])
matching_trials = [trial for score, trial in trials_with_scores[:max_results]]
if matching_trials:
context = "\n\n---\n\n".join(matching_trials)
if len(context) > 6000:
context = context[:6000] + "..."
logger.info(f"Keyword search found {len(matching_trials)} trials (from {len(trials_with_scores)} candidates)")
return context
else:
logger.warning("Keyword search found no trials matching drug")
return ""
except Exception as e:
logger.error(f"Keyword search failed: {e}")
return ""
# ============================================================================
# ENTITY EXTRACTION
# ============================================================================
def parse_entities_from_query(conversation, hf_token=None):
"""Parse entities from query using both 355M and 8B models + regex fallback"""
entities = {'drugs': [], 'conditions': []}
# Use 355M model for entity extraction
extracted_355m = extract_entities_with_small_model(conversation)
# Also use 8B model for more reliable extraction
extracted_8b = extract_entities_with_8b(conversation, hf_token=hf_token)
# Combine both extractions
extracted = (extracted_355m or "") + "\n" + (extracted_8b or "")
# Parse model output
if extracted:
lines = extracted.split('\n')
for line in lines:
lower_line = line.lower()
if 'drug:' in lower_line or 'medication:' in lower_line:
drug = re.sub(r'(drug:|medication:)', '', line, flags=re.IGNORECASE).strip()
if drug:
entities['drugs'].append(drug)
elif 'condition:' in lower_line or 'disease:' in lower_line:
condition = re.sub(r'(condition:|disease:)', '', line, flags=re.IGNORECASE).strip()
if condition:
entities['conditions'].append(condition)
# Regex fallback for standard drug naming patterns
drug_patterns = [
r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: -mab suffix
r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: -nib suffix
r'\b([A-Z]\d+[A-Z]+\d+)\b' # Alphanumeric codes like F8IL10
]
for pattern in drug_patterns:
matches = re.findall(pattern, conversation)
for match in matches:
if match.lower() not in [d.lower() for d in entities['drugs']]:
entities['drugs'].append(match)
condition_patterns = [
r'\b(sjogren\'?s?|lupus|myelofibrosis|rheumatoid arthritis)\b'
]
for pattern in condition_patterns:
matches = re.findall(pattern, conversation, re.IGNORECASE)
for match in matches:
if match not in [c.lower() for c in entities['conditions']]:
entities['conditions'].append(match)
logger.info(f"Extracted entities: {entities}")
return entities
# ============================================================================
# MAIN QUERY PROCESSING
# ============================================================================
def extract_entities_simple(query):
"""Simple entity extraction using regex patterns - no model needed"""
entities = {'drugs': [], 'conditions': []}
# Drug patterns
drug_patterns = [
r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: ianalumab, rituximab, etc.
r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: imatinib, etc.
r'\b([A-Z]\d+[A-Z]+\d+)\b', # Alphanumeric codes
r'\b(ianalumab|rituximab|tocilizumab|adalimumab|infliximab)\b', # Common drugs
]
# Condition patterns
condition_patterns = [
r'\b(sjogren\'?s?\s+syndrome)\b',
r'\b(rheumatoid arthritis)\b',
r'\b(lupus)\b',
r'\b(myelofibrosis)\b',
r'\b(diabetes)\b',
r'\b(cancer|carcinoma|melanoma)\b',
]
query_lower = query.lower()
# Extract drugs
for pattern in drug_patterns:
matches = re.findall(pattern, query, re.IGNORECASE)
for match in matches:
if match.lower() not in [d.lower() for d in entities['drugs']]:
entities['drugs'].append(match)
# Extract conditions
for pattern in condition_patterns:
matches = re.findall(pattern, query, re.IGNORECASE)
for match in matches:
if match.lower() not in [c.lower() for c in entities['conditions']]:
entities['conditions'].append(match)
logger.info(f"Extracted entities: {entities}")
return entities
def parse_query_with_llm(query, hf_token=None):
"""
Use fast LLM to parse query and extract structured information
Extracts:
- Drug names
- Diseases/conditions
- Companies (sponsors/manufacturers)
- Endpoints (what's being measured)
- Search terms (optimized for RAG)
Returns: Dict with extracted entities and optimized search query
"""
try:
from huggingface_hub import InferenceClient
logger.info("[QUERY PARSER] Analyzing user query with LLM...")
client = InferenceClient(token=hf_token, timeout=30)
parse_prompt = f"""You are an expert in clinical trial terminology. Extract and expand entities from this query.
Query: "{query}"
Your task is to think creatively about ALL possible ways these entities might appear in clinical trial databases.
For each entity type, brainstorm extensively:
DRUGS:
- Start with drugs explicitly mentioned
- Add ALL possible names: brand names, generic names, research codes (like BNT162b2),
manufacturer+drug combos (Pfizer-BioNTech), chemical names, common abbreviations
- Think: "What would a pharmaceutical company call this in a trial?"
- Example: "Pfizer COVID vaccine" → ["Comirnaty", "BNT162b2", "tozinameran", "Pfizer-BioNTech COVID-19 vaccine", "mRNA-1273"]
DISEASES:
- Include the disease/condition mentioned
- Add medical synonyms, ICD-10 terms, related conditions
- Both technical and colloquial terms
- Example: "COVID" → ["COVID-19", "SARS-CoV-2", "coronavirus disease 2019", "severe acute respiratory syndrome coronavirus 2"]
COMPANIES:
- Company mentioned plus parent companies, subsidiaries
- Include previous names, merged entities, partnership names
- Example: "Pfizer" → ["Pfizer", "Pfizer Inc.", "Pfizer-BioNTech", "BioNTech SE"]
ENDPOINTS:
- Any specific outcomes, measures, or endpoints mentioned
- Include related clinical measures
SEARCH_TERMS:
- Comprehensive keywords combining above entities
- Include partial matches that might be relevant
Format EXACTLY as:
DRUGS: [list or "none"]
DISEASES: [list or "none"]
COMPANIES: [list or "none"]
ENDPOINTS: [list or "none"]
SEARCH_TERMS: [comprehensive keyword list]
Be expansive - more synonyms mean better trial matching."""
response = client.chat_completion(
model="meta-llama/Llama-3.1-70B-Instruct",
messages=[{"role": "user", "content": parse_prompt}],
max_tokens=500, # Increased for comprehensive synonyms
temperature=0.3 # Slightly higher for creative synonym generation
)
parsed = response.choices[0].message.content.strip()
logger.info(f"[QUERY PARSER] Extracted entities:\n{parsed}")
# Parse the response into dict
result = {
'raw_parsed': parsed,
'drugs': [],
'diseases': [],
'companies': [],
'endpoints': [],
'search_terms': query # fallback
}
lines = parsed.split('\n')
for line in lines:
line = line.strip()
if line.startswith('DRUGS:'):
drugs = line.replace('DRUGS:', '').strip()
# Remove brackets if present: [item1, item2] → item1, item2
drugs = drugs.strip('[]')
if drugs and drugs.lower() != 'none':
result['drugs'] = [d.strip().strip('"\'') for d in drugs.split(',') if d.strip()]
elif line.startswith('DISEASES:'):
diseases = line.replace('DISEASES:', '').strip()
diseases = diseases.strip('[]')
if diseases and diseases.lower() != 'none':
result['diseases'] = [d.strip().strip('"\'') for d in diseases.split(',') if d.strip()]
elif line.startswith('COMPANIES:'):
companies = line.replace('COMPANIES:', '').strip()
companies = companies.strip('[]')
if companies and companies.lower() != 'none':
result['companies'] = [c.strip().strip('"\'') for c in companies.split(',') if c.strip()]
elif line.startswith('ENDPOINTS:'):
endpoints = line.replace('ENDPOINTS:', '').strip()
endpoints = endpoints.strip('[]')
if endpoints and endpoints.lower() != 'none':
result['endpoints'] = [e.strip().strip('"\'') for e in endpoints.split(',') if e.strip()]
elif line.startswith('SEARCH_TERMS:'):
terms = line.replace('SEARCH_TERMS:', '').strip()
terms = terms.strip('[]')
result['search_terms'] = terms if terms else query
# FALLBACK: If LLM returned empty, try regex extraction from query
if not result['drugs'] and not result['diseases'] and not result['companies']:
logger.warning("[QUERY PARSER] LLM returned empty entities, using regex fallback")
# Extract drug-like terms (capitalized words, could be drug names)
import re
query_lower = query.lower()
# Common drug patterns
drug_patterns = [
r'\b(ianalumab|pembrolizumab|nivolumab|rituximab|tocilizumab)\b',
r'\b(keytruda|opdivo|humira|enbrel|remicade)\b',
r'\b([A-Z][a-z]+mab)\b', # -mab suffix (monoclonal antibodies)
r'\b([A-Z][a-z]+nib)\b', # -nib suffix (kinase inhibitors)
]
for pattern in drug_patterns:
matches = re.findall(pattern, query, re.IGNORECASE)
for match in matches:
if match.lower() not in [d.lower() for d in result['drugs']]:
result['drugs'].append(match)
# Extract disease terms
disease_patterns = [
r"\b(sjogren'?s?|sjogrens)\s*(syndrome|disease)?\b",
r'\b(lupus|arthritis|melanoma|diabetes|cancer)\b',
r'\b(rheumatoid\s+arthritis|multiple\s+sclerosis)\b',
]
for pattern in disease_patterns:
matches = re.findall(pattern, query, re.IGNORECASE)
for match in matches:
disease = match if isinstance(match, str) else ' '.join(match).strip()
if disease and disease.lower() not in [d.lower() for d in result['diseases']]:
result['diseases'].append(disease)
logger.info(f"[QUERY PARSER] Regex fallback found - Drugs: {result['drugs']}, Diseases: {result['diseases']}")
logger.info(f"[QUERY PARSER] ✓ Drugs: {result['drugs']}, Diseases: {result['diseases']}, Companies: {result['companies']}")
return result
except Exception as e:
logger.warning(f"[QUERY PARSER] Failed: {e}, using regex fallback on query")
# Emergency fallback - extract from query directly
import re
query_lower = query.lower()
drugs = []
diseases = []
# Extract Ianalumab specifically
if 'ianalumab' in query_lower:
drugs.append('Ianalumab')
# Extract Sjogren's
if 'sjogren' in query_lower:
diseases.append("Sjogren's syndrome")
return {
'drugs': drugs,
'diseases': diseases,
'companies': [],
'endpoints': [],
'search_terms': query,
'raw_parsed': ''
}
def plan_query_action(query, parsed_entities, hf_token=None):
"""
Use HuggingFace Llama-70B to decide the best action for this query.
Actions:
- SEARCH_TRIALS: Specific drug/disease questions (use RAG with top 30 trials)
- COUNT_AGGREGATE: "How many" or "list all" questions (use index counts)
- COMPARE: Compare two or more treatments
- GENERAL_KNOWLEDGE: Definitions or general info (skip RAG, use LLM knowledge)
Returns: Dict with action, reasoning, and parameters
"""
try:
from huggingface_hub import InferenceClient
logger.info("[PLANNING AGENT] Deciding action with HuggingFace Llama-70B...")
client = InferenceClient(token=hf_token, timeout=30)
planning_prompt = f"""You are a clinical trial search strategist. Route this query to the best action.
Query: "{query}"
Extracted entities:
- Drugs: {parsed_entities.get('drugs', [])}
- Diseases: {parsed_entities.get('diseases', [])}
- Companies: {parsed_entities.get('companies', [])}
- Endpoints: {parsed_entities.get('endpoints', [])}
ROUTING RULES:
1. SEARCH_TRIALS (default): Any question about specific drugs, treatments, efficacy, safety, trial results, side effects, or when entities are extracted
2. COUNT_AGGREGATE: Only when explicitly asking "how many", "list all", "total number"
3. COMPARE: Only when explicitly comparing with "vs", "versus", "compare", "better than", "difference between"
4. GENERAL_KNOWLEDGE: Only for pure definitions with no trial data needed
When in doubt, choose SEARCH_TRIALS - real trial data is almost always helpful.
Analyze the user's intent:
- Are they asking about specific trial outcomes? → SEARCH_TRIALS
- Do they want data about a drug/disease? → SEARCH_TRIALS
- Are they asking for counts or lists? → COUNT_AGGREGATE
- Are they comparing treatments? → COMPARE
- Is this purely definitional? → GENERAL_KNOWLEDGE
Respond with:
ACTION: [choose one action]
REASONING: [one clear sentence explaining why]
SEARCH_TERMS: [refined search terms to find the most relevant trials]
FOCUS: [what aspect to emphasize in the final answer - efficacy, safety, trial status, etc.]"""
response = client.chat_completion(
model="meta-llama/Llama-3.1-70B-Instruct",
messages=[{"role": "user", "content": planning_prompt}],
max_tokens=150,
temperature=0.1 # Low temp for consistent routing
)
result_text = response.choices[0].message.content.strip()
logger.info(f"[PLANNING AGENT] Decision:\n{result_text}")
# Parse the response
result = {
'action': 'SEARCH_TRIALS', # Default fallback
'reasoning': 'Could not parse response',
'params': query,
'focus': 'comprehensive trial data', # New field
'raw': result_text
}
lines = result_text.split('\n')
for line in lines:
line = line.strip()
if line.startswith('ACTION:'):
action = line.replace('ACTION:', '').strip()
if action in ['SEARCH_TRIALS', 'COUNT_AGGREGATE', 'COMPARE', 'GENERAL_KNOWLEDGE']:
result['action'] = action
elif line.startswith('REASONING:'):
result['reasoning'] = line.replace('REASONING:', '').strip()
elif line.startswith('SEARCH_TERMS:'):
params = line.replace('SEARCH_TERMS:', '').strip()
if params.lower() != 'none':
result['params'] = params
elif line.startswith('FOCUS:'):
result['focus'] = line.replace('FOCUS:', '').strip()
logger.info(f"[PLANNING AGENT] ✓ Action: {result['action']}, Focus: {result['focus']}, Reasoning: {result['reasoning']}")
return result
except Exception as e:
logger.warning(f"[PLANNING AGENT] Failed: {e}, defaulting to SEARCH_TRIALS")
return {
'action': 'SEARCH_TRIALS',
'reasoning': f'Planning failed: {e}',
'params': query,
'focus': 'available trial data'
}
def generate_llama_response(query, rag_context, hf_token=None, parsed_entities=None, planning_context=None):
"""
Intelligent synthesis that ALWAYS provides substantive answers from available data
Args:
query: User's question
rag_context: Retrieved trial data
hf_token: HuggingFace API token
parsed_entities: Dict with extracted entities (drugs, diseases, companies)
planning_context: Dict with planning agent output (action, focus, reasoning)
"""
# Build entity context string for better guidance
entity_context = ""
if parsed_entities:
drugs_list = parsed_entities.get('drugs', [])[:10]
diseases_list = parsed_entities.get('diseases', [])[:10]
companies_list = parsed_entities.get('companies', [])[:10]
if drugs_list or diseases_list or companies_list:
entity_context = f"""
Key entities to look for (including synonyms):
- Drugs/Treatments: {', '.join(drugs_list) if drugs_list else 'none'}
- Diseases: {', '.join(diseases_list) if diseases_list else 'none'}
- Companies: {', '.join(companies_list) if companies_list else 'none'}"""
# Focus area from planning
focus_area = planning_context.get('focus', 'comprehensive analysis') if planning_context else 'comprehensive analysis'
try:
# Try Groq first (much faster), fallback to HuggingFace
groq_api_key = os.getenv("GROQ_API_KEY")
system_prompt = """You are a leading clinical trials analyst. Your role is to provide the most helpful, informative answer possible using available trial data. You excel at finding connections and insights even from imperfect data matches.
CORE PRINCIPLES:
1. ALWAYS provide substantive, useful answers
2. Find relevant information even in partially-matching trials
3. Extract specific numbers, dates, phases, outcomes wherever available
4. Connect information across trials to build comprehensive insights
5. Never say "no relevant trials found" - work with what you have"""
user_prompt = f"""Question: {query}
Focus for this analysis: {focus_area}
{entity_context}
Clinical Trials Retrieved:
{rag_context[:12000]}
YOUR MISSION:
Provide the most comprehensive, helpful answer possible by intelligently analyzing ALL available trials.
ANALYSIS APPROACH:
1. SCAN all trials for ANY relevance to the query:
- Direct matches (same drug + disease) → Primary focus
- Same drug, different disease → Still valuable (shows drug profile)
- Same disease, different drug → Provides treatment landscape context
- Same company → Shows research pipeline
- Similar mechanisms/drug classes → Offers comparative insights
2. EXTRACT concrete information:
- Trial phases, enrollment numbers, completion dates
- Efficacy percentages, response rates, survival data
- Safety profiles, adverse events, tolerability
- Dosing regimens, administration routes
- Patient populations, inclusion/exclusion criteria
3. SYNTHESIZE intelligently:
- If asking about Drug X for Disease Y but only find Drug X for Disease Z,
discuss what this reveals about Drug X's mechanism and potential
- Find patterns across trials (e.g., consistent safety profile)
- Note trial progression (Phase 1 → 2 → 3) showing development status
## YOUR RESPONSE STRUCTURE:
### DIRECT ANSWER
[Immediately address the query with the best available information. Be confident and helpful.
If asking about "Sinopharm COVID vaccine" and trials mention "BBIBP-CorV" - recognize these as the same.
Lead with what you KNOW, not what you don't know.]
### KEY CLINICAL TRIALS EVIDENCE
[For each relevant trial, extract meaningful information:]
- **NCT#####**: [Specific findings relevant to query - be detailed with numbers/outcomes]
- **NCT#####**: [What this tells us - phases, enrollment, results if available]
[Include even partially relevant trials with appropriate context]
### CLINICAL INSIGHTS
[Synthesize patterns and meaningful conclusions:]
- What do these trials collectively reveal?
- Treatment landscape and development status
- Efficacy signals or safety patterns
- How different trials complement each other
- Comparison with similar drugs/approaches if relevant
### ADDITIONAL CONTEXT
[Brief, if needed - but keep positive and informative:]
- If data is from different indications, explain transferable insights
- If only early phase data, discuss what this means for development
- Focus on what the data DOES tell us
REMEMBER:
- Users want actionable information, not disclaimers
- Even Phase 1 safety data is valuable information
- Cross-indication data provides mechanism insights
- Company trial portfolios reveal strategic priorities
- Similar drug classes offer comparative context
- ALWAYS find something valuable to report"""
if groq_api_key:
logger.info("Generating response with Llama-3.1-70B via GROQ (fast)...")
from groq import Groq
client = Groq(api_key=groq_api_key)
response = client.chat.completions.create(
model="llama-3.1-70b-versatile",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
max_tokens=2000, # Increased for comprehensive answers
temperature=0.3,
timeout=30
)
return response.choices[0].message.content.strip()
else:
# Fallback to HuggingFace (slower)
logger.info("Generating response with Llama-3.1-70B via HuggingFace (slow)...")
from huggingface_hub import InferenceClient
client = InferenceClient(token=hf_token, timeout=120)
response = client.chat_completion(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
max_tokens=2000, # Increased for comprehensive answers
temperature=0.3
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Llama error: {e}")
return f"Llama API error: {str(e)}"
def process_query_simple_test(conversation):
"""TEST JUST THE RAG - no models"""
try:
import time
output = []
output.append(f"QUERY: {conversation}\n")
# Check if embeddings loaded
if doc_embeddings is None or len(doc_chunks) == 0:
return "FAIL: Embeddings not loaded"
output.append(f"✓ Embeddings loaded: {len(doc_chunks)} chunks\n")
output.append(f"✓ Embeddings shape: {doc_embeddings.shape}\n")
# Try to search
start = time.time()
context = retrieve_context_with_embeddings(conversation, top_k=10)
search_time = time.time() - start
if not context:
return "".join(output) + "\nFAIL: RAG returned empty"
output.append(f"✓ RAG search took: {search_time:.2f}s\n")
output.append(f"✓ Retrieved {context.count('NCT')} trials\n\n")
output.append("FIRST 1000 CHARS:\n")
output.append(context[:1000])
return "".join(output)
except Exception as e:
import traceback
return f"ERROR IN RAG TEST:\n{str(e)}\n\nTRACEBACK:\n{traceback.format_exc()}"
def process_query(conversation):
"""
Complete pipeline with LLM query parsing, planning agent, and natural language generation
Flow:
0. LLM Parser - Extract drugs, diseases, companies, endpoints (~2-3s)
0.5. Planning Agent - Decide action: SEARCH_TRIALS / COUNT_AGGREGATE / GENERAL_KNOWLEDGE (~1s)
1. Execute Action - Based on plan: RAG search, index count, or skip to LLM (~2s)
2. Skipped - 355M ranking removed (was broken)
3. LLM Response - Llama 70B generates natural language (~15s)
Total: ~21 seconds
"""
import time
import traceback
import sys
# MASTER try/except - catches EVERYTHING
try:
start_time = time.time()
output_parts = [f"QUERY: {conversation}\n\n"]
# Step 0: Parse query with LLM to extract structured info
try:
step0_start = time.time()
logger.info("Step 0: Parsing query with LLM...")
output_parts.append("✓ Step 0: LLM query parser started...\n")
parsed_query = parse_query_with_llm(conversation, hf_token=hf_token)
# Use optimized search terms from parser
search_query = parsed_query['search_terms']
step0_time = time.time() - step0_start
output_parts.append(f"✓ Step 0 Complete: Extracted entities ({step0_time:.1f}s)\n")
output_parts.append(f" Drugs: {parsed_query['drugs']}\n")
output_parts.append(f" Diseases: {parsed_query['diseases']}\n")
output_parts.append(f" Companies: {parsed_query['companies']}\n")
output_parts.append(f" Optimized search: {search_query}\n")
logger.info(f"Query parsing successful in {step0_time:.1f}s")
except Exception as e:
error_msg = f"✗ Step 0 WARNING (LLM Parser): {str(e)}, using original query"
logger.warning(error_msg)
output_parts.append(f"{error_msg}\n")
search_query = conversation # Fallback to original
parsed_query = {'drugs': [], 'diseases': [], 'companies': []}
# Step 0.5: Planning agent decides action
try:
planning_start = time.time()
logger.info("Step 0.5: Planning agent deciding action...")
output_parts.append("✓ Step 0.5: Planning agent started...\n")
plan = plan_query_action(conversation, parsed_query, hf_token=hf_token)
planning_time = time.time() - planning_start
output_parts.append(f"✓ Step 0.5 Complete: Action decided ({planning_time:.1f}s)\n")
output_parts.append(f" Action: {plan['action']}\n")
output_parts.append(f" Reasoning: {plan['reasoning']}\n")
logger.info(f"Planning complete: {plan['action']} - {plan['reasoning']}")
except Exception as e:
error_msg = f"✗ Step 0.5 WARNING (Planning): {str(e)}, defaulting to SEARCH_TRIALS"
logger.warning(error_msg)
output_parts.append(f"{error_msg}\n")
plan = {'action': 'SEARCH_TRIALS', 'reasoning': 'Planning failed', 'params': search_query}
# Step 1: Execute action based on plan
if plan['action'] == 'GENERAL_KNOWLEDGE':
# Skip RAG entirely, go straight to LLM
try:
step1_start = time.time()
logger.info("Step 1: GENERAL_KNOWLEDGE - Skipping RAG...")
output_parts.append("✓ Step 1: Skipped RAG (general knowledge query)\n")
context = "" # Empty context
step1_time = time.time() - step1_start
output_parts.append(f"✓ Step 1 Complete: Using LLM knowledge only ({step1_time:.1f}s)\n")
except Exception as e:
error_msg = f"✗ Step 1 FAILED: {str(e)}"
logger.error(error_msg)
return error_msg
elif plan['action'] == 'COUNT_AGGREGATE':
# Use index to count, pass summary to LLM
try:
step1_start = time.time()
logger.info("Step 1: COUNT_AGGREGATE - Using inverted index...")
output_parts.append("✓ Step 1: Count/aggregation started...\n")
# Get search terms from plan
search_terms = plan['params'].lower().split()
# Find matching trials from inverted index
global inverted_index
matching_trial_ids = set()
if inverted_index:
for term in search_terms:
if term in inverted_index:
matching_trial_ids.update(inverted_index[term])
logger.info(f" Found {len(inverted_index[term])} trials for '{term}'")
# Create summary context
if matching_trial_ids:
context = f"Found {len(matching_trial_ids)} trials matching the query.\n\n"
context += f"Note: This is an aggregate count. For detailed information about specific trials, "
context += f"please ask a more specific question about individual drugs or treatments."
else:
context = "No trials found matching the query."
step1_time = time.time() - step1_start
output_parts.append(f"✓ Step 1 Complete: Found {len(matching_trial_ids)} matching trials ({step1_time:.1f}s)\n")
logger.info(f"Count aggregation complete - {len(matching_trial_ids)} trials in {step1_time:.1f}s")
except Exception as e:
error_msg = f"✗ Step 1 FAILED (Count): {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return error_msg
elif plan['action'] == 'COMPARE':
# Compare treatments - retrieve trials for each and let LLM analyze
try:
step1_start = time.time()
logger.info("Step 1: COMPARE - Retrieving trials for comparison...")
output_parts.append("✓ Step 1: Comparison search started...\n")
# Extract treatments to compare from parsed drugs
treatments = parsed_query.get('drugs', [])
if len(treatments) < 2:
# Try to extract from query text if not in parsed drugs
import re
compare_patterns = [
r'(\w+)\s+(?:vs|versus|vs\.)\s+(\w+)',
r'compare\s+(\w+)\s+(?:and|with|to)\s+(\w+)'
]
for pattern in compare_patterns:
match = re.search(pattern, conversation.lower())
if match:
treatments = [match.group(1), match.group(2)]
break
if len(treatments) < 2:
context = "Could not identify two treatments to compare. Please specify which treatments you'd like to compare."
else:
logger.info(f"[COMPARE] Comparing: {treatments[0]} vs {treatments[1]}")
# Search for trials for each treatment
context_parts = []
for i, treatment in enumerate(treatments[:2], 1): # Compare first 2
logger.info(f"[COMPARE] Searching trials for {treatment}...")
treatment_trials = retrieve_context_with_embeddings(treatment, top_k=10, entities=parsed_query)
if treatment_trials:
context_parts.append(f"=== TRIALS FOR {treatment.upper()} ===\n{treatment_trials}\n")
else:
context_parts.append(f"=== TRIALS FOR {treatment.upper()} ===\nNo trials found.\n")
# Combine all trials for LLM comparison
context = "\n".join(context_parts)
context += f"\n\nPLEASE COMPARE: {treatments[0]} vs {treatments[1]}\n"
context += "Analyze the trials above and provide a side-by-side comparison including:\n"
context += "- Number of trials for each\n"
context += "- Key indications/diseases studied\n"
context += "- Trial phases\n"
context += "- Notable efficacy or safety findings\n"
context += "- Head-to-head comparison trials (if any)"
step1_time = time.time() - step1_start
output_parts.append(f"✓ Step 1 Complete: Retrieved comparison data ({step1_time:.1f}s)\n")
logger.info(f"Comparison search complete in {step1_time:.1f}s")
except Exception as e:
error_msg = f"✗ Step 1 FAILED (Compare): {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return error_msg
else: # SEARCH_TRIALS - normal RAG search (using optimized search query)
try:
step1_start = time.time()
logger.info("Step 1: RAG search...")
output_parts.append("✓ Step 1: RAG search started...\n")
# Pass entities for STRICT company filtering
context = retrieve_context_with_embeddings(search_query, top_k=10, entities=parsed_query)
if not context:
return "No matching trials found in RAG search."
# No limit - use complete trials
step1_time = time.time() - step1_start
output_parts.append(f"✓ Step 1 Complete: Found {context.count('NCT')} trials ({step1_time:.1f}s)\n")
logger.info(f"RAG search successful - found trials in {step1_time:.1f}s")
except Exception as e:
error_msg = f"✗ Step 1 FAILED (RAG search): {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return error_msg
# Step 2: Skipped (355M ranking removed - was broken)
output_parts.append("✓ Step 2: Skipped (using hybrid search + recency)\n")
# Step 3: Llama 70B
try:
step3_start = time.time()
logger.info("Step 3: Generating response with Llama-3.1-70B...")
output_parts.append("✓ Step 3: Llama 70B generation started...\n")
llama_response = generate_llama_response(
conversation,
context,
hf_token=hf_token,
parsed_entities=parsed_query,
planning_context=plan
)
step3_time = time.time() - step3_start
output_parts.append(f"✓ Step 3 Complete: Llama 70B response generated ({step3_time:.1f}s)\n")
logger.info(f"Llama 70B generation successful in {step3_time:.1f}s")
except Exception as e:
error_msg = f"✗ Step 3 FAILED (Llama 70B): {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
llama_response = f"[Llama 70B error: {str(e)}]"
output_parts.append(f"✗ Step 3 Failed: {str(e)}\n")
total_time = time.time() - start_time
# Format output - handle missing variables
try:
context_display = context if 'context' in locals() else "[No context retrieved]"
clinical_display = clinical_context_355m if 'clinical_context_355m' in locals() else "[355M not run]"
llama_display = llama_response if 'llama_response' in locals() else "[Llama 70B not run]"
output = f"""{''.join(output_parts)}
CLINICAL SUMMARY (Llama-3.1-70B-Instruct):
{llama_display}
---
RAG RETRIEVED TRIALS (Top 3 Most Relevant):
{context_display}
---
Total Time: {total_time:.1f}s
"""
# Record analytics
query_type = plan.get('action', 'UNKNOWN') if 'plan' in locals() else 'UNKNOWN'
query_analytics.record_query(query_type, total_time, success=True)
return output
except Exception as e:
# Absolute fallback
error_info = f"""
CRITICAL ERROR IN OUTPUT FORMATTING:
{str(e)}
TRACEBACK:
{traceback.format_exc()}
OUTPUT PARTS:
{''.join(output_parts)}
Variables defined: {locals().keys()}
"""
logger.error(error_info)
return error_info
# MASTER EXCEPTION HANDLER - catches ANY unhandled error
except Exception as master_error:
master_error_msg = f"""
========================================
MASTER ERROR HANDLER CAUGHT EXCEPTION
========================================
Error Type: {type(master_error).__name__}
Error Message: {str(master_error)}
FULL TRACEBACK:
{traceback.format_exc()}
System Info:
- Python version: {sys.version}
- Error at line: {sys.exc_info()[2].tb_lineno if sys.exc_info()[2] else 'unknown'}
========================================
"""
logger.error(master_error_msg)
# Record analytics for error
elapsed_time = time.time() - start_time if 'start_time' in locals() else 0
query_analytics.record_query('ERROR', elapsed_time, success=False)
return master_error_msg
def get_analytics_report():
"""
Get analytics report for monitoring
Returns formatted string with query statistics
"""
stats = query_analytics.get_stats()
uptime_hours = stats['uptime_seconds'] / 3600
report = f"""
=== ANALYTICS REPORT ===
Uptime: {uptime_hours:.1f} hours
Total Queries: {stats['total_queries']}
Error Rate: {stats['error_rate']*100:.1f}%
Query Type Distribution:
"""
for query_type, count in stats['query_type_distribution'].items():
percentage = (count / stats['total_queries'] * 100) if stats['total_queries'] > 0 else 0
avg_time = stats['avg_response_times'].get(query_type, 0)
report += f" {query_type}: {count} queries ({percentage:.1f}%) - avg {avg_time:.2f}s\n"
report += "\n=== END REPORT ===\n"
return report
# ============================================================================
# GRADIO INTERFACE
# ============================================================================
with gr.Blocks(title="Foundation 1.2") as demo:
gr.Markdown("# Foundation 1.2 - Clinical Trial AI")
query_input = gr.Textbox(
label="Ask about clinical trials",
placeholder="Example: What are the results for ianalumab in Sjogren's syndrome?",
lines=3
)
submit_btn = gr.Button("Generate Response", variant="primary")
output = gr.Textbox(
label="AI Response",
lines=30
)
submit_btn.click(
fn=process_query, # Full pipeline: RAG + 355M + Llama
inputs=query_input,
outputs=output
)
gr.Markdown("""
**Production Pipeline - Optimized for Clinical Accuracy**
""")
# ============================================================================
# STARTUP
# ============================================================================
# Embeddings will be loaded by FastAPI startup event in app.py
# Do NOT load here - causes Docker permission errors
logger.info("=== Foundation 1.2 Module Loaded ===")
logger.info("Call load_embeddings() to initialize the system")
if DATASET_FILE.exists():
file_size_mb = DATASET_FILE.stat().st_size / (1024 * 1024)
logger.info(f"✓ Dataset file found: {file_size_mb:.0f}MB")
else:
logger.error("✗ Dataset file not found!")
logger.info("=== Startup Complete ===")
if __name__ == "__main__":
demo.launch()