|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import ProcessorMixin |
|
|
from typing import List, Union, Dict, Any, Optional |
|
|
import torch |
|
|
|
|
|
|
|
|
class XVLAProcessor(ProcessorMixin): |
|
|
""" |
|
|
XVLAProcessor: Unified multimodal processor for XVLA models. |
|
|
|
|
|
Handles: |
|
|
- Multi-view image inputs (e.g., from multiple cameras). |
|
|
- Batch processing for multiple samples. |
|
|
- Joint tokenization and image tensor preparation. |
|
|
|
|
|
This processor combines an image processor and a tokenizer under a single interface |
|
|
so that users can call it directly like: |
|
|
|
|
|
>>> processor = XVLAProcessor.from_pretrained("path/to/xvla") |
|
|
>>> inputs = processor(images=batch_images, language_instruction=batch_texts) |
|
|
|
|
|
It is fully compatible with the Hugging Face AutoProcessor API. |
|
|
|
|
|
Attributes |
|
|
---------- |
|
|
num_views : int, default=3 |
|
|
Expected number of image views per sample. Missing views will be padded with zeros. |
|
|
language_max_length : int, default=50 |
|
|
Maximum token length for text encoding. |
|
|
attributes : list |
|
|
Required by ProcessorMixin to know which submodules are stored and reloaded. |
|
|
image_processor_class : str |
|
|
The name of the associated image processor class. |
|
|
tokenizer_class : tuple(str) |
|
|
The names of compatible tokenizer classes. |
|
|
""" |
|
|
|
|
|
num_views: int = 3 |
|
|
language_max_length: int = 50 |
|
|
|
|
|
|
|
|
attributes = ["image_processor", "tokenizer"] |
|
|
image_processor_class = "AutoImageProcessor" |
|
|
tokenizer_class = ("BartTokenizer", "BartTokenizerFast") |
|
|
|
|
|
def __init__(self, image_processor=None, tokenizer=None): |
|
|
""" |
|
|
Initialize XVLAProcessor. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
image_processor : PreTrainedImageProcessor, optional |
|
|
The image processor used to normalize/resize images. |
|
|
tokenizer : PreTrainedTokenizer, optional |
|
|
The tokenizer used for text tokenization. |
|
|
""" |
|
|
|
|
|
super().__init__(image_processor, tokenizer) |
|
|
|
|
|
|
|
|
def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Tokenize one or more language instructions. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
language_instruction : str or List[str] |
|
|
A single instruction or a batch of instructions. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Dict[str, torch.Tensor] |
|
|
{ |
|
|
"input_ids": tensor of shape [B, L] |
|
|
} |
|
|
""" |
|
|
if isinstance(language_instruction, str): |
|
|
language_instruction = [language_instruction] |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
language_instruction, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
max_length=self.language_max_length, |
|
|
truncation=True, |
|
|
) |
|
|
return {"input_ids": inputs["input_ids"]} |
|
|
|
|
|
|
|
|
def encode_image( |
|
|
self, |
|
|
images: Union[List, List[List]], |
|
|
**kwargs |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Preprocess one or more sets of multi-view images. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
images : List or List[List] |
|
|
Single sample: [img1, img2, ...] |
|
|
Batch: [[img1a, img1b], [img2a, img2b, img2c], ...] |
|
|
Each image may be a PIL.Image, NumPy array, or torch.Tensor. |
|
|
|
|
|
kwargs : dict |
|
|
Extra arguments passed to the underlying image processor |
|
|
(e.g., `do_resize=False`, `size=(224,224)`). |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Dict[str, torch.Tensor] |
|
|
{ |
|
|
"image_input": tensor [B, num_views, C, H, W], |
|
|
"image_mask": tensor [B, num_views] |
|
|
} |
|
|
""" |
|
|
|
|
|
if not isinstance(images[0], (list, tuple)): |
|
|
images = [images] |
|
|
|
|
|
batch_imgs, batch_masks = [], [] |
|
|
|
|
|
for sample_imgs in images: |
|
|
processed = self.image_processor(sample_imgs, return_tensors="pt", **kwargs)["pixel_values"] |
|
|
V_exist = processed.size(0) |
|
|
|
|
|
|
|
|
if V_exist < self.num_views: |
|
|
processed = torch.cat( |
|
|
[processed, |
|
|
processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])], |
|
|
dim=0, |
|
|
) |
|
|
|
|
|
|
|
|
image_mask = torch.zeros(self.num_views, dtype=torch.bool, device=processed.device) |
|
|
image_mask[:V_exist] = True |
|
|
|
|
|
batch_imgs.append(processed) |
|
|
batch_masks.append(image_mask) |
|
|
|
|
|
image_input = torch.stack(batch_imgs, dim=0) |
|
|
image_mask = torch.stack(batch_masks, dim=0) |
|
|
|
|
|
return {"image_input": image_input, "image_mask": image_mask} |
|
|
|
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: Optional[Union[List, List[List]]] = None, |
|
|
language_instruction: Optional[Union[str, List[str]]] = None, |
|
|
**kwargs |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Combine image and text encoding into a unified multimodal input. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
images : List or List[List], optional |
|
|
Single-sample or batched multi-view images. |
|
|
language_instruction : str or List[str], optional |
|
|
Corresponding text instructions. |
|
|
kwargs : dict |
|
|
Extra args passed to image processor. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Dict[str, torch.Tensor] |
|
|
{ |
|
|
"input_ids": [B, L], optional, |
|
|
"image_input": [B, num_views, C, H, W], optional, |
|
|
"image_mask": [B, num_views], optional |
|
|
} |
|
|
""" |
|
|
outputs: Dict[str, Any] = {} |
|
|
|
|
|
|
|
|
if language_instruction is not None: |
|
|
outputs.update(self.encode_language(language_instruction)) |
|
|
|
|
|
|
|
|
if images is not None: |
|
|
outputs.update(self.encode_image(images, **kwargs)) |
|
|
|
|
|
|
|
|
if "input_ids" in outputs and "image_input" in outputs: |
|
|
assert outputs["input_ids"].size(0) == outputs["image_input"].size(0), ( |
|
|
f"Batch mismatch: text batch {outputs['input_ids'].size(0)} " |
|
|
f"!= image batch {outputs['image_input'].size(0)}" |
|
|
) |
|
|
return outputs |
|
|
|