actio-ui-7b-rlvr / sanity.py
chc012's picture
remove code in readme and keep only in sanity.py
09febf4
import base64
import sys
import torch
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoProcessor
from PIL import Image
def encode_image(image_path: str) -> str:
"""Encode image to base64 string for model input."""
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode()
def load_model(
model_path: str,
) -> tuple[AutoModelForVision2Seq, AutoTokenizer, AutoProcessor]:
"""Load OpenCUA model, tokenizer, and image processor."""
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True
)
image_processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer, image_processor
def create_grounding_messages(image_path: str, instruction: str) -> list[dict]:
"""Create chat messages for GUI grounding task."""
system_prompt = (
"You are a GUI agent. You are given a task and a screenshot of the screen. "
"You need to perform a series of pyautogui actions to complete the task."
)
messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Please perform the following task by providing the action and the coordinates in the format of <action>(x, y): "
+ instruction,
},
{
"type": "image",
"image": f"data:image/png;base64,{encode_image(image_path)}",
},
],
},
]
return messages
def run_inference(
model: AutoModelForVision2Seq,
tokenizer: AutoTokenizer,
image_processor: AutoProcessor,
messages: list[dict],
image_path: str,
) -> str:
"""Run inference on the model."""
# Prepare text from messages
text = image_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Open image
image = Image.open(image_path).convert("RGB")
# Process inputs using the processor
inputs = image_processor(
text=[text], images=[image], padding=True, return_tensors="pt"
)
# Move inputs to model device
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate response
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False,
)
# Decode output (skip the input tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_text = image_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return output_text
def main():
"""Main function to run the sanity check."""
# Configuration
model_path = "Uniphore/actio-ui-7b-rlvr" # or other model variants
image_path = "screenshot.png"
instruction = "Click on the submit button"
# Check if custom instruction provided
if len(sys.argv) > 1:
instruction = " ".join(sys.argv[1:])
print(f"Loading model from: {model_path}")
try:
model, tokenizer, image_processor = load_model(model_path)
print("✓ Model loaded successfully")
except Exception as e:
print(f"✗ Error loading model: {e}")
return 1
print(f"Processing image: {image_path}")
print(f"Instruction: {instruction}")
try:
messages = create_grounding_messages(image_path, instruction)
result = run_inference(model, tokenizer, image_processor, messages, image_path)
print("\n" + "=" * 60)
print("MODEL OUTPUT:")
print("=" * 60)
print(result)
print("=" * 60)
return 0
except Exception as e:
print(f"✗ Error during inference: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())