Spaces:
Runtime error
Runtime error
| from transformers import Owlv2Processor, Owlv2ForObjectDetection | |
| from typing import List | |
| import os | |
| import numpy as np | |
| import supervision as sv | |
| import uuid | |
| import torch | |
| from tqdm import tqdm | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import spaces | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") | |
| model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device) | |
| BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator() | |
| MASK_ANNOTATOR = sv.MaskAnnotator() | |
| LABEL_ANNOTATOR = sv.LabelAnnotator() | |
| def calculate_end_frame_index(source_video_path): | |
| video_info = sv.VideoInfo.from_video_path(source_video_path) | |
| return min( | |
| video_info.total_frames, | |
| video_info.fps * 2 | |
| ) | |
| def annotate_image( | |
| input_image, | |
| detections, | |
| labels | |
| ) -> np.ndarray: | |
| output_image = MASK_ANNOTATOR.annotate(input_image, detections) | |
| output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections) | |
| output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels) | |
| return output_image | |
| def process_video( | |
| input_video, | |
| labels, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| labels = labels.split(",") | |
| video_info = sv.VideoInfo.from_video_path(input_video) | |
| total = calculate_end_frame_index(input_video) | |
| frame_generator = sv.get_video_frames_generator( | |
| source_path=input_video, | |
| end=total | |
| ) | |
| result_file_name = f"{uuid.uuid4()}.mp4" | |
| result_file_path = os.path.join("./", result_file_name) | |
| with sv.VideoSink(result_file_path, video_info=video_info) as sink: | |
| for _ in tqdm(range(total), desc="Processing video.."): | |
| frame = next(frame_generator) | |
| # list of dict of {"box": box, "mask":mask, "score":score, "label":label} | |
| results = query(frame, labels) | |
| print("results", results) | |
| detections = sv.Detections.from_transformers(results[0]) | |
| final_labels = [] | |
| for id in results[0]["labels"]: | |
| final_labels.append(labels[id]) | |
| frame = annotate_image( | |
| input_image=frame, | |
| detections=detections, | |
| labels=final_labels, | |
| ) | |
| sink.write_frame(frame) | |
| return result_file_path | |
| def query(image, texts): | |
| inputs = processor(text=texts, images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| target_sizes = torch.Tensor([image.shape[:-1]]) | |
| results = processor.post_process_object_detection(outputs=outputs, threshold=0.3, target_sizes=target_sizes) | |
| return results | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Zero-shot Object Tracking with OWLv2 🦉") | |
| gr.Markdown("This is a demo for zero-shot object tracking using [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) model by Google.") | |
| gr.Markdown("Simply upload a video and enter the candidate labels, or try the example below. 👇") | |
| with gr.Tab(label="Video"): | |
| with gr.Row(): | |
| input_video = gr.Video( | |
| label='Input Video' | |
| ) | |
| output_video = gr.Video( | |
| label='Output Video' | |
| ) | |
| with gr.Row(): | |
| candidate_labels = gr.Textbox( | |
| label='Labels', | |
| placeholder='Labels separated by a comma', | |
| ) | |
| submit = gr.Button() | |
| gr.Examples( | |
| fn=process_video, | |
| examples=[["./cats.mp4", "dog,cat"]], | |
| inputs=[ | |
| input_video, | |
| candidate_labels, | |
| ], | |
| outputs=output_video | |
| ) | |
| submit.click( | |
| fn=process_video, | |
| inputs=[input_video, candidate_labels], | |
| outputs=output_video | |
| ) | |
| demo.launch(debug=False, show_error=True) | |