IFMedTechdemo commited on
Commit
c2a331b
·
verified ·
1 Parent(s): 068d019

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -62
app.py CHANGED
@@ -21,30 +21,40 @@ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipelin
21
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
- # Choose best attention implementation based on device
25
  if device == "cuda":
26
- attn_implementation = "sdpa"
27
  dtype = torch.bfloat16
28
  print("Using sdpa for GPU")
29
  else:
30
- attn_implementation = "eager" # Best for CPU
31
  dtype = torch.float32
32
  print("Using eager attention for CPU")
33
 
34
- # Initialize the LightOnOCR model and processor
35
- print(f"Loading model on {device} with {attn_implementation} attention...")
36
- model = LightOnOCRForConditionalGeneration.from_pretrained(
37
  "lightonai/LightOnOCR-1B-1025",
38
  attn_implementation=attn_implementation,
39
  torch_dtype=dtype,
40
- trust_remote_code=True
41
  ).to(device).eval()
42
 
43
  processor = LightOnOCRProcessor.from_pretrained(
44
  "lightonai/LightOnOCR-1B-1025",
45
- trust_remote_code=True
46
  )
47
- print("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  def render_pdf_page(page, max_resolution=1540, scale=2.77):
@@ -99,34 +109,38 @@ def clean_output_text(text):
99
 
100
  @spaces.GPU
101
  def extract_text_from_image(image, temperature=0.2, stream=False):
102
- """Extract text from image using LightOnOCR model."""
103
  # Prepare the chat format
104
  chat = [
105
  {
106
  "role": "user",
107
  "content": [
108
- {"type": "image", "url": image},
109
  ],
110
  }
111
  ]
112
-
113
- # Apply chat template and tokenize
114
  inputs = processor.apply_chat_template(
115
  chat,
116
  add_generation_prompt=True,
117
  tokenize=True,
118
  return_dict=True,
119
- return_tensors="pt"
120
  )
121
-
122
- # Move inputs to device AND convert to the correct dtype
123
  inputs = {
124
- k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
125
- else v.to(device) if isinstance(v, torch.Tensor)
126
- else v
 
 
 
 
127
  for k, v in inputs.items()
128
  }
129
-
130
  generation_kwargs = dict(
131
  **inputs,
132
  max_new_tokens=2048,
@@ -134,49 +148,39 @@ def extract_text_from_image(image, temperature=0.2, stream=False):
134
  use_cache=True,
135
  do_sample=temperature > 0,
136
  )
137
-
138
  if stream:
139
- # Setup streamer for streaming generation
140
  streamer = TextIteratorStreamer(
141
  processor.tokenizer,
142
  skip_prompt=True,
143
- skip_special_tokens=True
144
  )
145
  generation_kwargs["streamer"] = streamer
146
-
147
- # Run generation in a separate thread
148
- thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
149
  thread.start()
150
-
151
- # Yield chunks as they arrive
152
  full_text = ""
153
  for new_text in streamer:
154
  full_text += new_text
155
- # Clean the accumulated text
156
  cleaned_text = clean_output_text(full_text)
157
- yield cleaned_text
158
-
 
 
 
159
  thread.join()
160
  else:
161
  # Non-streaming generation
162
  with torch.no_grad():
163
- outputs = model.generate(**generation_kwargs)
164
-
165
- # Decode the output
166
  output_text = processor.decode(outputs[0], skip_special_tokens=True)
167
-
168
- # Clean the output
169
  cleaned_text = clean_output_text(output_text)
170
 
171
- ######### clinical NER ##############
172
-
173
- tokenizer = AutoTokenizer.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
174
- model = AutoModelForTokenClassification.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
175
- ner = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
176
-
177
-
178
- #Clinical NER process
179
- entities = ner(cleaned_text)
180
  medications = []
181
  for ent in entities:
182
  if ent["entity_group"] == "treatment":
@@ -185,50 +189,62 @@ def extract_text_from_image(image, temperature=0.2, stream=False):
185
  medications[-1] += word[2:]
186
  else:
187
  medications.append(word)
 
188
  medications_str = ", ".join(set(medications)) if medications else "None detected"
189
-
190
- yield cleaned_text
191
- yield medications_s
192
 
193
 
194
 
195
-
196
  def process_input(file_input, temperature, page_num, enable_streaming):
197
  """Process uploaded file (image or PDF) and extract text with optional streaming."""
198
  if file_input is None:
199
- yield "Please upload an image or PDF first.", "", "", None, gr.update()
 
200
  return
201
-
202
  image_to_process = None
203
  page_info = ""
204
-
 
205
  file_path = file_input if isinstance(file_input, str) else file_input.name
206
-
207
  # Handle PDF files
208
- if file_path.lower().endswith('.pdf'):
209
  try:
210
  image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
211
  page_info = f"Processing page {actual_page} of {total_pages}"
 
212
  except Exception as e:
213
- yield f"Error processing PDF: {str(e)}", "", "", None, gr.update()
 
214
  return
215
- # Handle image files
216
  else:
 
217
  try:
218
  image_to_process = Image.open(file_path)
219
  page_info = "Processing image"
220
  except Exception as e:
221
- yield f"Error opening image: {str(e)}", "", "", None, gr.update()
 
222
  return
223
-
224
  try:
225
  # Extract text using LightOnOCR with optional streaming
226
- for extracted_text, medications in extract_text_from_image(image_to_process, temperature, stream=enable_streaming):
227
- yield extracted_text, medications, page_info, image_to_process, gr.update()
228
-
 
 
 
 
 
 
229
  except Exception as e:
230
  error_msg = f"Error during text extraction: {str(e)}"
231
- yield error_msg, error_msg, page_info, image_to_process, gr.update()
 
232
 
233
 
234
  def update_slider(file_input):
 
21
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
 
24
  if device == "cuda":
25
+ attn_implementation = "sdpa"
26
  dtype = torch.bfloat16
27
  print("Using sdpa for GPU")
28
  else:
29
+ attn_implementation = "eager"
30
  dtype = torch.float32
31
  print("Using eager attention for CPU")
32
 
33
+ print(f"Loading LightOnOCR model on {device} with {attn_implementation} attention...")
34
+ ocr_model = LightOnOCRForConditionalGeneration.from_pretrained(
 
35
  "lightonai/LightOnOCR-1B-1025",
36
  attn_implementation=attn_implementation,
37
  torch_dtype=dtype,
38
+ trust_remote_code=True,
39
  ).to(device).eval()
40
 
41
  processor = LightOnOCRProcessor.from_pretrained(
42
  "lightonai/LightOnOCR-1B-1025",
43
+ trust_remote_code=True,
44
  )
45
+ print("LightOnOCR model loaded successfully!")
46
+
47
+ # -------- Clinical NER models (load ONCE) --------
48
+ print("Loading clinical NER model...")
49
+ ner_tokenizer = AutoTokenizer.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
50
+ ner_model = AutoModelForTokenClassification.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
51
+ ner_pipeline = pipeline(
52
+ "ner",
53
+ model=ner_model,
54
+ tokenizer=ner_tokenizer,
55
+ aggregation_strategy="simple",
56
+ )
57
+ print("Clinical NER model loaded successfully!")
58
 
59
 
60
  def render_pdf_page(page, max_resolution=1540, scale=2.77):
 
109
 
110
  @spaces.GPU
111
  def extract_text_from_image(image, temperature=0.2, stream=False):
112
+ """Extract text from image using LightOnOCR model, and run clinical NER."""
113
  # Prepare the chat format
114
  chat = [
115
  {
116
  "role": "user",
117
  "content": [
118
+ {"type": "image", "url": image}, # adjust to {"type": "image", "image": image} if LightOnOCR expects that
119
  ],
120
  }
121
  ]
122
+
123
+ # Tokenize
124
  inputs = processor.apply_chat_template(
125
  chat,
126
  add_generation_prompt=True,
127
  tokenize=True,
128
  return_dict=True,
129
+ return_tensors="pt",
130
  )
131
+
132
+ # Move inputs to device
133
  inputs = {
134
+ k: (
135
+ v.to(device=device, dtype=dtype)
136
+ if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
137
+ else v.to(device)
138
+ if isinstance(v, torch.Tensor)
139
+ else v
140
+ )
141
  for k, v in inputs.items()
142
  }
143
+
144
  generation_kwargs = dict(
145
  **inputs,
146
  max_new_tokens=2048,
 
148
  use_cache=True,
149
  do_sample=temperature > 0,
150
  )
151
+
152
  if stream:
153
+ # Streaming generation
154
  streamer = TextIteratorStreamer(
155
  processor.tokenizer,
156
  skip_prompt=True,
157
+ skip_special_tokens=True,
158
  )
159
  generation_kwargs["streamer"] = streamer
160
+
161
+ thread = threading.Thread(target=ocr_model.generate, kwargs=generation_kwargs)
 
162
  thread.start()
163
+
 
164
  full_text = ""
165
  for new_text in streamer:
166
  full_text += new_text
 
167
  cleaned_text = clean_output_text(full_text)
168
+
169
+ # For streaming, we’ll only show text progressively,
170
+ # and keep medications empty (or compute at the end if you prefer).
171
+ yield cleaned_text, ""
172
+
173
  thread.join()
174
  else:
175
  # Non-streaming generation
176
  with torch.no_grad():
177
+ outputs = ocr_model.generate(**generation_kwargs)
178
+
 
179
  output_text = processor.decode(outputs[0], skip_special_tokens=True)
 
 
180
  cleaned_text = clean_output_text(output_text)
181
 
182
+ # Clinical NER on the full cleaned text
183
+ entities = ner_pipeline(cleaned_text)
 
 
 
 
 
 
 
184
  medications = []
185
  for ent in entities:
186
  if ent["entity_group"] == "treatment":
 
189
  medications[-1] += word[2:]
190
  else:
191
  medications.append(word)
192
+
193
  medications_str = ", ".join(set(medications)) if medications else "None detected"
194
+
195
+ yield cleaned_text, medications_str
196
+
197
 
198
 
199
 
 
200
  def process_input(file_input, temperature, page_num, enable_streaming):
201
  """Process uploaded file (image or PDF) and extract text with optional streaming."""
202
  if file_input is None:
203
+ # 6 outputs: [output_text, medications_output, raw_output, page_info, rendered_image, num_pages]
204
+ yield "Please upload an image or PDF first.", "", "", "", None, 1
205
  return
206
+
207
  image_to_process = None
208
  page_info = ""
209
+ slider_value = page_num
210
+
211
  file_path = file_input if isinstance(file_input, str) else file_input.name
212
+
213
  # Handle PDF files
214
+ if file_path.lower().endswith(".pdf"):
215
  try:
216
  image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
217
  page_info = f"Processing page {actual_page} of {total_pages}"
218
+ slider_value = actual_page
219
  except Exception as e:
220
+ msg = f"Error processing PDF: {str(e)}"
221
+ yield msg, "", msg, "", None, slider_value
222
  return
 
223
  else:
224
+ # Handle image files
225
  try:
226
  image_to_process = Image.open(file_path)
227
  page_info = "Processing image"
228
  except Exception as e:
229
+ msg = f"Error opening image: {str(e)}"
230
+ yield msg, "", msg, "", None, slider_value
231
  return
232
+
233
  try:
234
  # Extract text using LightOnOCR with optional streaming
235
+ for extracted_text, medications in extract_text_from_image(
236
+ image_to_process, temperature, stream=enable_streaming
237
+ ):
238
+ raw_md = extracted_text # or you can keep a different raw version
239
+ # 6 outputs: markdown_text, medications, raw_output, page_info, image, slider
240
+ yield extracted_text, medications, raw_md, page_info, image_to_process, gr.update(
241
+ value=slider_value
242
+ )
243
+
244
  except Exception as e:
245
  error_msg = f"Error during text extraction: {str(e)}"
246
+ # 6 outputs
247
+ yield error_msg, "", error_msg, page_info, image_to_process, gr.update(value=slider_value)
248
 
249
 
250
  def update_slider(file_input):