prithivMLmods commited on
Commit
e98b5fc
·
verified ·
1 Parent(s): 22e74e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -101
app.py CHANGED
@@ -186,7 +186,7 @@ def get_uitars_prompt(task, image):
186
 
187
  # --- Holo2 Prompt ---
188
  def get_holo2_prompt(task, image):
189
- # Holo2 typically uses standard chat formatting
190
  return [
191
  {
192
  "role": "user",
@@ -211,7 +211,7 @@ def get_image_proc_params(processor) -> Dict[str, int]:
211
  # -----------------------------------------------------------------------------
212
 
213
  def parse_uitars_response(text: str) -> List[Dict]:
214
- """Parse UI-TARS specific output formats"""
215
  actions = []
216
  text = text.strip()
217
 
@@ -221,13 +221,10 @@ def parse_uitars_response(text: str) -> List[Dict]:
221
  m = re.findall(r"point=\[\s*(\d+)\s*,\s*(\d+)\s*\]", text, re.IGNORECASE)
222
  for p in m: actions.append({"type": "click", "x": int(p[0]), "y": int(p[1]), "text": ""})
223
 
224
- m = re.search(r"start_box=['\"]?\(\s*(\d+)\s*,\s*(\d+)\s*\)['\"]?", text, re.IGNORECASE)
225
- if m: actions.append({"type": "click", "x": int(m[0]), "y": int(m[1]), "text": ""})
226
-
227
  return actions
228
 
229
  def parse_fara_response(response: str) -> List[Dict]:
230
- """Parse Fara <tool_call> JSON format"""
231
  actions = []
232
  matches = re.findall(r"<tool_call>(.*?)</tool_call>", response, re.DOTALL)
233
  for match in matches:
@@ -245,70 +242,56 @@ def parse_fara_response(response: str) -> List[Dict]:
245
  return actions
246
 
247
  def parse_holo2_response(generated_ids, processor, input_len) -> Tuple[str, str, List[Dict]]:
248
- """Parse Holo2 reasoning tokens and JSON content"""
249
  all_ids = generated_ids[0].tolist()
250
 
251
- # Token IDs for <|thought_start|> and <|thought_end|> (Qwen/Holo specific)
252
  THOUGHT_START = 151667
253
  THOUGHT_END = 151668
254
 
255
  thinking_content = ""
256
  content = ""
257
 
258
- try:
259
- if THOUGHT_START in all_ids:
260
- start_idx = all_ids.index(THOUGHT_START)
261
- try:
262
- end_idx = all_ids.index(THOUGHT_END)
263
- except ValueError:
264
- end_idx = len(all_ids)
265
-
266
- thinking_ids = all_ids[start_idx+1:end_idx]
267
- thinking_content = processor.decode(thinking_ids, skip_special_tokens=True).strip()
268
-
269
- # Content is everything after thought_end
270
- content_ids = all_ids[end_idx+1:]
271
- content = processor.decode(content_ids, skip_special_tokens=True).strip()
272
- else:
273
- # Fallback if no reasoning tokens found (just raw output)
274
- # Slice off input tokens first
275
- output_ids = all_ids[input_len:]
276
- content = processor.decode(output_ids, skip_special_tokens=True).strip()
277
- except Exception as e:
278
- print(f"Holo Parsing Error: {e}")
279
- content = processor.decode(all_ids[input_len:], skip_special_tokens=True).strip()
280
-
281
- # Parse JSON Content
282
  actions = []
283
- try:
284
- # Holo2 outputs strictly valid JSON usually
285
- # E.g. {"x": 500, "y": 300, "description": "search bar"}
286
- # Or {"action": "click", "point": [100, 200]}
287
- # Flattening to common format
288
- if "{" in content and "}" in content:
289
- # Find JSON block if surrounded by text
290
- json_str = re.search(r"(\{.*\})", content, re.DOTALL).group(1)
291
- data = json.loads(json_str)
292
-
293
- x, y = 0, 0
294
- if "x" in data and "y" in data:
295
- x, y = data["x"], data["y"]
296
- elif "point" in data:
297
- x, y = data["point"][0], data["point"][1]
298
- elif "coordinate" in data:
299
- x, y = data["coordinate"][0], data["coordinate"][1]
300
-
301
- if x or y:
302
- # Holo2 output is 0-1000 scale
303
- actions.append({
304
- "type": "click",
305
- "x": float(x),
306
- "y": float(y),
307
- "text": data.get("description", "") or data.get("text", ""),
308
- "scale_base": 1000 # Flag to indicate this needs normalization from 1000
309
- })
310
- except Exception as e:
311
- print(f"Holo JSON Parse Failed: {e}")
312
 
313
  return content, thinking_content, actions
314
 
@@ -325,38 +308,41 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
325
  x = act['x']
326
  y = act['y']
327
 
328
- # Holo2 Special Case (0-1000 scaling)
 
 
 
329
  if act.get('scale_base') == 1000:
330
- pixel_x = int((x / 1000) * width)
331
- pixel_y = int((y / 1000) * height)
332
- # Normalized (0-1)
333
  elif x <= 1.0 and y <= 1.0 and x > 0:
334
  pixel_x = int(x * width)
335
  pixel_y = int(y * height)
336
- # Absolute Pixels
337
  else:
338
  pixel_x = int(x)
339
  pixel_y = int(y)
340
 
341
- color = 'red' if 'click' in act['type'].lower() else 'blue'
342
 
343
- # Draw Visuals
344
  r = 15
345
- draw.ellipse([pixel_x - r, pixel_y - r, pixel_x + r, pixel_y + r], outline=color, width=4)
346
- draw.ellipse([pixel_x - 3, pixel_y - 3, pixel_x + 3, pixel_y + 3], fill=color)
347
 
348
- # Draw Cross
349
- draw.line([pixel_x - 10, pixel_y, pixel_x + 10, pixel_y], fill=color, width=2)
350
- draw.line([pixel_x, pixel_y - 10, pixel_x, pixel_y + 10], fill=color, width=2)
351
 
352
- # Label
353
- label = f"{act['type']}"
354
- if act['text']: label += f": {act['text']}"
355
 
356
- text_pos = (pixel_x + 20, pixel_y - 10)
357
- bbox = draw.textbbox(text_pos, label, font=font)
358
- draw.rectangle((bbox[0]-4, bbox[1]-2, bbox[2]+4, bbox[3]+2), fill="black")
359
- draw.text(text_pos, label, fill="white", font=font)
360
 
361
  return img_copy
362
 
@@ -366,18 +352,17 @@ def create_localized_image(original_image: Image.Image, actions: list[dict]) ->
366
 
367
  @spaces.GPU(duration=120)
368
  def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: str):
369
- if input_numpy_image is None: return "⚠️ Please upload an image.", None, None
370
 
371
  input_pil_image = array_to_image(input_numpy_image)
372
  orig_w, orig_h = input_pil_image.size
373
  actions = []
374
  raw_response = ""
375
- reasoning_text = None
376
 
377
  # --- UI-TARS Logic ---
378
  if model_choice == "UI-TARS-1.5-7B":
379
- if model_x is None: return "Error: UI-TARS model failed to load.", None, None
380
- print("Running UI-TARS...")
381
 
382
  ip_params = get_image_proc_params(processor_x)
383
  resized_h, resized_w = smart_resize(
@@ -400,7 +385,7 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
400
 
401
  actions = parse_uitars_response(raw_response)
402
 
403
- # Rescale
404
  scale_x = orig_w / resized_w
405
  scale_y = orig_h / resized_h
406
  for a in actions:
@@ -409,8 +394,7 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
409
 
410
  # --- Holo2 Logic ---
411
  elif model_choice == "Holo2-8B":
412
- if model_h is None: return "Error: Holo2 model failed to load.", None, None
413
- print("Running Holo2...")
414
 
415
  messages = get_holo2_prompt(task, input_pil_image)
416
  text_prompt = processor_h.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
@@ -421,18 +405,12 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
421
  with torch.no_grad():
422
  generated_ids = model_h.generate(**inputs, max_new_tokens=512)
423
 
424
- # Parse Reasoning + Content
425
  input_len = len(inputs.input_ids[0])
426
- content, thinking, parsed_actions = parse_holo2_response(generated_ids, processor_h, input_len)
427
-
428
- raw_response = content
429
- reasoning_text = thinking
430
- actions = parsed_actions
431
 
432
  # --- Fara Logic ---
433
  else:
434
- if model_v is None: return "Error: Fara model failed to load.", None, None
435
- print("Running Fara...")
436
  messages = get_fara_prompt(task, input_pil_image)
437
  text_prompt = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
438
  image_inputs, video_inputs = process_vision_info(messages)
@@ -446,20 +424,17 @@ def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: s
446
  raw_response = processor_v.batch_decode(generated_ids, skip_special_tokens=True)[0]
447
  actions = parse_fara_response(raw_response)
448
 
449
- print(f"Raw: {raw_response}")
450
- if reasoning_text: print(f"Thinking: {reasoning_text}")
451
-
452
  # Visualize
453
  output_image = input_pil_image
454
  if actions:
455
  vis = create_localized_image(input_pil_image, actions)
456
  if vis: output_image = vis
457
 
458
- final_text_output = f"▶️ OUTPUT:\n{raw_response}"
459
  if reasoning_text:
460
- final_text_output = f"🧠 THINKING PROCESS:\n{reasoning_text}\n\n" + final_text_output
461
 
462
- return final_text_output, output_image
463
 
464
  # -----------------------------------------------------------------------------
465
  # 6. UI SETUP
 
186
 
187
  # --- Holo2 Prompt ---
188
  def get_holo2_prompt(task, image):
189
+ # Holo2 often expects a simple user prompt with the image
190
  return [
191
  {
192
  "role": "user",
 
211
  # -----------------------------------------------------------------------------
212
 
213
  def parse_uitars_response(text: str) -> List[Dict]:
214
+ """Parse UI-TARS output"""
215
  actions = []
216
  text = text.strip()
217
 
 
221
  m = re.findall(r"point=\[\s*(\d+)\s*,\s*(\d+)\s*\]", text, re.IGNORECASE)
222
  for p in m: actions.append({"type": "click", "x": int(p[0]), "y": int(p[1]), "text": ""})
223
 
 
 
 
224
  return actions
225
 
226
  def parse_fara_response(response: str) -> List[Dict]:
227
+ """Parse Fara output"""
228
  actions = []
229
  matches = re.findall(r"<tool_call>(.*?)</tool_call>", response, re.DOTALL)
230
  for match in matches:
 
242
  return actions
243
 
244
  def parse_holo2_response(generated_ids, processor, input_len) -> Tuple[str, str, List[Dict]]:
245
+ """Parse Holo2 reasoning and actions"""
246
  all_ids = generated_ids[0].tolist()
247
 
248
+ # Qwen/Holo specific reasoning tokens
249
  THOUGHT_START = 151667
250
  THOUGHT_END = 151668
251
 
252
  thinking_content = ""
253
  content = ""
254
 
255
+ # 1. Extract Thinking
256
+ if THOUGHT_START in all_ids:
257
+ start_idx = all_ids.index(THOUGHT_START)
258
+ try:
259
+ end_idx = all_ids.index(THOUGHT_END)
260
+ except ValueError:
261
+ end_idx = len(all_ids)
262
+ thinking_ids = all_ids[start_idx+1:end_idx]
263
+ thinking_content = processor.decode(thinking_ids, skip_special_tokens=True).strip()
264
+ # Content is after thought_end
265
+ output_ids = all_ids[end_idx+1:]
266
+ content = processor.decode(output_ids, skip_special_tokens=True).strip()
267
+ else:
268
+ output_ids = all_ids[input_len:]
269
+ content = processor.decode(output_ids, skip_special_tokens=True).strip()
270
+
271
+ # 2. Extract Coordinates (Robust parsing)
 
 
 
 
 
 
 
272
  actions = []
273
+
274
+ # Pattern A: point=[x, y] (Common in Holo)
275
+ points = re.findall(r"point=\[\s*(\d+)\s*,\s*(\d+)\s*\]", content)
276
+ for p in points:
277
+ actions.append({"type": "click", "x": float(p[0]), "y": float(p[1]), "scale_base": 1000})
278
+
279
+ # Pattern B: JSON {"point": [x, y]}
280
+ json_candidates = re.findall(r"\{.*?\}", content, re.DOTALL)
281
+ for jc in json_candidates:
282
+ try:
283
+ data = json.loads(jc)
284
+ if "point" in data:
285
+ actions.append({"type": "click", "x": float(data["point"][0]), "y": float(data["point"][1]), "scale_base": 1000})
286
+ if "coordinate" in data:
287
+ actions.append({"type": "click", "x": float(data["coordinate"][0]), "y": float(data["coordinate"][1]), "scale_base": 1000})
288
+ except: pass
289
+
290
+ # Pattern C: Plain [x, y] at end of string
291
+ if not actions:
292
+ plain_coords = re.findall(r"\[\s*(\d+)\s*,\s*(\d+)\s*\]", content)
293
+ for p in plain_coords:
294
+ actions.append({"type": "click", "x": float(p[0]), "y": float(p[1]), "scale_base": 1000})
 
 
 
 
 
 
 
295
 
296
  return content, thinking_content, actions
297
 
 
308
  x = act['x']
309
  y = act['y']
310
 
311
+ # Scaling Logic
312
+ pixel_x, pixel_y = 0, 0
313
+
314
+ # Case 1: Holo2 0-1000 scale
315
  if act.get('scale_base') == 1000:
316
+ pixel_x = int((x / 1000.0) * width)
317
+ pixel_y = int((y / 1000.0) * height)
318
+ # Case 2: Normalized 0-1
319
  elif x <= 1.0 and y <= 1.0 and x > 0:
320
  pixel_x = int(x * width)
321
  pixel_y = int(y * height)
322
+ # Case 3: Absolute Pixels
323
  else:
324
  pixel_x = int(x)
325
  pixel_y = int(y)
326
 
327
+ color = 'red'
328
 
329
+ # Draw Markers (Thicker for visibility)
330
  r = 15
331
+ draw.ellipse([pixel_x - r, pixel_y - r, pixel_x + r, pixel_y + r], outline=color, width=5)
332
+ draw.ellipse([pixel_x - 4, pixel_y - 4, pixel_x + 4, pixel_y + 4], fill=color)
333
 
334
+ # Crosshair
335
+ draw.line([pixel_x - 20, pixel_y, pixel_x + 20, pixel_y], fill=color, width=3)
336
+ draw.line([pixel_x, pixel_y - 20, pixel_x, pixel_y + 20], fill=color, width=3)
337
 
338
+ # Text Label
339
+ label = f"{act.get('type','Action')}"
340
+ text_pos = (pixel_x + 20, pixel_y - 15)
341
 
342
+ if font:
343
+ bbox = draw.textbbox(text_pos, label, font=font)
344
+ draw.rectangle((bbox[0]-4, bbox[1]-2, bbox[2]+4, bbox[3]+2), fill="black")
345
+ draw.text(text_pos, label, fill="white", font=font)
346
 
347
  return img_copy
348
 
 
352
 
353
  @spaces.GPU(duration=120)
354
  def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: str):
355
+ if input_numpy_image is None: return "⚠️ Please upload an image.", None
356
 
357
  input_pil_image = array_to_image(input_numpy_image)
358
  orig_w, orig_h = input_pil_image.size
359
  actions = []
360
  raw_response = ""
361
+ reasoning_text = ""
362
 
363
  # --- UI-TARS Logic ---
364
  if model_choice == "UI-TARS-1.5-7B":
365
+ if model_x is None: return "Error: UI-TARS model failed to load.", None
 
366
 
367
  ip_params = get_image_proc_params(processor_x)
368
  resized_h, resized_w = smart_resize(
 
385
 
386
  actions = parse_uitars_response(raw_response)
387
 
388
+ # Rescale UI-TARS coords
389
  scale_x = orig_w / resized_w
390
  scale_y = orig_h / resized_h
391
  for a in actions:
 
394
 
395
  # --- Holo2 Logic ---
396
  elif model_choice == "Holo2-8B":
397
+ if model_h is None: return "Error: Holo2 model failed to load.", None
 
398
 
399
  messages = get_holo2_prompt(task, input_pil_image)
400
  text_prompt = processor_h.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
405
  with torch.no_grad():
406
  generated_ids = model_h.generate(**inputs, max_new_tokens=512)
407
 
 
408
  input_len = len(inputs.input_ids[0])
409
+ raw_response, reasoning_text, actions = parse_holo2_response(generated_ids, processor_h, input_len)
 
 
 
 
410
 
411
  # --- Fara Logic ---
412
  else:
413
+ if model_v is None: return "Error: Fara model failed to load.", None
 
414
  messages = get_fara_prompt(task, input_pil_image)
415
  text_prompt = processor_v.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
416
  image_inputs, video_inputs = process_vision_info(messages)
 
424
  raw_response = processor_v.batch_decode(generated_ids, skip_special_tokens=True)[0]
425
  actions = parse_fara_response(raw_response)
426
 
 
 
 
427
  # Visualize
428
  output_image = input_pil_image
429
  if actions:
430
  vis = create_localized_image(input_pil_image, actions)
431
  if vis: output_image = vis
432
 
433
+ final_output = f"▶️ OUTPUT:\n{raw_response}"
434
  if reasoning_text:
435
+ final_output = f"🧠 THINKING:\n{reasoning_text}\n\n" + final_output
436
 
437
+ return final_output, output_image
438
 
439
  # -----------------------------------------------------------------------------
440
  # 6. UI SETUP