Spaces:
Paused
Paused
| """ | |
| Test Option B System with Physician Query | |
| Tests: "what should a physician considering prescribing ianalumab for sjogren's disease know" | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Check if HF_TOKEN is set | |
| if not os.getenv("HF_TOKEN"): | |
| logger.warning("⚠️ HF_TOKEN not set! Query parsing will fail.") | |
| logger.warning(" Set it with: export HF_TOKEN=your_token_here") | |
| logger.warning(" Continuing with limited functionality...") | |
| try: | |
| # Try to use the existing foundation_engine which has download capability | |
| logger.info("Loading foundation_engine (with auto-download)...") | |
| import foundation_engine | |
| logger.info("=" * 80) | |
| logger.info("TESTING OPTION B SYSTEM") | |
| logger.info("=" * 80) | |
| # Load data (will auto-download if needed) | |
| logger.info("Loading RAG data (will download from HF if needed)...") | |
| foundation_engine.load_embeddings() | |
| logger.info("=" * 80) | |
| logger.info("DATA LOADED SUCCESSFULLY") | |
| logger.info("=" * 80) | |
| logger.info(f"✓ Trials loaded: {len(foundation_engine.doc_chunks):,}") | |
| logger.info(f"✓ Embeddings shape: {foundation_engine.doc_embeddings.shape if foundation_engine.doc_embeddings is not None else 'None'}") | |
| logger.info(f"✓ Inverted index terms: {len(foundation_engine.inverted_index):,}" if foundation_engine.inverted_index else "None") | |
| # Test query | |
| test_query = "what should a physician considering prescribing ianalumab for sjogren's disease know" | |
| logger.info("=" * 80) | |
| logger.info(f"TEST QUERY: {test_query}") | |
| logger.info("=" * 80) | |
| # Use the structured query processor (Option B!) | |
| logger.info("Processing with Option B pipeline...") | |
| result = foundation_engine.process_query_structured(test_query, top_k=5) | |
| logger.info("=" * 80) | |
| logger.info("RESULTS") | |
| logger.info("=" * 80) | |
| # Print timing breakdown | |
| if 'benchmarking' in result: | |
| bench = result['benchmarking'] | |
| logger.info(f"\n⏱️ PERFORMANCE:") | |
| logger.info(f" Query Parsing: {bench.get('query_parsing_time', 0):.2f}s") | |
| logger.info(f" RAG Search: {bench.get('rag_search_time', 0):.2f}s") | |
| logger.info(f" 355M Ranking: {bench.get('355m_ranking_time', 0):.2f}s") | |
| logger.info(f" TOTAL: {result.get('processing_time', 0):.2f}s") | |
| # Print query analysis | |
| if 'query_analysis' in result: | |
| qa = result['query_analysis'] | |
| logger.info(f"\n🔍 QUERY ANALYSIS:") | |
| entities = qa.get('extracted_entities', {}) | |
| logger.info(f" Drugs: {entities.get('drugs', [])}") | |
| logger.info(f" Diseases: {entities.get('diseases', [])}") | |
| logger.info(f" Companies: {entities.get('companies', [])}") | |
| logger.info(f" Endpoints: {entities.get('endpoints', [])}") | |
| logger.info(f" Optimized: {qa.get('optimized_search', 'N/A')}") | |
| # Print results summary | |
| if 'results' in result: | |
| res = result['results'] | |
| logger.info(f"\n📊 SEARCH RESULTS:") | |
| logger.info(f" Total Found: {res.get('total_found', 0)}") | |
| logger.info(f" Returned: {res.get('returned', 0)}") | |
| logger.info(f" Top Relevance: {res.get('top_relevance_score', 0):.3f}") | |
| # Print top trials | |
| if 'trials' in result and len(result['trials']) > 0: | |
| logger.info(f"\n🏥 TOP TRIALS:\n") | |
| for i, trial in enumerate(result['trials'][:5], 1): | |
| logger.info(f"{i}. NCT ID: {trial['nct_id']}") | |
| logger.info(f" Title: {trial.get('title', 'N/A')}") | |
| logger.info(f" Status: {trial.get('status', 'N/A')}") | |
| logger.info(f" Phase: {trial.get('phase', 'N/A')}") | |
| if 'scoring' in trial: | |
| scoring = trial['scoring'] | |
| logger.info(f" Scoring:") | |
| logger.info(f" Relevance: {scoring.get('relevance_score', 0):.3f}") | |
| logger.info(f" Perplexity: {scoring.get('perplexity', 'N/A')}") | |
| logger.info(f" Rank before: {scoring.get('rank_before_355m', 'N/A')}") | |
| logger.info(f" Rank after: {scoring.get('rank_after_355m', 'N/A')}") | |
| rank_change = "" | |
| if scoring.get('rank_before_355m') and scoring.get('rank_after_355m'): | |
| change = scoring['rank_before_355m'] - scoring['rank_after_355m'] | |
| if change > 0: | |
| rank_change = f" (↑ improved by {change})" | |
| elif change < 0: | |
| rank_change = f" (↓ dropped by {-change})" | |
| else: | |
| rank_change = " (→ no change)" | |
| logger.info(f" Impact: {rank_change}") | |
| logger.info(f" URL: {trial.get('url', 'N/A')}") | |
| logger.info("") | |
| # Save full results to JSON | |
| output_file = "test_results_option_b.json" | |
| with open(output_file, 'w') as f: | |
| json.dump(result, f, indent=2) | |
| logger.info(f"💾 Full results saved to: {output_file}") | |
| logger.info("=" * 80) | |
| logger.info("TEST COMPLETED SUCCESSFULLY ✅") | |
| logger.info("=" * 80) | |
| # Print what a physician should know | |
| logger.info("\n📋 SUMMARY FOR PHYSICIAN:") | |
| logger.info(" Based on the ranked trials, here's what the API returns:") | |
| logger.info(f" - Found {result['results']['returned']} relevant trials") | |
| logger.info(f" - Top trial has {result['results']['top_relevance_score']:.1%} relevance") | |
| logger.info("") | |
| logger.info(" ⚠️ NOTE: This API returns STRUCTURED DATA only") | |
| logger.info(" The chatbot company would use their LLM to generate a response like:") | |
| logger.info("") | |
| logger.info(" 'Based on clinical trial data, physicians prescribing ianalumab") | |
| logger.info(" for Sjögren's disease should know:'") | |
| logger.info(f" '- {len(result['trials'])} clinical trials are available'") | |
| if result['trials']: | |
| trial = result['trials'][0] | |
| logger.info(f" '- Primary trial: {trial.get('title', 'N/A')}'") | |
| logger.info(f" '- Status: {trial.get('status', 'N/A')}'") | |
| logger.info(f" '- Phase: {trial.get('phase', 'N/A')}'") | |
| logger.info("") | |
| logger.info(" The client's LLM would generate this response using the JSON data.") | |
| logger.info("") | |
| except ImportError as e: | |
| logger.error(f"❌ Import failed: {e}") | |
| logger.error(" Make sure you're in the correct directory with foundation_engine.py") | |
| sys.exit(1) | |
| except Exception as e: | |
| logger.error(f"❌ Test failed: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| sys.exit(1) | |