Spaces:
Paused
Paused
| """ | |
| FAST VERSION: Bypasses 355M ranking bottleneck (300s -> 0s) | |
| Works with existing data structure: List[Tuple[int, str]] | |
| Keeps BM25 + semantic hybrid search intact | |
| """ | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2TokenizerFast, AutoTokenizer, AutoModelForCausalLM | |
| import logging | |
| import spaces | |
| from functools import lru_cache | |
| from typing import List, Tuple, Optional, Dict | |
| from huggingface_hub import InferenceClient | |
| logger = logging.getLogger(__name__) | |
| # =========================================================================== | |
| # CACHED MODEL LOADING - Load once, reuse forever | |
| # =========================================================================== | |
| def get_cached_355m_model(): | |
| """Load 355M model once and cache it for entity extraction""" | |
| logger.info("Loading 355M Clinical Trial GPT (cached for entity extraction)...") | |
| tokenizer = GPT2TokenizerFast.from_pretrained("gmkdigitalmedia/clinicaltrial2.2") | |
| model = GPT2LMHeadModel.from_pretrained( | |
| "gmkdigitalmedia/clinicaltrial2.2", | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| model.eval() | |
| return tokenizer, model | |
| def get_cached_8b_model(hf_token: Optional[str] = None): | |
| """Load 8B model once and cache it""" | |
| logger.info("Loading II-Medical-8B (cached)...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "Intelligent-Internet/II-Medical-8B-1706", | |
| token=hf_token, | |
| trust_remote_code=True | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "Intelligent-Internet/II-Medical-8B-1706", | |
| device_map="auto", | |
| token=hf_token, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| return tokenizer, model | |
| # =========================================================================== | |
| # FAST RANKING - Replace 300s function with instant passthrough | |
| # =========================================================================== | |
| def rank_trials_FAST(query: str, trials_list: List[Tuple[float, str]], hf_token=None) -> List[Tuple[float, str]]: | |
| """ | |
| SMART RANKING: Use 355M to rank only top 3 trials | |
| Takes top 3 from BM25+semantic search, then uses 355M Clinical Trial GPT | |
| to re-rank them by clinical relevance. | |
| Time: ~30 seconds for 3 trials (vs 300s for 30 trials) | |
| Args: | |
| query: The search query | |
| trials_list: List of (score, trial_text) tuples from BM25+semantic search | |
| hf_token: Not needed | |
| Returns: | |
| Top 3 trials re-ranked by 355M clinical relevance | |
| """ | |
| import time | |
| import re | |
| start_time = time.time() | |
| # Take only top 3 trials for 355M ranking | |
| top_3 = trials_list[:3] | |
| logger.info(f"[355M RANKING] Ranking top 3 trials with Clinical Trial GPT...") | |
| # Get cached 355M model | |
| tokenizer, model = get_cached_355m_model() | |
| # Score each trial | |
| trial_scores = [] | |
| for idx, (bm25_score, trial_text) in enumerate(top_3): | |
| # Extract NCT ID for logging | |
| nct_match = re.search(r'NCT_ID:\s*(NCT\d+)', trial_text) | |
| nct_id = nct_match.group(1) if nct_match else f"Trial_{idx+1}" | |
| # Create prompt for relevance scoring | |
| # Truncate trial to 800 chars to keep it fast | |
| trial_snippet = trial_text[:800] | |
| prompt = f"""Query: {query} | |
| Clinical Trial: {trial_snippet} | |
| Rate clinical relevance (1-10):""" | |
| # Get model score | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| max_length=inputs.input_ids.shape[1] + 10, | |
| temperature=0.3, | |
| do_sample=False, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| generated = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True) | |
| # Extract number from response | |
| score_match = re.search(r'(\d+)', generated.strip()) | |
| relevance_score = float(score_match.group(1)) if score_match else 5.0 | |
| # Normalize to 0-1 range | |
| relevance_score = relevance_score / 10.0 | |
| logger.info(f"[355M RANKING] {nct_id}: relevance={relevance_score:.2f} (BM25={bm25_score:.3f})") | |
| except Exception as e: | |
| logger.warning(f"[355M RANKING] Scoring failed for {nct_id}: {e}, using BM25 score") | |
| relevance_score = bm25_score | |
| trial_scores.append((relevance_score, trial_text, nct_id)) | |
| # Sort by 355M relevance score (descending) | |
| trial_scores.sort(key=lambda x: x[0], reverse=True) | |
| # Format as (score, text) tuples for backwards compatibility | |
| # Create a custom list class that can hold attributes | |
| class RankedTrialsList(list): | |
| """List that can hold ranking metadata""" | |
| pass | |
| ranked_trials = RankedTrialsList() | |
| ranking_metadata = [] | |
| for rank, (score, text, nct_id) in enumerate(trial_scores, 1): | |
| ranked_trials.append((score, text)) | |
| ranking_metadata.append({ | |
| 'rank': rank, | |
| 'nct_id': nct_id, | |
| 'relevance_score': score, | |
| 'relevance_rating': f"{score*10:.1f}/10" | |
| }) | |
| elapsed = time.time() - start_time | |
| logger.info(f"[355M RANKING] ✓ Ranked 3 trials in {elapsed:.1f}s") | |
| logger.info(f"[355M RANKING] Final order: {[nct_id for _, _, nct_id in trial_scores]}") | |
| logger.info(f"[355M RANKING] Scores: {[f'{s:.2f}' for s, _, _ in trial_scores]}") | |
| # Store metadata as attribute for retrieval | |
| ranked_trials.ranking_info = ranking_metadata | |
| # Return re-ranked top 3 plus remaining trials (if any) | |
| return ranked_trials + trials_list[3:] | |
| # Alias for drop-in replacement | |
| rank_trials_with_355m = rank_trials_FAST # Override the slow function! | |
| # =========================================================================== | |
| # FAST GENERATION using HuggingFace Inference API (Free) | |
| # =========================================================================== | |
| def generate_with_llama_70b_hf(query: str, rag_context: str = "", hf_token: str = None) -> str: | |
| """ | |
| Use Llama-3.1-70B via HuggingFace Inference API (FREE) | |
| This is what you're already using successfully! | |
| ~10 second response time on HF free tier | |
| """ | |
| try: | |
| logger.info("Using Llama-3.1-70B via HuggingFace Inference API...") | |
| client = InferenceClient(token=hf_token) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": "You are a medical information specialist. Answer based on the provided clinical trial data. Be concise and accurate." | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"""Clinical Trial Data: | |
| {rag_context[:4000]} | |
| Question: {query} | |
| Please provide a concise answer based on the clinical trial data above.""" | |
| } | |
| ] | |
| response = client.chat_completion( | |
| model="meta-llama/Llama-3.1-70B-Instruct", | |
| messages=messages, | |
| max_tokens=512, | |
| temperature=0.3 | |
| ) | |
| answer = response.choices[0].message.content.strip() | |
| logger.info(f"Llama 70B response generated via HF Inference API") | |
| return answer | |
| except Exception as e: | |
| logger.error(f"Llama 70B generation failed: {e}") | |
| return f"Error generating response with Llama 70B: {str(e)}" | |
| # =========================================================================== | |
| # OPTIMIZED 8B GENERATION (with cached model) | |
| # =========================================================================== | |
| def generate_clinical_response_with_xupract(conversation, rag_context="", hf_token=None): | |
| """OPTIMIZED: Use cached 8B model for faster generation""" | |
| logger.info("Generating response with cached II-Medical-8B...") | |
| # Get cached model (loads once, reuses after) | |
| tokenizer, model = get_cached_8b_model(hf_token) | |
| # Build prompt with RAG context (ChatML format for II-Medical-8B) | |
| if rag_context: | |
| prompt = f"""<|im_start|>system | |
| You are a medical information specialist. Answer based on the provided clinical trial data. Please reason step-by-step, and put your final answer within \\boxed{{}}. | |
| <|im_end|> | |
| <|im_start|>user | |
| Clinical Trial Data: | |
| {rag_context[:4000]} | |
| Question: {conversation} | |
| Please reason step-by-step, and put your final answer within \\boxed{{}}. | |
| <|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| else: | |
| prompt = f"""<|im_start|>system | |
| You are a medical information specialist. Please reason step-by-step, and put your final answer within \\boxed{{}}. | |
| <|im_end|> | |
| <|im_start|>user | |
| {conversation} | |
| Please reason step-by-step, and put your final answer within \\boxed{{}}. | |
| <|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| temperature=0.3, | |
| do_sample=True, | |
| top_p=0.9, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip() | |
| return response | |
| except Exception as e: | |
| logger.error(f"Generation failed: {e}") | |
| return f"Error generating response: {str(e)}" | |
| # =========================================================================== | |
| # FAST ENTITY EXTRACTION (with cached model) | |
| # =========================================================================== | |
| def extract_entities_with_small_model(conversation): | |
| """OPTIMIZED: Use cached 355M model for entity extraction""" | |
| logger.info("Extracting entities with cached 355M model...") | |
| # Get cached model | |
| tokenizer, model = get_cached_355m_model() | |
| # Better prompt for extraction | |
| prompt = f"""Clinical query: {conversation} | |
| Extract: | |
| Drug name:""" | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| max_length=400, | |
| temperature=0.3, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| generated = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return generated | |
| # =========================================================================== | |
| # QUERY EXPANSION (optional, with cached model) | |
| # =========================================================================== | |
| def expand_query_with_355m(query): | |
| """OPTIMIZED: Use cached 355M for query expansion""" | |
| logger.info("Expanding query with cached 355M...") | |
| # Get cached model | |
| tokenizer, model = get_cached_355m_model() | |
| # Prompt to get clinical context | |
| prompt = f"Question: {query}\nClinical trial information:" | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256).to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| max_length=inputs.input_ids.shape[1] + 100, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| generated = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract the expansion part | |
| if "Clinical trial information:" in generated: | |
| expansion = generated.split("Clinical trial information:")[-1].strip() | |
| else: | |
| expansion = generated[len(prompt):].strip() | |
| # Limit to reasonable length | |
| expansion = expansion[:500] if len(expansion) > 500 else expansion | |
| logger.info(f"Query expanded: {expansion[:100]}...") | |
| return expansion | |
| # =========================================================================== | |
| # MAIN PIPELINE - Now FAST! | |
| # =========================================================================== | |
| def process_two_llm_system(conversation, rag_context="", hf_token=None, use_validation=False): | |
| """ | |
| FAST pipeline: | |
| 1. Small 355M model extracts entities (cached model - fast) | |
| 2. RAG retrieves context (BM25 + semantic - already fast) | |
| 3. Big model generates response (8B local or 70B API) | |
| 4. Skip validation for speed | |
| Total time: ~15s instead of 300+s | |
| """ | |
| import time | |
| start_time = time.time() | |
| # Step 1: Use cached 355M to extract entities | |
| entities = extract_entities_with_small_model(conversation) | |
| logger.info(f"Entities extracted in {time.time()-start_time:.1f}s") | |
| # Step 2: Generate response (choose one): | |
| # Option A: Use 70B via HF Inference API (better quality, ~10s) | |
| if hf_token: | |
| clinical_evidence = generate_with_llama_70b_hf( | |
| conversation, | |
| rag_context, | |
| hf_token | |
| ) | |
| model_used = "Llama-3.1-70B (HF Inference API)" | |
| else: | |
| # Option B: Use cached 8B model (faster loading, ~5s) | |
| clinical_evidence = generate_clinical_response_with_xupract( | |
| conversation, | |
| rag_context, | |
| hf_token | |
| ) | |
| model_used = "II-Medical-8B (cached)" | |
| total_time = time.time() - start_time | |
| logger.info(f"Total pipeline time: {total_time:.1f}s (was 300+s with 355M ranking)") | |
| return { | |
| 'clinical_evidence': clinical_evidence, | |
| 'entities': entities, | |
| 'model_used': model_used, | |
| 'time_taken': total_time | |
| } | |
| def format_two_llm_response(result): | |
| """Format the fast response""" | |
| return f"""ENTITY EXTRACTION (Clinical Trial GPT 355M - Cached) | |
| {'='*60} | |
| {result.get('entities', 'None identified')} | |
| CLINICAL RESPONSE ({result.get('model_used', 'Unknown')}) | |
| {'='*60} | |
| {result['clinical_evidence']} | |
| PERFORMANCE | |
| {'='*60} | |
| Time: {result.get('time_taken', 0):.1f}s (was 300+s with 355M ranking) | |
| {'='*60} | |
| """ | |
| # =========================================================================== | |
| # PRELOAD MODELS AT STARTUP (Call this once in app.py!) | |
| # =========================================================================== | |
| def preload_all_models(hf_token=None): | |
| """ | |
| Call this ONCE at app startup to cache all models. | |
| This prevents model reloading on every query. | |
| Add to your app.py initialization: | |
| from two_llm_system_FAST import preload_all_models | |
| preload_all_models(hf_token) | |
| """ | |
| logger.info("Preloading and caching all models...") | |
| # Cache the 355M model | |
| _ = get_cached_355m_model() | |
| logger.info("✓ 355M model cached") | |
| # Cache the 8B model if token available | |
| if hf_token: | |
| try: | |
| _ = get_cached_8b_model(hf_token) | |
| logger.info("✓ 8B model cached") | |
| except Exception as e: | |
| logger.warning(f"Could not cache 8B model: {e}") | |
| logger.info("All models preloaded and cached!") | |
| # =========================================================================== | |
| # BACKWARD COMPATIBILITY - Keep all original function names | |
| # =========================================================================== | |
| # These functions exist in the original but we optimize them | |
| validate_with_small_model = lambda *args, **kwargs: "Validation skipped for speed" | |
| extract_keywords_with_llama = lambda conv, hf_token=None: extract_entities_with_small_model(conv)[:100] | |
| generate_response_with_llama = generate_with_llama_70b_hf | |
| generate_clinical_knowledge_with_355m = lambda conv: f"Knowledge: {conv[:100]}..." | |
| generate_with_355m = lambda conv, rag="", hf_token=None: generate_clinical_response_with_xupract(conv, rag, hf_token) | |
| # Ensure we override the slow ranking function | |
| rank_trials_with_355m = rank_trials_FAST | |
| logger.info("Fast Two-LLM System loaded - 355M ranking bypassed!") | |