File size: 4,233 Bytes
569717f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09febf4
569717f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc8dadf
569717f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import base64
import sys
import torch
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoProcessor
from PIL import Image


def encode_image(image_path: str) -> str:
    """Encode image to base64 string for model input."""
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode()


def load_model(
    model_path: str,
) -> tuple[AutoModelForVision2Seq, AutoTokenizer, AutoProcessor]:
    """Load OpenCUA model, tokenizer, and image processor."""
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModelForVision2Seq.from_pretrained(
        model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True
    )
    image_processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)

    return model, tokenizer, image_processor


def create_grounding_messages(image_path: str, instruction: str) -> list[dict]:
    """Create chat messages for GUI grounding task."""
    system_prompt = (
        "You are a GUI agent. You are given a task and a screenshot of the screen. "
        "You need to perform a series of pyautogui actions to complete the task."
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Please perform the following task by providing the action and the coordinates in the format of <action>(x, y): "
                    + instruction,
                },
                {
                    "type": "image",
                    "image": f"data:image/png;base64,{encode_image(image_path)}",
                },
            ],
        },
    ]
    return messages


def run_inference(
    model: AutoModelForVision2Seq,
    tokenizer: AutoTokenizer,
    image_processor: AutoProcessor,
    messages: list[dict],
    image_path: str,
) -> str:
    """Run inference on the model."""
    # Prepare text from messages
    text = image_processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    # Open image
    image = Image.open(image_path).convert("RGB")

    # Process inputs using the processor
    inputs = image_processor(
        text=[text], images=[image], padding=True, return_tensors="pt"
    )

    # Move inputs to model device
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Generate response
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=2048,
            do_sample=False,
        )

    # Decode output (skip the input tokens)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :]
        for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
    ]
    output_text = image_processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )[0]

    return output_text


def main():
    """Main function to run the sanity check."""
    # Configuration
    model_path = "Uniphore/actio-ui-7b-rlvr"  # or other model variants
    image_path = "screenshot.png"
    instruction = "Click on the submit button"

    # Check if custom instruction provided
    if len(sys.argv) > 1:
        instruction = " ".join(sys.argv[1:])

    print(f"Loading model from: {model_path}")
    try:
        model, tokenizer, image_processor = load_model(model_path)
        print("✓ Model loaded successfully")
    except Exception as e:
        print(f"✗ Error loading model: {e}")
        return 1

    print(f"Processing image: {image_path}")
    print(f"Instruction: {instruction}")

    try:
        messages = create_grounding_messages(image_path, instruction)
        result = run_inference(model, tokenizer, image_processor, messages, image_path)

        print("\n" + "=" * 60)
        print("MODEL OUTPUT:")
        print("=" * 60)
        print(result)
        print("=" * 60)
        return 0
    except Exception as e:
        print(f"✗ Error during inference: {e}")
        import traceback

        traceback.print_exc()
        return 1


if __name__ == "__main__":
    sys.exit(main())