Spaces:
Runtime error
Runtime error
| import asyncio | |
| import weave | |
| from medrag_multi_modal.assistant import LLMClient, MedQAAssistant | |
| from medrag_multi_modal.metrics import MMLUOptionAccuracy | |
| from medrag_multi_modal.retrieval.text_retrieval import ( | |
| BM25sRetriever, | |
| ContrieverRetriever, | |
| MedCPTRetriever, | |
| NVEmbed2Retriever, | |
| ) | |
| def test_mmlu_correctness_anatomy_bm25s(model_name: str): | |
| weave.init("ml-colabs/medrag-multi-modal") | |
| retriever = BM25sRetriever().from_index( | |
| index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s" | |
| ) | |
| llm_client = LLMClient(model_name=model_name) | |
| medqa_assistant = MedQAAssistant( | |
| llm_client=llm_client, | |
| retriever=retriever, | |
| top_k_chunks_for_query=5, | |
| top_k_chunks_for_options=3, | |
| ) | |
| dataset = weave.ref("mmlu-anatomy-test:v2").get() | |
| with weave.attributes( | |
| {"retriever": retriever.__class__.__name__, "llm": llm_client.model_name} | |
| ): | |
| evaluation = weave.Evaluation( | |
| dataset=dataset, | |
| scorers=[MMLUOptionAccuracy()], | |
| name="MMLU-Anatomy-BM25s", | |
| ) | |
| summary = asyncio.run( | |
| evaluation.evaluate( | |
| medqa_assistant, | |
| __weave={"display_name": evaluation.name + ":" + llm_client.model_name}, | |
| ) | |
| ) | |
| assert ( | |
| summary["MMLUOptionAccuracy"]["correct"]["true_count"] | |
| > summary["MMLUOptionAccuracy"]["correct"]["false_count"] | |
| ) | |
| def test_mmlu_correctness_anatomy_contriever(model_name: str): | |
| weave.init("ml-colabs/medrag-multi-modal") | |
| retriever = ContrieverRetriever().from_index( | |
| index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever", | |
| chunk_dataset="ashwiniai/medrag-text-corpus-chunks", | |
| ) | |
| llm_client = LLMClient(model_name=model_name) | |
| medqa_assistant = MedQAAssistant( | |
| llm_client=llm_client, | |
| retriever=retriever, | |
| top_k_chunks_for_query=5, | |
| top_k_chunks_for_options=3, | |
| ) | |
| dataset = weave.ref("mmlu-anatomy-test:v2").get() | |
| with weave.attributes( | |
| {"retriever": retriever.__class__.__name__, "llm": llm_client.model_name} | |
| ): | |
| evaluation = weave.Evaluation( | |
| dataset=dataset, | |
| scorers=[MMLUOptionAccuracy()], | |
| name="MMLU-Anatomy-Contriever", | |
| ) | |
| summary = asyncio.run( | |
| evaluation.evaluate( | |
| medqa_assistant, | |
| __weave={"display_name": evaluation.name + ":" + llm_client.model_name}, | |
| ) | |
| ) | |
| assert ( | |
| summary["MMLUOptionAccuracy"]["correct"]["true_count"] | |
| > summary["MMLUOptionAccuracy"]["correct"]["false_count"] | |
| ) | |
| def test_mmlu_correctness_anatomy_medcpt(model_name: str): | |
| weave.init("ml-colabs/medrag-multi-modal") | |
| retriever = MedCPTRetriever().from_index( | |
| index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", | |
| chunk_dataset="ashwiniai/medrag-text-corpus-chunks", | |
| ) | |
| llm_client = LLMClient(model_name=model_name) | |
| medqa_assistant = MedQAAssistant( | |
| llm_client=llm_client, | |
| retriever=retriever, | |
| top_k_chunks_for_query=5, | |
| top_k_chunks_for_options=3, | |
| ) | |
| dataset = weave.ref("mmlu-anatomy-test:v2").get() | |
| with weave.attributes( | |
| {"retriever": retriever.__class__.__name__, "llm": llm_client.model_name} | |
| ): | |
| evaluation = weave.Evaluation( | |
| dataset=dataset, | |
| scorers=[MMLUOptionAccuracy()], | |
| name="MMLU-Anatomy-MedCPT", | |
| ) | |
| summary = asyncio.run( | |
| evaluation.evaluate( | |
| medqa_assistant, | |
| __weave={"display_name": evaluation.name + ":" + llm_client.model_name}, | |
| ) | |
| ) | |
| assert ( | |
| summary["MMLUOptionAccuracy"]["correct"]["true_count"] | |
| > summary["MMLUOptionAccuracy"]["correct"]["false_count"] | |
| ) | |
| def test_mmlu_correctness_anatomy_nvembed2(model_name: str): | |
| weave.init("ml-colabs/medrag-multi-modal") | |
| retriever = NVEmbed2Retriever().from_index( | |
| index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", | |
| chunk_dataset="ashwiniai/medrag-text-corpus-chunks", | |
| ) | |
| llm_client = LLMClient(model_name=model_name) | |
| medqa_assistant = MedQAAssistant( | |
| llm_client=llm_client, | |
| retriever=retriever, | |
| top_k_chunks_for_query=5, | |
| top_k_chunks_for_options=3, | |
| ) | |
| dataset = weave.ref("mmlu-anatomy-test:v2").get() | |
| with weave.attributes( | |
| {"retriever": retriever.__class__.__name__, "llm": llm_client.model_name} | |
| ): | |
| evaluation = weave.Evaluation( | |
| dataset=dataset, | |
| scorers=[MMLUOptionAccuracy()], | |
| name="MMLU-Anatomy-NVEmbed2", | |
| ) | |
| summary = asyncio.run( | |
| evaluation.evaluate( | |
| medqa_assistant, | |
| __weave={"display_name": evaluation.name + ":" + llm_client.model_name}, | |
| ) | |
| ) | |
| assert ( | |
| summary["MMLUOptionAccuracy"]["correct"]["true_count"] | |
| > summary["MMLUOptionAccuracy"]["correct"]["false_count"] | |
| ) | |