|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class TemporalSelfAttention(nn.Module): |
|
|
def __init__(self, embed_dim, num_heads, bias_type="linear", gamma=1.0, causal=False): |
|
|
super().__init__() |
|
|
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" |
|
|
assert bias_type in ["linear", "gaussian"] |
|
|
|
|
|
self.embed_dim = embed_dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = embed_dim // num_heads |
|
|
self.bias_type = bias_type |
|
|
self.gamma = gamma |
|
|
self.causal = causal |
|
|
|
|
|
self.qkv = nn.Linear(embed_dim, 3 * embed_dim) |
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
|
|
|
|
def forward(self, x, timestamps): |
|
|
""" |
|
|
x: [B, T, D] |
|
|
timestamps: [B, T] — real-valued time signals per token |
|
|
""" |
|
|
B, T, D = x.size() |
|
|
|
|
|
|
|
|
qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim) |
|
|
q, k, v = qkv.unbind(dim=2) |
|
|
|
|
|
q = q.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
|
|
|
|
|
|
attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) |
|
|
|
|
|
|
|
|
t_i = timestamps.unsqueeze(2) |
|
|
t_j = timestamps.unsqueeze(1) |
|
|
delta_t = t_j - t_i |
|
|
|
|
|
if self.bias_type == "linear": |
|
|
temporal_bias = -self.gamma * torch.abs(delta_t) |
|
|
elif self.bias_type == "gaussian": |
|
|
temporal_bias = -self.gamma * (delta_t ** 2) |
|
|
|
|
|
|
|
|
attn_logits = attn_logits + temporal_bias.unsqueeze(1) |
|
|
|
|
|
|
|
|
if self.causal: |
|
|
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0) |
|
|
attn_logits = attn_logits.masked_fill(causal_mask == 0, float("-inf")) |
|
|
|
|
|
attn_weights = F.softmax(attn_logits, dim=-1) |
|
|
attn_output = torch.matmul(attn_weights, v) |
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).reshape(B, T, D) |
|
|
output = self.out_proj(attn_output) |
|
|
|
|
|
return output, attn_weights |
|
|
|