Spaces:
Runtime error
Runtime error
Add comments and docstrings
Browse files
app.py
CHANGED
|
@@ -1,3 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import cv2
|
| 2 |
import time
|
| 3 |
import torch
|
|
@@ -8,13 +19,24 @@ from fastrtc import Stream, VideoStreamHandler, AdditionalOutputs
|
|
| 8 |
from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor
|
| 9 |
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
def add_text_on_image(image, text):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# Add a black background to the text
|
| 19 |
image[:70] = 0
|
| 20 |
|
|
@@ -56,6 +78,17 @@ def add_text_on_image(image, text):
|
|
| 56 |
|
| 57 |
|
| 58 |
class RunningFramesCache:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16):
|
| 60 |
self.save_every_k_frame = save_every_k_frame
|
| 61 |
self.max_frames = max_frames
|
|
@@ -74,6 +107,16 @@ class RunningFramesCache:
|
|
| 74 |
|
| 75 |
|
| 76 |
class RunningResult:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
def __init__(self, max_predictions: int = 4):
|
| 78 |
self.predictions = []
|
| 79 |
self.max_predictions = max_predictions
|
|
@@ -100,6 +143,19 @@ class RunningResult:
|
|
| 100 |
|
| 101 |
|
| 102 |
class FrameProcessingCallback:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
def __init__(self):
|
| 104 |
# Loading model and processor
|
| 105 |
self.model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16)
|
|
@@ -146,6 +202,7 @@ class FrameProcessingCallback:
|
|
| 146 |
return image, AdditionalOutputs(formatted_predictions)
|
| 147 |
|
| 148 |
|
|
|
|
| 149 |
stream = Stream(
|
| 150 |
handler=VideoStreamHandler(FrameProcessingCallback(), skip_frames=True),
|
| 151 |
modality="video",
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Real-time video classification using VJEPA2 model with streaming capabilities.
|
| 3 |
+
|
| 4 |
+
This module implements a real-time video classification system that:
|
| 5 |
+
1. Captures video frames from a webcam
|
| 6 |
+
2. Processes batches of frames using the V-JEPA 2 model
|
| 7 |
+
3. Displays predictions overlaid on the video stream
|
| 8 |
+
4. Maintains a history of recent predictions
|
| 9 |
+
|
| 10 |
+
The system uses FastRTC for video streaming and Gradio for the web interface.
|
| 11 |
+
"""
|
| 12 |
import cv2
|
| 13 |
import time
|
| 14 |
import torch
|
|
|
|
| 19 |
from transformers import VJEPA2ForVideoClassification, AutoVideoProcessor
|
| 20 |
|
| 21 |
|
| 22 |
+
# Model configuration
|
| 23 |
+
CHECKPOINT = "qubvel-hf/vjepa2-vitl-fpc16-256-ssv2" # Pre-trained VJEPA2 model checkpoint
|
| 24 |
+
TORCH_DTYPE = torch.float16 # Use half precision for faster inference
|
| 25 |
+
TORCH_DEVICE = "cuda" # Use GPU for inference
|
| 26 |
+
UPDATE_EVERY_N_FRAMES = 64 # How often to update predictions (in frames)
|
| 27 |
|
| 28 |
|
| 29 |
def add_text_on_image(image, text):
|
| 30 |
+
"""
|
| 31 |
+
Overlays text on an image with a black background bar at the top.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
image (np.ndarray): Input image to add text to
|
| 35 |
+
text (str): Text to overlay on the image
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
np.ndarray: Image with text overlaid
|
| 39 |
+
"""
|
| 40 |
# Add a black background to the text
|
| 41 |
image[:70] = 0
|
| 42 |
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
class RunningFramesCache:
|
| 81 |
+
"""
|
| 82 |
+
Maintains a rolling buffer of video frames for model input.
|
| 83 |
+
|
| 84 |
+
This class manages a fixed-size queue of frames, keeping only the most recent
|
| 85 |
+
frames needed for model inference. It supports subsampling frames to reduce
|
| 86 |
+
memory usage and processing requirements.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
save_every_k_frame (int): Only save every k-th frame (for subsampling)
|
| 90 |
+
max_frames (int): Maximum number of frames to keep in cache
|
| 91 |
+
"""
|
| 92 |
def __init__(self, save_every_k_frame: int = 1, max_frames: int = 16):
|
| 93 |
self.save_every_k_frame = save_every_k_frame
|
| 94 |
self.max_frames = max_frames
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
class RunningResult:
|
| 110 |
+
"""
|
| 111 |
+
Maintains a history of recent model predictions with timestamps.
|
| 112 |
+
|
| 113 |
+
This class keeps track of the most recent predictions made by the model,
|
| 114 |
+
including timestamps for each prediction. It provides formatted output
|
| 115 |
+
for display in the UI.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
max_predictions (int): Maximum number of predictions to keep in history
|
| 119 |
+
"""
|
| 120 |
def __init__(self, max_predictions: int = 4):
|
| 121 |
self.predictions = []
|
| 122 |
self.max_predictions = max_predictions
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
class FrameProcessingCallback:
|
| 146 |
+
"""
|
| 147 |
+
Handles real-time video frame processing and model inference.
|
| 148 |
+
|
| 149 |
+
This class is responsible for:
|
| 150 |
+
1. Loading and managing the V-JEPA 2 model
|
| 151 |
+
2. Processing incoming video frames
|
| 152 |
+
3. Running model inference at regular intervals
|
| 153 |
+
4. Managing frame caching and prediction history
|
| 154 |
+
5. Formatting output for display
|
| 155 |
+
|
| 156 |
+
The callback is called for each frame from the video stream and handles
|
| 157 |
+
the coordination between frame capture, model inference, and result display.
|
| 158 |
+
"""
|
| 159 |
def __init__(self):
|
| 160 |
# Loading model and processor
|
| 161 |
self.model = VJEPA2ForVideoClassification.from_pretrained(CHECKPOINT, torch_dtype=torch.bfloat16)
|
|
|
|
| 202 |
return image, AdditionalOutputs(formatted_predictions)
|
| 203 |
|
| 204 |
|
| 205 |
+
# Initialize the video stream with processing callback
|
| 206 |
stream = Stream(
|
| 207 |
handler=VideoStreamHandler(FrameProcessingCallback(), skip_frames=True),
|
| 208 |
modality="video",
|