|
|
import base64 |
|
|
import sys |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoProcessor |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
def encode_image(image_path: str) -> str: |
|
|
"""Encode image to base64 string for model input.""" |
|
|
with open(image_path, "rb") as f: |
|
|
return base64.b64encode(f.read()).decode() |
|
|
|
|
|
|
|
|
def load_model( |
|
|
model_path: str, |
|
|
) -> tuple[AutoModelForVision2Seq, AutoTokenizer, AutoProcessor]: |
|
|
"""Load OpenCUA model, tokenizer, and image processor.""" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
model = AutoModelForVision2Seq.from_pretrained( |
|
|
model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True |
|
|
) |
|
|
image_processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
|
|
return model, tokenizer, image_processor |
|
|
|
|
|
|
|
|
def create_grounding_messages(image_path: str, instruction: str) -> list[dict]: |
|
|
"""Create chat messages for GUI grounding task.""" |
|
|
system_prompt = ( |
|
|
"You are a GUI agent. You are given a task and a screenshot of the screen. " |
|
|
"You need to perform a series of pyautogui actions to complete the task." |
|
|
) |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "text", |
|
|
"text": "Please perform the following task by providing the action and the coordinates in the format of <action>(x, y): " |
|
|
+ instruction, |
|
|
}, |
|
|
{ |
|
|
"type": "image", |
|
|
"image": f"data:image/png;base64,{encode_image(image_path)}", |
|
|
}, |
|
|
], |
|
|
}, |
|
|
] |
|
|
return messages |
|
|
|
|
|
|
|
|
def run_inference( |
|
|
model: AutoModelForVision2Seq, |
|
|
tokenizer: AutoTokenizer, |
|
|
image_processor: AutoProcessor, |
|
|
messages: list[dict], |
|
|
image_path: str, |
|
|
) -> str: |
|
|
"""Run inference on the model.""" |
|
|
|
|
|
text = image_processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
|
|
|
inputs = image_processor( |
|
|
text=[text], images=[image], padding=True, return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=2048, |
|
|
do_sample=False, |
|
|
) |
|
|
|
|
|
|
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids) :] |
|
|
for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) |
|
|
] |
|
|
output_text = image_processor.batch_decode( |
|
|
generated_ids_trimmed, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False, |
|
|
)[0] |
|
|
|
|
|
return output_text |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to run the sanity check.""" |
|
|
|
|
|
model_path = "Uniphore/actio-ui-7b-rlvr" |
|
|
image_path = "screenshot.png" |
|
|
instruction = "Click on the submit button" |
|
|
|
|
|
|
|
|
if len(sys.argv) > 1: |
|
|
instruction = " ".join(sys.argv[1:]) |
|
|
|
|
|
print(f"Loading model from: {model_path}") |
|
|
try: |
|
|
model, tokenizer, image_processor = load_model(model_path) |
|
|
print("✓ Model loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"✗ Error loading model: {e}") |
|
|
return 1 |
|
|
|
|
|
print(f"Processing image: {image_path}") |
|
|
print(f"Instruction: {instruction}") |
|
|
|
|
|
try: |
|
|
messages = create_grounding_messages(image_path, instruction) |
|
|
result = run_inference(model, tokenizer, image_processor, messages, image_path) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("MODEL OUTPUT:") |
|
|
print("=" * 60) |
|
|
print(result) |
|
|
print("=" * 60) |
|
|
return 0 |
|
|
except Exception as e: |
|
|
print(f"✗ Error during inference: {e}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
return 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|