| | from transformers import AutoProcessor |
| | from PIL import Image |
| | import numpy as np |
| | import onnxruntime as ort |
| | import time |
| | import argparse |
| | import random |
| |
|
| | |
| | import ztu_somemodelruntime_rknnlite2 as rknnort |
| | |
| | |
| |
|
| | |
| | import os |
| |
|
| | os.chdir(os.path.dirname(os.path.abspath(__file__))) |
| |
|
| |
|
| | def run(image_path, prompt, max_new_tokens, output_image_path, temperature, seed): |
| | |
| | if seed is not None: |
| | random.seed(seed) |
| | np.random.seed(seed) |
| |
|
| | |
| | total_time = 0 |
| |
|
| | |
| | vision_encoder = rknnort.InferenceSession( |
| | "vision_encoder.onnx", providers=["CPUExecutionProvider"] |
| | ) |
| | encoder = rknnort.InferenceSession( |
| | "encoder_model.onnx", providers=["CPUExecutionProvider"] |
| | ) |
| | decoder_prefill = rknnort.InferenceSession( |
| | "decoder_model.onnx", providers=["CPUExecutionProvider"] |
| | ) |
| |
|
| | text_embed = ort.InferenceSession( |
| | "embed_tokens.onnx", providers=["CPUExecutionProvider"] |
| | ) |
| | decoder_decode = ort.InferenceSession( |
| | "decoder_model_merged.onnx", providers=["CPUExecutionProvider"] |
| | ) |
| |
|
| | |
| | processor = AutoProcessor.from_pretrained( |
| | "microsoft/Florence-2-base", trust_remote_code=True |
| | ) |
| |
|
| | |
| | image = Image.open(image_path).convert("RGB") |
| | original_image = image.copy() |
| | original_size = image.size |
| | |
| | image = image.resize((64, 64)) |
| | |
| |
|
| | inputs = processor( |
| | text=prompt, images=image, return_tensors="np", do_resize=False |
| | ) |
| | for k, v in inputs.items(): |
| | print(k, v.shape) |
| | |
| | |
| | start_time = time.time() |
| | image_features = vision_encoder.run(None, {"pixel_values": inputs["pixel_values"]})[ |
| | 0 |
| | ] |
| |
|
| | end_time = time.time() |
| | vision_encoder_time = (end_time - start_time) * 1000 |
| | total_time += vision_encoder_time |
| | print(f"Vision encoder time: {vision_encoder_time:.2f} ms") |
| | print(image_features.shape) |
| | |
| |
|
| | |
| | start_time = time.time() |
| | inputs_embeds = text_embed.run(None, {"input_ids": inputs["input_ids"]})[0] |
| | end_time = time.time() |
| | text_embed_time = (end_time - start_time) * 1000 |
| | total_time += text_embed_time |
| | print(f"Text embed time: {text_embed_time:.2f} ms") |
| | print(inputs_embeds.shape) |
| | |
| |
|
| | |
| | batch_size, image_token_length = image_features.shape[:-1] |
| | image_attention_mask = np.ones((batch_size, image_token_length)) |
| | task_prefix_embeds = inputs_embeds |
| | task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1])) |
| | |
| | if len(task_prefix_attention_mask.shape) == 3: |
| | task_prefix_attention_mask = task_prefix_attention_mask[:, 0] |
| | inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1) |
| | attention_mask = np.concatenate( |
| | [image_attention_mask, task_prefix_attention_mask], axis=1 |
| | ) |
| |
|
| | |
| | start_time = time.time() |
| | encoder_out = encoder.run( |
| | None, |
| | { |
| | "inputs_embeds": inputs_embeds, |
| | "attention_mask": attention_mask.astype(np.int64), |
| | }, |
| | ) |
| | end_time = time.time() |
| | encoder_time = (end_time - start_time) * 1000 |
| | total_time += encoder_time |
| | print(f"Encoder time: {encoder_time:.2f} ms") |
| | encoder_hidden_states = encoder_out[0] |
| | print(encoder_hidden_states.shape) |
| |
|
| | |
| | start_time = time.time() |
| | next_token = processor.tokenizer.bos_token_id |
| | next_input_embeds = text_embed.run(None, { |
| | "input_ids": np.array([[next_token]], dtype=np.int64) |
| | })[0] |
| | decoder_outs = decoder_prefill.run( |
| | None, |
| | { |
| | "inputs_embeds": next_input_embeds, |
| | "encoder_hidden_states": encoder_hidden_states, |
| | |
| | }, |
| | ) |
| | end_time = time.time() |
| | decoder_prefill_time = (end_time - start_time) * 1000 |
| | total_time += decoder_prefill_time |
| | print(f"Decoder prefill time: {decoder_prefill_time:.2f} ms") |
| | |
| | |
| |
|
| | encoder_kv = decoder_outs[1:] |
| |
|
| | |
| | generated_tokens = [] |
| | decoder_decode_total_time = 0 |
| | while generated_tokens.__len__() < max_new_tokens: |
| | |
| | logits = decoder_outs[0] |
| | decoder_kv = decoder_outs[1:] |
| |
|
| | |
| | next_token_logits = logits[:, -1, :] |
| |
|
| | if temperature == 0: |
| | |
| | next_token = np.argmax(next_token_logits, axis=-1)[0] |
| | else: |
| | |
| | |
| | next_token_logits /= temperature |
| |
|
| | |
| | next_token_logits -= np.max(next_token_logits) |
| |
|
| | |
| | probs = np.exp(next_token_logits) / np.sum(np.exp(next_token_logits)) |
| |
|
| | |
| | next_token = np.random.choice(len(probs[0]), p=probs[0]) |
| |
|
| | print("next_token: ", processor.decode([next_token])) |
| | |
| | generated_tokens.append(next_token) |
| |
|
| | |
| | if next_token == 2: |
| | break |
| |
|
| | |
| | start_time = time.time() |
| | next_input_embeds = text_embed.run( |
| | None, {"input_ids": np.array([[next_token]], dtype=np.int64)} |
| | )[0] |
| | end_time = time.time() |
| | text_embed_time = (end_time - start_time) * 1000 |
| | decoder_decode_total_time += text_embed_time |
| |
|
| | |
| | start_time = time.time() |
| | decoder_outs = decoder_decode.run( |
| | None, |
| | { |
| | "use_cache_branch": np.array([True], dtype=np.bool_), |
| | "inputs_embeds": next_input_embeds, |
| | "encoder_hidden_states": encoder_hidden_states, |
| | |
| | "past_key_values.0.decoder.key": decoder_kv[0], |
| | "past_key_values.0.decoder.value": decoder_kv[1], |
| | "past_key_values.0.encoder.key": encoder_kv[2], |
| | "past_key_values.0.encoder.value": encoder_kv[3], |
| | "past_key_values.1.decoder.key": decoder_kv[4], |
| | "past_key_values.1.decoder.value": decoder_kv[5], |
| | "past_key_values.1.encoder.key": encoder_kv[6], |
| | "past_key_values.1.encoder.value": encoder_kv[7], |
| | "past_key_values.2.decoder.key": decoder_kv[8], |
| | "past_key_values.2.decoder.value": decoder_kv[9], |
| | "past_key_values.2.encoder.key": encoder_kv[10], |
| | "past_key_values.2.encoder.value": encoder_kv[11], |
| | "past_key_values.3.decoder.key": decoder_kv[12], |
| | "past_key_values.3.decoder.value": decoder_kv[13], |
| | "past_key_values.3.encoder.key": encoder_kv[14], |
| | "past_key_values.3.encoder.value": encoder_kv[15], |
| | "past_key_values.4.decoder.key": decoder_kv[16], |
| | "past_key_values.4.decoder.value": decoder_kv[17], |
| | "past_key_values.4.encoder.key": encoder_kv[18], |
| | "past_key_values.4.encoder.value": encoder_kv[19], |
| | "past_key_values.5.decoder.key": decoder_kv[20], |
| | "past_key_values.5.decoder.value": decoder_kv[21], |
| | "past_key_values.5.encoder.key": encoder_kv[22], |
| | "past_key_values.5.encoder.value": encoder_kv[23], |
| | }, |
| | ) |
| | end_time = time.time() |
| | decoder_decode_time = (end_time - start_time) * 1000 |
| | decoder_decode_total_time += decoder_decode_time |
| |
|
| | total_time += decoder_decode_total_time |
| | print(f"Decoder decode total time: {decoder_decode_total_time:.2f} ms") |
| |
|
| | |
| | print("generated_tokens: ", generated_tokens) |
| | generated_text = processor.batch_decode( |
| | [generated_tokens], skip_special_tokens=False |
| | )[0] |
| | print("Generated Text:", generated_text) |
| | parsed_answer = processor.post_process_generation( |
| | generated_text, |
| | task=prompt.split(">")[0].strip() + ">", |
| | image_size=original_size, |
| | ) |
| | print("Parsed Answer:", parsed_answer) |
| |
|
| | print(f"Total inference time: {total_time:.2f} ms") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) |
| | parser.add_argument("image_path", type=str, help="Path to the input image.") |
| | parser.add_argument( |
| | "--max_new_tokens", |
| | type=int, |
| | default=512, |
| | help="Maximum number of new tokens to generate.", |
| | ) |
| | parser.add_argument( |
| | "--output_image_path", |
| | type=str, |
| | default="result_image.jpg", |
| | help="Path to save the output image with visualizations.", |
| | ) |
| | parser.add_argument( |
| | "--temperature", |
| | type=float, |
| | default=0, |
| | help="Temperature for sampling. Set to 0 for greedy decoding.", |
| | ) |
| | parser.add_argument( |
| | "--seed", type=int, default=None, help="Random seed for reproducibility." |
| | ) |
| | args = parser.parse_args() |
| | run( |
| | args.image_path, |
| | "<CAPTION>", |
| | args.max_new_tokens, |
| | args.output_image_path, |
| | args.temperature, |
| | args.seed, |
| | ) |
| |
|