|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import math |
|
|
from functools import partial |
|
|
from typing import Final, Iterable, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _to_2tuple(x) -> Tuple: |
|
|
"""Minimal replacement for timm.layers.to_2tuple.""" |
|
|
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): |
|
|
t = tuple(x) |
|
|
return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0]) |
|
|
return (x, x) |
|
|
|
|
|
|
|
|
def _has_sdp_attention() -> bool: |
|
|
"""Check if we can use PyTorch fused scaled_dot_product_attention.""" |
|
|
return hasattr(F, "scaled_dot_product_attention") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Mlp(nn.Module): |
|
|
""" |
|
|
MLP used in ViT-style blocks. |
|
|
|
|
|
Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
hidden_features: int | None = None, |
|
|
out_features: int | None = None, |
|
|
norm_layer: type[nn.Module] | None = None, |
|
|
bias: bool | Tuple[bool, bool] = True, |
|
|
drop: float | Tuple[float, float] = 0.0, |
|
|
use_conv: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
bias = _to_2tuple(bias) |
|
|
drop_probs = _to_2tuple(drop) |
|
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear |
|
|
|
|
|
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) |
|
|
self.act = nn.GELU(approximate="tanh") |
|
|
self.drop1 = nn.Dropout(drop_probs[0]) |
|
|
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() |
|
|
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) |
|
|
self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.drop1(x) |
|
|
x = self.norm(x) |
|
|
x = self.fc2(x) |
|
|
x = self.drop2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
""" |
|
|
Multi-Head Self-Attention with optional fused SDPA fallback. |
|
|
|
|
|
If PyTorch provides `scaled_dot_product_attention`, it will be used |
|
|
(usually faster and more stable); otherwise we use a manual implementation. |
|
|
""" |
|
|
|
|
|
fused_attn: Final[bool] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_heads: int = 8, |
|
|
qkv_bias: bool = False, |
|
|
qk_norm: bool = False, |
|
|
attn_drop: float = 0.0, |
|
|
proj_drop: float = 0.0, |
|
|
norm_layer: type[nn.Module] = nn.LayerNorm, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
assert dim % num_heads == 0, "dim should be divisible by num_heads" |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = dim // num_heads |
|
|
self.scale = self.head_dim ** -0.5 |
|
|
self.fused_attn = _has_sdp_attention() |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
|
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
x : Tensor, shape [B, T, C] |
|
|
Input sequence. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Tensor, shape [B, T, C] |
|
|
Output sequence after MHSA + projection. |
|
|
""" |
|
|
B, T, C = x.shape |
|
|
qkv = ( |
|
|
self.qkv(x) |
|
|
.reshape(B, T, 3, self.num_heads, self.head_dim) |
|
|
.permute(2, 0, 3, 1, 4) |
|
|
) |
|
|
q, k, v = qkv.unbind(0) |
|
|
q, k = self.q_norm(q), self.k_norm(k) |
|
|
|
|
|
if self.fused_attn: |
|
|
x = F.scaled_dot_product_attention( |
|
|
q, k, v, |
|
|
dropout_p=self.attn_drop.p if self.training else 0.0, |
|
|
) |
|
|
else: |
|
|
q = q * self.scale |
|
|
attn = q @ k.transpose(-2, -1) |
|
|
attn = attn.softmax(dim=-1) |
|
|
attn = self.attn_drop(attn) |
|
|
x = attn @ v |
|
|
|
|
|
x = x.transpose(1, 2).reshape(B, T, C) |
|
|
x = self.proj(x) |
|
|
x = self.proj_drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def basic_init(module: nn.Module) -> None: |
|
|
""" |
|
|
Apply a basic initialization scheme to Linear layers. |
|
|
|
|
|
- Weight: Xavier uniform initialization. |
|
|
- Bias: Set to zero. |
|
|
""" |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.constant_(module.bias, 0.0) |
|
|
|
|
|
|
|
|
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor: |
|
|
""" |
|
|
Create sinusoidal timestep embeddings. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
t : torch.Tensor |
|
|
Shape [B]. Each element is a timestep index, may be fractional. |
|
|
dim : int |
|
|
Dimensionality of the output embedding. |
|
|
max_period : int, default=100 |
|
|
Controls the minimum frequency of the sinusoids. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor |
|
|
Shape [B, dim]. Sinusoidal embeddings. |
|
|
""" |
|
|
half = dim // 2 |
|
|
freqs = torch.exp( |
|
|
-math.log(max_period) |
|
|
* torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) |
|
|
/ half |
|
|
) |
|
|
args = t[:, None] * freqs[None] |
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
if dim % 2 == 1: |
|
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
|
|
return embedding |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DomainAwareLinear(nn.Module): |
|
|
""" |
|
|
Linear layer with domain-conditioned parameters (per-sample). |
|
|
|
|
|
Each domain has its own weight and bias vectors, stored in embeddings. |
|
|
""" |
|
|
|
|
|
def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None: |
|
|
super().__init__() |
|
|
self.input_size = input_size |
|
|
self.output_size = output_size |
|
|
self.fc = nn.Embedding(num_domains, output_size * input_size) |
|
|
self.bias = nn.Embedding(num_domains, output_size) |
|
|
nn.init.xavier_uniform_(self.fc.weight) |
|
|
nn.init.zeros_(self.bias.weight) |
|
|
|
|
|
def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor: |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
x : Tensor |
|
|
[B, I] or [B, T, I] |
|
|
domain_id : LongTensor |
|
|
[B], domain indices. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Tensor |
|
|
[B, O] or [B, T, O] |
|
|
""" |
|
|
B = domain_id.shape[0] |
|
|
squeeze_T = False |
|
|
if x.dim() == 2: |
|
|
x = x.unsqueeze(1) |
|
|
squeeze_T = True |
|
|
W = self.fc(domain_id).view(B, self.input_size, self.output_size) |
|
|
b = self.bias(domain_id).view(B, self.output_size) |
|
|
y = torch.matmul(x, W) + b.view(B, 1, self.output_size) |
|
|
if squeeze_T: |
|
|
y = y.squeeze(1) |
|
|
return y |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
""" |
|
|
Standard Transformer block (pre-LN): LN → MHSA → residual, LN → MLP → residual. |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None: |
|
|
super().__init__() |
|
|
self.norm1 = nn.LayerNorm(hidden_size) |
|
|
self.norm2 = nn.LayerNorm(hidden_size) |
|
|
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1) |
|
|
self.mlp = Mlp( |
|
|
in_features=hidden_size, |
|
|
hidden_features=int(hidden_size * mlp_ratio), |
|
|
drop=0.1, |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
x : Tensor, [B, T, H] |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Tensor, [B, T, H] |
|
|
""" |
|
|
x = x + self.attn(self.norm1(x)) |
|
|
x = x + self.mlp(self.norm2(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SoftPromptedTransformer(nn.Module): |
|
|
""" |
|
|
Multi-modal, domain-aware Transformer with optional soft prompts. |
|
|
|
|
|
See parameter and forward I/O descriptions inside the docstrings. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int = 768, |
|
|
multi_modal_input_size: int = 768, |
|
|
depth: int = 24, |
|
|
num_heads: int = 16, |
|
|
mlp_ratio: float = 4.0, |
|
|
num_domains: int = 20, |
|
|
dim_action: int = 20, |
|
|
dim_propio: int = 20, |
|
|
dim_time: int = 32, |
|
|
len_soft_prompts: int = 32, |
|
|
max_len_seq: int = 512, |
|
|
use_hetero_proj: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.dim_action = dim_action |
|
|
self.dim_time = dim_time |
|
|
self.len_soft_prompts = len_soft_prompts |
|
|
self.use_hetero_proj = use_hetero_proj |
|
|
|
|
|
self.blocks = nn.ModuleList( |
|
|
[TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)] |
|
|
) |
|
|
|
|
|
if use_hetero_proj: |
|
|
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains) |
|
|
self.aux_visual_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains) |
|
|
else: |
|
|
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size) |
|
|
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size) |
|
|
|
|
|
self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True) |
|
|
nn.init.normal_(self.pos_emb, std=0.02) |
|
|
|
|
|
self.norm = nn.LayerNorm(hidden_size) |
|
|
self.action_encoder = DomainAwareLinear( |
|
|
dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains |
|
|
) |
|
|
self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains) |
|
|
|
|
|
if len_soft_prompts > 0: |
|
|
self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size) |
|
|
nn.init.normal_(self.soft_prompt_hub.weight, std=0.02) |
|
|
|
|
|
self.apply(basic_init) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
domain_id: torch.LongTensor, |
|
|
vlm_features: torch.Tensor, |
|
|
aux_visual_inputs: torch.Tensor, |
|
|
action_with_noise: torch.Tensor, |
|
|
proprio: torch.Tensor, |
|
|
t: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Inputs |
|
|
------ |
|
|
domain_id : [B] |
|
|
vlm_features : [B, T_vlm, D] |
|
|
aux_visual_inputs : [B, T_aux, D] |
|
|
action_with_noise : [B, T_action, dim_action] |
|
|
proprio : [B, dim_propio] |
|
|
t : [B] |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Tensor |
|
|
Predicted actions, [B, T_action, dim_action] |
|
|
""" |
|
|
B, num_actions = action_with_noise.shape[:2] |
|
|
|
|
|
|
|
|
time_emb = timestep_embedding(t, self.dim_time) |
|
|
time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time) |
|
|
proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1]) |
|
|
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1) |
|
|
x = self.action_encoder(action_tokens, domain_id) |
|
|
|
|
|
|
|
|
if self.use_hetero_proj: |
|
|
x = torch.cat( |
|
|
[x, self.vlm_proj(vlm_features, domain_id), self.aux_visual_proj(aux_visual_inputs, domain_id)], |
|
|
dim=1, |
|
|
) |
|
|
else: |
|
|
x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1) |
|
|
|
|
|
|
|
|
seq_len = x.shape[1] |
|
|
if seq_len > self.pos_emb.shape[1]: |
|
|
raise ValueError( |
|
|
f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}." |
|
|
) |
|
|
x = x + self.pos_emb[:, :seq_len, :] |
|
|
|
|
|
|
|
|
if self.len_soft_prompts > 0: |
|
|
soft_prompts = self.soft_prompt_hub(domain_id).view(B, self.len_soft_prompts, self.hidden_size) |
|
|
x = torch.cat([x, soft_prompts], dim=1) |
|
|
|
|
|
|
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
|
|
|
|
|
|
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id) |