X-VLA-Google-Robot / transformer.py
2toINF's picture
Initial upload for X-VLA-Google-Robot
cb94537 verified
# ------------------------------------------------------------------------------
# Copyright 2025 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
import math
from functools import partial
from typing import Final, Iterable, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ------------------------------- Small utils ----------------------------------
def _to_2tuple(x) -> Tuple:
"""Minimal replacement for timm.layers.to_2tuple."""
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
t = tuple(x)
return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
return (x, x)
def _has_sdp_attention() -> bool:
"""Check if we can use PyTorch fused scaled_dot_product_attention."""
return hasattr(F, "scaled_dot_product_attention")
# ---------------------------------- MLP --------------------------------------
class Mlp(nn.Module):
"""
MLP used in ViT-style blocks.
Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
"""
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
norm_layer: type[nn.Module] | None = None,
bias: bool | Tuple[bool, bool] = True,
drop: float | Tuple[float, float] = 0.0,
use_conv: bool = False,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = _to_2tuple(bias)
drop_probs = _to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = nn.GELU(approximate="tanh")
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Expect [B, T, C] for Linear variant; caller is responsible for shapes.
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
# -------------------------------- Attention ----------------------------------
class Attention(nn.Module):
"""
Multi-Head Self-Attention with optional fused SDPA fallback.
If PyTorch provides `scaled_dot_product_attention`, it will be used
(usually faster and more stable); otherwise we use a manual implementation.
"""
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = _has_sdp_attention()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor, shape [B, T, C]
Input sequence.
Returns
-------
Tensor, shape [B, T, C]
Output sequence after MHSA + projection.
"""
B, T, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, T, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4) # 3 x [B, H, T, Dh]
)
q, k, v = qkv.unbind(0) # each: [B, H, T, Dh]
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0,
) # [B, H, T, Dh]
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1) # [B, H, T, T]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v # [B, H, T, Dh]
x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
x = self.proj(x)
x = self.proj_drop(x)
return x
# ------------------------------- Utilities -----------------------------------
def basic_init(module: nn.Module) -> None:
"""
Apply a basic initialization scheme to Linear layers.
- Weight: Xavier uniform initialization.
- Bias: Set to zero.
"""
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
"""
Create sinusoidal timestep embeddings.
Parameters
----------
t : torch.Tensor
Shape [B]. Each element is a timestep index, may be fractional.
dim : int
Dimensionality of the output embedding.
max_period : int, default=100
Controls the minimum frequency of the sinusoids.
Returns
-------
torch.Tensor
Shape [B, dim]. Sinusoidal embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=t.dtype, device=t.device)
/ half
)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2 == 1:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
# ------------------------------- Core Layers ----------------------------------
class DomainAwareLinear(nn.Module):
"""
Linear layer with domain-conditioned parameters (per-sample).
Each domain has its own weight and bias vectors, stored in embeddings.
"""
def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.fc = nn.Embedding(num_domains, output_size * input_size)
self.bias = nn.Embedding(num_domains, output_size)
nn.init.xavier_uniform_(self.fc.weight)
nn.init.zeros_(self.bias.weight)
def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor
[B, I] or [B, T, I]
domain_id : LongTensor
[B], domain indices.
Returns
-------
Tensor
[B, O] or [B, T, O]
"""
B = domain_id.shape[0]
squeeze_T = False
if x.dim() == 2:
x = x.unsqueeze(1)
squeeze_T = True
W = self.fc(domain_id).view(B, self.input_size, self.output_size)
b = self.bias(domain_id).view(B, self.output_size)
y = torch.matmul(x, W) + b.view(B, 1, self.output_size)
if squeeze_T:
y = y.squeeze(1)
return y
class TransformerBlock(nn.Module):
"""
Standard Transformer block (pre-LN): LN → MHSA → residual, LN → MLP → residual.
"""
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=int(hidden_size * mlp_ratio),
drop=0.1,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor, [B, T, H]
Returns
-------
Tensor, [B, T, H]
"""
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
# --------------------------- Main Model ---------------------------------------
class SoftPromptedTransformer(nn.Module):
"""
Multi-modal, domain-aware Transformer with optional soft prompts.
See parameter and forward I/O descriptions inside the docstrings.
"""
def __init__(
self,
hidden_size: int = 768,
multi_modal_input_size: int = 768,
depth: int = 24,
num_heads: int = 16,
mlp_ratio: float = 4.0,
num_domains: int = 20,
dim_action: int = 20,
dim_propio: int = 20,
dim_time: int = 32,
len_soft_prompts: int = 32,
max_len_seq: int = 512,
use_hetero_proj: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.dim_action = dim_action
self.dim_time = dim_time
self.len_soft_prompts = len_soft_prompts
self.use_hetero_proj = use_hetero_proj
self.blocks = nn.ModuleList(
[TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
)
if use_hetero_proj:
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
self.aux_visual_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
else:
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
nn.init.normal_(self.pos_emb, std=0.02)
self.norm = nn.LayerNorm(hidden_size)
self.action_encoder = DomainAwareLinear(
dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
)
self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
if len_soft_prompts > 0:
self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
self.apply(basic_init)
def forward(
self,
domain_id: torch.LongTensor,
vlm_features: torch.Tensor,
aux_visual_inputs: torch.Tensor,
action_with_noise: torch.Tensor,
proprio: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass.
Inputs
------
domain_id : [B]
vlm_features : [B, T_vlm, D]
aux_visual_inputs : [B, T_aux, D]
action_with_noise : [B, T_action, dim_action]
proprio : [B, dim_propio]
t : [B]
Returns
-------
Tensor
Predicted actions, [B, T_action, dim_action]
"""
B, num_actions = action_with_noise.shape[:2]
# Encode (action + proprio + time) → tokens
time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time)
proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1])
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
# Project visual streams and concatenate
if self.use_hetero_proj:
x = torch.cat(
[x, self.vlm_proj(vlm_features, domain_id), self.aux_visual_proj(aux_visual_inputs, domain_id)],
dim=1,
)
else:
x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
# Add positional embeddings (truncate if needed)
seq_len = x.shape[1]
if seq_len > self.pos_emb.shape[1]:
raise ValueError(
f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}."
)
x = x + self.pos_emb[:, :seq_len, :]
# Append soft prompts
if self.len_soft_prompts > 0:
soft_prompts = self.soft_prompt_hub(domain_id).view(B, self.len_soft_prompts, self.hidden_size)
x = torch.cat([x, soft_prompts], dim=1)
# Transformer backbone
for block in self.blocks:
x = block(x)
# Decode only the action segment
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)