yonigozlan HF Staff commited on
Commit
c266ce5
·
verified ·
1 Parent(s): 3bba6db

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +439 -131
README.md CHANGED
@@ -1,199 +1,507 @@
1
  ---
 
 
2
  library_name: transformers
3
- tags: []
4
  ---
 
5
 
6
- # Model Card for Model ID
 
 
7
 
8
- <!-- Provide a quick summary of what the model is/does. -->
9
 
 
10
 
 
11
 
12
- ## Model Details
13
 
14
- ### Model Description
 
15
 
16
- <!-- Provide a longer summary of what this model is. -->
 
 
17
 
18
- This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
 
 
19
 
20
- - **Developed by:** [More Information Needed]
21
- - **Funded by [optional]:** [More Information Needed]
22
- - **Shared by [optional]:** [More Information Needed]
23
- - **Model type:** [More Information Needed]
24
- - **Language(s) (NLP):** [More Information Needed]
25
- - **License:** [More Information Needed]
26
- - **Finetuned from model [optional]:** [More Information Needed]
27
 
28
- ### Model Sources [optional]
29
 
30
- <!-- Provide the basic links for the model. -->
31
 
32
- - **Repository:** [More Information Needed]
33
- - **Paper [optional]:** [More Information Needed]
34
- - **Demo [optional]:** [More Information Needed]
 
 
35
 
36
- ## Uses
37
 
38
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
 
39
 
40
- ### Direct Use
 
41
 
42
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
 
43
 
44
- [More Information Needed]
45
 
46
- ### Downstream Use [optional]
 
47
 
48
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
 
50
- [More Information Needed]
 
 
 
51
 
52
- ### Out-of-Scope Use
53
 
54
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
 
56
- [More Information Needed]
 
 
 
57
 
58
- ## Bias, Risks, and Limitations
59
 
60
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
 
61
 
62
- [More Information Needed]
 
63
 
64
- ### Recommendations
65
 
66
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
 
68
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
 
 
69
 
70
- ## How to Get Started with the Model
71
 
72
- Use the code below to get started with the model.
 
73
 
74
- [More Information Needed]
 
75
 
76
- ## Training Details
77
 
78
- ### Training Data
79
 
80
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
 
 
 
81
 
82
- [More Information Needed]
83
 
84
- ### Training Procedure
 
85
 
86
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
 
 
 
 
87
 
88
- #### Preprocessing [optional]
89
 
90
- [More Information Needed]
91
 
 
92
 
93
- #### Training Hyperparameters
 
 
 
 
94
 
95
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
 
97
- #### Speeds, Sizes, Times [optional]
 
98
 
99
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
 
 
 
 
 
100
 
101
- [More Information Needed]
 
 
102
 
103
- ## Evaluation
104
 
105
- <!-- This section describes the evaluation protocols and provides the results. -->
 
106
 
107
- ### Testing Data, Factors & Metrics
 
 
 
 
108
 
109
- #### Testing Data
110
 
111
- <!-- This should link to a Dataset Card if possible. -->
112
 
113
- [More Information Needed]
 
 
 
 
 
 
 
 
 
114
 
115
- #### Factors
116
 
117
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
 
118
 
119
- [More Information Needed]
 
120
 
121
- #### Metrics
122
 
123
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
 
125
- [More Information Needed]
 
 
 
 
126
 
127
- ### Results
 
 
 
 
 
 
 
 
128
 
129
- [More Information Needed]
130
 
131
- #### Summary
 
132
 
 
 
133
 
 
134
 
135
- ## Model Examination [optional]
136
 
137
- <!-- Relevant interpretability work for the model goes here -->
138
-
139
- [More Information Needed]
140
-
141
- ## Environmental Impact
142
-
143
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
-
145
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
-
147
- - **Hardware Type:** [More Information Needed]
148
- - **Hours used:** [More Information Needed]
149
- - **Cloud Provider:** [More Information Needed]
150
- - **Compute Region:** [More Information Needed]
151
- - **Carbon Emitted:** [More Information Needed]
152
-
153
- ## Technical Specifications [optional]
154
-
155
- ### Model Architecture and Objective
156
-
157
- [More Information Needed]
158
-
159
- ### Compute Infrastructure
160
-
161
- [More Information Needed]
162
-
163
- #### Hardware
164
-
165
- [More Information Needed]
166
-
167
- #### Software
168
-
169
- [More Information Needed]
170
-
171
- ## Citation [optional]
172
-
173
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
-
175
- **BibTeX:**
176
-
177
- [More Information Needed]
178
-
179
- **APA:**
180
-
181
- [More Information Needed]
182
-
183
- ## Glossary [optional]
184
-
185
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
-
187
- [More Information Needed]
188
-
189
- ## More Information [optional]
190
-
191
- [More Information Needed]
192
-
193
- ## Model Card Authors [optional]
194
-
195
- [More Information Needed]
196
-
197
- ## Model Card Contact
198
-
199
- [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: apache-2.0
3
+ pipeline_tag: mask-generation
4
  library_name: transformers
 
5
  ---
6
+ # Model Details
7
 
8
+ [\[📃 Tech Report\]](https://arxiv.org/abs/2501.07256)
9
+ [\[📂 Github\]](https://github.com/facebookresearch/EdgeTAM)
10
+ [\[🤗 Demo\]](https://huggingface.co/spaces/yonigozlan/EdgeTAM-hf)
11
 
12
+ EdgeTAM is an on-device executable variant of the SAM 2 for promptable segmentation and tracking in videos. It runs 22× faster than SAM 2 and achieves 16 FPS on iPhone 15 Pro Max without quantization.
13
 
14
+ # How to use with Transformers
15
 
16
+ ### Automatic Mask Generation with Pipeline
17
 
18
+ EdgeTAM can be used for automatic mask generation to segment all objects in an image using the `mask-generation` pipeline:
19
 
20
+ ```python
21
+ >>> from transformers import pipeline
22
 
23
+ >>> generator = pipeline("mask-generation", model="yonigozlan/EdgeTAM-hf", device=0)
24
+ >>> image_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
25
+ >>> outputs = generator(image_url, points_per_batch=64)
26
 
27
+ >>> len(outputs["masks"]) # Number of masks generated
28
+ 39
29
+ ```
30
 
31
+ ### Basic Image Segmentation
 
 
 
 
 
 
32
 
33
+ #### Single Point Click
34
 
35
+ You can segment objects by providing a single point click on the object you want to segment:
36
 
37
+ ```python
38
+ >>> from transformers import Sam2Processor, EdgeTamModel
39
+ >>> import torch
40
+ >>> from PIL import Image
41
+ >>> import requests
42
 
43
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
44
 
45
+ >>> model = EdgeTamModel.from_pretrained("yonigozlan/EdgeTAM-hf").to(device)
46
+ >>> processor = Sam2Processor.from_pretrained("yonigozlan/EdgeTAM-hf")
47
 
48
+ >>> image_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
49
+ >>> raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
50
 
51
+ >>> input_points = [[[[500, 375]]]] # Single point click, 4 dimensions (image_dim, object_dim, point_per_object_dim, coordinates)
52
+ >>> input_labels = [[[1]]] # 1 for positive click, 0 for negative click, 3 dimensions (image_dim, object_dim, point_label)
53
 
54
+ >>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
55
 
56
+ >>> with torch.no_grad():
57
+ ... outputs = model(**inputs)
58
 
59
+ >>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
60
 
61
+ >>> # The model outputs multiple mask predictions ranked by quality score
62
+ >>> print(f"Generated {masks.shape[1]} masks with shape {masks.shape}")
63
+ Generated 3 masks with shape torch.Size(1, 3, 1500, 2250)
64
+ ```
65
 
66
+ #### Multiple Points for Refinement
67
 
68
+ You can provide multiple points to refine the segmentation:
69
 
70
+ ```python
71
+ >>> # Add both positive and negative points to refine the mask
72
+ >>> input_points = [[[[500, 375], [1125, 625]]]] # Multiple points for refinement
73
+ >>> input_labels = [[[1, 1]]] # Both positive clicks
74
 
75
+ >>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
76
 
77
+ >>> with torch.no_grad():
78
+ ... outputs = model(**inputs)
79
 
80
+ >>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
81
+ ```
82
 
83
+ #### Bounding Box Input
84
 
85
+ EdgeTAM also supports bounding box inputs for segmentation:
86
 
87
+ ```python
88
+ >>> # Define bounding box as [x_min, y_min, x_max, y_max]
89
+ >>> input_boxes = [[[75, 275, 1725, 850]]]
90
 
91
+ >>> inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(device)
92
 
93
+ >>> with torch.no_grad():
94
+ ... outputs = model(**inputs)
95
 
96
+ >>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
97
+ ```
98
 
99
+ #### Multiple Objects Segmentation
100
 
101
+ You can segment multiple objects simultaneously:
102
 
103
+ ```python
104
+ >>> # Define points for two different objects
105
+ >>> input_points = [[[[500, 375]], [[650, 750]]]] # Points for two objects in same image
106
+ >>> input_labels = [[[1], [1]]] # Positive clicks for both objects
107
 
108
+ >>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
109
 
110
+ >>> with torch.no_grad():
111
+ ... outputs = model(**inputs, multimask_output=False)
112
 
113
+ >>> # Each object gets its own mask
114
+ >>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
115
+ >>> print(f"Generated masks for {masks.shape[0]} objects")
116
+ Generated masks for 2 objects
117
+ ```
118
 
119
+ ### Batch Inference
120
 
121
+ #### Batched Images
122
 
123
+ Process multiple images simultaneously for improved efficiency:
124
 
125
+ ```python
126
+ >>> from transformers import Sam2Processor, EdgeTamModel
127
+ >>> import torch
128
+ >>> from PIL import Image
129
+ >>> import requests
130
 
131
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
132
 
133
+ >>> model = EdgeTamModel.from_pretrained("yonigozlan/EdgeTAM-hf").to(device)
134
+ >>> processor = Sam2Processor.from_pretrained("yonigozlan/EdgeTAM-hf")
135
 
136
+ >>> # Load multiple images
137
+ >>> image_urls = [
138
+ ... "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg",
139
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
140
+ ... ]
141
+ >>> raw_images = [Image.open(requests.get(url, stream=True).raw).convert("RGB") for url in image_urls]
142
 
143
+ >>> # Single point per image
144
+ >>> input_points = [[[[500, 375]]], [[[770, 200]]]] # One point for each image
145
+ >>> input_labels = [[[1]], [[1]]] # Positive clicks for both images
146
 
147
+ >>> inputs = processor(images=raw_images, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
148
 
149
+ >>> with torch.no_grad():
150
+ ... outputs = model(**inputs, multimask_output=False)
151
 
152
+ >>> # Post-process masks for each image
153
+ >>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])
154
+ >>> print(f"Processed {len(all_masks)} images, each with {all_masks[0].shape[0]} objects")
155
+ Processed 2 images, each with 1 objects
156
+ ```
157
 
158
+ #### Batched Objects per Image
159
 
160
+ Segment multiple objects within each image using batch inference:
161
 
162
+ ```python
163
+ >>> # Multiple objects per image - different numbers of objects per image
164
+ >>> input_points = [
165
+ ... [[[500, 375]], [[650, 750]]], # Truck image: 2 objects
166
+ ... [[[770, 200]]] # Dog image: 1 object
167
+ ... ]
168
+ >>> input_labels = [
169
+ ... [[1], [1]], # Truck image: positive clicks for both objects
170
+ ... [[1]] # Dog image: positive click for the object
171
+ ... ]
172
 
173
+ >>> inputs = processor(images=raw_images, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
174
 
175
+ >>> with torch.no_grad():
176
+ ... outputs = model(**inputs, multimask_output=False)
177
 
178
+ >>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])
179
+ ```
180
 
181
+ #### Batched Images with Batched Objects and Multiple Points
182
 
183
+ Handle complex batch scenarios with multiple points per object:
184
 
185
+ ```python
186
+ >>> # Add groceries image for more complex example
187
+ >>> groceries_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/groceries.jpg"
188
+ >>> groceries_image = Image.open(requests.get(groceries_url, stream=True).raw).convert("RGB")
189
+ >>> raw_images = [raw_images[0], groceries_image] # Use truck and groceries images
190
 
191
+ >>> # Complex batching: multiple images, multiple objects, multiple points per object
192
+ >>> input_points = [
193
+ ... [[[500, 375]], [[650, 750]]], # Truck image: 2 objects with 1 point each
194
+ ... [[[400, 300]], [[630, 300], [550, 300]]] # Groceries image: obj1 has 1 point, obj2 has 2 points
195
+ ... ]
196
+ >>> input_labels = [
197
+ ... [[1], [1]], # Truck image: positive clicks
198
+ ... [[1], [1, 1]] # Groceries image: positive clicks for refinement
199
+ ... ]
200
 
201
+ >>> inputs = processor(images=raw_images, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
202
 
203
+ >>> with torch.no_grad():
204
+ ... outputs = model(**inputs, multimask_output=False)
205
 
206
+ >>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])
207
+ ```
208
 
209
+ #### Batched Bounding Boxes
210
 
211
+ Process multiple images with bounding box inputs:
212
 
213
+ ```python
214
+ >>> # Multiple bounding boxes per image (using truck and groceries images)
215
+ >>> input_boxes = [
216
+ ... [[75, 275, 1725, 850], [425, 600, 700, 875], [1375, 550, 1650, 800], [1240, 675, 1400, 750]], # Truck image: 4 boxes
217
+ ... [[450, 170, 520, 350], [350, 190, 450, 350], [500, 170, 580, 350], [580, 170, 640, 350]] # Groceries image: 4 boxes
218
+ ... ]
219
+
220
+ >>> # Update images for this example
221
+ >>> raw_images = [raw_images[0], groceries_image] # truck and groceries
222
+
223
+ >>> inputs = processor(images=raw_images, input_boxes=input_boxes, return_tensors="pt").to(device)
224
+
225
+ >>> with torch.no_grad():
226
+ ... outputs = model(**inputs, multimask_output=False)
227
+
228
+ >>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])
229
+ >>> print(f"Processed {len(input_boxes)} images with {len(input_boxes[0])} and {len(input_boxes[1])} boxes respectively")
230
+ Processed 2 images with 4 and 4 boxes respectively
231
+ ```
232
+
233
+ ### Using Previous Masks as Input
234
+
235
+ EdgeTAM can use masks from previous predictions as input to refine segmentation:
236
+
237
+ ```python
238
+ >>> # Get initial segmentation
239
+ >>> input_points = [[[[500, 375]]]]
240
+ >>> input_labels = [[[1]]]
241
+ >>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
242
+
243
+ >>> with torch.no_grad():
244
+ ... outputs = model(**inputs)
245
+
246
+ >>> # Use the best mask as input for refinement
247
+ >>> mask_input = outputs.pred_masks[:, :, torch.argmax(outputs.iou_scores.squeeze())]
248
+
249
+ >>> # Add additional points with the mask input
250
+ >>> new_input_points = [[[[500, 375], [450, 300]]]]
251
+ >>> new_input_labels = [[[1, 1]]]
252
+ >>> inputs = processor(
253
+ ... input_points=new_input_points,
254
+ ... input_labels=new_input_labels,
255
+ ... original_sizes=inputs["original_sizes"],
256
+ ... return_tensors="pt",
257
+ ... ).to(device)
258
+
259
+ >>> with torch.no_grad():
260
+ ... refined_outputs = model(
261
+ ... **inputs,
262
+ ... input_masks=mask_input,
263
+ ... image_embeddings=outputs.image_embeddings,
264
+ ... multimask_output=False,
265
+ ... )
266
+ ```
267
+
268
+
269
+ ### Video Segmentation and Tracking
270
+
271
+ EdgeTAM's key strength is its ability to track objects across video frames. Here's how to use it for video segmentation:
272
+
273
+ #### Basic Video Tracking
274
+
275
+ ```python
276
+ >>> from transformers import EdgeTamVideoModel, Sam2VideoProcessor
277
+ >>> import torch
278
+
279
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
280
+ >>> model = EdgeTamVideoModel.from_pretrained("yonigozlan/EdgeTAM-hf").to(device, dtype=torch.bfloat16)
281
+ >>> processor = Sam2VideoProcessor.from_pretrained("yonigozlan/EdgeTAM-hf")
282
+
283
+ >>> # Load video frames (example assumes you have a list of PIL Images)
284
+ >>> # video_frames = [Image.open(f"frame_{i:05d}.jpg") for i in range(num_frames)]
285
+
286
+ >>> # For this example, we'll use the video loading utility
287
+ >>> from transformers.video_utils import load_video
288
+ >>> video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"
289
+ >>> video_frames, _ = load_video(video_url)
290
+
291
+ >>> # Initialize video inference session
292
+ >>> inference_session = processor.init_video_session(
293
+ ... video=video_frames,
294
+ ... inference_device=device,
295
+ ... torch_dtype=torch.bfloat16,
296
+ ... )
297
+
298
+ >>> # Add click on first frame to select object
299
+ >>> ann_frame_idx = 0
300
+ >>> ann_obj_id = 1
301
+ >>> points = [[[[210, 350]]]]
302
+ >>> labels = [[[1]]]
303
+
304
+ >>> processor.add_inputs_to_inference_session(
305
+ ... inference_session=inference_session,
306
+ ... frame_idx=ann_frame_idx,
307
+ ... obj_ids=ann_obj_id,
308
+ ... input_points=points,
309
+ ... input_labels=labels,
310
+ ... )
311
+
312
+ >>> # Segment the object on the first frame
313
+ >>> outputs = model(
314
+ ... inference_session=inference_session,
315
+ ... frame_idx=ann_frame_idx,
316
+ ... )
317
+ >>> video_res_masks = processor.post_process_masks(
318
+ ... [outputs.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False
319
+ ... )[0]
320
+ >>> print(f"Segmentation shape: {video_res_masks.shape}")
321
+ Segmentation shape: torch.Size([1, 1, 480, 854])
322
+
323
+ >>> # Propagate through the entire video
324
+ >>> video_segments = {}
325
+ >>> for edgetam_video_output in model.propagate_in_video_iterator(inference_session):
326
+ ... video_res_masks = processor.post_process_masks(
327
+ ... [edgetam_video_output.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False
328
+ ... )[0]
329
+ ... video_segments[edgetam_video_output.frame_idx] = video_res_masks
330
+
331
+ >>> print(f"Tracked object through {len(video_segments)} frames")
332
+ Tracked object through 180 frames
333
+ ```
334
+
335
+ #### Multi-Object Video Tracking
336
+
337
+ Track multiple objects simultaneously across video frames:
338
+
339
+ ```python
340
+ >>> # Reset for new tracking session
341
+ >>> inference_session.reset_inference_session()
342
+
343
+ >>> # Add multiple objects on the first frame
344
+ >>> ann_frame_idx = 0
345
+ >>> obj_ids = [2, 3]
346
+ >>> input_points = [[[[200, 300]], [[400, 150]]]] # Points for two objects (batched)
347
+ >>> input_labels = [[[1], [1]]]
348
+
349
+ >>> processor.add_inputs_to_inference_session(
350
+ ... inference_session=inference_session,
351
+ ... frame_idx=ann_frame_idx,
352
+ ... obj_ids=obj_ids,
353
+ ... input_points=input_points,
354
+ ... input_labels=input_labels,
355
+ ... )
356
+
357
+ >>> # Get masks for both objects on first frame
358
+ >>> outputs = model(
359
+ ... inference_session=inference_session,
360
+ ... frame_idx=ann_frame_idx,
361
+ ... )
362
+
363
+ >>> # Propagate both objects through video
364
+ >>> video_segments = {}
365
+ >>> for edgetam_video_output in model.propagate_in_video_iterator(inference_session):
366
+ ... video_res_masks = processor.post_process_masks(
367
+ ... [edgetam_video_output.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False
368
+ ... )[0]
369
+ ... video_segments[edgetam_video_output.frame_idx] = {
370
+ ... obj_id: video_res_masks[i]
371
+ ... for i, obj_id in enumerate(inference_session.obj_ids)
372
+ ... }
373
+
374
+ >>> print(f"Tracked {len(inference_session.obj_ids)} objects through {len(video_segments)} frames")
375
+ Tracked 2 objects through 180 frames
376
+ ```
377
+
378
+ #### Refining Video Segmentation
379
+
380
+ You can add additional clicks on any frame to refine the tracking:
381
+
382
+ ```python
383
+ >>> # Add refinement click on a later frame
384
+ >>> refine_frame_idx = 50
385
+ >>> ann_obj_id = 2 # Refining first object
386
+ >>> points = [[[[220, 280]]]] # Additional point
387
+ >>> labels = [[[1]]] # Positive click
388
+
389
+ >>> processor.add_inputs_to_inference_session(
390
+ ... inference_session=inference_session,
391
+ ... frame_idx=refine_frame_idx,
392
+ ... obj_ids=ann_obj_id,
393
+ ... input_points=points,
394
+ ... input_labels=labels,
395
+ ... )
396
+
397
+ >>> # Re-propagate with the additional information
398
+ >>> video_segments = {}
399
+ >>> for edgetam_video_output in model.propagate_in_video_iterator(inference_session):
400
+ ... video_res_masks = processor.post_process_masks(
401
+ ... [edgetam_video_output.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False
402
+ ... )[0]
403
+ ... video_segments[edgetam_video_output.frame_idx] = video_res_masks
404
+ ```
405
+
406
+ ### Streaming Video Inference
407
+
408
+ For real-time applications, EdgeTAM supports processing video frames as they arrive:
409
+
410
+ ```python
411
+ >>> # Initialize session for streaming
412
+ >>> inference_session = processor.init_video_session(
413
+ ... inference_device=device,
414
+ ... torch_dtype=torch.bfloat16,
415
+ ... )
416
+
417
+ >>> # Process frames one by one
418
+ >>> for frame_idx, frame in enumerate(video_frames[:10]): # Process first 10 frames
419
+ ... inputs = processor(images=frame, device=device, return_tensors="pt")
420
+ ...
421
+ ... if frame_idx == 0:
422
+ ... # Add point input on first frame
423
+ ... processor.add_inputs_to_inference_session(
424
+ ... inference_session=inference_session,
425
+ ... frame_idx=0,
426
+ ... obj_ids=1,
427
+ ... input_points=[[[[210, 350], [250, 220]]]],
428
+ ... input_labels=[[[1, 1]]],
429
+ ... original_size=inputs.original_sizes[0], # need to be provided when using streaming video inference
430
+ ... )
431
+ ...
432
+ ... # Process current frame
433
+ ... edgetam_video_output = model(inference_session=inference_session, frame=inputs.pixel_values[0])
434
+ ...
435
+ ... video_res_masks = processor.post_process_masks(
436
+ ... [edgetam_video_output.pred_masks], original_sizes=inputs.original_sizes, binarize=False
437
+ ... )[0]
438
+ ... print(f"Frame {frame_idx}: mask shape {video_res_masks.shape}")
439
+ ```
440
+
441
+ #### Video Batch Processing for Multiple Objects
442
+
443
+ Track multiple objects simultaneously in video by adding them all at once:
444
+
445
+ ```python
446
+ >>> # Initialize video session
447
+ >>> inference_session = processor.init_video_session(
448
+ ... video=video_frames,
449
+ ... inference_device=device,
450
+ ... torch_dtype=torch.bfloat16,
451
+ ... )
452
+
453
+ >>> # Add multiple objects on the first frame using batch processing
454
+ >>> ann_frame_idx = 0
455
+ >>> obj_ids = [2, 3] # Track two different objects
456
+ >>> input_points = [
457
+ ... [[[200, 300], [230, 250], [275, 175]], [[400, 150]]]
458
+ ... ] # Object 2: 3 points (2 positive, 1 negative); Object 3: 1 point
459
+ >>> input_labels = [
460
+ ... [[1, 1, 0], [1]]
461
+ ... ] # Object 2: positive, positive, negative; Object 3: positive
462
+
463
+ >>> processor.add_inputs_to_inference_session(
464
+ ... inference_session=inference_session,
465
+ ... frame_idx=ann_frame_idx,
466
+ ... obj_ids=obj_ids,
467
+ ... input_points=input_points,
468
+ ... input_labels=input_labels,
469
+ ... )
470
+
471
+ >>> # Get masks for all objects on the first frame
472
+ >>> outputs = model(
473
+ ... inference_session=inference_session,
474
+ ... frame_idx=ann_frame_idx,
475
+ ... )
476
+ >>> video_res_masks = processor.post_process_masks(
477
+ ... [outputs.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False
478
+ ... )[0]
479
+ >>> print(f"Generated masks for {video_res_masks.shape[0]} objects")
480
+ Generated masks for 2 objects
481
+
482
+ >>> # Propagate all objects through the video
483
+ >>> video_segments = {}
484
+ >>> for edgetam_video_output in model.propagate_in_video_iterator(inference_session):
485
+ ... video_res_masks = processor.post_process_masks(
486
+ ... [edgetam_video_output.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False
487
+ ... )[0]
488
+ ... video_segments[edgetam_video_output.frame_idx] = {
489
+ ... obj_id: video_res_masks[i]
490
+ ... for i, obj_id in enumerate(inference_session.obj_ids)
491
+ ... }
492
+
493
+ >>> print(f"Tracked {len(inference_session.obj_ids)} objects through {len(video_segments)} frames")
494
+ Tracked 2 objects through 180 frames
495
+ ```
496
+
497
+ # Citation
498
+ If you find our code useful for your research, please consider citing:
499
+
500
+ ```
501
+ @article{zhou2025edgetam,
502
+ title={EdgeTAM: On-Device Track Anything Model},
503
+ author={Zhou, Chong and Zhu, Chenchen and Xiong, Yunyang and Suri, Saksham and Xiao, Fanyi and Wu, Lemeng and Krishnamoorthi, Raghuraman and Dai, Bo and Loy, Chen Change and Chandra, Vikas and Soran, Bilge},
504
+ journal={arXiv preprint arXiv:2501.07256},
505
+ year={2025}
506
+ }
507
+ ```