Initial upload for X-VLA-Google-Robot
Browse files- action_hub.py +275 -0
- config.json +78 -0
- configuration_florence2.py +340 -0
- configuration_xvla.py +95 -0
- model.safetensors +3 -0
- modeling_florence2.py +0 -0
- modeling_xvla.py +287 -0
- preprocessor_config.json +38 -0
- processing_xvla.py +206 -0
- tokenizer.json +0 -0
- tokenizer_config.json +4 -0
- transformer.py +403 -0
- vocab.json +0 -0
action_hub.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Copyright 2025 2toINF (https://github.com/2toINF)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
from typing import Iterable, Tuple, Dict, Type
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
# =============================================================================
|
| 23 |
+
# Registry
|
| 24 |
+
# =============================================================================
|
| 25 |
+
ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def register_action(name: str):
|
| 29 |
+
"""Decorator for registering a new action space."""
|
| 30 |
+
def _wrap(cls):
|
| 31 |
+
key = name.lower()
|
| 32 |
+
if key in ACTION_REGISTRY:
|
| 33 |
+
raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
|
| 34 |
+
ACTION_REGISTRY[key] = cls
|
| 35 |
+
cls.name = key
|
| 36 |
+
return cls
|
| 37 |
+
return _wrap
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def build_action_space(name: str, **kwargs) -> "BaseActionSpace":
|
| 41 |
+
"""Instantiate a registered action space by name."""
|
| 42 |
+
key = name.lower()
|
| 43 |
+
if key not in ACTION_REGISTRY:
|
| 44 |
+
raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
|
| 45 |
+
return ACTION_REGISTRY[key](**kwargs)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# =============================================================================
|
| 49 |
+
# Base class
|
| 50 |
+
# =============================================================================
|
| 51 |
+
class BaseActionSpace(nn.Module):
|
| 52 |
+
"""
|
| 53 |
+
Abstract base class for all action-space definitions.
|
| 54 |
+
|
| 55 |
+
Each subclass defines:
|
| 56 |
+
- `dim_action`: dimension of the action vector.
|
| 57 |
+
- `gripper_idx`: indices of gripper channels.
|
| 58 |
+
- `compute_loss(pred, target)`: supervised loss for this space.
|
| 59 |
+
- `preprocess(proprio, action, mode)`: pre-step modifications.
|
| 60 |
+
- `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
name: str = "base"
|
| 64 |
+
dim_action: int = 0
|
| 65 |
+
gripper_idx: Tuple[int, ...] = ()
|
| 66 |
+
|
| 67 |
+
def __init__(self):
|
| 68 |
+
super().__init__()
|
| 69 |
+
|
| 70 |
+
# ---------------------------------------------------------------------
|
| 71 |
+
# Core supervised loss
|
| 72 |
+
# ---------------------------------------------------------------------
|
| 73 |
+
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 74 |
+
raise NotImplementedError
|
| 75 |
+
|
| 76 |
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 77 |
+
"""Alias for compute_loss."""
|
| 78 |
+
return self.compute_loss(pred, target)
|
| 79 |
+
|
| 80 |
+
# ---------------------------------------------------------------------
|
| 81 |
+
# Space-level hooks
|
| 82 |
+
# ---------------------------------------------------------------------
|
| 83 |
+
def preprocess(
|
| 84 |
+
self,
|
| 85 |
+
proprio: torch.Tensor,
|
| 86 |
+
action: torch.Tensor,
|
| 87 |
+
mode: str = "train",
|
| 88 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 89 |
+
"""Default: return unchanged."""
|
| 90 |
+
return proprio, action
|
| 91 |
+
|
| 92 |
+
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
"""Default: return unchanged."""
|
| 94 |
+
return action
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# =============================================================================
|
| 98 |
+
# Utilities
|
| 99 |
+
# =============================================================================
|
| 100 |
+
def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None:
|
| 101 |
+
bad = [i for i in idx if i < 0 or i >= D]
|
| 102 |
+
if bad:
|
| 103 |
+
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# =============================================================================
|
| 107 |
+
# Implementations
|
| 108 |
+
# =============================================================================
|
| 109 |
+
@register_action("ee6d")
|
| 110 |
+
class EE6DActionSpace(BaseActionSpace):
|
| 111 |
+
"""End-effector layout with xyz, 6D rotation, and gripper channels."""
|
| 112 |
+
|
| 113 |
+
dim_action = 20
|
| 114 |
+
gripper_idx = (9, 19)
|
| 115 |
+
GRIPPER_SCALE = 1.0
|
| 116 |
+
XYZ_SCALE = 500.0
|
| 117 |
+
ROT_SCALE = 10.0
|
| 118 |
+
|
| 119 |
+
POS_IDX_1 = (0, 1, 2)
|
| 120 |
+
POS_IDX_2 = (10, 11, 12)
|
| 121 |
+
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
|
| 122 |
+
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
|
| 123 |
+
|
| 124 |
+
def __init__(self):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.mse = nn.MSELoss()
|
| 127 |
+
self.bce = nn.BCEWithLogitsLoss()
|
| 128 |
+
|
| 129 |
+
def compute_loss(self, pred, target):
|
| 130 |
+
assert pred.shape == target.shape, "pred/target shapes must match"
|
| 131 |
+
B, T, D = pred.shape
|
| 132 |
+
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
|
| 133 |
+
|
| 134 |
+
# Gripper BCE
|
| 135 |
+
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
| 136 |
+
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
| 137 |
+
|
| 138 |
+
# XYZ position
|
| 139 |
+
pos_loss = (
|
| 140 |
+
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
|
| 141 |
+
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
| 142 |
+
) * self.XYZ_SCALE
|
| 143 |
+
|
| 144 |
+
# Rotation 6D
|
| 145 |
+
rot_loss = (
|
| 146 |
+
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
|
| 147 |
+
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
| 148 |
+
) * self.ROT_SCALE
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
"position_loss": pos_loss,
|
| 152 |
+
"rotate6D_loss": rot_loss,
|
| 153 |
+
"gripper_loss": gripper_loss,
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
def preprocess(self, proprio, action, mode="train"):
|
| 157 |
+
"""Zero-out gripper channels in proprio/action."""
|
| 158 |
+
proprio_m = proprio.clone()
|
| 159 |
+
action_m = action.clone()
|
| 160 |
+
proprio_m[..., self.gripper_idx] = 0.0
|
| 161 |
+
action_m[..., self.gripper_idx] = 0.0
|
| 162 |
+
return proprio_m, action_m
|
| 163 |
+
|
| 164 |
+
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
| 165 |
+
"""Apply sigmoid to gripper logits."""
|
| 166 |
+
if action.size(-1) > max(self.gripper_idx):
|
| 167 |
+
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
| 168 |
+
return action
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@register_action("joint")
|
| 172 |
+
class JointActionSpace(BaseActionSpace):
|
| 173 |
+
"""Joint-space layout with joints + gripper only."""
|
| 174 |
+
|
| 175 |
+
dim_action = 14
|
| 176 |
+
gripper_idx = (6, 13)
|
| 177 |
+
GRIPPER_SCALE = 0.1
|
| 178 |
+
JOINTS_SCALE = 1.0
|
| 179 |
+
|
| 180 |
+
def __init__(self):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.mse = nn.MSELoss()
|
| 183 |
+
self.bce = nn.BCEWithLogitsLoss()
|
| 184 |
+
|
| 185 |
+
def compute_loss(self, pred, target):
|
| 186 |
+
assert pred.shape == target.shape
|
| 187 |
+
B, T, D = pred.shape
|
| 188 |
+
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
|
| 189 |
+
|
| 190 |
+
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
| 191 |
+
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
| 192 |
+
|
| 193 |
+
joints_idx = tuple(i for i in range(D) if i not in set(self.gripper_idx))
|
| 194 |
+
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
|
| 195 |
+
|
| 196 |
+
return {
|
| 197 |
+
"joints_loss": joints_loss,
|
| 198 |
+
"gripper_loss": gripper_loss,
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
def preprocess(self, proprio, action, mode="train"):
|
| 202 |
+
"""Zero-out gripper channels in proprio/action."""
|
| 203 |
+
proprio_m = proprio.clone()
|
| 204 |
+
action_m = action.clone()
|
| 205 |
+
proprio_m[..., self.gripper_idx] = 0.0
|
| 206 |
+
action_m[..., self.gripper_idx] = 0.0
|
| 207 |
+
return proprio_m, action_m
|
| 208 |
+
|
| 209 |
+
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
| 210 |
+
"""Apply sigmoid to gripper logits."""
|
| 211 |
+
if action.size(-1) > max(self.gripper_idx):
|
| 212 |
+
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
| 213 |
+
return action
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@register_action("agibot_ee6d")
|
| 217 |
+
class AGIBOTEE6DActionSpace(BaseActionSpace):
|
| 218 |
+
"""AGI-bot variant of EE6DActionSpace using MSE for all components."""
|
| 219 |
+
|
| 220 |
+
dim_action = 20
|
| 221 |
+
gripper_idx = (9, 19)
|
| 222 |
+
GRIPPER_SCALE = 10.0
|
| 223 |
+
XYZ_SCALE = 500.0
|
| 224 |
+
ROT_SCALE = 10.0
|
| 225 |
+
POS_IDX_1 = (0, 1, 2)
|
| 226 |
+
POS_IDX_2 = (10, 11, 12)
|
| 227 |
+
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
|
| 228 |
+
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
|
| 229 |
+
|
| 230 |
+
def __init__(self):
|
| 231 |
+
super().__init__()
|
| 232 |
+
self.mse = nn.MSELoss()
|
| 233 |
+
|
| 234 |
+
def compute_loss(self, pred, target):
|
| 235 |
+
assert pred.shape == target.shape
|
| 236 |
+
B, T, D = pred.shape
|
| 237 |
+
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
|
| 238 |
+
|
| 239 |
+
gripper_loss = self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
|
| 240 |
+
pos_loss = (
|
| 241 |
+
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
|
| 242 |
+
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
| 243 |
+
) * self.XYZ_SCALE
|
| 244 |
+
rot_loss = (
|
| 245 |
+
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
|
| 246 |
+
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
| 247 |
+
) * self.ROT_SCALE
|
| 248 |
+
|
| 249 |
+
return {
|
| 250 |
+
"position_loss": pos_loss,
|
| 251 |
+
"rotate6D_loss": rot_loss,
|
| 252 |
+
"gripper_loss": gripper_loss,
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
def preprocess(self, proprio, action, mode="train"):
|
| 256 |
+
"""No preprocessing applied in AGIBOT variant."""
|
| 257 |
+
return proprio, action
|
| 258 |
+
|
| 259 |
+
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
| 260 |
+
"""AGIBOT does not postprocess."""
|
| 261 |
+
return action
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# =============================================================================
|
| 265 |
+
# Exports
|
| 266 |
+
# =============================================================================
|
| 267 |
+
__all__ = [
|
| 268 |
+
"BaseActionSpace",
|
| 269 |
+
"build_action_space",
|
| 270 |
+
"register_action",
|
| 271 |
+
"EE6DActionSpace",
|
| 272 |
+
"JointActionSpace",
|
| 273 |
+
"AGIBOTEE6DActionSpace",
|
| 274 |
+
"ACTION_REGISTRY",
|
| 275 |
+
]
|
config.json
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "xvla",
|
| 3 |
+
"model_type": "xvla",
|
| 4 |
+
"architectures": ["XVLA"],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_xvla.XVLAConfig",
|
| 7 |
+
"AutoModel": "modeling_xvla.XVLA"
|
| 8 |
+
},
|
| 9 |
+
|
| 10 |
+
"action_mode": "ee6d",
|
| 11 |
+
"use_proprio": true,
|
| 12 |
+
"num_actions": 30,
|
| 13 |
+
|
| 14 |
+
"hidden_size": 1024,
|
| 15 |
+
"depth": 24,
|
| 16 |
+
"num_heads": 16,
|
| 17 |
+
"mlp_ratio": 4.0,
|
| 18 |
+
"num_domains": 30,
|
| 19 |
+
"len_soft_prompts": 32,
|
| 20 |
+
"dim_time": 32,
|
| 21 |
+
"max_len_seq": 512,
|
| 22 |
+
"use_hetero_proj": false,
|
| 23 |
+
"soft_prompt_length": 32,
|
| 24 |
+
|
| 25 |
+
"florence_config": {
|
| 26 |
+
"model_type": "florence2",
|
| 27 |
+
"bos_token_id": 0,
|
| 28 |
+
"eos_token_id": 2,
|
| 29 |
+
"ignore_index": -100,
|
| 30 |
+
"pad_token_id": 1,
|
| 31 |
+
"projection_dim": 1024,
|
| 32 |
+
|
| 33 |
+
"text_config": {
|
| 34 |
+
"vocab_size": 51289,
|
| 35 |
+
"activation_dropout": 0.1,
|
| 36 |
+
"activation_function": "gelu",
|
| 37 |
+
"attention_dropout": 0.1,
|
| 38 |
+
"d_model": 1024,
|
| 39 |
+
"decoder_attention_heads": 16,
|
| 40 |
+
"decoder_layers": 12,
|
| 41 |
+
"encoder_attention_heads": 16,
|
| 42 |
+
"encoder_layers": 12,
|
| 43 |
+
"dropout": 0.1,
|
| 44 |
+
"max_position_embeddings": 4096,
|
| 45 |
+
"num_hidden_layers": 12,
|
| 46 |
+
"num_beams": 3
|
| 47 |
+
},
|
| 48 |
+
|
| 49 |
+
"vision_config": {
|
| 50 |
+
"model_type": "davit",
|
| 51 |
+
"drop_path_rate": 0.1,
|
| 52 |
+
"patch_size": [7, 3, 3, 3],
|
| 53 |
+
"patch_stride": [4, 2, 2, 2],
|
| 54 |
+
"patch_padding": [3, 1, 1, 1],
|
| 55 |
+
"patch_prenorm": [false, true, true, true],
|
| 56 |
+
"enable_checkpoint": false,
|
| 57 |
+
"dim_embed": [256, 512, 1024, 2048],
|
| 58 |
+
"num_heads": [8, 16, 32, 64],
|
| 59 |
+
"num_groups": [8, 16, 32, 64],
|
| 60 |
+
"depths": [1, 1, 9, 1],
|
| 61 |
+
"window_size": 12,
|
| 62 |
+
"projection_dim": 1024,
|
| 63 |
+
"visual_temporal_embedding": {
|
| 64 |
+
"type": "COSINE",
|
| 65 |
+
"max_temporal_embeddings": 100
|
| 66 |
+
},
|
| 67 |
+
"image_pos_embed": {
|
| 68 |
+
"type": "learned_abs_2d",
|
| 69 |
+
"max_pos_embeddings": 50
|
| 70 |
+
},
|
| 71 |
+
"image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"]
|
| 72 |
+
},
|
| 73 |
+
|
| 74 |
+
"vocab_size": 51289,
|
| 75 |
+
"torch_dtype": "float16",
|
| 76 |
+
"is_encoder_decoder": true
|
| 77 |
+
}
|
| 78 |
+
}
|
configuration_florence2.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import warnings
|
| 15 |
+
""" Florence-2 configuration"""
|
| 16 |
+
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
from transformers import AutoConfig
|
| 20 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 21 |
+
from transformers.utils import logging
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
class Florence2VisionConfig(PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
|
| 28 |
+
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 29 |
+
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
| 36 |
+
The dropout rate of the drop path layer.
|
| 37 |
+
patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
|
| 38 |
+
The patch size of the image.
|
| 39 |
+
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
|
| 40 |
+
The patch stride of the image.
|
| 41 |
+
patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
|
| 42 |
+
The patch padding of the image.
|
| 43 |
+
patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
|
| 44 |
+
Whether to apply layer normalization before the patch embedding layer.
|
| 45 |
+
enable_checkpoint (`bool`, *optional*, defaults to False):
|
| 46 |
+
Whether to enable checkpointing.
|
| 47 |
+
dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
|
| 48 |
+
The dimension of the embedding layer.
|
| 49 |
+
num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
| 50 |
+
The number of attention heads.
|
| 51 |
+
num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
| 52 |
+
The number of groups.
|
| 53 |
+
depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
|
| 54 |
+
The depth of the model.
|
| 55 |
+
window_size (`int`, *optional*, defaults to 12):
|
| 56 |
+
The window size of the model.
|
| 57 |
+
projection_dim (`int`, *optional*, defaults to 1024):
|
| 58 |
+
The dimension of the projection layer.
|
| 59 |
+
visual_temporal_embedding (`dict`, *optional*):
|
| 60 |
+
The configuration of the visual temporal embedding.
|
| 61 |
+
image_pos_embed (`dict`, *optional*):
|
| 62 |
+
The configuration of the image position embedding.
|
| 63 |
+
image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
|
| 64 |
+
The source of the image feature.
|
| 65 |
+
Example:
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
>>> from transformers import Florence2VisionConfig, Florence2VisionModel
|
| 69 |
+
|
| 70 |
+
>>> # Initializing a Florence2 Vision style configuration
|
| 71 |
+
>>> configuration = Florence2VisionConfig()
|
| 72 |
+
|
| 73 |
+
>>> # Initializing a model (with random weights)
|
| 74 |
+
>>> model = Florence2VisionModel(configuration)
|
| 75 |
+
|
| 76 |
+
>>> # Accessing the model configuration
|
| 77 |
+
>>> configuration = model.config
|
| 78 |
+
```"""
|
| 79 |
+
|
| 80 |
+
model_type = "davit"
|
| 81 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
drop_path_rate=0.1,
|
| 86 |
+
patch_size=[7, 3, 3, 3],
|
| 87 |
+
patch_stride=[4, 2, 2, 2],
|
| 88 |
+
patch_padding=[3, 1, 1, 1],
|
| 89 |
+
patch_prenorm=[False, True, True, True],
|
| 90 |
+
enable_checkpoint=False,
|
| 91 |
+
dim_embed=[256, 512, 1024, 2048],
|
| 92 |
+
num_heads=[8, 16, 32, 64],
|
| 93 |
+
num_groups=[8, 16, 32, 64],
|
| 94 |
+
depths=[1, 1, 9, 1],
|
| 95 |
+
window_size=12,
|
| 96 |
+
projection_dim=1024,
|
| 97 |
+
visual_temporal_embedding=None,
|
| 98 |
+
image_pos_embed=None,
|
| 99 |
+
image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
|
| 100 |
+
**kwargs,
|
| 101 |
+
):
|
| 102 |
+
self.drop_path_rate = drop_path_rate
|
| 103 |
+
self.patch_size = patch_size
|
| 104 |
+
self.patch_stride = patch_stride
|
| 105 |
+
self.patch_padding = patch_padding
|
| 106 |
+
self.patch_prenorm = patch_prenorm
|
| 107 |
+
self.enable_checkpoint = enable_checkpoint
|
| 108 |
+
self.dim_embed = dim_embed
|
| 109 |
+
self.num_heads = num_heads
|
| 110 |
+
self.num_groups = num_groups
|
| 111 |
+
self.depths = depths
|
| 112 |
+
self.window_size = window_size
|
| 113 |
+
self.projection_dim = projection_dim
|
| 114 |
+
self.visual_temporal_embedding = visual_temporal_embedding
|
| 115 |
+
self.image_pos_embed = image_pos_embed
|
| 116 |
+
self.image_feature_source = image_feature_source
|
| 117 |
+
|
| 118 |
+
super().__init__(**kwargs)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Florence2LanguageConfig(PretrainedConfig):
|
| 123 |
+
r"""
|
| 124 |
+
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
|
| 125 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 126 |
+
defaults will yield a similar configuration to that of the BART
|
| 127 |
+
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
|
| 128 |
+
|
| 129 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 130 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
vocab_size (`int`, *optional*, defaults to 51289):
|
| 135 |
+
Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
|
| 136 |
+
`inputs_ids` passed when calling [`Florence2LanguageModel`].
|
| 137 |
+
d_model (`int`, *optional*, defaults to 1024):
|
| 138 |
+
Dimensionality of the layers and the pooler layer.
|
| 139 |
+
encoder_layers (`int`, *optional*, defaults to 12):
|
| 140 |
+
Number of encoder layers.
|
| 141 |
+
decoder_layers (`int`, *optional*, defaults to 12):
|
| 142 |
+
Number of decoder layers.
|
| 143 |
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 144 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 145 |
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
| 146 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 147 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
| 148 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
| 149 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
| 150 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
| 151 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 152 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 153 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 154 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 155 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 156 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 157 |
+
The dropout ratio for the attention probabilities.
|
| 158 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 159 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 160 |
+
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
| 161 |
+
The dropout ratio for classifier.
|
| 162 |
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
| 163 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 164 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 165 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
| 166 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 167 |
+
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
| 168 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 169 |
+
for more details.
|
| 170 |
+
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
| 171 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 172 |
+
for more details.
|
| 173 |
+
scale_embedding (`bool`, *optional*, defaults to `False`):
|
| 174 |
+
Scale embeddings by diving by sqrt(d_model).
|
| 175 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 176 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 177 |
+
num_labels (`int`, *optional*, defaults to 3):
|
| 178 |
+
The number of labels to use in [`Florence2LanguageForSequenceClassification`].
|
| 179 |
+
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
| 180 |
+
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
| 181 |
+
`eos_token_id`.
|
| 182 |
+
|
| 183 |
+
Example:
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
>>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
|
| 187 |
+
|
| 188 |
+
>>> # Initializing a Florence2 Language style configuration
|
| 189 |
+
>>> configuration = Florence2LanguageConfig()
|
| 190 |
+
|
| 191 |
+
>>> # Initializing a model (with random weights)
|
| 192 |
+
>>> model = Florence2LangaugeModel(configuration)
|
| 193 |
+
|
| 194 |
+
>>> # Accessing the model configuration
|
| 195 |
+
>>> configuration = model.config
|
| 196 |
+
```"""
|
| 197 |
+
|
| 198 |
+
model_type = "florence2_language"
|
| 199 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 200 |
+
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
| 201 |
+
|
| 202 |
+
def __init__(
|
| 203 |
+
self,
|
| 204 |
+
vocab_size=51289,
|
| 205 |
+
max_position_embeddings=1024,
|
| 206 |
+
encoder_layers=12,
|
| 207 |
+
encoder_ffn_dim=4096,
|
| 208 |
+
encoder_attention_heads=16,
|
| 209 |
+
decoder_layers=12,
|
| 210 |
+
decoder_ffn_dim=4096,
|
| 211 |
+
decoder_attention_heads=16,
|
| 212 |
+
encoder_layerdrop=0.0,
|
| 213 |
+
decoder_layerdrop=0.0,
|
| 214 |
+
activation_function="gelu",
|
| 215 |
+
d_model=1024,
|
| 216 |
+
dropout=0.1,
|
| 217 |
+
attention_dropout=0.0,
|
| 218 |
+
activation_dropout=0.0,
|
| 219 |
+
init_std=0.02,
|
| 220 |
+
classifier_dropout=0.0,
|
| 221 |
+
scale_embedding=False,
|
| 222 |
+
use_cache=True,
|
| 223 |
+
num_labels=3,
|
| 224 |
+
pad_token_id=1,
|
| 225 |
+
bos_token_id=0,
|
| 226 |
+
eos_token_id=2,
|
| 227 |
+
is_encoder_decoder=True,
|
| 228 |
+
decoder_start_token_id=2,
|
| 229 |
+
forced_eos_token_id=2,
|
| 230 |
+
**kwargs,
|
| 231 |
+
):
|
| 232 |
+
self.vocab_size = vocab_size
|
| 233 |
+
self.max_position_embeddings = max_position_embeddings
|
| 234 |
+
self.d_model = d_model
|
| 235 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 236 |
+
self.encoder_layers = encoder_layers
|
| 237 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 238 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 239 |
+
self.decoder_layers = decoder_layers
|
| 240 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 241 |
+
self.dropout = dropout
|
| 242 |
+
self.attention_dropout = attention_dropout
|
| 243 |
+
self.activation_dropout = activation_dropout
|
| 244 |
+
self.activation_function = activation_function
|
| 245 |
+
self.init_std = init_std
|
| 246 |
+
self.encoder_layerdrop = encoder_layerdrop
|
| 247 |
+
self.decoder_layerdrop = decoder_layerdrop
|
| 248 |
+
self.classifier_dropout = classifier_dropout
|
| 249 |
+
self.use_cache = use_cache
|
| 250 |
+
self.num_hidden_layers = encoder_layers
|
| 251 |
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
| 252 |
+
|
| 253 |
+
super().__init__(
|
| 254 |
+
num_labels=num_labels,
|
| 255 |
+
pad_token_id=pad_token_id,
|
| 256 |
+
bos_token_id=bos_token_id,
|
| 257 |
+
eos_token_id=eos_token_id,
|
| 258 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 259 |
+
decoder_start_token_id=decoder_start_token_id,
|
| 260 |
+
forced_eos_token_id=forced_eos_token_id,
|
| 261 |
+
**kwargs,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# ensure backward compatibility for BART CNN models
|
| 265 |
+
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
| 266 |
+
self.forced_bos_token_id = self.bos_token_id
|
| 267 |
+
warnings.warn(
|
| 268 |
+
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
|
| 269 |
+
"The config can simply be saved and uploaded again to be fixed."
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
class Florence2Config(PretrainedConfig):
|
| 273 |
+
r"""
|
| 274 |
+
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
|
| 275 |
+
Florence-2 model according to the specified arguments, defining the model architecture.
|
| 276 |
+
|
| 277 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 278 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
vision_config (`Florence2VisionConfig`, *optional*):
|
| 282 |
+
Custom vision config or dict
|
| 283 |
+
text_config (`Union[AutoConfig, dict]`, *optional*):
|
| 284 |
+
The config object of the text backbone.
|
| 285 |
+
ignore_index (`int`, *optional*, defaults to -100):
|
| 286 |
+
The ignore index for the loss function.
|
| 287 |
+
vocab_size (`int`, *optional*, defaults to 51289):
|
| 288 |
+
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
|
| 289 |
+
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
|
| 290 |
+
projection_dim (`int`, *optional*, defaults to 1024):
|
| 291 |
+
Dimension of the multimodal projection space.
|
| 292 |
+
|
| 293 |
+
Example:
|
| 294 |
+
|
| 295 |
+
```python
|
| 296 |
+
>>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
|
| 297 |
+
|
| 298 |
+
>>> # Initializing a clip-like vision config
|
| 299 |
+
>>> vision_config = CLIPVisionConfig()
|
| 300 |
+
|
| 301 |
+
>>> # Initializing a Bart config
|
| 302 |
+
>>> text_config = BartConfig()
|
| 303 |
+
|
| 304 |
+
>>> # Initializing a Florence-2 configuration
|
| 305 |
+
>>> configuration = Florence2Config(vision_config, text_config)
|
| 306 |
+
|
| 307 |
+
>>> # Initializing a model from the florence-2 configuration
|
| 308 |
+
>>> model = Florence2ForConditionalGeneration(configuration)
|
| 309 |
+
|
| 310 |
+
>>> # Accessing the model configuration
|
| 311 |
+
>>> configuration = model.config
|
| 312 |
+
```"""
|
| 313 |
+
|
| 314 |
+
model_type = "florence2"
|
| 315 |
+
is_composition = False
|
| 316 |
+
|
| 317 |
+
def __init__(
|
| 318 |
+
self,
|
| 319 |
+
vision_config=None,
|
| 320 |
+
text_config=None,
|
| 321 |
+
ignore_index=-100,
|
| 322 |
+
vocab_size=51289,
|
| 323 |
+
projection_dim=1024,
|
| 324 |
+
**kwargs,
|
| 325 |
+
):
|
| 326 |
+
self.ignore_index = ignore_index
|
| 327 |
+
self.vocab_size = vocab_size
|
| 328 |
+
self.projection_dim = projection_dim
|
| 329 |
+
if vision_config is not None:
|
| 330 |
+
vision_config = Florence2VisionConfig(**vision_config)
|
| 331 |
+
self.vision_config = vision_config
|
| 332 |
+
self.vocab_size = self.vocab_size
|
| 333 |
+
|
| 334 |
+
self.text_config = text_config
|
| 335 |
+
if text_config is not None:
|
| 336 |
+
self.text_config = Florence2LanguageConfig(**text_config)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
super().__init__(**kwargs)
|
| 340 |
+
|
configuration_xvla.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Copyright 2025 2toINF (https://github.com/2toINF)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
from .configuration_florence2 import Florence2Config
|
| 18 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class XVLAConfig(PretrainedConfig):
|
| 22 |
+
"""
|
| 23 |
+
Configuration class for the **XVLA (Extended Vision-Language-Action)** model.
|
| 24 |
+
|
| 25 |
+
This configuration defines all submodules of XVLA in a single place:
|
| 26 |
+
- The visual-language backbone (Florence2)
|
| 27 |
+
- The temporal/action transformer
|
| 28 |
+
- The action/proprio setup
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
model_type = "xvla"
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
# === Florence backbone ===
|
| 35 |
+
self,
|
| 36 |
+
florence_config: dict | None = None,
|
| 37 |
+
|
| 38 |
+
# === Transformer head ===
|
| 39 |
+
hidden_size: int = 1024,
|
| 40 |
+
depth: int = 24,
|
| 41 |
+
num_heads: int = 16,
|
| 42 |
+
mlp_ratio: float = 4.0,
|
| 43 |
+
num_domains: int = 30,
|
| 44 |
+
len_soft_prompts: int = 32,
|
| 45 |
+
dim_time: int = 32,
|
| 46 |
+
max_len_seq: int = 512,
|
| 47 |
+
use_hetero_proj: bool = False,
|
| 48 |
+
soft_prompt_length: int = 32,
|
| 49 |
+
|
| 50 |
+
# === Action & proprio ===
|
| 51 |
+
num_actions: int = 30,
|
| 52 |
+
action_mode: str = "ee6d",
|
| 53 |
+
use_proprio: bool = True,
|
| 54 |
+
|
| 55 |
+
**kwargs,
|
| 56 |
+
):
|
| 57 |
+
# Florence2 backbone configuration
|
| 58 |
+
if isinstance(florence_config, dict):
|
| 59 |
+
self.florence_config = Florence2Config(**florence_config)
|
| 60 |
+
elif isinstance(florence_config, Florence2Config):
|
| 61 |
+
self.florence_config = florence_config
|
| 62 |
+
else:
|
| 63 |
+
self.florence_config = Florence2Config()
|
| 64 |
+
|
| 65 |
+
# Transformer hyperparameters
|
| 66 |
+
self.hidden_size = hidden_size
|
| 67 |
+
self.depth = depth
|
| 68 |
+
self.num_heads = num_heads
|
| 69 |
+
self.mlp_ratio = mlp_ratio
|
| 70 |
+
self.num_domains = num_domains
|
| 71 |
+
self.len_soft_prompts = len_soft_prompts
|
| 72 |
+
self.dim_time = dim_time
|
| 73 |
+
self.max_len_seq = max_len_seq
|
| 74 |
+
self.use_hetero_proj = use_hetero_proj
|
| 75 |
+
self.soft_prompt_length = soft_prompt_length
|
| 76 |
+
|
| 77 |
+
# Action/proprioception settings
|
| 78 |
+
self.num_actions = num_actions
|
| 79 |
+
self.action_mode = action_mode
|
| 80 |
+
self.use_proprio = use_proprio
|
| 81 |
+
|
| 82 |
+
# Initialize base HF config attributes (e.g. name_or_path)
|
| 83 |
+
super().__init__(**kwargs)
|
| 84 |
+
|
| 85 |
+
# -------------------------------------------------------------------------
|
| 86 |
+
# Serialization helpers
|
| 87 |
+
# -------------------------------------------------------------------------
|
| 88 |
+
def to_dict(self):
|
| 89 |
+
"""
|
| 90 |
+
Convert this configuration (and its Florence sub-config)
|
| 91 |
+
into a fully serializable dictionary for HF save/load.
|
| 92 |
+
"""
|
| 93 |
+
output = super().to_dict()
|
| 94 |
+
output["florence_config"] = self.florence_config.to_dict()
|
| 95 |
+
return output
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ea279d74b9a5878da79f7dae949a1d8e92cead2cf0f58612f9d11e4ba89788e
|
| 3 |
+
size 3519068172
|
modeling_florence2.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_xvla.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Copyright 2025 2toINF (https://github.com/2toINF)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import traceback
|
| 21 |
+
from typing import Any, Dict
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
from fastapi import FastAPI
|
| 26 |
+
from fastapi.responses import JSONResponse
|
| 27 |
+
from PIL import Image
|
| 28 |
+
import uvicorn
|
| 29 |
+
import json_numpy
|
| 30 |
+
import cv2
|
| 31 |
+
|
| 32 |
+
from transformers import PreTrainedModel
|
| 33 |
+
from .modeling_florence2 import Florence2ForConditionalGeneration
|
| 34 |
+
from .transformer import SoftPromptedTransformer
|
| 35 |
+
from .action_hub import build_action_space
|
| 36 |
+
from .configuration_xvla import XVLAConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class XVLA(PreTrainedModel):
|
| 40 |
+
"""
|
| 41 |
+
XVLA: HuggingFace-compatible Vision-Language-Action policy.
|
| 42 |
+
|
| 43 |
+
Components:
|
| 44 |
+
• Florence2 encoder-only backbone (vision-language)
|
| 45 |
+
• SoftPromptedTransformer (temporal/action head)
|
| 46 |
+
• Action space (pre/post-processing + loss)
|
| 47 |
+
"""
|
| 48 |
+
config_class = XVLAConfig
|
| 49 |
+
base_model_prefix = "xvla"
|
| 50 |
+
supports_gradient_checkpointing = True
|
| 51 |
+
|
| 52 |
+
def __init__(self, config: XVLAConfig, *args, **kwargs):
|
| 53 |
+
super().__init__(config, *args, **kwargs)
|
| 54 |
+
|
| 55 |
+
# Core settings
|
| 56 |
+
self.num_actions: int = config.num_actions
|
| 57 |
+
self.use_proprio: bool = config.use_proprio
|
| 58 |
+
self.action_mode: str = config.action_mode.lower()
|
| 59 |
+
# Action space (dimensions + hooks)
|
| 60 |
+
self.action_space = build_action_space(config.action_mode.lower())
|
| 61 |
+
dim_action = self.action_space.dim_action
|
| 62 |
+
dim_proprio = getattr(self.action_space, "dim_proprio", dim_action)
|
| 63 |
+
|
| 64 |
+
# Florence2 backbone (encoder only)
|
| 65 |
+
self.vlm = Florence2ForConditionalGeneration(config.florence_config)
|
| 66 |
+
if hasattr(self.vlm, "language_model"):
|
| 67 |
+
lm = self.vlm.language_model
|
| 68 |
+
if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
|
| 69 |
+
del lm.model.decoder
|
| 70 |
+
if hasattr(lm, "lm_head"):
|
| 71 |
+
del lm.lm_head
|
| 72 |
+
|
| 73 |
+
projection_dim = getattr(self.vlm.config, "projection_dim", None)
|
| 74 |
+
if projection_dim is None:
|
| 75 |
+
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
|
| 76 |
+
|
| 77 |
+
# Temporal/action head
|
| 78 |
+
self.transformer = SoftPromptedTransformer(
|
| 79 |
+
hidden_size=config.hidden_size,
|
| 80 |
+
multi_modal_input_size=projection_dim,
|
| 81 |
+
depth=config.depth,
|
| 82 |
+
num_heads=config.num_heads,
|
| 83 |
+
mlp_ratio=config.mlp_ratio,
|
| 84 |
+
num_domains=config.num_domains,
|
| 85 |
+
dim_action=dim_action,
|
| 86 |
+
dim_propio=dim_proprio,
|
| 87 |
+
len_soft_prompts=config.len_soft_prompts,
|
| 88 |
+
dim_time=config.dim_time,
|
| 89 |
+
max_len_seq=config.max_len_seq,
|
| 90 |
+
use_hetero_proj=config.use_hetero_proj,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Deferred FastAPI app
|
| 94 |
+
self.app: FastAPI | None = None
|
| 95 |
+
|
| 96 |
+
# ============================= Florence2 encoder =============================
|
| 97 |
+
def forward_vlm(
|
| 98 |
+
self,
|
| 99 |
+
input_ids: torch.LongTensor, # [B, L]
|
| 100 |
+
pixel_values: torch.FloatTensor, # [B, V, C, H, W]
|
| 101 |
+
image_mask: torch.Tensor, # [B, V] (bool or 0/1)
|
| 102 |
+
) -> Dict[str, torch.Tensor]:
|
| 103 |
+
"""
|
| 104 |
+
Encode text + multi-view images via Florence2 encoder.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
{ "vlm_features": [B, T_enc, D], "aux_visual_inputs": [B, (V-1)*N, D] }
|
| 108 |
+
"""
|
| 109 |
+
B, V = pixel_values.shape[:2]
|
| 110 |
+
flat_mask = image_mask.view(-1).to(torch.bool) # [B*V]
|
| 111 |
+
flat_images = pixel_values.flatten(0, 1) # [B*V, C, H, W]
|
| 112 |
+
|
| 113 |
+
num_valid = int(flat_mask.sum().item())
|
| 114 |
+
if num_valid == 0:
|
| 115 |
+
raise ValueError("At least one image view must be valid per batch.")
|
| 116 |
+
|
| 117 |
+
valid_images = flat_images[flat_mask] # [#valid, C, H, W]
|
| 118 |
+
valid_feats = self.vlm._encode_image(valid_images) # [#valid, N, D]
|
| 119 |
+
N, D = valid_feats.shape[1:]
|
| 120 |
+
|
| 121 |
+
image_features = valid_feats.new_zeros((B * V, N, D))
|
| 122 |
+
image_features[flat_mask] = valid_feats
|
| 123 |
+
image_features = image_features.view(B, V, N, D) # [B, V, N, D]
|
| 124 |
+
|
| 125 |
+
inputs_embeds = self.vlm.get_input_embeddings()(input_ids) # [B, L, D]
|
| 126 |
+
|
| 127 |
+
merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
|
| 128 |
+
image_features[:, 0], # first view: [B, N, D]
|
| 129 |
+
inputs_embeds, # [B, L, D]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
enc_out = self.vlm.language_model.model.encoder(
|
| 133 |
+
attention_mask=attention_mask,
|
| 134 |
+
inputs_embeds=merged_embeds,
|
| 135 |
+
)[0] # [B, T_enc, D]
|
| 136 |
+
|
| 137 |
+
aux_visual_inputs = image_features[:, 1:].reshape(B, -1, D) # remaining views flattened
|
| 138 |
+
return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
|
| 139 |
+
|
| 140 |
+
# ================================= training =================================
|
| 141 |
+
def forward(
|
| 142 |
+
self,
|
| 143 |
+
input_ids: torch.LongTensor,
|
| 144 |
+
image_input: torch.FloatTensor,
|
| 145 |
+
image_mask: torch.Tensor,
|
| 146 |
+
domain_id: torch.LongTensor,
|
| 147 |
+
proprio: torch.Tensor,
|
| 148 |
+
action: torch.Tensor, # [B, T=num_actions, D=dim_action]
|
| 149 |
+
) -> Dict[str, torch.Tensor]:
|
| 150 |
+
"""
|
| 151 |
+
1) Encode multimodal inputs.
|
| 152 |
+
2) Diffusion-style noisy mixture of actions: x_t = t*noise + (1-t)*gt.
|
| 153 |
+
3) Space-specific preprocessing, prediction, and supervised loss.
|
| 154 |
+
"""
|
| 155 |
+
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
| 156 |
+
|
| 157 |
+
B = input_ids.shape[0]
|
| 158 |
+
t = (torch.rand(1, device=input_ids.device)
|
| 159 |
+
+ torch.arange(B, device=input_ids.device) / B) % (1 - 1e-5)
|
| 160 |
+
|
| 161 |
+
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
| 162 |
+
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
|
| 163 |
+
|
| 164 |
+
pred_action = self.transformer(
|
| 165 |
+
domain_id=domain_id,
|
| 166 |
+
action_with_noise=action_noisy_m,
|
| 167 |
+
t=t,
|
| 168 |
+
proprio=proprio_m,
|
| 169 |
+
**enc,
|
| 170 |
+
)
|
| 171 |
+
return self.action_space.compute_loss(pred_action, action)
|
| 172 |
+
|
| 173 |
+
# ================================= inference =================================
|
| 174 |
+
@torch.no_grad()
|
| 175 |
+
def generate_actions(
|
| 176 |
+
self,
|
| 177 |
+
input_ids: torch.LongTensor,
|
| 178 |
+
image_input: torch.FloatTensor,
|
| 179 |
+
image_mask: torch.Tensor,
|
| 180 |
+
domain_id: torch.LongTensor,
|
| 181 |
+
proprio: torch.Tensor,
|
| 182 |
+
steps: int = 10,
|
| 183 |
+
) -> torch.Tensor:
|
| 184 |
+
"""
|
| 185 |
+
Iterative denoising (linear schedule).
|
| 186 |
+
Applies action_space.postprocess at the end (e.g., sigmoid on gripper).
|
| 187 |
+
"""
|
| 188 |
+
self.eval()
|
| 189 |
+
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
| 190 |
+
|
| 191 |
+
B = input_ids.shape[0]
|
| 192 |
+
D = self.action_space.dim_action
|
| 193 |
+
|
| 194 |
+
x1 = torch.randn(B, self.num_actions, D, device=proprio.device, dtype=proprio.dtype)
|
| 195 |
+
action = torch.zeros_like(x1)
|
| 196 |
+
|
| 197 |
+
steps = max(1, int(steps))
|
| 198 |
+
for i in range(steps, 0, -1):
|
| 199 |
+
t = torch.full((B,), i / steps, device=proprio.device, dtype=proprio.dtype)
|
| 200 |
+
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
| 201 |
+
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
|
| 202 |
+
action = self.transformer(
|
| 203 |
+
domain_id=domain_id,
|
| 204 |
+
action_with_noise=x_t_m,
|
| 205 |
+
proprio=proprio_m,
|
| 206 |
+
t=t,
|
| 207 |
+
**enc,
|
| 208 |
+
)
|
| 209 |
+
return self.action_space.postprocess(action)
|
| 210 |
+
|
| 211 |
+
# =============================== FastAPI service =============================
|
| 212 |
+
def _build_app(self, processor):
|
| 213 |
+
"""
|
| 214 |
+
Minimal FastAPI app for XVLA inference.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
processor: callable(images, text) -> Dict[str, torch.Tensor]
|
| 218 |
+
expected keys: "input_ids", "image_input", "image_mask"
|
| 219 |
+
"""
|
| 220 |
+
if self.app is not None:
|
| 221 |
+
return
|
| 222 |
+
|
| 223 |
+
app = FastAPI()
|
| 224 |
+
|
| 225 |
+
@app.post("/act")
|
| 226 |
+
def act(payload: Dict[str, Any]):
|
| 227 |
+
try:
|
| 228 |
+
self.eval()
|
| 229 |
+
# Decode up to 3 image inputs
|
| 230 |
+
images = []
|
| 231 |
+
for key in ("image0", "image1", "image2"):
|
| 232 |
+
if key not in payload: continue
|
| 233 |
+
v = json_numpy.loads(payload[key])
|
| 234 |
+
if isinstance(v, np.ndarray):
|
| 235 |
+
if v.ndim == 1: # encoded bytes
|
| 236 |
+
v = cv2.imdecode(v, cv2.IMREAD_COLOR)
|
| 237 |
+
images.append(Image.fromarray(v))
|
| 238 |
+
elif isinstance(v, (list, tuple)):
|
| 239 |
+
images.append(Image.fromarray(np.array(v)))
|
| 240 |
+
elif isinstance(v, str):
|
| 241 |
+
images.append(Image.open(v))
|
| 242 |
+
if not images:
|
| 243 |
+
return JSONResponse({"error": "No valid images found."}, status_code=400)
|
| 244 |
+
|
| 245 |
+
# Multimodal preprocessing by processor
|
| 246 |
+
inputs = processor(images, payload["language_instruction"])
|
| 247 |
+
if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
|
| 248 |
+
return JSONResponse({"error": "Processor returned incomplete inputs."}, status_code=400)
|
| 249 |
+
|
| 250 |
+
# Build proprio/domain tensors
|
| 251 |
+
proprio = torch.as_tensor(np.asarray(json_numpy.loads(payload["proprio"])))
|
| 252 |
+
domain_id = torch.tensor([int(payload["domain_id"])], dtype=torch.long)
|
| 253 |
+
|
| 254 |
+
# Align to model's device/dtype
|
| 255 |
+
device = next(self.parameters()).device
|
| 256 |
+
dtype = next(self.parameters()).dtype
|
| 257 |
+
|
| 258 |
+
def to_model(t: torch.Tensor) -> torch.Tensor:
|
| 259 |
+
if not isinstance(t, torch.Tensor):
|
| 260 |
+
t = torch.as_tensor(t)
|
| 261 |
+
# cast floats to model dtype, keep integral/bool as-is
|
| 262 |
+
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
|
| 263 |
+
|
| 264 |
+
inputs = {k: to_model(v) for k, v in inputs.items()}
|
| 265 |
+
inputs.update({
|
| 266 |
+
"proprio": to_model(proprio.unsqueeze(0)),
|
| 267 |
+
"domain_id": domain_id.to(device),
|
| 268 |
+
})
|
| 269 |
+
|
| 270 |
+
# Inference
|
| 271 |
+
steps = int(payload.get("steps", 10))
|
| 272 |
+
action = self.generate_actions(**inputs, steps=steps).squeeze(0).float().cpu().numpy()
|
| 273 |
+
return JSONResponse({"action": action.tolist()})
|
| 274 |
+
|
| 275 |
+
except Exception:
|
| 276 |
+
logging.error(traceback.format_exc())
|
| 277 |
+
return JSONResponse({"error": "Request failed"}, status_code=400)
|
| 278 |
+
|
| 279 |
+
self.app = app
|
| 280 |
+
|
| 281 |
+
def run(self, processor, host: str = "0.0.0.0", port: int = 8000):
|
| 282 |
+
"""
|
| 283 |
+
Launch the FastAPI service.
|
| 284 |
+
"""
|
| 285 |
+
self._build_app(processor)
|
| 286 |
+
assert self.app is not None
|
| 287 |
+
uvicorn.run(self.app, host=host, port=port)
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_xvla.XVLAProcessor"
|
| 4 |
+
},
|
| 5 |
+
"_valid_processor_keys": [
|
| 6 |
+
"images",
|
| 7 |
+
"do_resize",
|
| 8 |
+
"size",
|
| 9 |
+
"resample",
|
| 10 |
+
"do_rescale",
|
| 11 |
+
"rescale_factor",
|
| 12 |
+
"do_normalize",
|
| 13 |
+
"image_mean",
|
| 14 |
+
"image_std",
|
| 15 |
+
"return_tensors",
|
| 16 |
+
"data_format",
|
| 17 |
+
"input_data_format",
|
| 18 |
+
"do_convert_rgb"
|
| 19 |
+
],
|
| 20 |
+
"do_convert_rgb": null,
|
| 21 |
+
"do_normalize": true,
|
| 22 |
+
"do_rescale": true,
|
| 23 |
+
"do_resize": true,
|
| 24 |
+
"do_center_crop": false,
|
| 25 |
+
"image_processor_type": "CLIPImageProcessor",
|
| 26 |
+
"image_mean": [0.485, 0.456, 0.406],
|
| 27 |
+
"image_std": [0.229, 0.224, 0.225],
|
| 28 |
+
"processor_class": "XVLAProcessor",
|
| 29 |
+
"resample": 3,
|
| 30 |
+
"size": {
|
| 31 |
+
"height": 224,
|
| 32 |
+
"width": 224
|
| 33 |
+
},
|
| 34 |
+
"crop_size": {
|
| 35 |
+
"height": 224,
|
| 36 |
+
"width": 224
|
| 37 |
+
}
|
| 38 |
+
}
|
processing_xvla.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Copyright 2025 2toINF (https://github.com/2toINF)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
from transformers import ProcessorMixin
|
| 18 |
+
from typing import List, Union, Dict, Any, Optional
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class XVLAProcessor(ProcessorMixin):
|
| 23 |
+
"""
|
| 24 |
+
XVLAProcessor: Unified multimodal processor for XVLA models.
|
| 25 |
+
|
| 26 |
+
Handles:
|
| 27 |
+
- Multi-view image inputs (e.g., from multiple cameras).
|
| 28 |
+
- Batch processing for multiple samples.
|
| 29 |
+
- Joint tokenization and image tensor preparation.
|
| 30 |
+
|
| 31 |
+
This processor combines an image processor and a tokenizer under a single interface
|
| 32 |
+
so that users can call it directly like:
|
| 33 |
+
|
| 34 |
+
>>> processor = XVLAProcessor.from_pretrained("path/to/xvla")
|
| 35 |
+
>>> inputs = processor(images=batch_images, language_instruction=batch_texts)
|
| 36 |
+
|
| 37 |
+
It is fully compatible with the Hugging Face AutoProcessor API.
|
| 38 |
+
|
| 39 |
+
Attributes
|
| 40 |
+
----------
|
| 41 |
+
num_views : int, default=3
|
| 42 |
+
Expected number of image views per sample. Missing views will be padded with zeros.
|
| 43 |
+
language_max_length : int, default=50
|
| 44 |
+
Maximum token length for text encoding.
|
| 45 |
+
attributes : list
|
| 46 |
+
Required by ProcessorMixin to know which submodules are stored and reloaded.
|
| 47 |
+
image_processor_class : str
|
| 48 |
+
The name of the associated image processor class.
|
| 49 |
+
tokenizer_class : tuple(str)
|
| 50 |
+
The names of compatible tokenizer classes.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
num_views: int = 3
|
| 54 |
+
language_max_length: int = 50
|
| 55 |
+
|
| 56 |
+
# Hugging Face ProcessorMixin-required metadata
|
| 57 |
+
attributes = ["image_processor", "tokenizer"]
|
| 58 |
+
image_processor_class = "AutoImageProcessor"
|
| 59 |
+
tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
|
| 60 |
+
|
| 61 |
+
def __init__(self, image_processor=None, tokenizer=None):
|
| 62 |
+
"""
|
| 63 |
+
Initialize XVLAProcessor.
|
| 64 |
+
|
| 65 |
+
Parameters
|
| 66 |
+
----------
|
| 67 |
+
image_processor : PreTrainedImageProcessor, optional
|
| 68 |
+
The image processor used to normalize/resize images.
|
| 69 |
+
tokenizer : PreTrainedTokenizer, optional
|
| 70 |
+
The tokenizer used for text tokenization.
|
| 71 |
+
"""
|
| 72 |
+
# ProcessorMixin automatically saves these under self.image_processor / self.tokenizer
|
| 73 |
+
super().__init__(image_processor, tokenizer)
|
| 74 |
+
|
| 75 |
+
# ================== LANGUAGE ENCODING ==================
|
| 76 |
+
def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
|
| 77 |
+
"""
|
| 78 |
+
Tokenize one or more language instructions.
|
| 79 |
+
|
| 80 |
+
Parameters
|
| 81 |
+
----------
|
| 82 |
+
language_instruction : str or List[str]
|
| 83 |
+
A single instruction or a batch of instructions.
|
| 84 |
+
|
| 85 |
+
Returns
|
| 86 |
+
-------
|
| 87 |
+
Dict[str, torch.Tensor]
|
| 88 |
+
{
|
| 89 |
+
"input_ids": tensor of shape [B, L]
|
| 90 |
+
}
|
| 91 |
+
"""
|
| 92 |
+
if isinstance(language_instruction, str):
|
| 93 |
+
language_instruction = [language_instruction]
|
| 94 |
+
|
| 95 |
+
inputs = self.tokenizer(
|
| 96 |
+
language_instruction,
|
| 97 |
+
return_tensors="pt",
|
| 98 |
+
padding="max_length",
|
| 99 |
+
max_length=self.language_max_length,
|
| 100 |
+
truncation=True,
|
| 101 |
+
)
|
| 102 |
+
return {"input_ids": inputs["input_ids"]}
|
| 103 |
+
|
| 104 |
+
# ================== IMAGE ENCODING ==================
|
| 105 |
+
def encode_image(
|
| 106 |
+
self,
|
| 107 |
+
images: Union[List, List[List]],
|
| 108 |
+
**kwargs
|
| 109 |
+
) -> Dict[str, torch.Tensor]:
|
| 110 |
+
"""
|
| 111 |
+
Preprocess one or more sets of multi-view images.
|
| 112 |
+
|
| 113 |
+
Parameters
|
| 114 |
+
----------
|
| 115 |
+
images : List or List[List]
|
| 116 |
+
Single sample: [img1, img2, ...]
|
| 117 |
+
Batch: [[img1a, img1b], [img2a, img2b, img2c], ...]
|
| 118 |
+
Each image may be a PIL.Image, NumPy array, or torch.Tensor.
|
| 119 |
+
|
| 120 |
+
kwargs : dict
|
| 121 |
+
Extra arguments passed to the underlying image processor
|
| 122 |
+
(e.g., `do_resize=False`, `size=(224,224)`).
|
| 123 |
+
|
| 124 |
+
Returns
|
| 125 |
+
-------
|
| 126 |
+
Dict[str, torch.Tensor]
|
| 127 |
+
{
|
| 128 |
+
"image_input": tensor [B, num_views, C, H, W],
|
| 129 |
+
"image_mask": tensor [B, num_views]
|
| 130 |
+
}
|
| 131 |
+
"""
|
| 132 |
+
# Normalize to batch form
|
| 133 |
+
if not isinstance(images[0], (list, tuple)):
|
| 134 |
+
images = [images] # convert single sample to batch of size 1
|
| 135 |
+
|
| 136 |
+
batch_imgs, batch_masks = [], []
|
| 137 |
+
|
| 138 |
+
for sample_imgs in images:
|
| 139 |
+
processed = self.image_processor(sample_imgs, return_tensors="pt", **kwargs)["pixel_values"]
|
| 140 |
+
V_exist = processed.size(0)
|
| 141 |
+
|
| 142 |
+
# Pad to self.num_views
|
| 143 |
+
if V_exist < self.num_views:
|
| 144 |
+
processed = torch.cat(
|
| 145 |
+
[processed,
|
| 146 |
+
processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
|
| 147 |
+
dim=0,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Mask: True for valid slots, False for padding
|
| 151 |
+
image_mask = torch.zeros(self.num_views, dtype=torch.bool, device=processed.device)
|
| 152 |
+
image_mask[:V_exist] = True
|
| 153 |
+
|
| 154 |
+
batch_imgs.append(processed)
|
| 155 |
+
batch_masks.append(image_mask)
|
| 156 |
+
|
| 157 |
+
image_input = torch.stack(batch_imgs, dim=0) # [B, num_views, C, H, W]
|
| 158 |
+
image_mask = torch.stack(batch_masks, dim=0) # [B, num_views]
|
| 159 |
+
|
| 160 |
+
return {"image_input": image_input, "image_mask": image_mask}
|
| 161 |
+
|
| 162 |
+
# ================== COMBINED CALL ==================
|
| 163 |
+
def __call__(
|
| 164 |
+
self,
|
| 165 |
+
images: Optional[Union[List, List[List]]] = None,
|
| 166 |
+
language_instruction: Optional[Union[str, List[str]]] = None,
|
| 167 |
+
**kwargs
|
| 168 |
+
) -> Dict[str, torch.Tensor]:
|
| 169 |
+
"""
|
| 170 |
+
Combine image and text encoding into a unified multimodal input.
|
| 171 |
+
|
| 172 |
+
Parameters
|
| 173 |
+
----------
|
| 174 |
+
images : List or List[List], optional
|
| 175 |
+
Single-sample or batched multi-view images.
|
| 176 |
+
language_instruction : str or List[str], optional
|
| 177 |
+
Corresponding text instructions.
|
| 178 |
+
kwargs : dict
|
| 179 |
+
Extra args passed to image processor.
|
| 180 |
+
|
| 181 |
+
Returns
|
| 182 |
+
-------
|
| 183 |
+
Dict[str, torch.Tensor]
|
| 184 |
+
{
|
| 185 |
+
"input_ids": [B, L], optional,
|
| 186 |
+
"image_input": [B, num_views, C, H, W], optional,
|
| 187 |
+
"image_mask": [B, num_views], optional
|
| 188 |
+
}
|
| 189 |
+
"""
|
| 190 |
+
outputs: Dict[str, Any] = {}
|
| 191 |
+
|
| 192 |
+
# Encode language if provided
|
| 193 |
+
if language_instruction is not None:
|
| 194 |
+
outputs.update(self.encode_language(language_instruction))
|
| 195 |
+
|
| 196 |
+
# Encode image if provided
|
| 197 |
+
if images is not None:
|
| 198 |
+
outputs.update(self.encode_image(images, **kwargs))
|
| 199 |
+
|
| 200 |
+
# Sanity check for batch alignment
|
| 201 |
+
if "input_ids" in outputs and "image_input" in outputs:
|
| 202 |
+
assert outputs["input_ids"].size(0) == outputs["image_input"].size(0), (
|
| 203 |
+
f"Batch mismatch: text batch {outputs['input_ids'].size(0)} "
|
| 204 |
+
f"!= image batch {outputs['image_input'].size(0)}"
|
| 205 |
+
)
|
| 206 |
+
return outputs
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_max_length": 1024
|
| 3 |
+
}
|
| 4 |
+
|
transformer.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Copyright 2025 2toINF (https://github.com/2toINF)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
from functools import partial
|
| 21 |
+
from typing import Final, Iterable, Tuple
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ------------------------------- Small utils ----------------------------------
|
| 29 |
+
|
| 30 |
+
def _to_2tuple(x) -> Tuple:
|
| 31 |
+
"""Minimal replacement for timm.layers.to_2tuple."""
|
| 32 |
+
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
|
| 33 |
+
t = tuple(x)
|
| 34 |
+
return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
|
| 35 |
+
return (x, x)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _has_sdp_attention() -> bool:
|
| 39 |
+
"""Check if we can use PyTorch fused scaled_dot_product_attention."""
|
| 40 |
+
return hasattr(F, "scaled_dot_product_attention")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ---------------------------------- MLP --------------------------------------
|
| 44 |
+
|
| 45 |
+
class Mlp(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
MLP used in ViT-style blocks.
|
| 48 |
+
|
| 49 |
+
Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
in_features: int,
|
| 55 |
+
hidden_features: int | None = None,
|
| 56 |
+
out_features: int | None = None,
|
| 57 |
+
norm_layer: type[nn.Module] | None = None,
|
| 58 |
+
bias: bool | Tuple[bool, bool] = True,
|
| 59 |
+
drop: float | Tuple[float, float] = 0.0,
|
| 60 |
+
use_conv: bool = False,
|
| 61 |
+
) -> None:
|
| 62 |
+
super().__init__()
|
| 63 |
+
out_features = out_features or in_features
|
| 64 |
+
hidden_features = hidden_features or in_features
|
| 65 |
+
bias = _to_2tuple(bias)
|
| 66 |
+
drop_probs = _to_2tuple(drop)
|
| 67 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 68 |
+
|
| 69 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
| 70 |
+
self.act = nn.GELU(approximate="tanh")
|
| 71 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 72 |
+
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
| 73 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
| 74 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 75 |
+
|
| 76 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
# Expect [B, T, C] for Linear variant; caller is responsible for shapes.
|
| 78 |
+
x = self.fc1(x)
|
| 79 |
+
x = self.act(x)
|
| 80 |
+
x = self.drop1(x)
|
| 81 |
+
x = self.norm(x)
|
| 82 |
+
x = self.fc2(x)
|
| 83 |
+
x = self.drop2(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# -------------------------------- Attention ----------------------------------
|
| 88 |
+
|
| 89 |
+
class Attention(nn.Module):
|
| 90 |
+
"""
|
| 91 |
+
Multi-Head Self-Attention with optional fused SDPA fallback.
|
| 92 |
+
|
| 93 |
+
If PyTorch provides `scaled_dot_product_attention`, it will be used
|
| 94 |
+
(usually faster and more stable); otherwise we use a manual implementation.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
fused_attn: Final[bool]
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
dim: int,
|
| 102 |
+
num_heads: int = 8,
|
| 103 |
+
qkv_bias: bool = False,
|
| 104 |
+
qk_norm: bool = False,
|
| 105 |
+
attn_drop: float = 0.0,
|
| 106 |
+
proj_drop: float = 0.0,
|
| 107 |
+
norm_layer: type[nn.Module] = nn.LayerNorm,
|
| 108 |
+
) -> None:
|
| 109 |
+
super().__init__()
|
| 110 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 111 |
+
self.num_heads = num_heads
|
| 112 |
+
self.head_dim = dim // num_heads
|
| 113 |
+
self.scale = self.head_dim ** -0.5
|
| 114 |
+
self.fused_attn = _has_sdp_attention()
|
| 115 |
+
|
| 116 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 117 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 118 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 119 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 120 |
+
self.proj = nn.Linear(dim, dim)
|
| 121 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 122 |
+
|
| 123 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 124 |
+
"""
|
| 125 |
+
Parameters
|
| 126 |
+
----------
|
| 127 |
+
x : Tensor, shape [B, T, C]
|
| 128 |
+
Input sequence.
|
| 129 |
+
|
| 130 |
+
Returns
|
| 131 |
+
-------
|
| 132 |
+
Tensor, shape [B, T, C]
|
| 133 |
+
Output sequence after MHSA + projection.
|
| 134 |
+
"""
|
| 135 |
+
B, T, C = x.shape
|
| 136 |
+
qkv = (
|
| 137 |
+
self.qkv(x)
|
| 138 |
+
.reshape(B, T, 3, self.num_heads, self.head_dim)
|
| 139 |
+
.permute(2, 0, 3, 1, 4) # 3 x [B, H, T, Dh]
|
| 140 |
+
)
|
| 141 |
+
q, k, v = qkv.unbind(0) # each: [B, H, T, Dh]
|
| 142 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 143 |
+
|
| 144 |
+
if self.fused_attn:
|
| 145 |
+
x = F.scaled_dot_product_attention(
|
| 146 |
+
q, k, v,
|
| 147 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 148 |
+
) # [B, H, T, Dh]
|
| 149 |
+
else:
|
| 150 |
+
q = q * self.scale
|
| 151 |
+
attn = q @ k.transpose(-2, -1) # [B, H, T, T]
|
| 152 |
+
attn = attn.softmax(dim=-1)
|
| 153 |
+
attn = self.attn_drop(attn)
|
| 154 |
+
x = attn @ v # [B, H, T, Dh]
|
| 155 |
+
|
| 156 |
+
x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
|
| 157 |
+
x = self.proj(x)
|
| 158 |
+
x = self.proj_drop(x)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ------------------------------- Utilities -----------------------------------
|
| 163 |
+
|
| 164 |
+
def basic_init(module: nn.Module) -> None:
|
| 165 |
+
"""
|
| 166 |
+
Apply a basic initialization scheme to Linear layers.
|
| 167 |
+
|
| 168 |
+
- Weight: Xavier uniform initialization.
|
| 169 |
+
- Bias: Set to zero.
|
| 170 |
+
"""
|
| 171 |
+
if isinstance(module, nn.Linear):
|
| 172 |
+
nn.init.xavier_uniform_(module.weight)
|
| 173 |
+
if module.bias is not None:
|
| 174 |
+
nn.init.constant_(module.bias, 0.0)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
|
| 178 |
+
"""
|
| 179 |
+
Create sinusoidal timestep embeddings.
|
| 180 |
+
|
| 181 |
+
Parameters
|
| 182 |
+
----------
|
| 183 |
+
t : torch.Tensor
|
| 184 |
+
Shape [B]. Each element is a timestep index, may be fractional.
|
| 185 |
+
dim : int
|
| 186 |
+
Dimensionality of the output embedding.
|
| 187 |
+
max_period : int, default=100
|
| 188 |
+
Controls the minimum frequency of the sinusoids.
|
| 189 |
+
|
| 190 |
+
Returns
|
| 191 |
+
-------
|
| 192 |
+
torch.Tensor
|
| 193 |
+
Shape [B, dim]. Sinusoidal embeddings.
|
| 194 |
+
"""
|
| 195 |
+
half = dim // 2
|
| 196 |
+
freqs = torch.exp(
|
| 197 |
+
-math.log(max_period)
|
| 198 |
+
* torch.arange(start=0, end=half, dtype=t.dtype, device=t.device)
|
| 199 |
+
/ half
|
| 200 |
+
)
|
| 201 |
+
args = t[:, None] * freqs[None]
|
| 202 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 203 |
+
if dim % 2 == 1:
|
| 204 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 205 |
+
return embedding
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# ------------------------------- Core Layers ----------------------------------
|
| 209 |
+
|
| 210 |
+
class DomainAwareLinear(nn.Module):
|
| 211 |
+
"""
|
| 212 |
+
Linear layer with domain-conditioned parameters (per-sample).
|
| 213 |
+
|
| 214 |
+
Each domain has its own weight and bias vectors, stored in embeddings.
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.input_size = input_size
|
| 220 |
+
self.output_size = output_size
|
| 221 |
+
self.fc = nn.Embedding(num_domains, output_size * input_size)
|
| 222 |
+
self.bias = nn.Embedding(num_domains, output_size)
|
| 223 |
+
nn.init.xavier_uniform_(self.fc.weight)
|
| 224 |
+
nn.init.zeros_(self.bias.weight)
|
| 225 |
+
|
| 226 |
+
def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
|
| 227 |
+
"""
|
| 228 |
+
Parameters
|
| 229 |
+
----------
|
| 230 |
+
x : Tensor
|
| 231 |
+
[B, I] or [B, T, I]
|
| 232 |
+
domain_id : LongTensor
|
| 233 |
+
[B], domain indices.
|
| 234 |
+
|
| 235 |
+
Returns
|
| 236 |
+
-------
|
| 237 |
+
Tensor
|
| 238 |
+
[B, O] or [B, T, O]
|
| 239 |
+
"""
|
| 240 |
+
B = domain_id.shape[0]
|
| 241 |
+
squeeze_T = False
|
| 242 |
+
if x.dim() == 2:
|
| 243 |
+
x = x.unsqueeze(1)
|
| 244 |
+
squeeze_T = True
|
| 245 |
+
W = self.fc(domain_id).view(B, self.input_size, self.output_size)
|
| 246 |
+
b = self.bias(domain_id).view(B, self.output_size)
|
| 247 |
+
y = torch.matmul(x, W) + b.view(B, 1, self.output_size)
|
| 248 |
+
if squeeze_T:
|
| 249 |
+
y = y.squeeze(1)
|
| 250 |
+
return y
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class TransformerBlock(nn.Module):
|
| 254 |
+
"""
|
| 255 |
+
Standard Transformer block (pre-LN): LN → MHSA → residual, LN → MLP → residual.
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
|
| 259 |
+
super().__init__()
|
| 260 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
| 261 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
| 262 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
|
| 263 |
+
self.mlp = Mlp(
|
| 264 |
+
in_features=hidden_size,
|
| 265 |
+
hidden_features=int(hidden_size * mlp_ratio),
|
| 266 |
+
drop=0.1,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 270 |
+
"""
|
| 271 |
+
Parameters
|
| 272 |
+
----------
|
| 273 |
+
x : Tensor, [B, T, H]
|
| 274 |
+
|
| 275 |
+
Returns
|
| 276 |
+
-------
|
| 277 |
+
Tensor, [B, T, H]
|
| 278 |
+
"""
|
| 279 |
+
x = x + self.attn(self.norm1(x))
|
| 280 |
+
x = x + self.mlp(self.norm2(x))
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# --------------------------- Main Model ---------------------------------------
|
| 285 |
+
|
| 286 |
+
class SoftPromptedTransformer(nn.Module):
|
| 287 |
+
"""
|
| 288 |
+
Multi-modal, domain-aware Transformer with optional soft prompts.
|
| 289 |
+
|
| 290 |
+
See parameter and forward I/O descriptions inside the docstrings.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
def __init__(
|
| 294 |
+
self,
|
| 295 |
+
hidden_size: int = 768,
|
| 296 |
+
multi_modal_input_size: int = 768,
|
| 297 |
+
depth: int = 24,
|
| 298 |
+
num_heads: int = 16,
|
| 299 |
+
mlp_ratio: float = 4.0,
|
| 300 |
+
num_domains: int = 20,
|
| 301 |
+
dim_action: int = 20,
|
| 302 |
+
dim_propio: int = 20,
|
| 303 |
+
dim_time: int = 32,
|
| 304 |
+
len_soft_prompts: int = 32,
|
| 305 |
+
max_len_seq: int = 512,
|
| 306 |
+
use_hetero_proj: bool = False,
|
| 307 |
+
) -> None:
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.hidden_size = hidden_size
|
| 310 |
+
self.dim_action = dim_action
|
| 311 |
+
self.dim_time = dim_time
|
| 312 |
+
self.len_soft_prompts = len_soft_prompts
|
| 313 |
+
self.use_hetero_proj = use_hetero_proj
|
| 314 |
+
|
| 315 |
+
self.blocks = nn.ModuleList(
|
| 316 |
+
[TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
if use_hetero_proj:
|
| 320 |
+
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
|
| 321 |
+
self.aux_visual_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
|
| 322 |
+
else:
|
| 323 |
+
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
| 324 |
+
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
| 325 |
+
|
| 326 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
|
| 327 |
+
nn.init.normal_(self.pos_emb, std=0.02)
|
| 328 |
+
|
| 329 |
+
self.norm = nn.LayerNorm(hidden_size)
|
| 330 |
+
self.action_encoder = DomainAwareLinear(
|
| 331 |
+
dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
|
| 332 |
+
)
|
| 333 |
+
self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
|
| 334 |
+
|
| 335 |
+
if len_soft_prompts > 0:
|
| 336 |
+
self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
|
| 337 |
+
nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
|
| 338 |
+
|
| 339 |
+
self.apply(basic_init)
|
| 340 |
+
|
| 341 |
+
def forward(
|
| 342 |
+
self,
|
| 343 |
+
domain_id: torch.LongTensor,
|
| 344 |
+
vlm_features: torch.Tensor,
|
| 345 |
+
aux_visual_inputs: torch.Tensor,
|
| 346 |
+
action_with_noise: torch.Tensor,
|
| 347 |
+
proprio: torch.Tensor,
|
| 348 |
+
t: torch.Tensor,
|
| 349 |
+
) -> torch.Tensor:
|
| 350 |
+
"""
|
| 351 |
+
Forward pass.
|
| 352 |
+
|
| 353 |
+
Inputs
|
| 354 |
+
------
|
| 355 |
+
domain_id : [B]
|
| 356 |
+
vlm_features : [B, T_vlm, D]
|
| 357 |
+
aux_visual_inputs : [B, T_aux, D]
|
| 358 |
+
action_with_noise : [B, T_action, dim_action]
|
| 359 |
+
proprio : [B, dim_propio]
|
| 360 |
+
t : [B]
|
| 361 |
+
|
| 362 |
+
Returns
|
| 363 |
+
-------
|
| 364 |
+
Tensor
|
| 365 |
+
Predicted actions, [B, T_action, dim_action]
|
| 366 |
+
"""
|
| 367 |
+
B, num_actions = action_with_noise.shape[:2]
|
| 368 |
+
|
| 369 |
+
# Encode (action + proprio + time) → tokens
|
| 370 |
+
time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
|
| 371 |
+
time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time)
|
| 372 |
+
proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1])
|
| 373 |
+
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
|
| 374 |
+
x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
|
| 375 |
+
|
| 376 |
+
# Project visual streams and concatenate
|
| 377 |
+
if self.use_hetero_proj:
|
| 378 |
+
x = torch.cat(
|
| 379 |
+
[x, self.vlm_proj(vlm_features, domain_id), self.aux_visual_proj(aux_visual_inputs, domain_id)],
|
| 380 |
+
dim=1,
|
| 381 |
+
)
|
| 382 |
+
else:
|
| 383 |
+
x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
|
| 384 |
+
|
| 385 |
+
# Add positional embeddings (truncate if needed)
|
| 386 |
+
seq_len = x.shape[1]
|
| 387 |
+
if seq_len > self.pos_emb.shape[1]:
|
| 388 |
+
raise ValueError(
|
| 389 |
+
f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}."
|
| 390 |
+
)
|
| 391 |
+
x = x + self.pos_emb[:, :seq_len, :]
|
| 392 |
+
|
| 393 |
+
# Append soft prompts
|
| 394 |
+
if self.len_soft_prompts > 0:
|
| 395 |
+
soft_prompts = self.soft_prompt_hub(domain_id).view(B, self.len_soft_prompts, self.hidden_size)
|
| 396 |
+
x = torch.cat([x, soft_prompts], dim=1)
|
| 397 |
+
|
| 398 |
+
# Transformer backbone
|
| 399 |
+
for block in self.blocks:
|
| 400 |
+
x = block(x)
|
| 401 |
+
|
| 402 |
+
# Decode only the action segment
|
| 403 |
+
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|