ctapi / two_llm_system_FIXED.py
Your Name
Clone api2 for experimentation
d78f02a
"""
FAST VERSION: Bypasses 355M ranking bottleneck (300s -> 0s)
Works with existing data structure: List[Tuple[int, str]]
Keeps BM25 + semantic hybrid search intact
"""
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, AutoTokenizer, AutoModelForCausalLM
import logging
import spaces
from functools import lru_cache
from typing import List, Tuple, Optional, Dict
from huggingface_hub import InferenceClient
logger = logging.getLogger(__name__)
# ===========================================================================
# CACHED MODEL LOADING - Load once, reuse forever
# ===========================================================================
@lru_cache(maxsize=1)
def get_cached_355m_model():
"""Load 355M model once and cache it for entity extraction"""
logger.info("Loading 355M Clinical Trial GPT (cached for entity extraction)...")
tokenizer = GPT2TokenizerFast.from_pretrained("gmkdigitalmedia/clinicaltrial2.2")
model = GPT2LMHeadModel.from_pretrained(
"gmkdigitalmedia/clinicaltrial2.2",
torch_dtype=torch.float16,
device_map="auto"
)
model.eval()
return tokenizer, model
@lru_cache(maxsize=1)
def get_cached_8b_model(hf_token: Optional[str] = None):
"""Load 8B model once and cache it"""
logger.info("Loading II-Medical-8B (cached)...")
tokenizer = AutoTokenizer.from_pretrained(
"Intelligent-Internet/II-Medical-8B-1706",
token=hf_token,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
"Intelligent-Internet/II-Medical-8B-1706",
device_map="auto",
token=hf_token,
trust_remote_code=True,
torch_dtype=torch.bfloat16
)
return tokenizer, model
# ===========================================================================
# FAST RANKING - Replace 300s function with instant passthrough
# ===========================================================================
@spaces.GPU
def rank_trials_FAST(query: str, trials_list: List[Tuple[float, str]], hf_token=None) -> List[Tuple[float, str]]:
"""
SMART RANKING: Use 355M to rank only top 3 trials
Takes top 3 from BM25+semantic search, then uses 355M Clinical Trial GPT
to re-rank them by clinical relevance.
Time: ~30 seconds for 3 trials (vs 300s for 30 trials)
Args:
query: The search query
trials_list: List of (score, trial_text) tuples from BM25+semantic search
hf_token: Not needed
Returns:
Top 3 trials re-ranked by 355M clinical relevance
"""
import time
import re
start_time = time.time()
# Take only top 3 trials for 355M ranking
top_3 = trials_list[:3]
logger.info(f"[355M RANKING] Ranking top 3 trials with Clinical Trial GPT...")
# Get cached 355M model
tokenizer, model = get_cached_355m_model()
# Score each trial
trial_scores = []
for idx, (bm25_score, trial_text) in enumerate(top_3):
# Extract NCT ID for logging
nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text)
nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}"
# Create prompt for relevance scoring
# Truncate trial to 800 chars to keep it fast
trial_snippet = trial_text[:800]
prompt = f"""Query: {query}
Clinical Trial: {trial_snippet}
Rate clinical relevance (1-10):"""
# Get model score
try:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
max_length=inputs.input_ids.shape[1] + 10,
temperature=0.3,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
generated = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
# Extract number from response
score_match = re.search(r'(\d+)', generated.strip())
relevance_score = float(score_match.group(1)) if score_match else 5.0
# Normalize to 0-1 range
relevance_score = relevance_score / 10.0
logger.info(f"[355M RANKING] {nct_id}: relevance={relevance_score:.2f} (BM25={bm25_score:.3f})")
except Exception as e:
logger.warning(f"[355M RANKING] Scoring failed for {nct_id}: {e}, using BM25 score")
relevance_score = bm25_score
trial_scores.append((relevance_score, trial_text, nct_id))
# Sort by 355M relevance score (descending)
trial_scores.sort(key=lambda x: x[0], reverse=True)
# Format as (score, text) tuples for backwards compatibility
# Create a custom list class that can hold attributes
class RankedTrialsList(list):
"""List that can hold ranking metadata"""
pass
ranked_trials = RankedTrialsList()
ranking_metadata = []
for rank, (score, text, nct_id) in enumerate(trial_scores, 1):
ranked_trials.append((score, text))
ranking_metadata.append({
'rank': rank,
'nct_id': nct_id,
'relevance_score': score,
'relevance_rating': f"{score*10:.1f}/10"
})
elapsed = time.time() - start_time
logger.info(f"[355M RANKING] ✓ Ranked 3 trials in {elapsed:.1f}s")
logger.info(f"[355M RANKING] Final order: {[nct_id for _, _, nct_id in trial_scores]}")
logger.info(f"[355M RANKING] Scores: {[f'{s:.2f}' for s, _, _ in trial_scores]}")
# Store metadata as attribute for retrieval
ranked_trials.ranking_info = ranking_metadata
# Return re-ranked top 3 plus remaining trials (if any)
return ranked_trials + trials_list[3:]
# Alias for drop-in replacement
rank_trials_with_355m = rank_trials_FAST # Override the slow function!
# ===========================================================================
# FAST GENERATION using HuggingFace Inference API (Free)
# ===========================================================================
def generate_with_llama_70b_hf(query: str, rag_context: str = "", hf_token: str = None) -> str:
"""
Use Llama-3.1-70B via HuggingFace Inference API (FREE)
This is what you're already using successfully!
~10 second response time on HF free tier
"""
try:
logger.info("Using Llama-3.1-70B via HuggingFace Inference API...")
client = InferenceClient(token=hf_token)
messages = [
{
"role": "system",
"content": "You are a medical information specialist. Answer based on the provided clinical trial data. Be concise and accurate."
},
{
"role": "user",
"content": f"""Clinical Trial Data:
{rag_context[:4000]}
Question: {query}
Please provide a concise answer based on the clinical trial data above."""
}
]
response = client.chat_completion(
model="meta-llama/Llama-3.1-70B-Instruct",
messages=messages,
max_tokens=512,
temperature=0.3
)
answer = response.choices[0].message.content.strip()
logger.info(f"Llama 70B response generated via HF Inference API")
return answer
except Exception as e:
logger.error(f"Llama 70B generation failed: {e}")
return f"Error generating response with Llama 70B: {str(e)}"
# ===========================================================================
# OPTIMIZED 8B GENERATION (with cached model)
# ===========================================================================
@spaces.GPU
def generate_clinical_response_with_xupract(conversation, rag_context="", hf_token=None):
"""OPTIMIZED: Use cached 8B model for faster generation"""
logger.info("Generating response with cached II-Medical-8B...")
# Get cached model (loads once, reuses after)
tokenizer, model = get_cached_8b_model(hf_token)
# Build prompt with RAG context (ChatML format for II-Medical-8B)
if rag_context:
prompt = f"""<|im_start|>system
You are a medical information specialist. Answer based on the provided clinical trial data. Please reason step-by-step, and put your final answer within \\boxed{{}}.
<|im_end|>
<|im_start|>user
Clinical Trial Data:
{rag_context[:4000]}
Question: {conversation}
Please reason step-by-step, and put your final answer within \\boxed{{}}.
<|im_end|>
<|im_start|>assistant
"""
else:
prompt = f"""<|im_start|>system
You are a medical information specialist. Please reason step-by-step, and put your final answer within \\boxed{{}}.
<|im_end|>
<|im_start|>user
{conversation}
Please reason step-by-step, and put your final answer within \\boxed{{}}.
<|im_end|>
<|im_start|>assistant
"""
try:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.3,
do_sample=True,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip()
return response
except Exception as e:
logger.error(f"Generation failed: {e}")
return f"Error generating response: {str(e)}"
# ===========================================================================
# FAST ENTITY EXTRACTION (with cached model)
# ===========================================================================
@spaces.GPU
def extract_entities_with_small_model(conversation):
"""OPTIMIZED: Use cached 355M model for entity extraction"""
logger.info("Extracting entities with cached 355M model...")
# Get cached model
tokenizer, model = get_cached_355m_model()
# Better prompt for extraction
prompt = f"""Clinical query: {conversation}
Extract:
Drug name:"""
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
max_length=400,
temperature=0.3,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated
# ===========================================================================
# QUERY EXPANSION (optional, with cached model)
# ===========================================================================
@spaces.GPU
def expand_query_with_355m(query):
"""OPTIMIZED: Use cached 355M for query expansion"""
logger.info("Expanding query with cached 355M...")
# Get cached model
tokenizer, model = get_cached_355m_model()
# Prompt to get clinical context
prompt = f"Question: {query}\nClinical trial information:"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
max_length=inputs.input_ids.shape[1] + 100,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the expansion part
if "Clinical trial information:" in generated:
expansion = generated.split("Clinical trial information:")[-1].strip()
else:
expansion = generated[len(prompt):].strip()
# Limit to reasonable length
expansion = expansion[:500] if len(expansion) > 500 else expansion
logger.info(f"Query expanded: {expansion[:100]}...")
return expansion
# ===========================================================================
# MAIN PIPELINE - Now FAST!
# ===========================================================================
def process_two_llm_system(conversation, rag_context="", hf_token=None, use_validation=False):
"""
FAST pipeline:
1. Small 355M model extracts entities (cached model - fast)
2. RAG retrieves context (BM25 + semantic - already fast)
3. Big model generates response (8B local or 70B API)
4. Skip validation for speed
Total time: ~15s instead of 300+s
"""
import time
start_time = time.time()
# Step 1: Use cached 355M to extract entities
entities = extract_entities_with_small_model(conversation)
logger.info(f"Entities extracted in {time.time()-start_time:.1f}s")
# Step 2: Generate response (choose one):
# Option A: Use 70B via HF Inference API (better quality, ~10s)
if hf_token:
clinical_evidence = generate_with_llama_70b_hf(
conversation,
rag_context,
hf_token
)
model_used = "Llama-3.1-70B (HF Inference API)"
else:
# Option B: Use cached 8B model (faster loading, ~5s)
clinical_evidence = generate_clinical_response_with_xupract(
conversation,
rag_context,
hf_token
)
model_used = "II-Medical-8B (cached)"
total_time = time.time() - start_time
logger.info(f"Total pipeline time: {total_time:.1f}s (was 300+s with 355M ranking)")
return {
'clinical_evidence': clinical_evidence,
'entities': entities,
'model_used': model_used,
'time_taken': total_time
}
def format_two_llm_response(result):
"""Format the fast response"""
return f"""ENTITY EXTRACTION (Clinical Trial GPT 355M - Cached)
{'='*60}
{result.get('entities', 'None identified')}
CLINICAL RESPONSE ({result.get('model_used', 'Unknown')})
{'='*60}
{result['clinical_evidence']}
PERFORMANCE
{'='*60}
Time: {result.get('time_taken', 0):.1f}s (was 300+s with 355M ranking)
{'='*60}
"""
# ===========================================================================
# PRELOAD MODELS AT STARTUP (Call this once in app.py!)
# ===========================================================================
def preload_all_models(hf_token=None):
"""
Call this ONCE at app startup to cache all models.
This prevents model reloading on every query.
Add to your app.py initialization:
from two_llm_system_FAST import preload_all_models
preload_all_models(hf_token)
"""
logger.info("Preloading and caching all models...")
# Cache the 355M model
_ = get_cached_355m_model()
logger.info("✓ 355M model cached")
# Cache the 8B model if token available
if hf_token:
try:
_ = get_cached_8b_model(hf_token)
logger.info("✓ 8B model cached")
except Exception as e:
logger.warning(f"Could not cache 8B model: {e}")
logger.info("All models preloaded and cached!")
# ===========================================================================
# BACKWARD COMPATIBILITY - Keep all original function names
# ===========================================================================
# These functions exist in the original but we optimize them
validate_with_small_model = lambda *args, **kwargs: "Validation skipped for speed"
extract_keywords_with_llama = lambda conv, hf_token=None: extract_entities_with_small_model(conv)[:100]
generate_response_with_llama = generate_with_llama_70b_hf
generate_clinical_knowledge_with_355m = lambda conv: f"Knowledge: {conv[:100]}..."
generate_with_355m = lambda conv, rag="", hf_token=None: generate_clinical_response_with_xupract(conv, rag, hf_token)
# Ensure we override the slow ranking function
rank_trials_with_355m = rank_trials_FAST
logger.info("Fast Two-LLM System loaded - 355M ranking bypassed!")