Nirvana / ttt_cross_layer.py
YuhuaJiang's picture
initial upload
4cc4bbd verified
import torch
import torch.nn as nn
import torch.nn.functional as F
def scan(f, init, xs, out, checkpoint_group=0):
"""
模拟JAX中的lax.scan函数,用于序列化处理数据。
参数:
f: 处理函数,接收(carry, x)作为输入,返回(new_carry, y)
init: 初始状态值
xs: 输入序列,可以是字典或列表
out: 输出结果的存储张量
checkpoint_group: 梯度检查点分组数量,用于节省内存
返回:
carry: 最终的状态值
out: 填充好的输出张量
"""
# 初始化状态值
carry = init
# 确定输入序列的长度
if isinstance(xs, dict):
# 如果输入是字典,取第一个键对应值的长度
num_items = len(next(iter(xs.values())))
else:
# 如果输入是列表,取第一个元素的长度
num_items = len(xs[0])
def scan_fn(carry, i_start, i_end):
"""内部扫描函数,处理从i_start到i_end的元素"""
for i in range(i_start, i_end):
# 提取当前位置的输入
if isinstance(xs, dict):
# 字典情况:创建包含每个键在位置i处值的新字典
x = {key: tensor[i] for key, tensor in xs.items()}
else:
# 列表情况:创建包含每个列表在位置i处值的新列表
x = [x[i] for x in xs]
# 调用处理函数f,获取新的状态和输出
carry, y = f(carry, x)
# 将输出存储到结果张量中
out[i] = y
# 返回最终状态
return carry
# 根据checkpoint_group决定是否使用梯度检查点
if checkpoint_group > 0:
# 计算每个检查点组包含的元素数量
ckpt_every_n = num_items // checkpoint_group
# 按组处理数据
for k in range(0, num_items, ckpt_every_n):
# 使用torch.utils.checkpoint节省内存
carry = torch.utils.checkpoint.checkpoint(
scan_fn, carry, k, min(k + ckpt_every_n, num_items), use_reentrant=False
)
else:
# 不使用检查点,直接处理所有数据
carry = scan_fn(carry, 0, num_items)
# 返回最终状态和填充好的输出张量
return carry, out
def ln_fwd(x, gamma, beta, eps=1e-6):
"Batch forward for LayerNorm."
# Mean and variance computation
mu = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
# Normalization
std = torch.sqrt(var + eps)
x_hat = (x - mu) / std
# Scale and shift
y = gamma * x_hat + beta
return y
def ln_fused_l2_bwd(x, l2_target, gamma, beta, eps=1e-6):
"""
层归一化(LayerNorm)与L2损失融合的反向传播函数。
这个函数执行两个操作:
1. 前向传播:对输入x进行层归一化,得到输出y
2. 反向传播:计算L2损失(y - l2_target)对输入x的梯度
参数:
x: 输入张量
l2_target: L2损失的目标值
gamma: 层归一化的缩放参数
beta: 层归一化的偏移参数
eps: 数值稳定性的小常数
返回:
z: 损失对输入x的梯度
"""
D = x.shape[-1] # 获取特征维度
# 计算均值和方差
mu = x.mean(dim=-1, keepdim=True) # 沿特征维度计算均值
var = x.var(dim=-1, keepdim=True, unbiased=False) # 计算方差
# 归一化处理
std = torch.sqrt(var + eps) # 计算标准差
x_hat = (x - mu) / std # 标准化输入
# 缩放和偏移
y = gamma * x_hat + beta # 层归一化的输出
# 计算梯度
grad_output = y - l2_target # L2损失的梯度
grad_x_hat = grad_output * gamma # 对标准化输入的梯度
# 完整的反向传播公式,考虑了归一化操作的链式法则
z = (
(1.0 / D)
* (
D * grad_x_hat
- grad_x_hat.sum(dim=-1, keepdim=True) # 均值项的梯度贡献
- x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) # 方差项的梯度贡献
)
/ std # 除以标准差完成梯度计算
)
return z
from torch.autograd import Function
class MyLinearFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias):
"""
正向计算: y = x * W^T + b
参数:
ctx :上下文对象,用于保存反向传播时需要的信息。
input :输入 tensor, 尺寸为 (N, in_features)
weight :权重 tensor, 尺寸为 (out_features, in_features)
bias :偏置 tensor, 尺寸为 (out_features)
返回:
输出 tensor, 尺寸为 (N, out_features)
"""
# 保存必要的中间变量,供 backward 时使用
ctx.save_for_backward(input, weight, bias)
# 计算输出
output = input.matmul(weight.t()) + bias
return output
@staticmethod
def backward(ctx, grad_output):
"""
反向传播:计算正向计算中各个输入的梯度。
参数:
grad_output:从上层传回来的梯度,形状与 forward 的输出相同 (N, out_features)
返回:
grad_input :关于 input 的梯度,形状 (N, in_features)
grad_weight :关于 weight 的梯度,形状 (out_features, in_features)
grad_bias :关于 bias 的梯度,形状 (out_features)
"""
# 从上下文中取出保存的变量
input, weight, bias = ctx.saved_tensors
# 链式法则:已知 output = input.matmul(weight.t()) + bias
# 关于 input 的梯度:
# ∂L/∂input = ∂L/∂output * ∂output/∂input = grad_output.matmul(weight)
grad_input = grad_output.matmul(weight)
# 关于 weight 的梯度:
# ∂L/∂weight = ∂L/∂output^T * ∂output/∂weight
# 注意到 output 对 weight 的导数为 input 的转置,此处:
# grad_weight 的计算通常为:grad_output^T.matmul(input)
grad_weight = grad_output.t().matmul(input)
# 关于 bias 的梯度:
# 因为 output = ... + bias,因此每个 bias 项对应所有样本的梯度和
grad_bias = grad_output.sum(dim=0)
# 注意:返回的梯度顺序必须与 forward 中参数的顺序一致
return grad_input, grad_weight, grad_bias
class TTT_Cross_Layer(nn.Module):
def __init__(self, config):
super().__init__()
self.input_size = config.concept_dim # 128
self.concept_dim = config.concept_dim # 128
# self.linear = nn.Linear(self.input_size, self.hidden_size)
# self.ln = nn.LayerNorm(self.hidden_size)
# self.logit_dim = 32
self.logit_dim = config.logit_dim
self.weight_linear = nn.Parameter(torch.empty(self.concept_dim, self.input_size, self.logit_dim))
self.weight_ln = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim))
self.bias_linear = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim))
self.bias_ln = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim))
# self.weight_linear_tmp = torch.empty_like(self.weight_linear)
# self.weight_ln_tmp = torch.empty_like(self.weight_ln)
# self.bias_linear_tmp = torch.empty_like(self.bias_linear)
# self.bias_ln_tmp = torch.empty_like(self.bias_ln)
self.config = config
self.init_weights()
# def init_tmp_weights(self):
# weight_linear_tmp = self.weight_linear.clone().to(self.weight_linear.device).to(self.weight_linear.dtype)
# weight_ln_tmp = self.weight_ln.clone().to(self.weight_ln.device).to(self.weight_ln.dtype)
# bias_linear_tmp = self.bias_linear.clone().to(self.bias_linear.device).to(self.bias_linear.dtype)
# bias_ln_tmp = self.bias_ln.clone().to(self.bias_ln.device).to(self.bias_ln.dtype)
# params = {
# 'weight_linear_tmp': weight_linear_tmp,
# 'weight_ln_tmp': weight_ln_tmp,
# 'bias_linear_tmp': bias_linear_tmp,
# 'bias_ln_tmp': bias_ln_tmp
# }
# return params
def init_params_as_logits(self, batch_size, sequence_length):
weight_linear_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype)
weight_ln_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype)
bias_linear_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype)
bias_ln_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype)
params = {
'weight_linear_tmp': weight_linear_tmp,
'weight_ln_tmp': weight_ln_tmp,
'bias_linear_tmp': bias_linear_tmp,
'bias_ln_tmp': bias_ln_tmp
}
return params
def init_weights(self):
# torch.manual_seed(42) # 固定随机种子可能导致可预测性
nn.init.normal_(self.weight_linear, mean=0.0, std=self.config.initializer_range)
nn.init._no_grad_fill_(self.weight_ln, 1.0 / self.logit_dim)
# nn.init.zeros_(self.bias_linear)
# nn.init.zeros_(self.bias_ln)
nn.init.normal_(self.bias_linear, mean=0.0, std=self.config.initializer_range / self.logit_dim)
nn.init.normal_(self.bias_linear, mean=0.0, std=self.config.initializer_range / self.logit_dim)
def get_weight_per_token(self, params):
weight_linear_tmp = torch.einsum('iol,bsl->bsio', self.weight_linear, params['weight_linear_tmp'])
weight_ln_tmp = torch.einsum('ol,bsl->bso', self.weight_ln, params['weight_ln_tmp'])
bias_linear_tmp = torch.einsum('ol,bsl->bso', self.bias_linear, params['bias_linear_tmp'])
bias_ln_tmp = torch.einsum('ol,bsl->bso', self.bias_ln, params['bias_ln_tmp'])
return weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp
def learn(self, k, v, params, lr_linear=1, lr_ln=1):
# k和v形状: [batch_size, length, channel_dim]
# batch_size, seq_length, channel_dim = k.shape
# weight_linear_tmp = params['weight_linear_tmp']
# weight_ln_tmp = params['weight_ln_tmp']
# bias_linear_tmp = params['bias_linear_tmp']
# bias_ln_tmp = params['bias_ln_tmp']
weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp = self.get_weight_per_token(params)
# 1. 将输入重塑为二维以进行预测
# k_reshaped = k.reshape(-1, channel_dim) # [batch_size*length, channel_dim]
# output_reshaped = self.predict(k_reshaped, params) # [batch_size*length, channel_dim]
# z = F.linear(k_reshaped, params['weight_linear_tmp'], params['bias_linear_tmp'])
# mu = z.mean(dim=-1, keepdim=True)
# var = z.var(dim=-1, keepdim=True, unbiased=False)
z = torch.einsum('bsi,bsio->bso', k, weight_linear_tmp) + bias_linear_tmp
mu = z.mean(dim=-1, keepdim=True)
var = z.var(dim=-1, keepdim=True, unbiased=False)
# Normalization
eps = 1e-6
std = torch.sqrt(var + eps)
z_hat = (z - mu) / std
# output_reshaped = params['weight_ln_tmp'] * z_hat + params['bias_ln_tmp'] + k
output_reshaped = weight_ln_tmp * z_hat + bias_ln_tmp + k
# # 计算误差
# v_reshaped = v.reshape(-1, channel_dim)
# error_reshaped = output_reshaped - v_reshaped # [batch_size*length, channel_dim]
error_reshaped = output_reshaped - v
# 计算层归一化梯度
# 层归一化参数更新
# ln_rate = learning_rate * 0.1 # 降低LN学习率
grad_weight_ln_temp = error_reshaped * z_hat
# grad_weight_ln = grad_weight_ln_temp.mean(dim=0) #
# weight_ln_tmp = weight_ln_tmp - ln_rate * grad_weight_ln # sequence length, channel_dim
grad_weight_ln = grad_weight_ln_temp
# batch_size, sequence length, logit_dim
params0 = params['weight_ln_tmp'] - lr_ln * torch.einsum('ol,bso->bsl', self.weight_ln, grad_weight_ln)
# bias_update = ln_rate * error_reshaped # .mean(dim=0)
# bias_ln_tmp = bias_ln_tmp - bias_update # batch_size, sequence length, concept_dim
grad_bias_ln = error_reshaped
params1 = params['bias_ln_tmp'] - lr_ln * torch.einsum('ol,bso->bsl', self.bias_ln, grad_bias_ln)
# 线性层权重梯度: [out_dim, in_dim]
# grad_linear_temp = error_reshaped - error_reshaped.mean(dim=-1, keepdim=True) - z_hat * grad_weight_ln_temp.mean(dim=-1, keepdim=True)
grad_linear = weight_ln_tmp * error_reshaped / std # batch_size, sequence length, concept_dim
# grad_weight_linear = grad_linear.t() @ k # [channel_dim, channel_dim]
grad_weight_linear = torch.einsum('bsi,bso->bsio', k, grad_linear)
# 应用梯度 (避免使用原地操作 -=)
# weight_linear_tmp = weight_linear_tmp - learning_rate * grad_weight_linear.mean(dim=0)
params2 = params['weight_linear_tmp'] - lr_linear * torch.einsum('iol,bsio->bsl', self.weight_linear, grad_weight_linear)
# 更新偏置(如果存在) (避免使用原地操作 -=)
grad_b = grad_linear #.mean(dim=0) # [channel_dim]
# bias_linear_tmp = bias_linear_tmp - learning_rate * grad_b
params3 = params['bias_linear_tmp'] - lr_linear * torch.einsum('ol,bso->bsl', self.bias_linear, grad_b)
params_new = {
'weight_linear_tmp': params2,
'weight_ln_tmp': params0,
'bias_linear_tmp': params3,
'bias_ln_tmp': params1
}
return params_new
def predict(self, q, params):
weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp = self.get_weight_per_token(params)
z = torch.einsum('bsi,bsio->bso', q, weight_linear_tmp) + bias_linear_tmp
mu = z.mean(dim=-1, keepdim=True)
var = z.var(dim=-1, keepdim=True, unbiased=False)
# Normalization
eps = 1e-6
std = torch.sqrt(var + eps)
z_hat = (z - mu) / std
# output_reshaped = params['weight_ln_tmp'] * z_hat + params['bias_ln_tmp'] + k
output = weight_ln_tmp * z_hat + bias_ln_tmp + q
return output