temporal-self-attention / temporal_attention.py
sanskxr02's picture
Update temporal_attention.py
0076e3f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class TemporalSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, bias_type="linear", gamma=1.0, causal=False):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
assert bias_type in ["linear", "gaussian"]
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.bias_type = bias_type
self.gamma = gamma
self.causal = causal
self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x, timestamps):
"""
x: [B, T, D]
timestamps: [B, T] — real-valued time signals per token
"""
B, T, D = x.size()
# Project input to Q, K, V
qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2) # each: [B, T, num_heads, head_dim]
q = q.transpose(1, 2) # [B, num_heads, T, head_dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Scaled dot-product attention
attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) # [B, H, T, T]
# Compute temporal bias
t_i = timestamps.unsqueeze(2) # [B, T, 1]
t_j = timestamps.unsqueeze(1) # [B, 1, T]
delta_t = t_j - t_i # [B, T, T]
if self.bias_type == "linear":
temporal_bias = -self.gamma * torch.abs(delta_t) # [B, T, T]
elif self.bias_type == "gaussian":
temporal_bias = -self.gamma * (delta_t ** 2)
# Expand for broadcasting: [B, 1, T, T]
attn_logits = attn_logits + temporal_bias.unsqueeze(1)
# Causal masking (prevent attending to future)
if self.causal:
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0) # [1,1,T,T]
attn_logits = attn_logits.masked_fill(causal_mask == 0, float("-inf"))
attn_weights = F.softmax(attn_logits, dim=-1) # [B, H, T, T]
attn_output = torch.matmul(attn_weights, v) # [B, H, T, head_dim]
# Merge heads
attn_output = attn_output.transpose(1, 2).reshape(B, T, D)
output = self.out_proj(attn_output)
return output, attn_weights