LightOnOCR / app.py
IFMedTechdemo's picture
Update app.py
ec11554 verified
raw
history blame
21.5 kB
#!/usr/bin/env python3
import subprocess
import sys
import spaces
import torch
import gradio as gr
from PIL import Image
from io import BytesIO
import pypdfium2 as pdfium
from transformers import (
LightOnOCRForConditionalGeneration,
LightOnOCRProcessor,
)
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
attn_implementation = "sdpa"
dtype = torch.bfloat16
print("Using sdpa for GPU")
else:
attn_implementation = "eager"
dtype = torch.float32
print("Using eager attention for CPU")
print(f"Loading LightOnOCR model on {device} with {attn_implementation} attention...")
ocr_model = LightOnOCRForConditionalGeneration.from_pretrained(
"lightonai/LightOnOCR-1B-1025",
attn_implementation=attn_implementation,
torch_dtype=dtype,
trust_remote_code=True,
).to(device).eval()
processor = LightOnOCRProcessor.from_pretrained(
"lightonai/LightOnOCR-1B-1025",
trust_remote_code=True,
)
print("LightOnOCR model loaded successfully!")
# -------- Clinical NER models (load ONCE) --------
print("Loading clinical NER model...")
ner_tokenizer = AutoTokenizer.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
ner_model = AutoModelForTokenClassification.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
ner_pipeline = pipeline(
"ner",
model=ner_model,
tokenizer=ner_tokenizer,
aggregation_strategy="simple",
)
print("Clinical NER model loaded successfully!")
def render_pdf_page(page, max_resolution=1540, scale=2.77):
"""Render a PDF page to PIL Image."""
width, height = page.get_size()
pixel_width = width * scale
pixel_height = height * scale
resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height)
target_scale = scale * resize_factor
return page.render(scale=target_scale, rev_byteorder=True).to_pil()
def process_pdf(pdf_path, page_num=1):
"""Extract a specific page from PDF."""
pdf = pdfium.PdfDocument(pdf_path)
total_pages = len(pdf)
page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
page = pdf[page_idx]
img = render_pdf_page(page)
pdf.close()
return img, total_pages, page_idx + 1
def clean_output_text(text):
"""Remove chat template artifacts from output."""
# Remove common chat template markers
markers_to_remove = ["system", "user", "assistant"]
# Split by lines and filter
lines = text.split('\n')
cleaned_lines = []
for line in lines:
stripped = line.strip()
# Skip lines that are just template markers
if stripped.lower() not in markers_to_remove:
cleaned_lines.append(line)
# Join back and strip leading/trailing whitespace
cleaned = '\n'.join(cleaned_lines).strip()
# Alternative approach: if there's an "assistant" marker, take everything after it
if "assistant" in text.lower():
parts = text.split("assistant", 1)
if len(parts) > 1:
cleaned = parts[1].strip()
return cleaned
@spaces.GPU
def extract_text_from_image(image, temperature=0.2):
"""Extract text from image using LightOnOCR model, and run clinical NER."""
# Prepare the chat format
chat = [
{
"role": "user",
"content": [
{"type": "image", "url": image}, # adjust to {"type": "image", "image": image} if LightOnOCR expects that
],
}
]
# Tokenize
inputs = processor.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
# Move inputs to device
inputs = {
k: (
v.to(device=device, dtype=dtype)
if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
else v.to(device)
if isinstance(v, torch.Tensor)
else v
)
for k, v in inputs.items()
}
generation_kwargs = dict(
**inputs,
max_new_tokens=2048,
temperature=temperature if temperature > 0 else 0.0,
use_cache=True,
do_sample=temperature > 0,
)
# Non-streaming generation
with torch.no_grad():
outputs = ocr_model.generate(**generation_kwargs)
output_text = processor.decode(outputs[0], skip_special_tokens=True)
cleaned_text = clean_output_text(output_text)
print("\n this is cleaned_text",cleaned_text )
# Clinical NER on the full cleaned text
entities = ner_pipeline(cleaned_text)
print("\n this is entity",entities)
medications = []
for ent in entities:
if ent["entity_group"] == "treatment":
word = ent["word"]
if word.startswith("##") and medications:
medications[-1] += word[2:]
else:
medications.append(word)
medications_str = ", ".join(set(medications)) if medications else "None detected"
yield cleaned_text, medications_str
def process_input(file_input, temperature, page_num):
"""Process uploaded file (image or PDF) and extract text with optional streaming."""
if file_input is None:
# 6 outputs: [output_text, medications_output, raw_output, page_info, rendered_image, num_pages]
yield "Please upload an image or PDF first.", "", "", "", None, 1
return
image_to_process = None
page_info = ""
slider_value = page_num
file_path = file_input if isinstance(file_input, str) else file_input.name
# Handle PDF files
if file_path.lower().endswith(".pdf"):
try:
image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
page_info = f"Processing page {actual_page} of {total_pages}"
slider_value = actual_page
except Exception as e:
msg = f"Error processing PDF: {str(e)}"
yield msg, "", msg, "", None, slider_value
return
else:
# Handle image files
try:
image_to_process = Image.open(file_path)
page_info = "Processing image"
except Exception as e:
msg = f"Error opening image: {str(e)}"
yield msg, "", msg, "", None, slider_value
return
try:
# Extract text using LightOnOCR with optional streaming
for extracted_text, medications in extract_text_from_image(
image_to_process, temperature
):
raw_md = extracted_text # or you can keep a different raw version
# 6 outputs: markdown_text, medications, raw_output, page_info, image, slider
yield extracted_text, medications, raw_md, page_info, image_to_process, gr.update(
value=slider_value
)
except Exception as e:
error_msg = f"Error during text extraction: {str(e)}"
# 6 outputs
yield error_msg, "", error_msg, page_info, image_to_process, gr.update(value=slider_value)
def update_slider(file_input):
"""Update page slider based on PDF page count."""
if file_input is None:
return gr.update(maximum=20, value=1)
file_path = file_input if isinstance(file_input, str) else file_input.name
if file_path.lower().endswith('.pdf'):
try:
pdf = pdfium.PdfDocument(file_path)
total_pages = len(pdf)
pdf.close()
return gr.update(maximum=total_pages, value=1)
except:
return gr.update(maximum=20, value=1)
else:
return gr.update(maximum=1, value=1)
# Create Gradio interface
with gr.Blocks(title="📖 Image/PDF OCR with LightOnOCR", theme=gr.themes.Soft()) as demo:
gr.Markdown(f"""
# 📖 Image/PDF to Text Extraction with LightOnOCR
**💡 How to use:**
1. Upload an image or PDF
2. For PDFs: select which page to extract (1-20)
3. Adjust temperature if needed
4. Click "Extract Text"
**Note:** The Markdown rendering for tables may not always be perfect. Check the raw output for complex tables!
**Model:** LightOnOCR-1B-1025 by LightOn AI
**Device:** {device.upper()}
**Attention:** {attn_implementation}
""")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="🖼️ Upload Image or PDF",
file_types=[".pdf", ".png", ".jpg", ".jpeg"],
type="filepath"
)
rendered_image = gr.Image(
label="📄 Preview",
type="pil",
height=400,
interactive=False
)
num_pages = gr.Slider(
minimum=1,
maximum=20,
value=1,
step=1,
label="PDF: Page Number",
info="Select which page to extract"
)
page_info = gr.Textbox(
label="Processing Info",
value="",
interactive=False
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.2,
step=0.05,
label="Temperature",
info="0.0 = deterministic, Higher = more varied"
)
submit_btn = gr.Button("Extract Text", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Column(scale=2):
output_text = gr.Markdown(
label="📄 Extracted Text (Rendered)",
value="*Extracted text will appear here...*"
)
medications_output = gr.Textbox(
label="💊 Extracted Medicines/Drugs",
placeholder="Medicine/drug names will appear here...",
lines=2,
max_lines=5,
interactive=False,
show_copy_button=True
)
with gr.Row():
with gr.Column():
raw_output = gr.Textbox(
label="Raw Markdown Output",
placeholder="Raw text will appear here...",
lines=20,
max_lines=30,
show_copy_button=True
)
# Event handlers
submit_btn.click(
fn=process_input,
inputs=[file_input, temperature, num_pages, ],
outputs=[output_text, medications_output, raw_output, page_info, rendered_image, num_pages]
)
file_input.change(
fn=update_slider,
inputs=[file_input],
outputs=[num_pages]
)
clear_btn.click(
fn=lambda: (None, "*Extracted text will appear here...*", "", "", None, 1),
outputs=[file_input, output_text, raw_output, page_info, rendered_image, num_pages]
)
if __name__ == "__main__":
demo.launch()
#################################### old code to be checked #############################################
# import sys
# import threading
# import spaces
# import torch
# import gradio as gr
# from PIL import Image
# from io import BytesIO
# import pypdfium2 as pdfium
# from transformers import (
# LightOnOCRForConditionalGeneration,
# LightOnOCRProcessor,
# TextIteratorStreamer,
# )
# # ---- CLINICAL NER IMPORTS ----
# import spacy
# device = "cuda" if torch.cuda.is_available() else "cpu"
# # Choose best attention implementation based on device
# if device == "cuda":
# attn_implementation = "sdpa"
# dtype = torch.bfloat16
# print("Using sdpa for GPU")
# else:
# attn_implementation = "eager" # Best for CPU
# dtype = torch.float32
# print("Using eager attention for CPU")
# # Initialize the LightOnOCR model and processor
# print(f"Loading model on {device} with {attn_implementation} attention...")
# model = LightOnOCRForConditionalGeneration.from_pretrained(
# "lightonai/LightOnOCR-1B-1025",
# attn_implementation=attn_implementation,
# torch_dtype=dtype,
# trust_remote_code=True
# ).to(device).eval()
# processor = LightOnOCRProcessor.from_pretrained(
# "lightonai/LightOnOCR-1B-1025",
# trust_remote_code=True
# )
# print("Model loaded successfully!")
# # ---- LOAD CLINICAL NER MODEL (BC5CDR) ----
# print("Loading clinical NER model (bc5cdr)...")
# nlp_ner = spacy.load("en_ner_bc5cdr_md")
# print("Clinical NER loaded.")
# def render_pdf_page(page, max_resolution=1540, scale=2.77):
# """Render a PDF page to PIL Image."""
# width, height = page.get_size()
# pixel_width = width * scale
# pixel_height = height * scale
# resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height)
# target_scale = scale * resize_factor
# return page.render(scale=target_scale, rev_byteorder=True).to_pil()
# def process_pdf(pdf_path, page_num=1):
# """Extract a specific page from PDF."""
# pdf = pdfium.PdfDocument(pdf_path)
# total_pages = len(pdf)
# page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
# page = pdf[page_idx]
# img = render_pdf_page(page)
# pdf.close()
# return img, total_pages, page_idx + 1
# def clean_output_text(text):
# """Remove chat template artifacts from output."""
# markers_to_remove = ["system", "user", "assistant"]
# lines = text.split('\n')
# cleaned_lines = []
# for line in lines:
# stripped = line.strip()
# # Skip lines that are just template markers
# if stripped.lower() not in markers_to_remove:
# cleaned_lines.append(line)
# cleaned = '\n'.join(cleaned_lines).strip()
# if "assistant" in text.lower():
# parts = text.split("assistant", 1)
# if len(parts) > 1:
# cleaned = parts[1].strip()
# return cleaned
# def extract_medication_names(text):
# """Extract medication names using clinical NER (spacy: bc5cdr CHEMICAL)."""
# doc = nlp_ner(text)
# meds = [ent.text for ent in doc.ents if ent.label_ == "CHEMICAL"]
# meds_unique = list(dict.fromkeys(meds))
# return meds_unique
# @spaces.GPU
# def extract_text_from_image(image, temperature=0.2, stream=False):
# """Extract text from image using LightOnOCR model."""
# chat = [
# {
# "role": "user",
# "content": [
# {"type": "image", "url": image},
# ],
# }
# ]
# inputs = processor.apply_chat_template(
# chat,
# add_generation_prompt=True,
# tokenize=True,
# return_dict=True,
# return_tensors="pt"
# )
# inputs = {
# k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
# else v.to(device) if isinstance(v, torch.Tensor)
# else v
# for k, v in inputs.items()
# }
# generation_kwargs = dict(
# **inputs,
# max_new_tokens=2048,
# temperature=temperature if temperature > 0 else 0.0,
# use_cache=True,
# do_sample=temperature > 0,
# )
# if stream:
# # Streaming generation
# streamer = TextIteratorStreamer(
# processor.tokenizer,
# skip_prompt=True,
# skip_special_tokens=True
# )
# generation_kwargs["streamer"] = streamer
# thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
# thread.start()
# full_text = ""
# for new_text in streamer:
# full_text += new_text
# cleaned_text = clean_output_text(full_text)
# yield cleaned_text
# thread.join()
# else:
# # Non-streaming generation
# with torch.no_grad():
# outputs = model.generate(**generation_kwargs)
# output_text = processor.decode(outputs[0], skip_special_tokens=True)
# cleaned_text = clean_output_text(output_text)
# yield cleaned_text
# def process_input(file_input, temperature, page_num, enable_streaming):
# """Process uploaded file (image or PDF) and extract medication names via OCR+NER."""
# if file_input is None:
# yield "Please upload an image or PDF first.", "", "", None, gr.update()
# return
# image_to_process = None
# page_info = ""
# file_path = file_input if isinstance(file_input, str) else file_input.name
# # Handle PDF files
# if file_path.lower().endswith('.pdf'):
# try:
# image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
# page_info = f"Processing page {actual_page} of {total_pages}"
# except Exception as e:
# yield f"Error processing PDF: {str(e)}", "", "", None, gr.update()
# return
# # Handle image files
# else:
# try:
# image_to_process = Image.open(file_path)
# page_info = "Processing image"
# except Exception as e:
# yield f"Error opening image: {str(e)}", "", "", None, gr.update()
# return
# try:
# for extracted_text in extract_text_from_image(image_to_process, temperature, stream=enable_streaming):
# meds = extract_medication_names(extracted_text)
# meds_str = "\n".join(meds) if meds else "No medications found."
# yield meds_str, meds_str, page_info, image_to_process, gr.update()
# except Exception as e:
# error_msg = f"Error during text extraction: {str(e)}"
# yield error_msg, error_msg, page_info, image_to_process, gr.update()
# def update_slider(file_input):
# """Update page slider based on PDF page count."""
# if file_input is None:
# return gr.update(maximum=20, value=1)
# file_path = file_input if isinstance(file_input, str) else file_input.name
# if file_path.lower().endswith('.pdf'):
# try:
# pdf = pdfium.PdfDocument(file_path)
# total_pages = len(pdf)
# pdf.close()
# return gr.update(maximum=total_pages, value=1)
# except:
# return gr.update(maximum=20, value=1)
# else:
# return gr.update(maximum=1, value=1)
# # ----- GRADIO UI -----
# with gr.Blocks(title="📖 Image/PDF OCR + Clinical NER", theme=gr.themes.Soft()) as demo:
# gr.Markdown(f"""
# # 📖 Medication Extraction from Image/PDF with LightOnOCR + Clinical NER
# **💡 How to use:**
# 1. Upload an image or PDF
# 2. For PDFs: select which page to extract
# 3. Adjust temperature if needed
# 4. Click "Extract Medications"
# **Output:** Only medication names found in text (via NER)
# **Model:** LightOnOCR-1B-1025 by LightOn AI
# **Device:** {device.upper()}
# **Attention:** {attn_implementation}
# """)
# with gr.Row():
# with gr.Column(scale=1):
# file_input = gr.File(
# label="🖼️ Upload Image or PDF",
# file_types=[".pdf", ".png", ".jpg", ".jpeg"],
# type="filepath"
# )
# rendered_image = gr.Image(
# label="📄 Preview",
# type="pil",
# height=400,
# interactive=False
# )
# num_pages = gr.Slider(
# minimum=1,
# maximum=20,
# value=1,
# step=1,
# label="PDF: Page Number",
# info="Select which page to extract"
# )
# page_info = gr.Textbox(
# label="Processing Info",
# value="",
# interactive=False
# )
# temperature = gr.Slider(
# minimum=0.0,
# maximum=1.0,
# value=0.2,
# step=0.05,
# label="Temperature",
# info="0.0 = deterministic, Higher = more varied"
# )
# enable_streaming = gr.Checkbox(
# label="Enable Streaming",
# value=True,
# info="Show text progressively as it's generated"
# )
# submit_btn = gr.Button("Extract Medications", variant="primary")
# clear_btn = gr.Button("Clear", variant="secondary")
# with gr.Column(scale=2):
# output_text = gr.Markdown(
# label="🩺 Extracted Medication Names",
# value="*Medication names will appear here...*"
# )
# with gr.Row():
# with gr.Column():
# raw_output = gr.Textbox(
# label="Extracted Medication Names (Raw)",
# placeholder="Medication list will appear here...",
# lines=20,
# max_lines=30,
# show_copy_button=True
# )
# # Event handlers
# submit_btn.click(
# fn=process_input,
# inputs=[file_input, temperature, num_pages, enable_streaming],
# outputs=[output_text, raw_output, page_info, rendered_image, num_pages]
# )
# file_input.change(
# fn=update_slider,
# inputs=[file_input],
# outputs=[num_pages]
# )
# clear_btn.click(
# fn=lambda: (None, "*Medication names will appear here...*", "", "", None, 1),
# outputs=[file_input, output_text, raw_output, page_info, rendered_image, num_pages]
# )
# if __name__ == "__main__":
# demo.launch()