X-VLA-Google-Robot / processing_xvla.py
2toINF's picture
Initial upload for X-VLA-Google-Robot
cb94537 verified
# ------------------------------------------------------------------------------
# Copyright 2025 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from transformers import ProcessorMixin
from typing import List, Union, Dict, Any, Optional
import torch
class XVLAProcessor(ProcessorMixin):
"""
XVLAProcessor: Unified multimodal processor for XVLA models.
Handles:
- Multi-view image inputs (e.g., from multiple cameras).
- Batch processing for multiple samples.
- Joint tokenization and image tensor preparation.
This processor combines an image processor and a tokenizer under a single interface
so that users can call it directly like:
>>> processor = XVLAProcessor.from_pretrained("path/to/xvla")
>>> inputs = processor(images=batch_images, language_instruction=batch_texts)
It is fully compatible with the Hugging Face AutoProcessor API.
Attributes
----------
num_views : int, default=3
Expected number of image views per sample. Missing views will be padded with zeros.
language_max_length : int, default=50
Maximum token length for text encoding.
attributes : list
Required by ProcessorMixin to know which submodules are stored and reloaded.
image_processor_class : str
The name of the associated image processor class.
tokenizer_class : tuple(str)
The names of compatible tokenizer classes.
"""
num_views: int = 3
language_max_length: int = 50
# Hugging Face ProcessorMixin-required metadata
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
def __init__(self, image_processor=None, tokenizer=None):
"""
Initialize XVLAProcessor.
Parameters
----------
image_processor : PreTrainedImageProcessor, optional
The image processor used to normalize/resize images.
tokenizer : PreTrainedTokenizer, optional
The tokenizer used for text tokenization.
"""
# ProcessorMixin automatically saves these under self.image_processor / self.tokenizer
super().__init__(image_processor, tokenizer)
# ================== LANGUAGE ENCODING ==================
def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
"""
Tokenize one or more language instructions.
Parameters
----------
language_instruction : str or List[str]
A single instruction or a batch of instructions.
Returns
-------
Dict[str, torch.Tensor]
{
"input_ids": tensor of shape [B, L]
}
"""
if isinstance(language_instruction, str):
language_instruction = [language_instruction]
inputs = self.tokenizer(
language_instruction,
return_tensors="pt",
padding="max_length",
max_length=self.language_max_length,
truncation=True,
)
return {"input_ids": inputs["input_ids"]}
# ================== IMAGE ENCODING ==================
def encode_image(
self,
images: Union[List, List[List]],
**kwargs
) -> Dict[str, torch.Tensor]:
"""
Preprocess one or more sets of multi-view images.
Parameters
----------
images : List or List[List]
Single sample: [img1, img2, ...]
Batch: [[img1a, img1b], [img2a, img2b, img2c], ...]
Each image may be a PIL.Image, NumPy array, or torch.Tensor.
kwargs : dict
Extra arguments passed to the underlying image processor
(e.g., `do_resize=False`, `size=(224,224)`).
Returns
-------
Dict[str, torch.Tensor]
{
"image_input": tensor [B, num_views, C, H, W],
"image_mask": tensor [B, num_views]
}
"""
# Normalize to batch form
if not isinstance(images[0], (list, tuple)):
images = [images] # convert single sample to batch of size 1
batch_imgs, batch_masks = [], []
for sample_imgs in images:
processed = self.image_processor(sample_imgs, return_tensors="pt", **kwargs)["pixel_values"]
V_exist = processed.size(0)
# Pad to self.num_views
if V_exist < self.num_views:
processed = torch.cat(
[processed,
processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
dim=0,
)
# Mask: True for valid slots, False for padding
image_mask = torch.zeros(self.num_views, dtype=torch.bool, device=processed.device)
image_mask[:V_exist] = True
batch_imgs.append(processed)
batch_masks.append(image_mask)
image_input = torch.stack(batch_imgs, dim=0) # [B, num_views, C, H, W]
image_mask = torch.stack(batch_masks, dim=0) # [B, num_views]
return {"image_input": image_input, "image_mask": image_mask}
# ================== COMBINED CALL ==================
def __call__(
self,
images: Optional[Union[List, List[List]]] = None,
language_instruction: Optional[Union[str, List[str]]] = None,
**kwargs
) -> Dict[str, torch.Tensor]:
"""
Combine image and text encoding into a unified multimodal input.
Parameters
----------
images : List or List[List], optional
Single-sample or batched multi-view images.
language_instruction : str or List[str], optional
Corresponding text instructions.
kwargs : dict
Extra args passed to image processor.
Returns
-------
Dict[str, torch.Tensor]
{
"input_ids": [B, L], optional,
"image_input": [B, num_views, C, H, W], optional,
"image_mask": [B, num_views], optional
}
"""
outputs: Dict[str, Any] = {}
# Encode language if provided
if language_instruction is not None:
outputs.update(self.encode_language(language_instruction))
# Encode image if provided
if images is not None:
outputs.update(self.encode_image(images, **kwargs))
# Sanity check for batch alignment
if "input_ids" in outputs and "image_input" in outputs:
assert outputs["input_ids"].size(0) == outputs["image_input"].size(0), (
f"Batch mismatch: text batch {outputs['input_ids'].size(0)} "
f"!= image batch {outputs['image_input'].size(0)}"
)
return outputs