Spaces:
Runtime error
Runtime error
| """ | |
| Real-time video classification using VJEPA2 model with streaming capabilities. | |
| This module implements a real-time video classification system that: | |
| 1. Captures video frames from a webcam | |
| 2. Processes batches of frames using the V-JEPA 2 model | |
| 3. Displays predictions overlaid on the video stream | |
| 4. Maintains a history of recent predictions | |
| The system uses FastRTC for video streaming and Gradio for the web interface. | |
| """ | |
| import cv2 | |
| import time | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from fastrtc import Stream, VideoStreamHandler, AdditionalOutputs | |
| from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor | |
| # Model configuration | |
| CHECKPOINT = "qubvel-hf/vjepa2-vitl-fpc16-256-ssv2" # Pre-trained VJEPA2 model checkpoint | |
| TORCH_DTYPE = torch.float16 # Use half precision for faster inference | |
| TORCH_DEVICE = "cuda" # Use GPU for inference | |
| UPDATE_EVERY_N_FRAMES = 64 # How often to update predictions (in frames) | |
| def add_text_on_image(image, text): | |
| """ | |
| Overlays text on an image with a black background bar at the top. | |
| Args: | |
| image (np.ndarray): Input image to add text to | |
| text (str): Text to overlay on the image | |
| Returns: | |
| np.ndarray: Image with text overlaid | |
| """ | |
| # Add a black background to the text | |
| image[:70] = 0 | |
| line_spacing = 10 | |
| top_margin = 20 | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.5 | |
| thickness = 1 | |
| color = (255, 255, 255) # White | |
| words = text.split() | |
| lines = [] | |
| current_line = "" | |
| img_width = image.shape[1] | |
| # Build lines that fit within the image width | |
| for word in words: | |
| test_line = current_line + (" " if current_line else "") + word | |
| (test_width, _), _ = cv2.getTextSize(test_line, font, font_scale, thickness) | |
| if test_width > img_width - 20: # 20 px margin | |
| lines.append(current_line) | |
| current_line = word | |
| else: | |
| current_line = test_line | |
| if current_line: | |
| lines.append(current_line) | |
| # Draw each line, centered | |
| y = top_margin | |
| for line in lines: | |
| (line_width, line_height), _ = cv2.getTextSize(line, font, font_scale, thickness) | |
| x = (img_width - line_width) // 2 | |
| cv2.putText(image, line, (x, y + line_height), font, font_scale, color, thickness, cv2.LINE_AA) | |
| y += line_height + line_spacing | |
| return image | |
| class RunningFramesCache: | |
| """ | |
| Maintains a rolling buffer of video frames for model input. | |
| This class manages a fixed-size queue of frames, keeping only the most recent | |
| frames needed for model inference. It supports subsampling frames to reduce | |
| memory usage and processing requirements. | |
| Args: | |
| save_every_k_frame (int): Only save every k-th frame (for subsampling) | |
| max_frames (int): Maximum number of frames to keep in cache | |
| """ | |
| def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16): | |
| self.save_every_k_frame = save_every_k_frame | |
| self.max_frames = max_frames | |
| self._frames = [] | |
| def add_frame(self, frame: np.ndarray): | |
| self._frames.append(frame) | |
| if len(self._frames) > self.max_frames: | |
| self._frames.pop(0) | |
| def get_last_n_frames(self, n: int) -> list[np.ndarray]: | |
| return self._frames[-n:] | |
| def __len__(self) -> int: | |
| return len(self._frames) | |
| class RunningResult: | |
| """ | |
| Maintains a history of recent model predictions with timestamps. | |
| This class keeps track of the most recent predictions made by the model, | |
| including timestamps for each prediction. It provides formatted output | |
| for display in the UI. | |
| Args: | |
| max_predictions (int): Maximum number of predictions to keep in history | |
| """ | |
| def __init__(self, max_predictions: int = 4): | |
| self.predictions = [] | |
| self.max_predictions = max_predictions | |
| def add_prediction(self, prediction: str): | |
| # add time in a format of HH:MM:SS | |
| current_time_formatted = time.strftime("%H:%M:%S", time.gmtime(time.time())) | |
| self.predictions.append((current_time_formatted, prediction)) | |
| if len(self.predictions) > self.max_predictions: | |
| self.predictions.pop(0) | |
| def get_formatted_predictions(self) -> str: | |
| if not self.predictions: | |
| return "Starting..." | |
| current, *past = self.predictions[::-1] | |
| text = f">>> {current[1]}\n\n" + "\n".join( | |
| [f"[{time_formatted}] {prediction}" for time_formatted, prediction in past] | |
| ) | |
| return text | |
| def get_last_prediction(self) -> str: | |
| return self.predictions[-1][1] if self.predictions else "Starting..." | |
| class FrameProcessingCallback: | |
| """ | |
| Handles real-time video frame processing and model inference. | |
| This class is responsible for: | |
| 1. Loading and managing the V-JEPA 2 model | |
| 2. Processing incoming video frames | |
| 3. Running model inference at regular intervals | |
| 4. Managing frame caching and prediction history | |
| 5. Formatting output for display | |
| The callback is called for each frame from the video stream and handles | |
| the coordination between frame capture, model inference, and result display. | |
| """ | |
| def __init__(self): | |
| # Loading model and processor | |
| self.model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16) | |
| self.model = self.model.to(TORCH_DEVICE) | |
| self.video_processor = AutoVideoProcessor.from_pretrained(CHECKPOINT) | |
| # Init frames cache | |
| self.frames_per_clip = self.model.config.frames_per_clip | |
| self.running_frames_cache = RunningFramesCache( | |
| save_every_k_frame=128 / self.frames_per_clip, | |
| max_frames=self.frames_per_clip, | |
| ) | |
| self.running_result = RunningResult(max_predictions=4) | |
| self.frame_count = 0 | |
| def __call__(self, image: np.ndarray): | |
| image = np.flip(image, axis=1).copy() | |
| self.running_frames_cache.add_frame(image) | |
| self.frame_count += 1 | |
| if ( | |
| self.frame_count % UPDATE_EVERY_N_FRAMES == 0 | |
| and len(self.running_frames_cache) >= self.frames_per_clip | |
| ): | |
| # Prepare frames for model | |
| frames = self.running_frames_cache.get_last_n_frames(self.frames_per_clip) | |
| frames = np.array(frames) | |
| inputs = self.video_processor(frames, device=TORCH_DEVICE, return_tensors="pt") | |
| inputs = inputs.to(dtype=TORCH_DTYPE) | |
| # Run model | |
| with torch.no_grad(): | |
| logits = self.model(**inputs).logits | |
| # Get top prediction | |
| top_index = logits.argmax(dim=-1).item() | |
| class_name = self.model.config.id2label[top_index] | |
| self.running_result.add_prediction(class_name) | |
| formatted_predictions = self.running_result.get_formatted_predictions() | |
| last_prediction = self.running_result.get_last_prediction() | |
| image = add_text_on_image(image, last_prediction) | |
| return image, AdditionalOutputs(formatted_predictions) | |
| # Initialize the video stream with processing callback | |
| stream = Stream( | |
| handler=VideoStreamHandler(FrameProcessingCallback(), skip_frames=True), | |
| modality="video", | |
| mode="send-receive", | |
| additional_outputs=[gr.TextArea(label="Actions", value="", lines=5)], | |
| additional_outputs_handler=lambda _, output: output, | |
| ) | |
| if __name__ == "__main__": | |
| stream.ui.launch() | |