IFMedTechdemo commited on
Commit
0c53c1b
·
verified ·
1 Parent(s): d20ccd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +395 -35
app.py CHANGED
@@ -1,4 +1,5 @@
1
  #!/usr/bin/env python3
 
2
  import subprocess
3
  import sys
4
  import threading
@@ -16,8 +17,7 @@ from transformers import (
16
  TextIteratorStreamer,
17
  )
18
 
19
- # ---- CLINICAL NER IMPORTS ----
20
- import spacy
21
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
@@ -46,10 +46,6 @@ processor = LightOnOCRProcessor.from_pretrained(
46
  )
47
  print("Model loaded successfully!")
48
 
49
- # ---- LOAD CLINICAL NER MODEL (BC5CDR) ----
50
- print("Loading clinical NER model (bc5cdr)...")
51
- nlp_ner = spacy.load("en_ner_bc5cdr_md")
52
- print("Clinical NER loaded.")
53
 
54
  def render_pdf_page(page, max_resolution=1540, scale=2.77):
55
  """Render a PDF page to PIL Image."""
@@ -76,32 +72,35 @@ def process_pdf(pdf_path, page_num=1):
76
 
77
  def clean_output_text(text):
78
  """Remove chat template artifacts from output."""
 
79
  markers_to_remove = ["system", "user", "assistant"]
 
 
80
  lines = text.split('\n')
81
  cleaned_lines = []
 
82
  for line in lines:
83
  stripped = line.strip()
84
  # Skip lines that are just template markers
85
  if stripped.lower() not in markers_to_remove:
86
  cleaned_lines.append(line)
 
 
87
  cleaned = '\n'.join(cleaned_lines).strip()
 
 
88
  if "assistant" in text.lower():
89
  parts = text.split("assistant", 1)
90
  if len(parts) > 1:
91
  cleaned = parts[1].strip()
 
92
  return cleaned
93
 
94
- def extract_medication_names(text):
95
- """Extract medication names using clinical NER (spacy: bc5cdr CHEMICAL)."""
96
- doc = nlp_ner(text)
97
- meds = [ent.text for ent in doc.ents if ent.label_ == "CHEMICAL"]
98
- meds_unique = list(dict.fromkeys(meds))
99
- return meds_unique
100
-
101
 
102
  @spaces.GPU
103
  def extract_text_from_image(image, temperature=0.2, stream=False):
104
  """Extract text from image using LightOnOCR model."""
 
105
  chat = [
106
  {
107
  "role": "user",
@@ -110,6 +109,8 @@ def extract_text_from_image(image, temperature=0.2, stream=False):
110
  ],
111
  }
112
  ]
 
 
113
  inputs = processor.apply_chat_template(
114
  chat,
115
  add_generation_prompt=True,
@@ -117,12 +118,15 @@ def extract_text_from_image(image, temperature=0.2, stream=False):
117
  return_dict=True,
118
  return_tensors="pt"
119
  )
 
 
120
  inputs = {
121
  k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
122
  else v.to(device) if isinstance(v, torch.Tensor)
123
  else v
124
  for k, v in inputs.items()
125
  }
 
126
  generation_kwargs = dict(
127
  **inputs,
128
  max_new_tokens=2048,
@@ -130,38 +134,76 @@ def extract_text_from_image(image, temperature=0.2, stream=False):
130
  use_cache=True,
131
  do_sample=temperature > 0,
132
  )
 
133
  if stream:
134
- # Streaming generation
135
  streamer = TextIteratorStreamer(
136
  processor.tokenizer,
137
  skip_prompt=True,
138
  skip_special_tokens=True
139
  )
140
  generation_kwargs["streamer"] = streamer
 
 
141
  thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
142
  thread.start()
 
 
143
  full_text = ""
144
  for new_text in streamer:
145
  full_text += new_text
 
146
  cleaned_text = clean_output_text(full_text)
147
  yield cleaned_text
 
148
  thread.join()
149
  else:
150
  # Non-streaming generation
151
  with torch.no_grad():
152
  outputs = model.generate(**generation_kwargs)
 
 
153
  output_text = processor.decode(outputs[0], skip_special_tokens=True)
 
 
154
  cleaned_text = clean_output_text(output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  yield cleaned_text
 
 
 
 
156
 
157
  def process_input(file_input, temperature, page_num, enable_streaming):
158
- """Process uploaded file (image or PDF) and extract medication names via OCR+NER."""
159
  if file_input is None:
160
  yield "Please upload an image or PDF first.", "", "", None, gr.update()
161
  return
 
162
  image_to_process = None
163
  page_info = ""
 
164
  file_path = file_input if isinstance(file_input, str) else file_input.name
 
165
  # Handle PDF files
166
  if file_path.lower().endswith('.pdf'):
167
  try:
@@ -178,20 +220,24 @@ def process_input(file_input, temperature, page_num, enable_streaming):
178
  except Exception as e:
179
  yield f"Error opening image: {str(e)}", "", "", None, gr.update()
180
  return
 
181
  try:
182
- for extracted_text in extract_text_from_image(image_to_process, temperature, stream=enable_streaming):
183
- meds = extract_medication_names(extracted_text)
184
- meds_str = "\n".join(meds) if meds else "No medications found."
185
- yield meds_str, meds_str, page_info, image_to_process, gr.update()
186
  except Exception as e:
187
  error_msg = f"Error during text extraction: {str(e)}"
188
  yield error_msg, error_msg, page_info, image_to_process, gr.update()
189
 
 
190
  def update_slider(file_input):
191
  """Update page slider based on PDF page count."""
192
  if file_input is None:
193
  return gr.update(maximum=20, value=1)
 
194
  file_path = file_input if isinstance(file_input, str) else file_input.name
 
195
  if file_path.lower().endswith('.pdf'):
196
  try:
197
  pdf = pdfium.PdfDocument(file_path)
@@ -203,23 +249,25 @@ def update_slider(file_input):
203
  else:
204
  return gr.update(maximum=1, value=1)
205
 
206
- # ----- GRADIO UI -----
207
- with gr.Blocks(title="📖 Image/PDF OCR + Clinical NER", theme=gr.themes.Soft()) as demo:
 
208
  gr.Markdown(f"""
209
- # 📖 Medication Extraction from Image/PDF with LightOnOCR + Clinical NER
210
 
211
  **💡 How to use:**
212
  1. Upload an image or PDF
213
- 2. For PDFs: select which page to extract
214
  3. Adjust temperature if needed
215
- 4. Click "Extract Medications"
216
 
217
- **Output:** Only medication names found in text (via NER)
218
 
219
  **Model:** LightOnOCR-1B-1025 by LightOn AI
220
  **Device:** {device.upper()}
221
  **Attention:** {attn_implementation}
222
  """)
 
223
  with gr.Row():
224
  with gr.Column(scale=1):
225
  file_input = gr.File(
@@ -259,37 +307,349 @@ with gr.Blocks(title="📖 Image/PDF OCR + Clinical NER", theme=gr.themes.Soft()
259
  value=True,
260
  info="Show text progressively as it's generated"
261
  )
262
- submit_btn = gr.Button("Extract Medications", variant="primary")
263
  clear_btn = gr.Button("Clear", variant="secondary")
 
264
  with gr.Column(scale=2):
265
  output_text = gr.Markdown(
266
- label="🩺 Extracted Medication Names",
267
- value="*Medication names will appear here...*"
268
  )
 
 
 
 
 
 
 
 
 
269
  with gr.Row():
270
  with gr.Column():
271
  raw_output = gr.Textbox(
272
- label="Extracted Medication Names (Raw)",
273
- placeholder="Medication list will appear here...",
274
  lines=20,
275
  max_lines=30,
276
  show_copy_button=True
277
  )
 
278
  # Event handlers
279
  submit_btn.click(
280
- fn=process_input,
281
- inputs=[file_input, temperature, num_pages, enable_streaming],
282
- outputs=[output_text, raw_output, page_info, rendered_image, num_pages]
283
- )
 
284
  file_input.change(
285
  fn=update_slider,
286
  inputs=[file_input],
287
  outputs=[num_pages]
288
  )
 
289
  clear_btn.click(
290
- fn=lambda: (None, "*Medication names will appear here...*", "", "", None, 1),
291
  outputs=[file_input, output_text, raw_output, page_info, rendered_image, num_pages]
292
  )
293
 
 
294
  if __name__ == "__main__":
295
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
+
3
  import subprocess
4
  import sys
5
  import threading
 
17
  TextIteratorStreamer,
18
  )
19
 
20
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
 
21
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
 
46
  )
47
  print("Model loaded successfully!")
48
 
 
 
 
 
49
 
50
  def render_pdf_page(page, max_resolution=1540, scale=2.77):
51
  """Render a PDF page to PIL Image."""
 
72
 
73
  def clean_output_text(text):
74
  """Remove chat template artifacts from output."""
75
+ # Remove common chat template markers
76
  markers_to_remove = ["system", "user", "assistant"]
77
+
78
+ # Split by lines and filter
79
  lines = text.split('\n')
80
  cleaned_lines = []
81
+
82
  for line in lines:
83
  stripped = line.strip()
84
  # Skip lines that are just template markers
85
  if stripped.lower() not in markers_to_remove:
86
  cleaned_lines.append(line)
87
+
88
+ # Join back and strip leading/trailing whitespace
89
  cleaned = '\n'.join(cleaned_lines).strip()
90
+
91
+ # Alternative approach: if there's an "assistant" marker, take everything after it
92
  if "assistant" in text.lower():
93
  parts = text.split("assistant", 1)
94
  if len(parts) > 1:
95
  cleaned = parts[1].strip()
96
+
97
  return cleaned
98
 
 
 
 
 
 
 
 
99
 
100
  @spaces.GPU
101
  def extract_text_from_image(image, temperature=0.2, stream=False):
102
  """Extract text from image using LightOnOCR model."""
103
+ # Prepare the chat format
104
  chat = [
105
  {
106
  "role": "user",
 
109
  ],
110
  }
111
  ]
112
+
113
+ # Apply chat template and tokenize
114
  inputs = processor.apply_chat_template(
115
  chat,
116
  add_generation_prompt=True,
 
118
  return_dict=True,
119
  return_tensors="pt"
120
  )
121
+
122
+ # Move inputs to device AND convert to the correct dtype
123
  inputs = {
124
  k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
125
  else v.to(device) if isinstance(v, torch.Tensor)
126
  else v
127
  for k, v in inputs.items()
128
  }
129
+
130
  generation_kwargs = dict(
131
  **inputs,
132
  max_new_tokens=2048,
 
134
  use_cache=True,
135
  do_sample=temperature > 0,
136
  )
137
+
138
  if stream:
139
+ # Setup streamer for streaming generation
140
  streamer = TextIteratorStreamer(
141
  processor.tokenizer,
142
  skip_prompt=True,
143
  skip_special_tokens=True
144
  )
145
  generation_kwargs["streamer"] = streamer
146
+
147
+ # Run generation in a separate thread
148
  thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
149
  thread.start()
150
+
151
+ # Yield chunks as they arrive
152
  full_text = ""
153
  for new_text in streamer:
154
  full_text += new_text
155
+ # Clean the accumulated text
156
  cleaned_text = clean_output_text(full_text)
157
  yield cleaned_text
158
+
159
  thread.join()
160
  else:
161
  # Non-streaming generation
162
  with torch.no_grad():
163
  outputs = model.generate(**generation_kwargs)
164
+
165
+ # Decode the output
166
  output_text = processor.decode(outputs[0], skip_special_tokens=True)
167
+
168
+ # Clean the output
169
  cleaned_text = clean_output_text(output_text)
170
+
171
+ ######### clinical NER ##############
172
+
173
+ tokenizer = AutoTokenizer.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
174
+ model = AutoModelForTokenClassification.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
175
+ ner = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
176
+
177
+
178
+ Clinical NER process
179
+ entities = ner(cleaned_text)
180
+ medications = []
181
+ for ent in entities:
182
+ if ent["entity_group"] == "treatment":
183
+ word = ent["word"]
184
+ if word.startswith("##") and medications:
185
+ medications[-1] += word[2:]
186
+ else:
187
+ medications.append(word)
188
+ medications_str = ", ".join(set(medications)) if medications else "None detected"
189
+
190
  yield cleaned_text
191
+ yield medications_s
192
+
193
+
194
+
195
 
196
  def process_input(file_input, temperature, page_num, enable_streaming):
197
+ """Process uploaded file (image or PDF) and extract text with optional streaming."""
198
  if file_input is None:
199
  yield "Please upload an image or PDF first.", "", "", None, gr.update()
200
  return
201
+
202
  image_to_process = None
203
  page_info = ""
204
+
205
  file_path = file_input if isinstance(file_input, str) else file_input.name
206
+
207
  # Handle PDF files
208
  if file_path.lower().endswith('.pdf'):
209
  try:
 
220
  except Exception as e:
221
  yield f"Error opening image: {str(e)}", "", "", None, gr.update()
222
  return
223
+
224
  try:
225
+ # Extract text using LightOnOCR with optional streaming
226
+ for extracted_text, medications in extract_text_from_image(image_to_process, temperature, stream=enable_streaming):
227
+ yield extracted_text, medications, page_info, image_to_process, gr.update()
228
+
229
  except Exception as e:
230
  error_msg = f"Error during text extraction: {str(e)}"
231
  yield error_msg, error_msg, page_info, image_to_process, gr.update()
232
 
233
+
234
  def update_slider(file_input):
235
  """Update page slider based on PDF page count."""
236
  if file_input is None:
237
  return gr.update(maximum=20, value=1)
238
+
239
  file_path = file_input if isinstance(file_input, str) else file_input.name
240
+
241
  if file_path.lower().endswith('.pdf'):
242
  try:
243
  pdf = pdfium.PdfDocument(file_path)
 
249
  else:
250
  return gr.update(maximum=1, value=1)
251
 
252
+
253
+ # Create Gradio interface
254
+ with gr.Blocks(title="📖 Image/PDF OCR with LightOnOCR", theme=gr.themes.Soft()) as demo:
255
  gr.Markdown(f"""
256
+ # 📖 Image/PDF to Text Extraction with LightOnOCR
257
 
258
  **💡 How to use:**
259
  1. Upload an image or PDF
260
+ 2. For PDFs: select which page to extract (1-20)
261
  3. Adjust temperature if needed
262
+ 4. Click "Extract Text"
263
 
264
+ **Note:** The Markdown rendering for tables may not always be perfect. Check the raw output for complex tables!
265
 
266
  **Model:** LightOnOCR-1B-1025 by LightOn AI
267
  **Device:** {device.upper()}
268
  **Attention:** {attn_implementation}
269
  """)
270
+
271
  with gr.Row():
272
  with gr.Column(scale=1):
273
  file_input = gr.File(
 
307
  value=True,
308
  info="Show text progressively as it's generated"
309
  )
310
+ submit_btn = gr.Button("Extract Text", variant="primary")
311
  clear_btn = gr.Button("Clear", variant="secondary")
312
+
313
  with gr.Column(scale=2):
314
  output_text = gr.Markdown(
315
+ label="📄 Extracted Text (Rendered)",
316
+ value="*Extracted text will appear here...*"
317
  )
318
+ medications_output = gr.Textbox(
319
+ label="💊 Extracted Medicines/Drugs",
320
+ placeholder="Medicine/drug names will appear here...",
321
+ lines=2,
322
+ max_lines=5,
323
+ interactive=False,
324
+ show_copy_button=True
325
+ )
326
+
327
  with gr.Row():
328
  with gr.Column():
329
  raw_output = gr.Textbox(
330
+ label="Raw Markdown Output",
331
+ placeholder="Raw text will appear here...",
332
  lines=20,
333
  max_lines=30,
334
  show_copy_button=True
335
  )
336
+
337
  # Event handlers
338
  submit_btn.click(
339
+ fn=process_input,
340
+ inputs=[file_input, temperature, num_pages, enable_streaming],
341
+ outputs=[output_text, medications_output, raw_output, page_info, rendered_image, num_pages]
342
+ )
343
+
344
  file_input.change(
345
  fn=update_slider,
346
  inputs=[file_input],
347
  outputs=[num_pages]
348
  )
349
+
350
  clear_btn.click(
351
+ fn=lambda: (None, "*Extracted text will appear here...*", "", "", None, 1),
352
  outputs=[file_input, output_text, raw_output, page_info, rendered_image, num_pages]
353
  )
354
 
355
+
356
  if __name__ == "__main__":
357
  demo.launch()
358
+
359
+
360
+
361
+ #################################### old code to be checked #############################################
362
+
363
+ # import sys
364
+ # import threading
365
+
366
+ # import spaces
367
+ # import torch
368
+
369
+ # import gradio as gr
370
+ # from PIL import Image
371
+ # from io import BytesIO
372
+ # import pypdfium2 as pdfium
373
+ # from transformers import (
374
+ # LightOnOCRForConditionalGeneration,
375
+ # LightOnOCRProcessor,
376
+ # TextIteratorStreamer,
377
+ # )
378
+
379
+ # # ---- CLINICAL NER IMPORTS ----
380
+ # import spacy
381
+
382
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
383
+
384
+ # # Choose best attention implementation based on device
385
+ # if device == "cuda":
386
+ # attn_implementation = "sdpa"
387
+ # dtype = torch.bfloat16
388
+ # print("Using sdpa for GPU")
389
+ # else:
390
+ # attn_implementation = "eager" # Best for CPU
391
+ # dtype = torch.float32
392
+ # print("Using eager attention for CPU")
393
+
394
+ # # Initialize the LightOnOCR model and processor
395
+ # print(f"Loading model on {device} with {attn_implementation} attention...")
396
+ # model = LightOnOCRForConditionalGeneration.from_pretrained(
397
+ # "lightonai/LightOnOCR-1B-1025",
398
+ # attn_implementation=attn_implementation,
399
+ # torch_dtype=dtype,
400
+ # trust_remote_code=True
401
+ # ).to(device).eval()
402
+
403
+ # processor = LightOnOCRProcessor.from_pretrained(
404
+ # "lightonai/LightOnOCR-1B-1025",
405
+ # trust_remote_code=True
406
+ # )
407
+ # print("Model loaded successfully!")
408
+
409
+ # # ---- LOAD CLINICAL NER MODEL (BC5CDR) ----
410
+ # print("Loading clinical NER model (bc5cdr)...")
411
+ # nlp_ner = spacy.load("en_ner_bc5cdr_md")
412
+ # print("Clinical NER loaded.")
413
+
414
+ # def render_pdf_page(page, max_resolution=1540, scale=2.77):
415
+ # """Render a PDF page to PIL Image."""
416
+ # width, height = page.get_size()
417
+ # pixel_width = width * scale
418
+ # pixel_height = height * scale
419
+ # resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height)
420
+ # target_scale = scale * resize_factor
421
+ # return page.render(scale=target_scale, rev_byteorder=True).to_pil()
422
+
423
+
424
+ # def process_pdf(pdf_path, page_num=1):
425
+ # """Extract a specific page from PDF."""
426
+ # pdf = pdfium.PdfDocument(pdf_path)
427
+ # total_pages = len(pdf)
428
+ # page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
429
+
430
+ # page = pdf[page_idx]
431
+ # img = render_pdf_page(page)
432
+
433
+ # pdf.close()
434
+ # return img, total_pages, page_idx + 1
435
+
436
+
437
+ # def clean_output_text(text):
438
+ # """Remove chat template artifacts from output."""
439
+ # markers_to_remove = ["system", "user", "assistant"]
440
+ # lines = text.split('\n')
441
+ # cleaned_lines = []
442
+ # for line in lines:
443
+ # stripped = line.strip()
444
+ # # Skip lines that are just template markers
445
+ # if stripped.lower() not in markers_to_remove:
446
+ # cleaned_lines.append(line)
447
+ # cleaned = '\n'.join(cleaned_lines).strip()
448
+ # if "assistant" in text.lower():
449
+ # parts = text.split("assistant", 1)
450
+ # if len(parts) > 1:
451
+ # cleaned = parts[1].strip()
452
+ # return cleaned
453
+
454
+ # def extract_medication_names(text):
455
+ # """Extract medication names using clinical NER (spacy: bc5cdr CHEMICAL)."""
456
+ # doc = nlp_ner(text)
457
+ # meds = [ent.text for ent in doc.ents if ent.label_ == "CHEMICAL"]
458
+ # meds_unique = list(dict.fromkeys(meds))
459
+ # return meds_unique
460
+
461
+
462
+ # @spaces.GPU
463
+ # def extract_text_from_image(image, temperature=0.2, stream=False):
464
+ # """Extract text from image using LightOnOCR model."""
465
+ # chat = [
466
+ # {
467
+ # "role": "user",
468
+ # "content": [
469
+ # {"type": "image", "url": image},
470
+ # ],
471
+ # }
472
+ # ]
473
+ # inputs = processor.apply_chat_template(
474
+ # chat,
475
+ # add_generation_prompt=True,
476
+ # tokenize=True,
477
+ # return_dict=True,
478
+ # return_tensors="pt"
479
+ # )
480
+ # inputs = {
481
+ # k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
482
+ # else v.to(device) if isinstance(v, torch.Tensor)
483
+ # else v
484
+ # for k, v in inputs.items()
485
+ # }
486
+ # generation_kwargs = dict(
487
+ # **inputs,
488
+ # max_new_tokens=2048,
489
+ # temperature=temperature if temperature > 0 else 0.0,
490
+ # use_cache=True,
491
+ # do_sample=temperature > 0,
492
+ # )
493
+ # if stream:
494
+ # # Streaming generation
495
+ # streamer = TextIteratorStreamer(
496
+ # processor.tokenizer,
497
+ # skip_prompt=True,
498
+ # skip_special_tokens=True
499
+ # )
500
+ # generation_kwargs["streamer"] = streamer
501
+ # thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
502
+ # thread.start()
503
+ # full_text = ""
504
+ # for new_text in streamer:
505
+ # full_text += new_text
506
+ # cleaned_text = clean_output_text(full_text)
507
+ # yield cleaned_text
508
+ # thread.join()
509
+ # else:
510
+ # # Non-streaming generation
511
+ # with torch.no_grad():
512
+ # outputs = model.generate(**generation_kwargs)
513
+ # output_text = processor.decode(outputs[0], skip_special_tokens=True)
514
+ # cleaned_text = clean_output_text(output_text)
515
+ # yield cleaned_text
516
+
517
+ # def process_input(file_input, temperature, page_num, enable_streaming):
518
+ # """Process uploaded file (image or PDF) and extract medication names via OCR+NER."""
519
+ # if file_input is None:
520
+ # yield "Please upload an image or PDF first.", "", "", None, gr.update()
521
+ # return
522
+ # image_to_process = None
523
+ # page_info = ""
524
+ # file_path = file_input if isinstance(file_input, str) else file_input.name
525
+ # # Handle PDF files
526
+ # if file_path.lower().endswith('.pdf'):
527
+ # try:
528
+ # image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
529
+ # page_info = f"Processing page {actual_page} of {total_pages}"
530
+ # except Exception as e:
531
+ # yield f"Error processing PDF: {str(e)}", "", "", None, gr.update()
532
+ # return
533
+ # # Handle image files
534
+ # else:
535
+ # try:
536
+ # image_to_process = Image.open(file_path)
537
+ # page_info = "Processing image"
538
+ # except Exception as e:
539
+ # yield f"Error opening image: {str(e)}", "", "", None, gr.update()
540
+ # return
541
+ # try:
542
+ # for extracted_text in extract_text_from_image(image_to_process, temperature, stream=enable_streaming):
543
+ # meds = extract_medication_names(extracted_text)
544
+ # meds_str = "\n".join(meds) if meds else "No medications found."
545
+ # yield meds_str, meds_str, page_info, image_to_process, gr.update()
546
+ # except Exception as e:
547
+ # error_msg = f"Error during text extraction: {str(e)}"
548
+ # yield error_msg, error_msg, page_info, image_to_process, gr.update()
549
+
550
+ # def update_slider(file_input):
551
+ # """Update page slider based on PDF page count."""
552
+ # if file_input is None:
553
+ # return gr.update(maximum=20, value=1)
554
+ # file_path = file_input if isinstance(file_input, str) else file_input.name
555
+ # if file_path.lower().endswith('.pdf'):
556
+ # try:
557
+ # pdf = pdfium.PdfDocument(file_path)
558
+ # total_pages = len(pdf)
559
+ # pdf.close()
560
+ # return gr.update(maximum=total_pages, value=1)
561
+ # except:
562
+ # return gr.update(maximum=20, value=1)
563
+ # else:
564
+ # return gr.update(maximum=1, value=1)
565
+
566
+ # # ----- GRADIO UI -----
567
+ # with gr.Blocks(title="📖 Image/PDF OCR + Clinical NER", theme=gr.themes.Soft()) as demo:
568
+ # gr.Markdown(f"""
569
+ # # 📖 Medication Extraction from Image/PDF with LightOnOCR + Clinical NER
570
+
571
+ # **💡 How to use:**
572
+ # 1. Upload an image or PDF
573
+ # 2. For PDFs: select which page to extract
574
+ # 3. Adjust temperature if needed
575
+ # 4. Click "Extract Medications"
576
+
577
+ # **Output:** Only medication names found in text (via NER)
578
+
579
+ # **Model:** LightOnOCR-1B-1025 by LightOn AI
580
+ # **Device:** {device.upper()}
581
+ # **Attention:** {attn_implementation}
582
+ # """)
583
+ # with gr.Row():
584
+ # with gr.Column(scale=1):
585
+ # file_input = gr.File(
586
+ # label="🖼️ Upload Image or PDF",
587
+ # file_types=[".pdf", ".png", ".jpg", ".jpeg"],
588
+ # type="filepath"
589
+ # )
590
+ # rendered_image = gr.Image(
591
+ # label="📄 Preview",
592
+ # type="pil",
593
+ # height=400,
594
+ # interactive=False
595
+ # )
596
+ # num_pages = gr.Slider(
597
+ # minimum=1,
598
+ # maximum=20,
599
+ # value=1,
600
+ # step=1,
601
+ # label="PDF: Page Number",
602
+ # info="Select which page to extract"
603
+ # )
604
+ # page_info = gr.Textbox(
605
+ # label="Processing Info",
606
+ # value="",
607
+ # interactive=False
608
+ # )
609
+ # temperature = gr.Slider(
610
+ # minimum=0.0,
611
+ # maximum=1.0,
612
+ # value=0.2,
613
+ # step=0.05,
614
+ # label="Temperature",
615
+ # info="0.0 = deterministic, Higher = more varied"
616
+ # )
617
+ # enable_streaming = gr.Checkbox(
618
+ # label="Enable Streaming",
619
+ # value=True,
620
+ # info="Show text progressively as it's generated"
621
+ # )
622
+ # submit_btn = gr.Button("Extract Medications", variant="primary")
623
+ # clear_btn = gr.Button("Clear", variant="secondary")
624
+ # with gr.Column(scale=2):
625
+ # output_text = gr.Markdown(
626
+ # label="🩺 Extracted Medication Names",
627
+ # value="*Medication names will appear here...*"
628
+ # )
629
+ # with gr.Row():
630
+ # with gr.Column():
631
+ # raw_output = gr.Textbox(
632
+ # label="Extracted Medication Names (Raw)",
633
+ # placeholder="Medication list will appear here...",
634
+ # lines=20,
635
+ # max_lines=30,
636
+ # show_copy_button=True
637
+ # )
638
+ # # Event handlers
639
+ # submit_btn.click(
640
+ # fn=process_input,
641
+ # inputs=[file_input, temperature, num_pages, enable_streaming],
642
+ # outputs=[output_text, raw_output, page_info, rendered_image, num_pages]
643
+ # )
644
+ # file_input.change(
645
+ # fn=update_slider,
646
+ # inputs=[file_input],
647
+ # outputs=[num_pages]
648
+ # )
649
+ # clear_btn.click(
650
+ # fn=lambda: (None, "*Medication names will appear here...*", "", "", None, 1),
651
+ # outputs=[file_input, output_text, raw_output, page_info, rendered_image, num_pages]
652
+ # )
653
+
654
+ # if __name__ == "__main__":
655
+ # demo.launch()