CTapi-raw / test_option_b.py
Your Name
Deploy Option B: Query Parser + RAG + 355M Ranking
45cf63e
"""
Test Option B System with Physician Query
Tests: "what should a physician considering prescribing ianalumab for sjogren's disease know"
"""
import os
import sys
import json
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Check if HF_TOKEN is set
if not os.getenv("HF_TOKEN"):
logger.warning("⚠️ HF_TOKEN not set! Query parsing will fail.")
logger.warning(" Set it with: export HF_TOKEN=your_token_here")
logger.warning(" Continuing with limited functionality...")
try:
# Try to use the existing foundation_engine which has download capability
logger.info("Loading foundation_engine (with auto-download)...")
import foundation_engine
logger.info("=" * 80)
logger.info("TESTING OPTION B SYSTEM")
logger.info("=" * 80)
# Load data (will auto-download if needed)
logger.info("Loading RAG data (will download from HF if needed)...")
foundation_engine.load_embeddings()
logger.info("=" * 80)
logger.info("DATA LOADED SUCCESSFULLY")
logger.info("=" * 80)
logger.info(f"✓ Trials loaded: {len(foundation_engine.doc_chunks):,}")
logger.info(f"✓ Embeddings shape: {foundation_engine.doc_embeddings.shape if foundation_engine.doc_embeddings is not None else 'None'}")
logger.info(f"✓ Inverted index terms: {len(foundation_engine.inverted_index):,}" if foundation_engine.inverted_index else "None")
# Test query
test_query = "what should a physician considering prescribing ianalumab for sjogren's disease know"
logger.info("=" * 80)
logger.info(f"TEST QUERY: {test_query}")
logger.info("=" * 80)
# Use the structured query processor (Option B!)
logger.info("Processing with Option B pipeline...")
result = foundation_engine.process_query_structured(test_query, top_k=5)
logger.info("=" * 80)
logger.info("RESULTS")
logger.info("=" * 80)
# Print timing breakdown
if 'benchmarking' in result:
bench = result['benchmarking']
logger.info(f"\n⏱️ PERFORMANCE:")
logger.info(f" Query Parsing: {bench.get('query_parsing_time', 0):.2f}s")
logger.info(f" RAG Search: {bench.get('rag_search_time', 0):.2f}s")
logger.info(f" 355M Ranking: {bench.get('355m_ranking_time', 0):.2f}s")
logger.info(f" TOTAL: {result.get('processing_time', 0):.2f}s")
# Print query analysis
if 'query_analysis' in result:
qa = result['query_analysis']
logger.info(f"\n🔍 QUERY ANALYSIS:")
entities = qa.get('extracted_entities', {})
logger.info(f" Drugs: {entities.get('drugs', [])}")
logger.info(f" Diseases: {entities.get('diseases', [])}")
logger.info(f" Companies: {entities.get('companies', [])}")
logger.info(f" Endpoints: {entities.get('endpoints', [])}")
logger.info(f" Optimized: {qa.get('optimized_search', 'N/A')}")
# Print results summary
if 'results' in result:
res = result['results']
logger.info(f"\n📊 SEARCH RESULTS:")
logger.info(f" Total Found: {res.get('total_found', 0)}")
logger.info(f" Returned: {res.get('returned', 0)}")
logger.info(f" Top Relevance: {res.get('top_relevance_score', 0):.3f}")
# Print top trials
if 'trials' in result and len(result['trials']) > 0:
logger.info(f"\n🏥 TOP TRIALS:\n")
for i, trial in enumerate(result['trials'][:5], 1):
logger.info(f"{i}. NCT ID: {trial['nct_id']}")
logger.info(f" Title: {trial.get('title', 'N/A')}")
logger.info(f" Status: {trial.get('status', 'N/A')}")
logger.info(f" Phase: {trial.get('phase', 'N/A')}")
if 'scoring' in trial:
scoring = trial['scoring']
logger.info(f" Scoring:")
logger.info(f" Relevance: {scoring.get('relevance_score', 0):.3f}")
logger.info(f" Perplexity: {scoring.get('perplexity', 'N/A')}")
logger.info(f" Rank before: {scoring.get('rank_before_355m', 'N/A')}")
logger.info(f" Rank after: {scoring.get('rank_after_355m', 'N/A')}")
rank_change = ""
if scoring.get('rank_before_355m') and scoring.get('rank_after_355m'):
change = scoring['rank_before_355m'] - scoring['rank_after_355m']
if change > 0:
rank_change = f" (↑ improved by {change})"
elif change < 0:
rank_change = f" (↓ dropped by {-change})"
else:
rank_change = " (→ no change)"
logger.info(f" Impact: {rank_change}")
logger.info(f" URL: {trial.get('url', 'N/A')}")
logger.info("")
# Save full results to JSON
output_file = "test_results_option_b.json"
with open(output_file, 'w') as f:
json.dump(result, f, indent=2)
logger.info(f"💾 Full results saved to: {output_file}")
logger.info("=" * 80)
logger.info("TEST COMPLETED SUCCESSFULLY ✅")
logger.info("=" * 80)
# Print what a physician should know
logger.info("\n📋 SUMMARY FOR PHYSICIAN:")
logger.info(" Based on the ranked trials, here's what the API returns:")
logger.info(f" - Found {result['results']['returned']} relevant trials")
logger.info(f" - Top trial has {result['results']['top_relevance_score']:.1%} relevance")
logger.info("")
logger.info(" ⚠️ NOTE: This API returns STRUCTURED DATA only")
logger.info(" The chatbot company would use their LLM to generate a response like:")
logger.info("")
logger.info(" 'Based on clinical trial data, physicians prescribing ianalumab")
logger.info(" for Sjögren's disease should know:'")
logger.info(f" '- {len(result['trials'])} clinical trials are available'")
if result['trials']:
trial = result['trials'][0]
logger.info(f" '- Primary trial: {trial.get('title', 'N/A')}'")
logger.info(f" '- Status: {trial.get('status', 'N/A')}'")
logger.info(f" '- Phase: {trial.get('phase', 'N/A')}'")
logger.info("")
logger.info(" The client's LLM would generate this response using the JSON data.")
logger.info("")
except ImportError as e:
logger.error(f"❌ Import failed: {e}")
logger.error(" Make sure you're in the correct directory with foundation_engine.py")
sys.exit(1)
except Exception as e:
logger.error(f"❌ Test failed: {e}")
import traceback
logger.error(traceback.format_exc())
sys.exit(1)