2toINF commited on
Commit
cb94537
·
verified ·
1 Parent(s): 84abbcf

Initial upload for X-VLA-Google-Robot

Browse files
action_hub.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright 2025 2toINF (https://github.com/2toINF)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------------
16
+
17
+ from __future__ import annotations
18
+ from typing import Iterable, Tuple, Dict, Type
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ # =============================================================================
23
+ # Registry
24
+ # =============================================================================
25
+ ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {}
26
+
27
+
28
+ def register_action(name: str):
29
+ """Decorator for registering a new action space."""
30
+ def _wrap(cls):
31
+ key = name.lower()
32
+ if key in ACTION_REGISTRY:
33
+ raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
34
+ ACTION_REGISTRY[key] = cls
35
+ cls.name = key
36
+ return cls
37
+ return _wrap
38
+
39
+
40
+ def build_action_space(name: str, **kwargs) -> "BaseActionSpace":
41
+ """Instantiate a registered action space by name."""
42
+ key = name.lower()
43
+ if key not in ACTION_REGISTRY:
44
+ raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
45
+ return ACTION_REGISTRY[key](**kwargs)
46
+
47
+
48
+ # =============================================================================
49
+ # Base class
50
+ # =============================================================================
51
+ class BaseActionSpace(nn.Module):
52
+ """
53
+ Abstract base class for all action-space definitions.
54
+
55
+ Each subclass defines:
56
+ - `dim_action`: dimension of the action vector.
57
+ - `gripper_idx`: indices of gripper channels.
58
+ - `compute_loss(pred, target)`: supervised loss for this space.
59
+ - `preprocess(proprio, action, mode)`: pre-step modifications.
60
+ - `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
61
+ """
62
+
63
+ name: str = "base"
64
+ dim_action: int = 0
65
+ gripper_idx: Tuple[int, ...] = ()
66
+
67
+ def __init__(self):
68
+ super().__init__()
69
+
70
+ # ---------------------------------------------------------------------
71
+ # Core supervised loss
72
+ # ---------------------------------------------------------------------
73
+ def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
74
+ raise NotImplementedError
75
+
76
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
77
+ """Alias for compute_loss."""
78
+ return self.compute_loss(pred, target)
79
+
80
+ # ---------------------------------------------------------------------
81
+ # Space-level hooks
82
+ # ---------------------------------------------------------------------
83
+ def preprocess(
84
+ self,
85
+ proprio: torch.Tensor,
86
+ action: torch.Tensor,
87
+ mode: str = "train",
88
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ """Default: return unchanged."""
90
+ return proprio, action
91
+
92
+ def postprocess(self, action: torch.Tensor) -> torch.Tensor:
93
+ """Default: return unchanged."""
94
+ return action
95
+
96
+
97
+ # =============================================================================
98
+ # Utilities
99
+ # =============================================================================
100
+ def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None:
101
+ bad = [i for i in idx if i < 0 or i >= D]
102
+ if bad:
103
+ raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}")
104
+
105
+
106
+ # =============================================================================
107
+ # Implementations
108
+ # =============================================================================
109
+ @register_action("ee6d")
110
+ class EE6DActionSpace(BaseActionSpace):
111
+ """End-effector layout with xyz, 6D rotation, and gripper channels."""
112
+
113
+ dim_action = 20
114
+ gripper_idx = (9, 19)
115
+ GRIPPER_SCALE = 1.0
116
+ XYZ_SCALE = 500.0
117
+ ROT_SCALE = 10.0
118
+
119
+ POS_IDX_1 = (0, 1, 2)
120
+ POS_IDX_2 = (10, 11, 12)
121
+ ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
122
+ ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
123
+
124
+ def __init__(self):
125
+ super().__init__()
126
+ self.mse = nn.MSELoss()
127
+ self.bce = nn.BCEWithLogitsLoss()
128
+
129
+ def compute_loss(self, pred, target):
130
+ assert pred.shape == target.shape, "pred/target shapes must match"
131
+ B, T, D = pred.shape
132
+ _ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
133
+
134
+ # Gripper BCE
135
+ g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
136
+ gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
137
+
138
+ # XYZ position
139
+ pos_loss = (
140
+ self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
141
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
142
+ ) * self.XYZ_SCALE
143
+
144
+ # Rotation 6D
145
+ rot_loss = (
146
+ self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
147
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
148
+ ) * self.ROT_SCALE
149
+
150
+ return {
151
+ "position_loss": pos_loss,
152
+ "rotate6D_loss": rot_loss,
153
+ "gripper_loss": gripper_loss,
154
+ }
155
+
156
+ def preprocess(self, proprio, action, mode="train"):
157
+ """Zero-out gripper channels in proprio/action."""
158
+ proprio_m = proprio.clone()
159
+ action_m = action.clone()
160
+ proprio_m[..., self.gripper_idx] = 0.0
161
+ action_m[..., self.gripper_idx] = 0.0
162
+ return proprio_m, action_m
163
+
164
+ def postprocess(self, action: torch.Tensor) -> torch.Tensor:
165
+ """Apply sigmoid to gripper logits."""
166
+ if action.size(-1) > max(self.gripper_idx):
167
+ action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
168
+ return action
169
+
170
+
171
+ @register_action("joint")
172
+ class JointActionSpace(BaseActionSpace):
173
+ """Joint-space layout with joints + gripper only."""
174
+
175
+ dim_action = 14
176
+ gripper_idx = (6, 13)
177
+ GRIPPER_SCALE = 0.1
178
+ JOINTS_SCALE = 1.0
179
+
180
+ def __init__(self):
181
+ super().__init__()
182
+ self.mse = nn.MSELoss()
183
+ self.bce = nn.BCEWithLogitsLoss()
184
+
185
+ def compute_loss(self, pred, target):
186
+ assert pred.shape == target.shape
187
+ B, T, D = pred.shape
188
+ _ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
189
+
190
+ g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
191
+ gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
192
+
193
+ joints_idx = tuple(i for i in range(D) if i not in set(self.gripper_idx))
194
+ joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
195
+
196
+ return {
197
+ "joints_loss": joints_loss,
198
+ "gripper_loss": gripper_loss,
199
+ }
200
+
201
+ def preprocess(self, proprio, action, mode="train"):
202
+ """Zero-out gripper channels in proprio/action."""
203
+ proprio_m = proprio.clone()
204
+ action_m = action.clone()
205
+ proprio_m[..., self.gripper_idx] = 0.0
206
+ action_m[..., self.gripper_idx] = 0.0
207
+ return proprio_m, action_m
208
+
209
+ def postprocess(self, action: torch.Tensor) -> torch.Tensor:
210
+ """Apply sigmoid to gripper logits."""
211
+ if action.size(-1) > max(self.gripper_idx):
212
+ action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
213
+ return action
214
+
215
+
216
+ @register_action("agibot_ee6d")
217
+ class AGIBOTEE6DActionSpace(BaseActionSpace):
218
+ """AGI-bot variant of EE6DActionSpace using MSE for all components."""
219
+
220
+ dim_action = 20
221
+ gripper_idx = (9, 19)
222
+ GRIPPER_SCALE = 10.0
223
+ XYZ_SCALE = 500.0
224
+ ROT_SCALE = 10.0
225
+ POS_IDX_1 = (0, 1, 2)
226
+ POS_IDX_2 = (10, 11, 12)
227
+ ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
228
+ ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
229
+
230
+ def __init__(self):
231
+ super().__init__()
232
+ self.mse = nn.MSELoss()
233
+
234
+ def compute_loss(self, pred, target):
235
+ assert pred.shape == target.shape
236
+ B, T, D = pred.shape
237
+ _ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
238
+
239
+ gripper_loss = self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
240
+ pos_loss = (
241
+ self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
242
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
243
+ ) * self.XYZ_SCALE
244
+ rot_loss = (
245
+ self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
246
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
247
+ ) * self.ROT_SCALE
248
+
249
+ return {
250
+ "position_loss": pos_loss,
251
+ "rotate6D_loss": rot_loss,
252
+ "gripper_loss": gripper_loss,
253
+ }
254
+
255
+ def preprocess(self, proprio, action, mode="train"):
256
+ """No preprocessing applied in AGIBOT variant."""
257
+ return proprio, action
258
+
259
+ def postprocess(self, action: torch.Tensor) -> torch.Tensor:
260
+ """AGIBOT does not postprocess."""
261
+ return action
262
+
263
+
264
+ # =============================================================================
265
+ # Exports
266
+ # =============================================================================
267
+ __all__ = [
268
+ "BaseActionSpace",
269
+ "build_action_space",
270
+ "register_action",
271
+ "EE6DActionSpace",
272
+ "JointActionSpace",
273
+ "AGIBOTEE6DActionSpace",
274
+ "ACTION_REGISTRY",
275
+ ]
config.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "xvla",
3
+ "model_type": "xvla",
4
+ "architectures": ["XVLA"],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_xvla.XVLAConfig",
7
+ "AutoModel": "modeling_xvla.XVLA"
8
+ },
9
+
10
+ "action_mode": "ee6d",
11
+ "use_proprio": true,
12
+ "num_actions": 30,
13
+
14
+ "hidden_size": 1024,
15
+ "depth": 24,
16
+ "num_heads": 16,
17
+ "mlp_ratio": 4.0,
18
+ "num_domains": 30,
19
+ "len_soft_prompts": 32,
20
+ "dim_time": 32,
21
+ "max_len_seq": 512,
22
+ "use_hetero_proj": false,
23
+ "soft_prompt_length": 32,
24
+
25
+ "florence_config": {
26
+ "model_type": "florence2",
27
+ "bos_token_id": 0,
28
+ "eos_token_id": 2,
29
+ "ignore_index": -100,
30
+ "pad_token_id": 1,
31
+ "projection_dim": 1024,
32
+
33
+ "text_config": {
34
+ "vocab_size": 51289,
35
+ "activation_dropout": 0.1,
36
+ "activation_function": "gelu",
37
+ "attention_dropout": 0.1,
38
+ "d_model": 1024,
39
+ "decoder_attention_heads": 16,
40
+ "decoder_layers": 12,
41
+ "encoder_attention_heads": 16,
42
+ "encoder_layers": 12,
43
+ "dropout": 0.1,
44
+ "max_position_embeddings": 4096,
45
+ "num_hidden_layers": 12,
46
+ "num_beams": 3
47
+ },
48
+
49
+ "vision_config": {
50
+ "model_type": "davit",
51
+ "drop_path_rate": 0.1,
52
+ "patch_size": [7, 3, 3, 3],
53
+ "patch_stride": [4, 2, 2, 2],
54
+ "patch_padding": [3, 1, 1, 1],
55
+ "patch_prenorm": [false, true, true, true],
56
+ "enable_checkpoint": false,
57
+ "dim_embed": [256, 512, 1024, 2048],
58
+ "num_heads": [8, 16, 32, 64],
59
+ "num_groups": [8, 16, 32, 64],
60
+ "depths": [1, 1, 9, 1],
61
+ "window_size": 12,
62
+ "projection_dim": 1024,
63
+ "visual_temporal_embedding": {
64
+ "type": "COSINE",
65
+ "max_temporal_embeddings": 100
66
+ },
67
+ "image_pos_embed": {
68
+ "type": "learned_abs_2d",
69
+ "max_pos_embeddings": 50
70
+ },
71
+ "image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"]
72
+ },
73
+
74
+ "vocab_size": 51289,
75
+ "torch_dtype": "float16",
76
+ "is_encoder_decoder": true
77
+ }
78
+ }
configuration_florence2.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import warnings
15
+ """ Florence-2 configuration"""
16
+
17
+ from typing import Optional
18
+
19
+ from transformers import AutoConfig
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ class Florence2VisionConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
28
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
29
+ defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
36
+ The dropout rate of the drop path layer.
37
+ patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
38
+ The patch size of the image.
39
+ patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
40
+ The patch stride of the image.
41
+ patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
42
+ The patch padding of the image.
43
+ patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
44
+ Whether to apply layer normalization before the patch embedding layer.
45
+ enable_checkpoint (`bool`, *optional*, defaults to False):
46
+ Whether to enable checkpointing.
47
+ dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
48
+ The dimension of the embedding layer.
49
+ num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
50
+ The number of attention heads.
51
+ num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
52
+ The number of groups.
53
+ depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
54
+ The depth of the model.
55
+ window_size (`int`, *optional*, defaults to 12):
56
+ The window size of the model.
57
+ projection_dim (`int`, *optional*, defaults to 1024):
58
+ The dimension of the projection layer.
59
+ visual_temporal_embedding (`dict`, *optional*):
60
+ The configuration of the visual temporal embedding.
61
+ image_pos_embed (`dict`, *optional*):
62
+ The configuration of the image position embedding.
63
+ image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
64
+ The source of the image feature.
65
+ Example:
66
+
67
+ ```python
68
+ >>> from transformers import Florence2VisionConfig, Florence2VisionModel
69
+
70
+ >>> # Initializing a Florence2 Vision style configuration
71
+ >>> configuration = Florence2VisionConfig()
72
+
73
+ >>> # Initializing a model (with random weights)
74
+ >>> model = Florence2VisionModel(configuration)
75
+
76
+ >>> # Accessing the model configuration
77
+ >>> configuration = model.config
78
+ ```"""
79
+
80
+ model_type = "davit"
81
+ keys_to_ignore_at_inference = ["past_key_values"]
82
+
83
+ def __init__(
84
+ self,
85
+ drop_path_rate=0.1,
86
+ patch_size=[7, 3, 3, 3],
87
+ patch_stride=[4, 2, 2, 2],
88
+ patch_padding=[3, 1, 1, 1],
89
+ patch_prenorm=[False, True, True, True],
90
+ enable_checkpoint=False,
91
+ dim_embed=[256, 512, 1024, 2048],
92
+ num_heads=[8, 16, 32, 64],
93
+ num_groups=[8, 16, 32, 64],
94
+ depths=[1, 1, 9, 1],
95
+ window_size=12,
96
+ projection_dim=1024,
97
+ visual_temporal_embedding=None,
98
+ image_pos_embed=None,
99
+ image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
100
+ **kwargs,
101
+ ):
102
+ self.drop_path_rate = drop_path_rate
103
+ self.patch_size = patch_size
104
+ self.patch_stride = patch_stride
105
+ self.patch_padding = patch_padding
106
+ self.patch_prenorm = patch_prenorm
107
+ self.enable_checkpoint = enable_checkpoint
108
+ self.dim_embed = dim_embed
109
+ self.num_heads = num_heads
110
+ self.num_groups = num_groups
111
+ self.depths = depths
112
+ self.window_size = window_size
113
+ self.projection_dim = projection_dim
114
+ self.visual_temporal_embedding = visual_temporal_embedding
115
+ self.image_pos_embed = image_pos_embed
116
+ self.image_feature_source = image_feature_source
117
+
118
+ super().__init__(**kwargs)
119
+
120
+
121
+
122
+ class Florence2LanguageConfig(PretrainedConfig):
123
+ r"""
124
+ This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
125
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
126
+ defaults will yield a similar configuration to that of the BART
127
+ [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
128
+
129
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
130
+ documentation from [`PretrainedConfig`] for more information.
131
+
132
+
133
+ Args:
134
+ vocab_size (`int`, *optional*, defaults to 51289):
135
+ Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
136
+ `inputs_ids` passed when calling [`Florence2LanguageModel`].
137
+ d_model (`int`, *optional*, defaults to 1024):
138
+ Dimensionality of the layers and the pooler layer.
139
+ encoder_layers (`int`, *optional*, defaults to 12):
140
+ Number of encoder layers.
141
+ decoder_layers (`int`, *optional*, defaults to 12):
142
+ Number of decoder layers.
143
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
144
+ Number of attention heads for each attention layer in the Transformer encoder.
145
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
146
+ Number of attention heads for each attention layer in the Transformer decoder.
147
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
148
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
149
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
150
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
151
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
152
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
153
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
154
+ dropout (`float`, *optional*, defaults to 0.1):
155
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
156
+ attention_dropout (`float`, *optional*, defaults to 0.0):
157
+ The dropout ratio for the attention probabilities.
158
+ activation_dropout (`float`, *optional*, defaults to 0.0):
159
+ The dropout ratio for activations inside the fully connected layer.
160
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
161
+ The dropout ratio for classifier.
162
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
163
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
164
+ just in case (e.g., 512 or 1024 or 2048).
165
+ init_std (`float`, *optional*, defaults to 0.02):
166
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
167
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
168
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
169
+ for more details.
170
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
171
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
172
+ for more details.
173
+ scale_embedding (`bool`, *optional*, defaults to `False`):
174
+ Scale embeddings by diving by sqrt(d_model).
175
+ use_cache (`bool`, *optional*, defaults to `True`):
176
+ Whether or not the model should return the last key/values attentions (not used by all models).
177
+ num_labels (`int`, *optional*, defaults to 3):
178
+ The number of labels to use in [`Florence2LanguageForSequenceClassification`].
179
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
180
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
181
+ `eos_token_id`.
182
+
183
+ Example:
184
+
185
+ ```python
186
+ >>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
187
+
188
+ >>> # Initializing a Florence2 Language style configuration
189
+ >>> configuration = Florence2LanguageConfig()
190
+
191
+ >>> # Initializing a model (with random weights)
192
+ >>> model = Florence2LangaugeModel(configuration)
193
+
194
+ >>> # Accessing the model configuration
195
+ >>> configuration = model.config
196
+ ```"""
197
+
198
+ model_type = "florence2_language"
199
+ keys_to_ignore_at_inference = ["past_key_values"]
200
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
201
+
202
+ def __init__(
203
+ self,
204
+ vocab_size=51289,
205
+ max_position_embeddings=1024,
206
+ encoder_layers=12,
207
+ encoder_ffn_dim=4096,
208
+ encoder_attention_heads=16,
209
+ decoder_layers=12,
210
+ decoder_ffn_dim=4096,
211
+ decoder_attention_heads=16,
212
+ encoder_layerdrop=0.0,
213
+ decoder_layerdrop=0.0,
214
+ activation_function="gelu",
215
+ d_model=1024,
216
+ dropout=0.1,
217
+ attention_dropout=0.0,
218
+ activation_dropout=0.0,
219
+ init_std=0.02,
220
+ classifier_dropout=0.0,
221
+ scale_embedding=False,
222
+ use_cache=True,
223
+ num_labels=3,
224
+ pad_token_id=1,
225
+ bos_token_id=0,
226
+ eos_token_id=2,
227
+ is_encoder_decoder=True,
228
+ decoder_start_token_id=2,
229
+ forced_eos_token_id=2,
230
+ **kwargs,
231
+ ):
232
+ self.vocab_size = vocab_size
233
+ self.max_position_embeddings = max_position_embeddings
234
+ self.d_model = d_model
235
+ self.encoder_ffn_dim = encoder_ffn_dim
236
+ self.encoder_layers = encoder_layers
237
+ self.encoder_attention_heads = encoder_attention_heads
238
+ self.decoder_ffn_dim = decoder_ffn_dim
239
+ self.decoder_layers = decoder_layers
240
+ self.decoder_attention_heads = decoder_attention_heads
241
+ self.dropout = dropout
242
+ self.attention_dropout = attention_dropout
243
+ self.activation_dropout = activation_dropout
244
+ self.activation_function = activation_function
245
+ self.init_std = init_std
246
+ self.encoder_layerdrop = encoder_layerdrop
247
+ self.decoder_layerdrop = decoder_layerdrop
248
+ self.classifier_dropout = classifier_dropout
249
+ self.use_cache = use_cache
250
+ self.num_hidden_layers = encoder_layers
251
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
252
+
253
+ super().__init__(
254
+ num_labels=num_labels,
255
+ pad_token_id=pad_token_id,
256
+ bos_token_id=bos_token_id,
257
+ eos_token_id=eos_token_id,
258
+ is_encoder_decoder=is_encoder_decoder,
259
+ decoder_start_token_id=decoder_start_token_id,
260
+ forced_eos_token_id=forced_eos_token_id,
261
+ **kwargs,
262
+ )
263
+
264
+ # ensure backward compatibility for BART CNN models
265
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
266
+ self.forced_bos_token_id = self.bos_token_id
267
+ warnings.warn(
268
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
269
+ "The config can simply be saved and uploaded again to be fixed."
270
+ )
271
+
272
+ class Florence2Config(PretrainedConfig):
273
+ r"""
274
+ This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
275
+ Florence-2 model according to the specified arguments, defining the model architecture.
276
+
277
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
278
+ documentation from [`PretrainedConfig`] for more information.
279
+
280
+ Args:
281
+ vision_config (`Florence2VisionConfig`, *optional*):
282
+ Custom vision config or dict
283
+ text_config (`Union[AutoConfig, dict]`, *optional*):
284
+ The config object of the text backbone.
285
+ ignore_index (`int`, *optional*, defaults to -100):
286
+ The ignore index for the loss function.
287
+ vocab_size (`int`, *optional*, defaults to 51289):
288
+ Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
289
+ `inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
290
+ projection_dim (`int`, *optional*, defaults to 1024):
291
+ Dimension of the multimodal projection space.
292
+
293
+ Example:
294
+
295
+ ```python
296
+ >>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
297
+
298
+ >>> # Initializing a clip-like vision config
299
+ >>> vision_config = CLIPVisionConfig()
300
+
301
+ >>> # Initializing a Bart config
302
+ >>> text_config = BartConfig()
303
+
304
+ >>> # Initializing a Florence-2 configuration
305
+ >>> configuration = Florence2Config(vision_config, text_config)
306
+
307
+ >>> # Initializing a model from the florence-2 configuration
308
+ >>> model = Florence2ForConditionalGeneration(configuration)
309
+
310
+ >>> # Accessing the model configuration
311
+ >>> configuration = model.config
312
+ ```"""
313
+
314
+ model_type = "florence2"
315
+ is_composition = False
316
+
317
+ def __init__(
318
+ self,
319
+ vision_config=None,
320
+ text_config=None,
321
+ ignore_index=-100,
322
+ vocab_size=51289,
323
+ projection_dim=1024,
324
+ **kwargs,
325
+ ):
326
+ self.ignore_index = ignore_index
327
+ self.vocab_size = vocab_size
328
+ self.projection_dim = projection_dim
329
+ if vision_config is not None:
330
+ vision_config = Florence2VisionConfig(**vision_config)
331
+ self.vision_config = vision_config
332
+ self.vocab_size = self.vocab_size
333
+
334
+ self.text_config = text_config
335
+ if text_config is not None:
336
+ self.text_config = Florence2LanguageConfig(**text_config)
337
+
338
+
339
+ super().__init__(**kwargs)
340
+
configuration_xvla.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright 2025 2toINF (https://github.com/2toINF)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------------
16
+
17
+ from .configuration_florence2 import Florence2Config
18
+ from transformers.configuration_utils import PretrainedConfig
19
+
20
+
21
+ class XVLAConfig(PretrainedConfig):
22
+ """
23
+ Configuration class for the **XVLA (Extended Vision-Language-Action)** model.
24
+
25
+ This configuration defines all submodules of XVLA in a single place:
26
+ - The visual-language backbone (Florence2)
27
+ - The temporal/action transformer
28
+ - The action/proprio setup
29
+ """
30
+
31
+ model_type = "xvla"
32
+
33
+ def __init__(
34
+ # === Florence backbone ===
35
+ self,
36
+ florence_config: dict | None = None,
37
+
38
+ # === Transformer head ===
39
+ hidden_size: int = 1024,
40
+ depth: int = 24,
41
+ num_heads: int = 16,
42
+ mlp_ratio: float = 4.0,
43
+ num_domains: int = 30,
44
+ len_soft_prompts: int = 32,
45
+ dim_time: int = 32,
46
+ max_len_seq: int = 512,
47
+ use_hetero_proj: bool = False,
48
+ soft_prompt_length: int = 32,
49
+
50
+ # === Action & proprio ===
51
+ num_actions: int = 30,
52
+ action_mode: str = "ee6d",
53
+ use_proprio: bool = True,
54
+
55
+ **kwargs,
56
+ ):
57
+ # Florence2 backbone configuration
58
+ if isinstance(florence_config, dict):
59
+ self.florence_config = Florence2Config(**florence_config)
60
+ elif isinstance(florence_config, Florence2Config):
61
+ self.florence_config = florence_config
62
+ else:
63
+ self.florence_config = Florence2Config()
64
+
65
+ # Transformer hyperparameters
66
+ self.hidden_size = hidden_size
67
+ self.depth = depth
68
+ self.num_heads = num_heads
69
+ self.mlp_ratio = mlp_ratio
70
+ self.num_domains = num_domains
71
+ self.len_soft_prompts = len_soft_prompts
72
+ self.dim_time = dim_time
73
+ self.max_len_seq = max_len_seq
74
+ self.use_hetero_proj = use_hetero_proj
75
+ self.soft_prompt_length = soft_prompt_length
76
+
77
+ # Action/proprioception settings
78
+ self.num_actions = num_actions
79
+ self.action_mode = action_mode
80
+ self.use_proprio = use_proprio
81
+
82
+ # Initialize base HF config attributes (e.g. name_or_path)
83
+ super().__init__(**kwargs)
84
+
85
+ # -------------------------------------------------------------------------
86
+ # Serialization helpers
87
+ # -------------------------------------------------------------------------
88
+ def to_dict(self):
89
+ """
90
+ Convert this configuration (and its Florence sub-config)
91
+ into a fully serializable dictionary for HF save/load.
92
+ """
93
+ output = super().to_dict()
94
+ output["florence_config"] = self.florence_config.to_dict()
95
+ return output
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ea279d74b9a5878da79f7dae949a1d8e92cead2cf0f58612f9d11e4ba89788e
3
+ size 3519068172
modeling_florence2.py ADDED
The diff for this file is too large to render. See raw diff
 
modeling_xvla.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright 2025 2toINF (https://github.com/2toINF)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------------
16
+
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ import traceback
21
+ from typing import Any, Dict
22
+
23
+ import numpy as np
24
+ import torch
25
+ from fastapi import FastAPI
26
+ from fastapi.responses import JSONResponse
27
+ from PIL import Image
28
+ import uvicorn
29
+ import json_numpy
30
+ import cv2
31
+
32
+ from transformers import PreTrainedModel
33
+ from .modeling_florence2 import Florence2ForConditionalGeneration
34
+ from .transformer import SoftPromptedTransformer
35
+ from .action_hub import build_action_space
36
+ from .configuration_xvla import XVLAConfig
37
+
38
+
39
+ class XVLA(PreTrainedModel):
40
+ """
41
+ XVLA: HuggingFace-compatible Vision-Language-Action policy.
42
+
43
+ Components:
44
+ • Florence2 encoder-only backbone (vision-language)
45
+ • SoftPromptedTransformer (temporal/action head)
46
+ • Action space (pre/post-processing + loss)
47
+ """
48
+ config_class = XVLAConfig
49
+ base_model_prefix = "xvla"
50
+ supports_gradient_checkpointing = True
51
+
52
+ def __init__(self, config: XVLAConfig, *args, **kwargs):
53
+ super().__init__(config, *args, **kwargs)
54
+
55
+ # Core settings
56
+ self.num_actions: int = config.num_actions
57
+ self.use_proprio: bool = config.use_proprio
58
+ self.action_mode: str = config.action_mode.lower()
59
+ # Action space (dimensions + hooks)
60
+ self.action_space = build_action_space(config.action_mode.lower())
61
+ dim_action = self.action_space.dim_action
62
+ dim_proprio = getattr(self.action_space, "dim_proprio", dim_action)
63
+
64
+ # Florence2 backbone (encoder only)
65
+ self.vlm = Florence2ForConditionalGeneration(config.florence_config)
66
+ if hasattr(self.vlm, "language_model"):
67
+ lm = self.vlm.language_model
68
+ if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
69
+ del lm.model.decoder
70
+ if hasattr(lm, "lm_head"):
71
+ del lm.lm_head
72
+
73
+ projection_dim = getattr(self.vlm.config, "projection_dim", None)
74
+ if projection_dim is None:
75
+ raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
76
+
77
+ # Temporal/action head
78
+ self.transformer = SoftPromptedTransformer(
79
+ hidden_size=config.hidden_size,
80
+ multi_modal_input_size=projection_dim,
81
+ depth=config.depth,
82
+ num_heads=config.num_heads,
83
+ mlp_ratio=config.mlp_ratio,
84
+ num_domains=config.num_domains,
85
+ dim_action=dim_action,
86
+ dim_propio=dim_proprio,
87
+ len_soft_prompts=config.len_soft_prompts,
88
+ dim_time=config.dim_time,
89
+ max_len_seq=config.max_len_seq,
90
+ use_hetero_proj=config.use_hetero_proj,
91
+ )
92
+
93
+ # Deferred FastAPI app
94
+ self.app: FastAPI | None = None
95
+
96
+ # ============================= Florence2 encoder =============================
97
+ def forward_vlm(
98
+ self,
99
+ input_ids: torch.LongTensor, # [B, L]
100
+ pixel_values: torch.FloatTensor, # [B, V, C, H, W]
101
+ image_mask: torch.Tensor, # [B, V] (bool or 0/1)
102
+ ) -> Dict[str, torch.Tensor]:
103
+ """
104
+ Encode text + multi-view images via Florence2 encoder.
105
+
106
+ Returns:
107
+ { "vlm_features": [B, T_enc, D], "aux_visual_inputs": [B, (V-1)*N, D] }
108
+ """
109
+ B, V = pixel_values.shape[:2]
110
+ flat_mask = image_mask.view(-1).to(torch.bool) # [B*V]
111
+ flat_images = pixel_values.flatten(0, 1) # [B*V, C, H, W]
112
+
113
+ num_valid = int(flat_mask.sum().item())
114
+ if num_valid == 0:
115
+ raise ValueError("At least one image view must be valid per batch.")
116
+
117
+ valid_images = flat_images[flat_mask] # [#valid, C, H, W]
118
+ valid_feats = self.vlm._encode_image(valid_images) # [#valid, N, D]
119
+ N, D = valid_feats.shape[1:]
120
+
121
+ image_features = valid_feats.new_zeros((B * V, N, D))
122
+ image_features[flat_mask] = valid_feats
123
+ image_features = image_features.view(B, V, N, D) # [B, V, N, D]
124
+
125
+ inputs_embeds = self.vlm.get_input_embeddings()(input_ids) # [B, L, D]
126
+
127
+ merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
128
+ image_features[:, 0], # first view: [B, N, D]
129
+ inputs_embeds, # [B, L, D]
130
+ )
131
+
132
+ enc_out = self.vlm.language_model.model.encoder(
133
+ attention_mask=attention_mask,
134
+ inputs_embeds=merged_embeds,
135
+ )[0] # [B, T_enc, D]
136
+
137
+ aux_visual_inputs = image_features[:, 1:].reshape(B, -1, D) # remaining views flattened
138
+ return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
139
+
140
+ # ================================= training =================================
141
+ def forward(
142
+ self,
143
+ input_ids: torch.LongTensor,
144
+ image_input: torch.FloatTensor,
145
+ image_mask: torch.Tensor,
146
+ domain_id: torch.LongTensor,
147
+ proprio: torch.Tensor,
148
+ action: torch.Tensor, # [B, T=num_actions, D=dim_action]
149
+ ) -> Dict[str, torch.Tensor]:
150
+ """
151
+ 1) Encode multimodal inputs.
152
+ 2) Diffusion-style noisy mixture of actions: x_t = t*noise + (1-t)*gt.
153
+ 3) Space-specific preprocessing, prediction, and supervised loss.
154
+ """
155
+ enc = self.forward_vlm(input_ids, image_input, image_mask)
156
+
157
+ B = input_ids.shape[0]
158
+ t = (torch.rand(1, device=input_ids.device)
159
+ + torch.arange(B, device=input_ids.device) / B) % (1 - 1e-5)
160
+
161
+ action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
162
+ proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
163
+
164
+ pred_action = self.transformer(
165
+ domain_id=domain_id,
166
+ action_with_noise=action_noisy_m,
167
+ t=t,
168
+ proprio=proprio_m,
169
+ **enc,
170
+ )
171
+ return self.action_space.compute_loss(pred_action, action)
172
+
173
+ # ================================= inference =================================
174
+ @torch.no_grad()
175
+ def generate_actions(
176
+ self,
177
+ input_ids: torch.LongTensor,
178
+ image_input: torch.FloatTensor,
179
+ image_mask: torch.Tensor,
180
+ domain_id: torch.LongTensor,
181
+ proprio: torch.Tensor,
182
+ steps: int = 10,
183
+ ) -> torch.Tensor:
184
+ """
185
+ Iterative denoising (linear schedule).
186
+ Applies action_space.postprocess at the end (e.g., sigmoid on gripper).
187
+ """
188
+ self.eval()
189
+ enc = self.forward_vlm(input_ids, image_input, image_mask)
190
+
191
+ B = input_ids.shape[0]
192
+ D = self.action_space.dim_action
193
+
194
+ x1 = torch.randn(B, self.num_actions, D, device=proprio.device, dtype=proprio.dtype)
195
+ action = torch.zeros_like(x1)
196
+
197
+ steps = max(1, int(steps))
198
+ for i in range(steps, 0, -1):
199
+ t = torch.full((B,), i / steps, device=proprio.device, dtype=proprio.dtype)
200
+ x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
201
+ proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
202
+ action = self.transformer(
203
+ domain_id=domain_id,
204
+ action_with_noise=x_t_m,
205
+ proprio=proprio_m,
206
+ t=t,
207
+ **enc,
208
+ )
209
+ return self.action_space.postprocess(action)
210
+
211
+ # =============================== FastAPI service =============================
212
+ def _build_app(self, processor):
213
+ """
214
+ Minimal FastAPI app for XVLA inference.
215
+
216
+ Args:
217
+ processor: callable(images, text) -> Dict[str, torch.Tensor]
218
+ expected keys: "input_ids", "image_input", "image_mask"
219
+ """
220
+ if self.app is not None:
221
+ return
222
+
223
+ app = FastAPI()
224
+
225
+ @app.post("/act")
226
+ def act(payload: Dict[str, Any]):
227
+ try:
228
+ self.eval()
229
+ # Decode up to 3 image inputs
230
+ images = []
231
+ for key in ("image0", "image1", "image2"):
232
+ if key not in payload: continue
233
+ v = json_numpy.loads(payload[key])
234
+ if isinstance(v, np.ndarray):
235
+ if v.ndim == 1: # encoded bytes
236
+ v = cv2.imdecode(v, cv2.IMREAD_COLOR)
237
+ images.append(Image.fromarray(v))
238
+ elif isinstance(v, (list, tuple)):
239
+ images.append(Image.fromarray(np.array(v)))
240
+ elif isinstance(v, str):
241
+ images.append(Image.open(v))
242
+ if not images:
243
+ return JSONResponse({"error": "No valid images found."}, status_code=400)
244
+
245
+ # Multimodal preprocessing by processor
246
+ inputs = processor(images, payload["language_instruction"])
247
+ if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
248
+ return JSONResponse({"error": "Processor returned incomplete inputs."}, status_code=400)
249
+
250
+ # Build proprio/domain tensors
251
+ proprio = torch.as_tensor(np.asarray(json_numpy.loads(payload["proprio"])))
252
+ domain_id = torch.tensor([int(payload["domain_id"])], dtype=torch.long)
253
+
254
+ # Align to model's device/dtype
255
+ device = next(self.parameters()).device
256
+ dtype = next(self.parameters()).dtype
257
+
258
+ def to_model(t: torch.Tensor) -> torch.Tensor:
259
+ if not isinstance(t, torch.Tensor):
260
+ t = torch.as_tensor(t)
261
+ # cast floats to model dtype, keep integral/bool as-is
262
+ return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
263
+
264
+ inputs = {k: to_model(v) for k, v in inputs.items()}
265
+ inputs.update({
266
+ "proprio": to_model(proprio.unsqueeze(0)),
267
+ "domain_id": domain_id.to(device),
268
+ })
269
+
270
+ # Inference
271
+ steps = int(payload.get("steps", 10))
272
+ action = self.generate_actions(**inputs, steps=steps).squeeze(0).float().cpu().numpy()
273
+ return JSONResponse({"action": action.tolist()})
274
+
275
+ except Exception:
276
+ logging.error(traceback.format_exc())
277
+ return JSONResponse({"error": "Request failed"}, status_code=400)
278
+
279
+ self.app = app
280
+
281
+ def run(self, processor, host: str = "0.0.0.0", port: int = 8000):
282
+ """
283
+ Launch the FastAPI service.
284
+ """
285
+ self._build_app(processor)
286
+ assert self.app is not None
287
+ uvicorn.run(self.app, host=host, port=port)
preprocessor_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_xvla.XVLAProcessor"
4
+ },
5
+ "_valid_processor_keys": [
6
+ "images",
7
+ "do_resize",
8
+ "size",
9
+ "resample",
10
+ "do_rescale",
11
+ "rescale_factor",
12
+ "do_normalize",
13
+ "image_mean",
14
+ "image_std",
15
+ "return_tensors",
16
+ "data_format",
17
+ "input_data_format",
18
+ "do_convert_rgb"
19
+ ],
20
+ "do_convert_rgb": null,
21
+ "do_normalize": true,
22
+ "do_rescale": true,
23
+ "do_resize": true,
24
+ "do_center_crop": false,
25
+ "image_processor_type": "CLIPImageProcessor",
26
+ "image_mean": [0.485, 0.456, 0.406],
27
+ "image_std": [0.229, 0.224, 0.225],
28
+ "processor_class": "XVLAProcessor",
29
+ "resample": 3,
30
+ "size": {
31
+ "height": 224,
32
+ "width": 224
33
+ },
34
+ "crop_size": {
35
+ "height": 224,
36
+ "width": 224
37
+ }
38
+ }
processing_xvla.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright 2025 2toINF (https://github.com/2toINF)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------------
16
+
17
+ from transformers import ProcessorMixin
18
+ from typing import List, Union, Dict, Any, Optional
19
+ import torch
20
+
21
+
22
+ class XVLAProcessor(ProcessorMixin):
23
+ """
24
+ XVLAProcessor: Unified multimodal processor for XVLA models.
25
+
26
+ Handles:
27
+ - Multi-view image inputs (e.g., from multiple cameras).
28
+ - Batch processing for multiple samples.
29
+ - Joint tokenization and image tensor preparation.
30
+
31
+ This processor combines an image processor and a tokenizer under a single interface
32
+ so that users can call it directly like:
33
+
34
+ >>> processor = XVLAProcessor.from_pretrained("path/to/xvla")
35
+ >>> inputs = processor(images=batch_images, language_instruction=batch_texts)
36
+
37
+ It is fully compatible with the Hugging Face AutoProcessor API.
38
+
39
+ Attributes
40
+ ----------
41
+ num_views : int, default=3
42
+ Expected number of image views per sample. Missing views will be padded with zeros.
43
+ language_max_length : int, default=50
44
+ Maximum token length for text encoding.
45
+ attributes : list
46
+ Required by ProcessorMixin to know which submodules are stored and reloaded.
47
+ image_processor_class : str
48
+ The name of the associated image processor class.
49
+ tokenizer_class : tuple(str)
50
+ The names of compatible tokenizer classes.
51
+ """
52
+
53
+ num_views: int = 3
54
+ language_max_length: int = 50
55
+
56
+ # Hugging Face ProcessorMixin-required metadata
57
+ attributes = ["image_processor", "tokenizer"]
58
+ image_processor_class = "AutoImageProcessor"
59
+ tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
60
+
61
+ def __init__(self, image_processor=None, tokenizer=None):
62
+ """
63
+ Initialize XVLAProcessor.
64
+
65
+ Parameters
66
+ ----------
67
+ image_processor : PreTrainedImageProcessor, optional
68
+ The image processor used to normalize/resize images.
69
+ tokenizer : PreTrainedTokenizer, optional
70
+ The tokenizer used for text tokenization.
71
+ """
72
+ # ProcessorMixin automatically saves these under self.image_processor / self.tokenizer
73
+ super().__init__(image_processor, tokenizer)
74
+
75
+ # ================== LANGUAGE ENCODING ==================
76
+ def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
77
+ """
78
+ Tokenize one or more language instructions.
79
+
80
+ Parameters
81
+ ----------
82
+ language_instruction : str or List[str]
83
+ A single instruction or a batch of instructions.
84
+
85
+ Returns
86
+ -------
87
+ Dict[str, torch.Tensor]
88
+ {
89
+ "input_ids": tensor of shape [B, L]
90
+ }
91
+ """
92
+ if isinstance(language_instruction, str):
93
+ language_instruction = [language_instruction]
94
+
95
+ inputs = self.tokenizer(
96
+ language_instruction,
97
+ return_tensors="pt",
98
+ padding="max_length",
99
+ max_length=self.language_max_length,
100
+ truncation=True,
101
+ )
102
+ return {"input_ids": inputs["input_ids"]}
103
+
104
+ # ================== IMAGE ENCODING ==================
105
+ def encode_image(
106
+ self,
107
+ images: Union[List, List[List]],
108
+ **kwargs
109
+ ) -> Dict[str, torch.Tensor]:
110
+ """
111
+ Preprocess one or more sets of multi-view images.
112
+
113
+ Parameters
114
+ ----------
115
+ images : List or List[List]
116
+ Single sample: [img1, img2, ...]
117
+ Batch: [[img1a, img1b], [img2a, img2b, img2c], ...]
118
+ Each image may be a PIL.Image, NumPy array, or torch.Tensor.
119
+
120
+ kwargs : dict
121
+ Extra arguments passed to the underlying image processor
122
+ (e.g., `do_resize=False`, `size=(224,224)`).
123
+
124
+ Returns
125
+ -------
126
+ Dict[str, torch.Tensor]
127
+ {
128
+ "image_input": tensor [B, num_views, C, H, W],
129
+ "image_mask": tensor [B, num_views]
130
+ }
131
+ """
132
+ # Normalize to batch form
133
+ if not isinstance(images[0], (list, tuple)):
134
+ images = [images] # convert single sample to batch of size 1
135
+
136
+ batch_imgs, batch_masks = [], []
137
+
138
+ for sample_imgs in images:
139
+ processed = self.image_processor(sample_imgs, return_tensors="pt", **kwargs)["pixel_values"]
140
+ V_exist = processed.size(0)
141
+
142
+ # Pad to self.num_views
143
+ if V_exist < self.num_views:
144
+ processed = torch.cat(
145
+ [processed,
146
+ processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
147
+ dim=0,
148
+ )
149
+
150
+ # Mask: True for valid slots, False for padding
151
+ image_mask = torch.zeros(self.num_views, dtype=torch.bool, device=processed.device)
152
+ image_mask[:V_exist] = True
153
+
154
+ batch_imgs.append(processed)
155
+ batch_masks.append(image_mask)
156
+
157
+ image_input = torch.stack(batch_imgs, dim=0) # [B, num_views, C, H, W]
158
+ image_mask = torch.stack(batch_masks, dim=0) # [B, num_views]
159
+
160
+ return {"image_input": image_input, "image_mask": image_mask}
161
+
162
+ # ================== COMBINED CALL ==================
163
+ def __call__(
164
+ self,
165
+ images: Optional[Union[List, List[List]]] = None,
166
+ language_instruction: Optional[Union[str, List[str]]] = None,
167
+ **kwargs
168
+ ) -> Dict[str, torch.Tensor]:
169
+ """
170
+ Combine image and text encoding into a unified multimodal input.
171
+
172
+ Parameters
173
+ ----------
174
+ images : List or List[List], optional
175
+ Single-sample or batched multi-view images.
176
+ language_instruction : str or List[str], optional
177
+ Corresponding text instructions.
178
+ kwargs : dict
179
+ Extra args passed to image processor.
180
+
181
+ Returns
182
+ -------
183
+ Dict[str, torch.Tensor]
184
+ {
185
+ "input_ids": [B, L], optional,
186
+ "image_input": [B, num_views, C, H, W], optional,
187
+ "image_mask": [B, num_views], optional
188
+ }
189
+ """
190
+ outputs: Dict[str, Any] = {}
191
+
192
+ # Encode language if provided
193
+ if language_instruction is not None:
194
+ outputs.update(self.encode_language(language_instruction))
195
+
196
+ # Encode image if provided
197
+ if images is not None:
198
+ outputs.update(self.encode_image(images, **kwargs))
199
+
200
+ # Sanity check for batch alignment
201
+ if "input_ids" in outputs and "image_input" in outputs:
202
+ assert outputs["input_ids"].size(0) == outputs["image_input"].size(0), (
203
+ f"Batch mismatch: text batch {outputs['input_ids'].size(0)} "
204
+ f"!= image batch {outputs['image_input'].size(0)}"
205
+ )
206
+ return outputs
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "model_max_length": 1024
3
+ }
4
+
transformer.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright 2025 2toINF (https://github.com/2toINF)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------------
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ from functools import partial
21
+ from typing import Final, Iterable, Tuple
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+
28
+ # ------------------------------- Small utils ----------------------------------
29
+
30
+ def _to_2tuple(x) -> Tuple:
31
+ """Minimal replacement for timm.layers.to_2tuple."""
32
+ if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
33
+ t = tuple(x)
34
+ return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
35
+ return (x, x)
36
+
37
+
38
+ def _has_sdp_attention() -> bool:
39
+ """Check if we can use PyTorch fused scaled_dot_product_attention."""
40
+ return hasattr(F, "scaled_dot_product_attention")
41
+
42
+
43
+ # ---------------------------------- MLP --------------------------------------
44
+
45
+ class Mlp(nn.Module):
46
+ """
47
+ MLP used in ViT-style blocks.
48
+
49
+ Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ in_features: int,
55
+ hidden_features: int | None = None,
56
+ out_features: int | None = None,
57
+ norm_layer: type[nn.Module] | None = None,
58
+ bias: bool | Tuple[bool, bool] = True,
59
+ drop: float | Tuple[float, float] = 0.0,
60
+ use_conv: bool = False,
61
+ ) -> None:
62
+ super().__init__()
63
+ out_features = out_features or in_features
64
+ hidden_features = hidden_features or in_features
65
+ bias = _to_2tuple(bias)
66
+ drop_probs = _to_2tuple(drop)
67
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
68
+
69
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
70
+ self.act = nn.GELU(approximate="tanh")
71
+ self.drop1 = nn.Dropout(drop_probs[0])
72
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
73
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
74
+ self.drop2 = nn.Dropout(drop_probs[1])
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ # Expect [B, T, C] for Linear variant; caller is responsible for shapes.
78
+ x = self.fc1(x)
79
+ x = self.act(x)
80
+ x = self.drop1(x)
81
+ x = self.norm(x)
82
+ x = self.fc2(x)
83
+ x = self.drop2(x)
84
+ return x
85
+
86
+
87
+ # -------------------------------- Attention ----------------------------------
88
+
89
+ class Attention(nn.Module):
90
+ """
91
+ Multi-Head Self-Attention with optional fused SDPA fallback.
92
+
93
+ If PyTorch provides `scaled_dot_product_attention`, it will be used
94
+ (usually faster and more stable); otherwise we use a manual implementation.
95
+ """
96
+
97
+ fused_attn: Final[bool]
98
+
99
+ def __init__(
100
+ self,
101
+ dim: int,
102
+ num_heads: int = 8,
103
+ qkv_bias: bool = False,
104
+ qk_norm: bool = False,
105
+ attn_drop: float = 0.0,
106
+ proj_drop: float = 0.0,
107
+ norm_layer: type[nn.Module] = nn.LayerNorm,
108
+ ) -> None:
109
+ super().__init__()
110
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
111
+ self.num_heads = num_heads
112
+ self.head_dim = dim // num_heads
113
+ self.scale = self.head_dim ** -0.5
114
+ self.fused_attn = _has_sdp_attention()
115
+
116
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
117
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
118
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
119
+ self.attn_drop = nn.Dropout(attn_drop)
120
+ self.proj = nn.Linear(dim, dim)
121
+ self.proj_drop = nn.Dropout(proj_drop)
122
+
123
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ """
125
+ Parameters
126
+ ----------
127
+ x : Tensor, shape [B, T, C]
128
+ Input sequence.
129
+
130
+ Returns
131
+ -------
132
+ Tensor, shape [B, T, C]
133
+ Output sequence after MHSA + projection.
134
+ """
135
+ B, T, C = x.shape
136
+ qkv = (
137
+ self.qkv(x)
138
+ .reshape(B, T, 3, self.num_heads, self.head_dim)
139
+ .permute(2, 0, 3, 1, 4) # 3 x [B, H, T, Dh]
140
+ )
141
+ q, k, v = qkv.unbind(0) # each: [B, H, T, Dh]
142
+ q, k = self.q_norm(q), self.k_norm(k)
143
+
144
+ if self.fused_attn:
145
+ x = F.scaled_dot_product_attention(
146
+ q, k, v,
147
+ dropout_p=self.attn_drop.p if self.training else 0.0,
148
+ ) # [B, H, T, Dh]
149
+ else:
150
+ q = q * self.scale
151
+ attn = q @ k.transpose(-2, -1) # [B, H, T, T]
152
+ attn = attn.softmax(dim=-1)
153
+ attn = self.attn_drop(attn)
154
+ x = attn @ v # [B, H, T, Dh]
155
+
156
+ x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
157
+ x = self.proj(x)
158
+ x = self.proj_drop(x)
159
+ return x
160
+
161
+
162
+ # ------------------------------- Utilities -----------------------------------
163
+
164
+ def basic_init(module: nn.Module) -> None:
165
+ """
166
+ Apply a basic initialization scheme to Linear layers.
167
+
168
+ - Weight: Xavier uniform initialization.
169
+ - Bias: Set to zero.
170
+ """
171
+ if isinstance(module, nn.Linear):
172
+ nn.init.xavier_uniform_(module.weight)
173
+ if module.bias is not None:
174
+ nn.init.constant_(module.bias, 0.0)
175
+
176
+
177
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
178
+ """
179
+ Create sinusoidal timestep embeddings.
180
+
181
+ Parameters
182
+ ----------
183
+ t : torch.Tensor
184
+ Shape [B]. Each element is a timestep index, may be fractional.
185
+ dim : int
186
+ Dimensionality of the output embedding.
187
+ max_period : int, default=100
188
+ Controls the minimum frequency of the sinusoids.
189
+
190
+ Returns
191
+ -------
192
+ torch.Tensor
193
+ Shape [B, dim]. Sinusoidal embeddings.
194
+ """
195
+ half = dim // 2
196
+ freqs = torch.exp(
197
+ -math.log(max_period)
198
+ * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device)
199
+ / half
200
+ )
201
+ args = t[:, None] * freqs[None]
202
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
203
+ if dim % 2 == 1:
204
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
205
+ return embedding
206
+
207
+
208
+ # ------------------------------- Core Layers ----------------------------------
209
+
210
+ class DomainAwareLinear(nn.Module):
211
+ """
212
+ Linear layer with domain-conditioned parameters (per-sample).
213
+
214
+ Each domain has its own weight and bias vectors, stored in embeddings.
215
+ """
216
+
217
+ def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
218
+ super().__init__()
219
+ self.input_size = input_size
220
+ self.output_size = output_size
221
+ self.fc = nn.Embedding(num_domains, output_size * input_size)
222
+ self.bias = nn.Embedding(num_domains, output_size)
223
+ nn.init.xavier_uniform_(self.fc.weight)
224
+ nn.init.zeros_(self.bias.weight)
225
+
226
+ def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
227
+ """
228
+ Parameters
229
+ ----------
230
+ x : Tensor
231
+ [B, I] or [B, T, I]
232
+ domain_id : LongTensor
233
+ [B], domain indices.
234
+
235
+ Returns
236
+ -------
237
+ Tensor
238
+ [B, O] or [B, T, O]
239
+ """
240
+ B = domain_id.shape[0]
241
+ squeeze_T = False
242
+ if x.dim() == 2:
243
+ x = x.unsqueeze(1)
244
+ squeeze_T = True
245
+ W = self.fc(domain_id).view(B, self.input_size, self.output_size)
246
+ b = self.bias(domain_id).view(B, self.output_size)
247
+ y = torch.matmul(x, W) + b.view(B, 1, self.output_size)
248
+ if squeeze_T:
249
+ y = y.squeeze(1)
250
+ return y
251
+
252
+
253
+ class TransformerBlock(nn.Module):
254
+ """
255
+ Standard Transformer block (pre-LN): LN → MHSA → residual, LN → MLP → residual.
256
+ """
257
+
258
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
259
+ super().__init__()
260
+ self.norm1 = nn.LayerNorm(hidden_size)
261
+ self.norm2 = nn.LayerNorm(hidden_size)
262
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
263
+ self.mlp = Mlp(
264
+ in_features=hidden_size,
265
+ hidden_features=int(hidden_size * mlp_ratio),
266
+ drop=0.1,
267
+ )
268
+
269
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
270
+ """
271
+ Parameters
272
+ ----------
273
+ x : Tensor, [B, T, H]
274
+
275
+ Returns
276
+ -------
277
+ Tensor, [B, T, H]
278
+ """
279
+ x = x + self.attn(self.norm1(x))
280
+ x = x + self.mlp(self.norm2(x))
281
+ return x
282
+
283
+
284
+ # --------------------------- Main Model ---------------------------------------
285
+
286
+ class SoftPromptedTransformer(nn.Module):
287
+ """
288
+ Multi-modal, domain-aware Transformer with optional soft prompts.
289
+
290
+ See parameter and forward I/O descriptions inside the docstrings.
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ hidden_size: int = 768,
296
+ multi_modal_input_size: int = 768,
297
+ depth: int = 24,
298
+ num_heads: int = 16,
299
+ mlp_ratio: float = 4.0,
300
+ num_domains: int = 20,
301
+ dim_action: int = 20,
302
+ dim_propio: int = 20,
303
+ dim_time: int = 32,
304
+ len_soft_prompts: int = 32,
305
+ max_len_seq: int = 512,
306
+ use_hetero_proj: bool = False,
307
+ ) -> None:
308
+ super().__init__()
309
+ self.hidden_size = hidden_size
310
+ self.dim_action = dim_action
311
+ self.dim_time = dim_time
312
+ self.len_soft_prompts = len_soft_prompts
313
+ self.use_hetero_proj = use_hetero_proj
314
+
315
+ self.blocks = nn.ModuleList(
316
+ [TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
317
+ )
318
+
319
+ if use_hetero_proj:
320
+ self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
321
+ self.aux_visual_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
322
+ else:
323
+ self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
324
+ self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
325
+
326
+ self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
327
+ nn.init.normal_(self.pos_emb, std=0.02)
328
+
329
+ self.norm = nn.LayerNorm(hidden_size)
330
+ self.action_encoder = DomainAwareLinear(
331
+ dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
332
+ )
333
+ self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
334
+
335
+ if len_soft_prompts > 0:
336
+ self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
337
+ nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
338
+
339
+ self.apply(basic_init)
340
+
341
+ def forward(
342
+ self,
343
+ domain_id: torch.LongTensor,
344
+ vlm_features: torch.Tensor,
345
+ aux_visual_inputs: torch.Tensor,
346
+ action_with_noise: torch.Tensor,
347
+ proprio: torch.Tensor,
348
+ t: torch.Tensor,
349
+ ) -> torch.Tensor:
350
+ """
351
+ Forward pass.
352
+
353
+ Inputs
354
+ ------
355
+ domain_id : [B]
356
+ vlm_features : [B, T_vlm, D]
357
+ aux_visual_inputs : [B, T_aux, D]
358
+ action_with_noise : [B, T_action, dim_action]
359
+ proprio : [B, dim_propio]
360
+ t : [B]
361
+
362
+ Returns
363
+ -------
364
+ Tensor
365
+ Predicted actions, [B, T_action, dim_action]
366
+ """
367
+ B, num_actions = action_with_noise.shape[:2]
368
+
369
+ # Encode (action + proprio + time) → tokens
370
+ time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
371
+ time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time)
372
+ proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1])
373
+ action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
374
+ x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
375
+
376
+ # Project visual streams and concatenate
377
+ if self.use_hetero_proj:
378
+ x = torch.cat(
379
+ [x, self.vlm_proj(vlm_features, domain_id), self.aux_visual_proj(aux_visual_inputs, domain_id)],
380
+ dim=1,
381
+ )
382
+ else:
383
+ x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
384
+
385
+ # Add positional embeddings (truncate if needed)
386
+ seq_len = x.shape[1]
387
+ if seq_len > self.pos_emb.shape[1]:
388
+ raise ValueError(
389
+ f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}."
390
+ )
391
+ x = x + self.pos_emb[:, :seq_len, :]
392
+
393
+ # Append soft prompts
394
+ if self.len_soft_prompts > 0:
395
+ soft_prompts = self.soft_prompt_hub(domain_id).view(B, self.len_soft_prompts, self.hidden_size)
396
+ x = torch.cat([x, soft_prompts], dim=1)
397
+
398
+ # Transformer backbone
399
+ for block in self.blocks:
400
+ x = block(x)
401
+
402
+ # Decode only the action segment
403
+ return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)
vocab.json ADDED
The diff for this file is too large to render. See raw diff