prithivMLmods commited on
Commit
639d6a2
·
verified ·
1 Parent(s): ed65219

update app

Browse files
Files changed (1) hide show
  1. app.py +53 -213
app.py CHANGED
@@ -12,7 +12,7 @@ import spaces
12
  import torch
13
  import numpy as np
14
  from PIL import Image
15
- import cv2
16
 
17
  from transformers import (
18
  Qwen2VLForConditionalGeneration,
@@ -27,35 +27,29 @@ from gradio.themes.utils import colors, fonts, sizes
27
 
28
  # --- Theme and CSS Definition ---
29
 
30
- # Define the Thistle color palette
31
- colors.thistle = colors.Color(
32
- name="thistle",
33
- c50="#F9F5F9",
34
- c100="#F0E8F1",
35
- c200="#E7DBE8",
36
- c300="#DECEE0",
37
- c400="#D2BFD8",
38
- c500="#D8BFD8", # Thistle base color
39
- c600="#B59CB7",
40
- c700="#927996",
41
- c800="#6F5675",
42
- c900="#4C3454",
43
- c950="#291233",
44
  )
45
 
46
- colors.red_gray = colors.Color(
47
- name="red_gray",
48
- c50="#f7eded", c100="#f5dcdc", c200="#efb4b4", c300="#e78f8f",
49
- c400="#d96a6a", c500="#c65353", c600="#b24444", c700="#8f3434",
50
- c800="#732d2d", c900="#5f2626", c950="#4d2020",
51
- )
52
 
53
- class ThistleTheme(Soft):
54
  def __init__(
55
  self,
56
  *,
57
  primary_hue: colors.Color | str = colors.gray,
58
- secondary_hue: colors.Color | str = colors.thistle, # Use the new color
59
  neutral_hue: colors.Color | str = colors.slate,
60
  text_size: sizes.Size | str = sizes.text_lg,
61
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
@@ -84,12 +78,6 @@ class ThistleTheme(Soft):
84
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_500, *secondary_600)",
85
  button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
86
  button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
87
- button_secondary_text_color="black",
88
- button_secondary_text_color_hover="white",
89
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
90
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
91
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
92
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
93
  slider_color="*secondary_400",
94
  slider_color_dark="*secondary_600",
95
  block_title_text_weight="600",
@@ -102,7 +90,8 @@ class ThistleTheme(Soft):
102
  )
103
 
104
  # Instantiate the new theme
105
- thistle_theme = ThistleTheme()
 
106
 
107
  css = """
108
  #main-title h1 {
@@ -111,56 +100,12 @@ css = """
111
  #output-title h2 {
112
  font-size: 2.1em !important;
113
  }
114
- :root {
115
- --color-grey-50: #f9fafb;
116
- --banner-background: var(--secondary-400);
117
- --banner-text-color: var(--primary-100);
118
- --banner-background-dark: var(--secondary-800);
119
- --banner-text-color-dark: var(--primary-100);
120
- --banner-chrome-height: calc(16px + 43px);
121
- --chat-chrome-height-wide-no-banner: 320px;
122
- --chat-chrome-height-narrow-no-banner: 450px;
123
- --chat-chrome-height-wide: calc(var(--chat-chrome-height-wide-no-banner) + var(--banner-chrome-height));
124
- --chat-chrome-height-narrow: calc(var(--chat-chrome-height-narrow-no-banner) + var(--banner-chrome-height));
125
- }
126
- .banner-message { background-color: var(--banner-background); padding: 5px; margin: 0; border-radius: 5px; border: none; }
127
- .banner-message-text { font-size: 13px; font-weight: bolder; color: var(--banner-text-color) !important; }
128
- body.dark .banner-message { background-color: var(--banner-background-dark) !important; }
129
- body.dark .gradio-container .contain .banner-message .banner-message-text { color: var(--banner-text-color-dark) !important; }
130
- .toast-body { background-color: var(--color-grey-50); }
131
- .html-container:has(.css-styles) { padding: 0; margin: 0; }
132
- .css-styles { height: 0; }
133
- .model-message { text-align: end; }
134
- .model-dropdown-container { display: flex; align-items: center; gap: 10px; padding: 0; }
135
- .user-input-container .multimodal-textbox{ border: none !important; }
136
- .control-button { height: 51px; }
137
- button.cancel { border: var(--button-border-width) solid var(--button-cancel-border-color); background: var(--button-cancel-background-fill); color: var(--button-cancel-text-color); box-shadow: var(--button-cancel-shadow); }
138
- button.cancel:hover, .cancel[disabled] { background: var(--button-cancel-background-fill-hover); color: var(--button-cancel-text-color-hover); }
139
- .opt-out-message { top: 8px; }
140
- .opt-out-message .html-container, .opt-out-checkbox label { font-size: 14px !important; padding: 0 !important; margin: 0 !important; color: var(--neutral-400) !important; }
141
- div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-wide)) !important; max-height: 900px !important; }
142
- div.no-padding { padding: 0 !important; }
143
- @media (max-width: 1280px) { div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-wide)) !important; } }
144
- @media (max-width: 1024px) {
145
- .responsive-row { flex-direction: column; }
146
- .model-message { text-align: start; font-size: 10px !important; }
147
- .model-dropdown-container { flex-direction: column; align-items: flex-start; }
148
- div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-narrow)) !important; }
149
- }
150
- @media (max-width: 400px) {
151
- .responsive-row { flex-direction: column; }
152
- .model-message { text-align: start; font-size: 10px !important; }
153
- .model-dropdown-container { flex-direction: column; align-items: flex-start; }
154
- div.block.chatbot { max-height: 360px !important; }
155
- }
156
- @media (max-height: 932px) { .chatbot { max-height: 500px !important; } }
157
- @media (max-height: 1280px) { div.block.chatbot { max-height: 800px !important; } }
158
  """
159
 
160
  # Constants for text generation
161
  MAX_MAX_NEW_TOKENS = 2048
162
  DEFAULT_MAX_NEW_TOKENS = 1024
163
- # Increased max_length to accommodate more complex inputs, especially with multiple images
164
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
165
 
166
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -208,7 +153,7 @@ model_a = AutoModelForImageTextToText.from_pretrained(
208
  MODEL_ID_W = "allenai/olmOCR-7B-0725"
209
  processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
210
  model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
211
- MODEL_ID_W,
212
  trust_remote_code=True,
213
  torch_dtype=torch.float16
214
  ).to(device).eval()
@@ -222,27 +167,6 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
222
  torch_dtype=torch.float16
223
  ).to(device).eval()
224
 
225
- def downsample_video(video_path):
226
- """
227
- Downsamples the video to evenly spaced frames.
228
- Each frame is returned as a PIL image along with its timestamp.
229
- """
230
- vidcap = cv2.VideoCapture(video_path)
231
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
232
- fps = vidcap.get(cv2.CAP_PROP_FPS)
233
- frames = []
234
- # Use a maximum of 10 frames to avoid excessive memory usage
235
- frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
236
- for i in frame_indices:
237
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
238
- success, image = vidcap.read()
239
- if success:
240
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
241
- pil_image = Image.fromarray(image)
242
- timestamp = round(i / fps, 2)
243
- frames.append((pil_image, timestamp))
244
- vidcap.release()
245
- return frames
246
 
247
  @spaces.GPU
248
  def generate_image(model_name: str, text: str, image: Image.Image,
@@ -286,8 +210,9 @@ def generate_image(model_name: str, text: str, image: Image.Image,
286
  ]
287
  }]
288
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
289
-
290
- # FIX: Set truncation to False to avoid the ValueError
 
291
  inputs = processor(
292
  text=[prompt_full],
293
  images=[image],
@@ -306,138 +231,53 @@ def generate_image(model_name: str, text: str, image: Image.Image,
306
  time.sleep(0.01)
307
  yield buffer, buffer
308
 
309
- @spaces.GPU
310
- def generate_video(model_name: str, text: str, video_path: str,
311
- max_new_tokens: int = 1024,
312
- temperature: float = 0.6,
313
- top_p: float = 0.9,
314
- top_k: int = 50,
315
- repetition_penalty: float = 1.2):
316
- """
317
- Generates responses using the selected model for video input.
318
- Yields raw text and Markdown-formatted text.
319
- """
320
- if model_name == "RolmOCR-7B":
321
- processor = processor_m
322
- model = model_m
323
- elif model_name == "Qwen2-VL-OCR-2B":
324
- processor = processor_x
325
- model = model_x
326
- elif model_name == "Nanonets-OCR2-3B":
327
- processor = processor_v
328
- model = model_v
329
- elif model_name == "Aya-Vision-8B":
330
- processor = processor_a
331
- model = model_a
332
- elif model_name == "olmOCR-7B-0725":
333
- processor = processor_w
334
- model = model_w
335
- else:
336
- yield "Invalid model selected.", "Invalid model selected."
337
- return
338
-
339
- if video_path is None:
340
- yield "Please upload a video.", "Please upload a video."
341
- return
342
-
343
- frames_with_ts = downsample_video(video_path)
344
- images_for_processor = [frame for frame, ts in frames_with_ts]
345
-
346
- messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
347
- for frame in images_for_processor:
348
- messages[0]["content"].insert(0, {"type": "image"})
349
-
350
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
351
-
352
- inputs = processor(
353
- text=[prompt_full],
354
- images=images_for_processor,
355
- return_tensors="pt",
356
- padding=True
357
- ).to(device)
358
-
359
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
360
- generation_kwargs = {
361
- **inputs,
362
- "streamer": streamer,
363
- "max_new_tokens": max_new_tokens,
364
- "do_sample": True,
365
- "temperature": temperature,
366
- "top_p": top_p,
367
- "top_k": top_k,
368
- "repetition_penalty": repetition_penalty,
369
- }
370
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
371
- thread.start()
372
- buffer = ""
373
- for new_text in streamer:
374
- buffer += new_text
375
- buffer = buffer.replace("<|im_end|>", "")
376
- time.sleep(0.01)
377
- yield buffer, buffer
378
 
379
- # Define examples for image and video inference
380
  image_examples = [
381
- ["Extract the full page.", "images/ocr.png"],
382
- ["Extract the content.", "images/4.png"],
383
  ["Convert this page to doc [table] precisely for markdown.", "images/0.png"]
384
  ]
385
 
386
- video_examples = [
387
- ["Explain the Ad in Detail.", "videos/1.mp4"],
388
- ]
389
 
390
  # Create the Gradio Interface
391
- with gr.Blocks(css=css, theme=thistle_theme) as demo:
392
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
393
  with gr.Row():
394
  with gr.Column(scale=2):
395
- with gr.Tabs():
396
- with gr.TabItem("Image Inference"):
397
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
398
- image_upload = gr.Image(type="pil", label="Upload Image", height=290)
399
- image_submit = gr.Button("Submit", variant="primary")
400
- gr.Examples(
401
- examples=image_examples,
402
- inputs=[image_query, image_upload]
403
- )
404
- with gr.TabItem("Video Inference"):
405
- video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
406
- video_upload = gr.Video(label="Upload Video", height=290)
407
- video_submit = gr.Button("Submit", variant="primary")
408
- gr.Examples(
409
- examples=video_examples,
410
- inputs=[video_query, video_upload]
411
- )
412
- gr.Markdown("> Only the olmOCR and RolmOCR models currently support video inference (max video length: 30 secs).")
413
  with gr.Accordion("Advanced options", open=False):
414
- max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
 
415
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
416
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
417
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
418
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
419
-
 
420
  with gr.Column(scale=3):
421
- gr.Markdown("## Output", elem_id="output-title")
422
- output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
423
- with gr.Accordion("(Result.md)", open=False):
424
- markdown_output = gr.Markdown(label="(Result.Md)")
425
-
426
- model_choice = gr.Radio(
427
- choices=["Nanonets-OCR2-3B", "olmOCR-7B-0725", "RolmOCR-7B",
428
  "Aya-Vision-8B", "Qwen2-VL-OCR-2B"],
429
- label="Select Model",
430
- value="Nanonets-OCR2-3B"
431
- )
432
-
433
  image_submit.click(
434
  fn=generate_image,
435
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
436
- outputs=[output, markdown_output]
437
- )
438
- video_submit.click(
439
- fn=generate_video,
440
- inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
441
  outputs=[output, markdown_output]
442
  )
443
 
 
12
  import torch
13
  import numpy as np
14
  from PIL import Image
15
+ # cv2 is no longer needed as video processing is removed
16
 
17
  from transformers import (
18
  Qwen2VLForConditionalGeneration,
 
27
 
28
  # --- Theme and CSS Definition ---
29
 
30
+ # Define the new SpringGreen color palette
31
+ colors.spring_green = colors.Color(
32
+ name="spring_green",
33
+ c50="#E5FFF2",
34
+ c100="#CCFFEC",
35
+ c200="#99FFD9",
36
+ c300="#66FFC6",
37
+ c400="#33FFB3",
38
+ c500="#00FF7F", # SpringGreen base color
39
+ c600="#00E672",
40
+ c700="#00CC66",
41
+ c800="#00B359",
42
+ c900="#00994D",
43
+ c950="#008040",
44
  )
45
 
 
 
 
 
 
 
46
 
47
+ class SpringGreenTheme(Soft):
48
  def __init__(
49
  self,
50
  *,
51
  primary_hue: colors.Color | str = colors.gray,
52
+ secondary_hue: colors.Color | str = colors.spring_green, # Use the new color
53
  neutral_hue: colors.Color | str = colors.slate,
54
  text_size: sizes.Size | str = sizes.text_lg,
55
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
78
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_500, *secondary_600)",
79
  button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
80
  button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
 
 
 
 
 
 
81
  slider_color="*secondary_400",
82
  slider_color_dark="*secondary_600",
83
  block_title_text_weight="600",
 
90
  )
91
 
92
  # Instantiate the new theme
93
+ spring_green_theme = SpringGreenTheme()
94
+
95
 
96
  css = """
97
  #main-title h1 {
 
100
  #output-title h2 {
101
  font-size: 2.1em !important;
102
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  """
104
 
105
  # Constants for text generation
106
  MAX_MAX_NEW_TOKENS = 2048
107
  DEFAULT_MAX_NEW_TOKENS = 1024
108
+ # Increased max_length to accommodate more complex inputs
109
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
110
 
111
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
153
  MODEL_ID_W = "allenai/olmOCR-7B-0725"
154
  processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
155
  model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
156
+ MODEL_ID_W,
157
  trust_remote_code=True,
158
  torch_dtype=torch.float16
159
  ).to(device).eval()
 
167
  torch_dtype=torch.float16
168
  ).to(device).eval()
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  @spaces.GPU
172
  def generate_image(model_name: str, text: str, image: Image.Image,
 
210
  ]
211
  }]
212
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
213
+
214
+ # FIX: Set truncation to False and rely on the model's context length.
215
+ # The increased MAX_INPUT_TOKEN_LENGTH at the top also helps.
216
  inputs = processor(
217
  text=[prompt_full],
218
  images=[image],
 
231
  time.sleep(0.01)
232
  yield buffer, buffer
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ # Define examples for image inference
236
  image_examples = [
237
+ ["Extract the full page.", "images/ocr.png"],
238
+ ["Extract the content.", "images/4.png"],
239
  ["Convert this page to doc [table] precisely for markdown.", "images/0.png"]
240
  ]
241
 
 
 
 
242
 
243
  # Create the Gradio Interface
244
+ with gr.Blocks(css=css, theme=spring_green_theme) as demo:
245
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
246
  with gr.Row():
247
  with gr.Column(scale=2):
248
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
249
+ image_upload = gr.Image(type="pil", label="Upload Image", height=290)
250
+ image_submit = gr.Button("Submit", variant="primary")
251
+ gr.Examples(
252
+ examples=image_examples,
253
+ inputs=[image_query, image_upload]
254
+ )
 
 
 
 
 
 
 
 
 
 
 
255
  with gr.Accordion("Advanced options", open=False):
256
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1,
257
+ value=DEFAULT_MAX_NEW_TOKENS)
258
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
259
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
260
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
261
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05,
262
+ value=1.2)
263
+
264
  with gr.Column(scale=3):
265
+ gr.Markdown("## Output", elem_id="output-title")
266
+ output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
267
+ with gr.Accordion("(Result.md)", open=False):
268
+ markdown_output = gr.Markdown(label="(Result.Md)")
269
+
270
+ model_choice = gr.Radio(
271
+ choices=["Nanonets-OCR2-3B", "olmOCR-7B-0725", "RolmOCR-7B",
272
  "Aya-Vision-8B", "Qwen2-VL-OCR-2B"],
273
+ label="Select Model",
274
+ value="Nanonets-OCR2-3B"
275
+ )
276
+
277
  image_submit.click(
278
  fn=generate_image,
279
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k,
280
+ repetition_penalty],
 
 
 
 
281
  outputs=[output, markdown_output]
282
  )
283