Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import base64 | |
| from io import BytesIO | |
| import re | |
| import os | |
| examples = [ | |
| {"image": "./assets/example_desktop.png", "prompt": "switch off the wired connection"}, | |
| {"image": "./assets/example_web.png", "prompt": "view all branches"}, | |
| {"image": "./assets/example_mobile.jpg", "prompt": "share the screenshot"}, | |
| ] | |
| # Code from user | |
| openai_api_key = os.environ["aria_ui_api_key"] | |
| openai_api_base = os.environ["aria_ui_api_base"] | |
| from openai import OpenAI # Assuming the OpenAI client library is installed | |
| client = OpenAI( | |
| api_key=openai_api_key, | |
| base_url=openai_api_base, | |
| ) | |
| models = client.models.list() | |
| model = models.data[0].id | |
| def encode_pil_image_to_base64(image: Image.Image) -> str: | |
| image = image.convert("RGB") | |
| buffered = BytesIO() | |
| image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return img_str | |
| def request_aria_ui(image: Image.Image, prompt: str) -> str: | |
| image_base64 = encode_pil_image_to_base64(image) | |
| chat_completion_from_url = client.chat.completions.create( | |
| messages=[{ | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "<image>Given a GUI image, what are the relative (0-1000) pixel point coordinates for the element corresponding to the following instruction or description: " + prompt | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{image_base64}" | |
| }, | |
| }, | |
| ], | |
| }], | |
| model=model, | |
| max_tokens=512, | |
| stop=["<|im_end|>"], | |
| extra_body={"split_image": True, "image_max_size": 980, "temperature": 0, "top_k": 1} | |
| ) | |
| result = chat_completion_from_url.choices[0].message.content | |
| return result | |
| def _extract_coords_from_response(response: str) -> tuple[int, int]: | |
| resp = response.replace("```", "").strip() | |
| numbers = re.findall(r'\d+', resp) | |
| if len(numbers) != 2: | |
| raise ValueError(f"Expected exactly 2 coordinates, found {len(numbers)} numbers in response: {response}") | |
| return int(numbers[0]), int(numbers[1]) | |
| def image_grounding(image: Image.Image, prompt: str) -> Image.Image: | |
| try: | |
| # Request processing from API | |
| response = request_aria_ui(image, prompt) | |
| # Extract normalized coordinates | |
| norm_coords = _extract_coords_from_response(response) | |
| # Convert normalized coordinates to absolute coordinates | |
| width, height = image.size | |
| long_side = max(width, height) | |
| abs_coords = ( | |
| int(norm_coords[0] * width / 1000), # Scale x-coordinate | |
| int(norm_coords[1] * height / 1000) # Scale y-coordinate | |
| ) | |
| # Load and prepare the click indicator image | |
| click_image = Image.open("assets/click.png") | |
| # Calculate adaptive size for click indicator | |
| # Make it proportional to the image width (e.g., 3% of image width) | |
| target_width = int(long_side * 0.03) # 3% of image width | |
| aspect_ratio = click_image.width / click_image.height | |
| target_height = int(target_width / aspect_ratio) | |
| click_image = click_image.resize((target_width, target_height)) | |
| # Calculate position to center the click image on the coordinates | |
| # Add a small offset downward (20% of click image height) | |
| # Calculate position to align the 30% point of the click image with the coordinates | |
| click_x = abs_coords[0] - int(click_image.width * 0.3) # Align 30% from left | |
| click_y = abs_coords[1] - int(click_image.height * 0.3) # Align 30% from top | |
| # Create output image and paste the click indicator | |
| output_image = image.copy() | |
| # Draw bounding box | |
| draw = ImageDraw.Draw(output_image) | |
| bbox = [ | |
| click_x, # left | |
| click_y, # top | |
| click_x + click_image.width, # right | |
| click_y + click_image.height # bottom | |
| ] | |
| draw.rectangle(bbox, outline='red', width=int(click_image.width * 0.1)) | |
| output_image.paste(click_image, (click_x, click_y), click_image) | |
| return output_image | |
| except Exception as e: | |
| raise ValueError(f"An error occurred: {e}") | |
| def resize_image_with_max_size(image: Image.Image, max_size: int = 1920) -> Image.Image: | |
| """Resize image to have a maximum dimension of max_size while maintaining aspect ratio.""" | |
| width, height = image.size | |
| if width <= max_size and height <= max_size: | |
| return image | |
| if width > height: | |
| new_width = max_size | |
| new_height = int(height * (max_size / width)) | |
| else: | |
| new_height = max_size | |
| new_width = int(width * (max_size / height)) | |
| return image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| # Gradio app | |
| def gradio_interface(input_image, prompt): | |
| print(input_image.size) | |
| input_image = resize_image_with_max_size(input_image) | |
| print(input_image.size) | |
| output_image = image_grounding(input_image, prompt) | |
| return output_image | |
| with gr.Blocks() as demo: | |
| # with gr.Row(elem_classes="container"): | |
| # gr.Image("https://raw.githubusercontent.com/AriaUI/Aria-UI/refs/heads/main/assets/logo_long.png", show_label=False, container=False, scale=1, elem_classes="logo", height=76) | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <div style="display: flex; justify-content: center;"> | |
| <img src="https://raw.githubusercontent.com/AriaUI/Aria-UI/refs/heads/main/assets/logo_long.png" alt="Aria-UI" style="height: 76px; margin-bottom: 10px;"/> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| gr.Markdown("""| [π€ Aria-UI Models](https://huggingface.co/Aria-UI/Aria-UI-base) β’ [π€ Aria-UI Dataset](https://huggingface.co/datasets/Aria-UI/Aria-UI_Data) β’ [π Project Page](https://ariaui.github.io) β’ [π Paper](https://arxiv.org/abs/2412.16256) | | |
| |:---------------------------------------------------------------------------------------------------------:|""") | |
| gr.Markdown("# Aria-UI: Visual Grounding for GUI Instructions") | |
| gr.Markdown("ππ Upload a GUI image and enter a instruction. Aria-UI will try its best to ground the instruction to specific element in the image. π―π―") | |
| with gr.Row(): | |
| with gr.Column(scale=2): # Make this column smaller | |
| image_input = gr.Image(type="pil", label="Upload GUI Image", height=600) | |
| prompt_input = gr.Textbox(label="Enter GUI Instruction") | |
| submit_button = gr.Button("Process") | |
| with gr.Column(scale=3): # Make this column larger | |
| output_image = gr.Image(label="Grounding Result", height=500) # Set specific height for larger display | |
| with gr.Column(scale=2): | |
| # Move examples here and make them vertical | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| example["image"], | |
| example["prompt"] | |
| ] | |
| for example in examples | |
| ], | |
| inputs=[image_input, prompt_input], | |
| outputs=[output_image], | |
| fn=gradio_interface, | |
| cache_examples=False, | |
| label="Example Tasks", # Add label for better organization | |
| examples_per_page=5 # Control number of examples shown at once | |
| ) | |
| submit_button.click( | |
| fn=gradio_interface, | |
| inputs=[image_input, prompt_input], | |
| outputs=[output_image] | |
| ) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ssr_mode=False, | |
| debug=True, | |
| ) |