import torch import torch.nn as nn import torch.nn.functional as F def scan(f, init, xs, out, checkpoint_group=0): """ 模拟JAX中的lax.scan函数,用于序列化处理数据。 参数: f: 处理函数,接收(carry, x)作为输入,返回(new_carry, y) init: 初始状态值 xs: 输入序列,可以是字典或列表 out: 输出结果的存储张量 checkpoint_group: 梯度检查点分组数量,用于节省内存 返回: carry: 最终的状态值 out: 填充好的输出张量 """ # 初始化状态值 carry = init # 确定输入序列的长度 if isinstance(xs, dict): # 如果输入是字典,取第一个键对应值的长度 num_items = len(next(iter(xs.values()))) else: # 如果输入是列表,取第一个元素的长度 num_items = len(xs[0]) def scan_fn(carry, i_start, i_end): """内部扫描函数,处理从i_start到i_end的元素""" for i in range(i_start, i_end): # 提取当前位置的输入 if isinstance(xs, dict): # 字典情况:创建包含每个键在位置i处值的新字典 x = {key: tensor[i] for key, tensor in xs.items()} else: # 列表情况:创建包含每个列表在位置i处值的新列表 x = [x[i] for x in xs] # 调用处理函数f,获取新的状态和输出 carry, y = f(carry, x) # 将输出存储到结果张量中 out[i] = y # 返回最终状态 return carry # 根据checkpoint_group决定是否使用梯度检查点 if checkpoint_group > 0: # 计算每个检查点组包含的元素数量 ckpt_every_n = num_items // checkpoint_group # 按组处理数据 for k in range(0, num_items, ckpt_every_n): # 使用torch.utils.checkpoint节省内存 carry = torch.utils.checkpoint.checkpoint( scan_fn, carry, k, min(k + ckpt_every_n, num_items), use_reentrant=False ) else: # 不使用检查点,直接处理所有数据 carry = scan_fn(carry, 0, num_items) # 返回最终状态和填充好的输出张量 return carry, out def ln_fwd(x, gamma, beta, eps=1e-6): "Batch forward for LayerNorm." # Mean and variance computation mu = x.mean(dim=-1, keepdim=True) var = x.var(dim=-1, keepdim=True, unbiased=False) # Normalization std = torch.sqrt(var + eps) x_hat = (x - mu) / std # Scale and shift y = gamma * x_hat + beta return y def ln_fused_l2_bwd(x, l2_target, gamma, beta, eps=1e-6): """ 层归一化(LayerNorm)与L2损失融合的反向传播函数。 这个函数执行两个操作: 1. 前向传播:对输入x进行层归一化,得到输出y 2. 反向传播:计算L2损失(y - l2_target)对输入x的梯度 参数: x: 输入张量 l2_target: L2损失的目标值 gamma: 层归一化的缩放参数 beta: 层归一化的偏移参数 eps: 数值稳定性的小常数 返回: z: 损失对输入x的梯度 """ D = x.shape[-1] # 获取特征维度 # 计算均值和方差 mu = x.mean(dim=-1, keepdim=True) # 沿特征维度计算均值 var = x.var(dim=-1, keepdim=True, unbiased=False) # 计算方差 # 归一化处理 std = torch.sqrt(var + eps) # 计算标准差 x_hat = (x - mu) / std # 标准化输入 # 缩放和偏移 y = gamma * x_hat + beta # 层归一化的输出 # 计算梯度 grad_output = y - l2_target # L2损失的梯度 grad_x_hat = grad_output * gamma # 对标准化输入的梯度 # 完整的反向传播公式,考虑了归一化操作的链式法则 z = ( (1.0 / D) * ( D * grad_x_hat - grad_x_hat.sum(dim=-1, keepdim=True) # 均值项的梯度贡献 - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) # 方差项的梯度贡献 ) / std # 除以标准差完成梯度计算 ) return z from torch.autograd import Function class MyLinearFunction(Function): @staticmethod def forward(ctx, input, weight, bias): """ 正向计算: y = x * W^T + b 参数: ctx :上下文对象,用于保存反向传播时需要的信息。 input :输入 tensor, 尺寸为 (N, in_features) weight :权重 tensor, 尺寸为 (out_features, in_features) bias :偏置 tensor, 尺寸为 (out_features) 返回: 输出 tensor, 尺寸为 (N, out_features) """ # 保存必要的中间变量,供 backward 时使用 ctx.save_for_backward(input, weight, bias) # 计算输出 output = input.matmul(weight.t()) + bias return output @staticmethod def backward(ctx, grad_output): """ 反向传播:计算正向计算中各个输入的梯度。 参数: grad_output:从上层传回来的梯度,形状与 forward 的输出相同 (N, out_features) 返回: grad_input :关于 input 的梯度,形状 (N, in_features) grad_weight :关于 weight 的梯度,形状 (out_features, in_features) grad_bias :关于 bias 的梯度,形状 (out_features) """ # 从上下文中取出保存的变量 input, weight, bias = ctx.saved_tensors # 链式法则:已知 output = input.matmul(weight.t()) + bias # 关于 input 的梯度: # ∂L/∂input = ∂L/∂output * ∂output/∂input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight) # 关于 weight 的梯度: # ∂L/∂weight = ∂L/∂output^T * ∂output/∂weight # 注意到 output 对 weight 的导数为 input 的转置,此处: # grad_weight 的计算通常为:grad_output^T.matmul(input) grad_weight = grad_output.t().matmul(input) # 关于 bias 的梯度: # 因为 output = ... + bias,因此每个 bias 项对应所有样本的梯度和 grad_bias = grad_output.sum(dim=0) # 注意:返回的梯度顺序必须与 forward 中参数的顺序一致 return grad_input, grad_weight, grad_bias class TTT_Cross_Layer(nn.Module): def __init__(self, config): super().__init__() self.input_size = config.concept_dim # 128 self.concept_dim = config.concept_dim # 128 # self.linear = nn.Linear(self.input_size, self.hidden_size) # self.ln = nn.LayerNorm(self.hidden_size) # self.logit_dim = 32 self.logit_dim = config.logit_dim self.weight_linear = nn.Parameter(torch.empty(self.concept_dim, self.input_size, self.logit_dim)) self.weight_ln = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim)) self.bias_linear = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim)) self.bias_ln = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim)) # self.weight_linear_tmp = torch.empty_like(self.weight_linear) # self.weight_ln_tmp = torch.empty_like(self.weight_ln) # self.bias_linear_tmp = torch.empty_like(self.bias_linear) # self.bias_ln_tmp = torch.empty_like(self.bias_ln) self.config = config self.init_weights() # def init_tmp_weights(self): # weight_linear_tmp = self.weight_linear.clone().to(self.weight_linear.device).to(self.weight_linear.dtype) # weight_ln_tmp = self.weight_ln.clone().to(self.weight_ln.device).to(self.weight_ln.dtype) # bias_linear_tmp = self.bias_linear.clone().to(self.bias_linear.device).to(self.bias_linear.dtype) # bias_ln_tmp = self.bias_ln.clone().to(self.bias_ln.device).to(self.bias_ln.dtype) # params = { # 'weight_linear_tmp': weight_linear_tmp, # 'weight_ln_tmp': weight_ln_tmp, # 'bias_linear_tmp': bias_linear_tmp, # 'bias_ln_tmp': bias_ln_tmp # } # return params def init_params_as_logits(self, batch_size, sequence_length): weight_linear_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype) weight_ln_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype) bias_linear_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype) bias_ln_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype) params = { 'weight_linear_tmp': weight_linear_tmp, 'weight_ln_tmp': weight_ln_tmp, 'bias_linear_tmp': bias_linear_tmp, 'bias_ln_tmp': bias_ln_tmp } return params def init_weights(self): # torch.manual_seed(42) # 固定随机种子可能导致可预测性 nn.init.normal_(self.weight_linear, mean=0.0, std=self.config.initializer_range) nn.init._no_grad_fill_(self.weight_ln, 1.0 / self.logit_dim) # nn.init.zeros_(self.bias_linear) # nn.init.zeros_(self.bias_ln) nn.init.normal_(self.bias_linear, mean=0.0, std=self.config.initializer_range / self.logit_dim) nn.init.normal_(self.bias_linear, mean=0.0, std=self.config.initializer_range / self.logit_dim) def get_weight_per_token(self, params): weight_linear_tmp = torch.einsum('iol,bsl->bsio', self.weight_linear, params['weight_linear_tmp']) weight_ln_tmp = torch.einsum('ol,bsl->bso', self.weight_ln, params['weight_ln_tmp']) bias_linear_tmp = torch.einsum('ol,bsl->bso', self.bias_linear, params['bias_linear_tmp']) bias_ln_tmp = torch.einsum('ol,bsl->bso', self.bias_ln, params['bias_ln_tmp']) return weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp def learn(self, k, v, params, lr_linear=1, lr_ln=1): # k和v形状: [batch_size, length, channel_dim] # batch_size, seq_length, channel_dim = k.shape # weight_linear_tmp = params['weight_linear_tmp'] # weight_ln_tmp = params['weight_ln_tmp'] # bias_linear_tmp = params['bias_linear_tmp'] # bias_ln_tmp = params['bias_ln_tmp'] weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp = self.get_weight_per_token(params) # 1. 将输入重塑为二维以进行预测 # k_reshaped = k.reshape(-1, channel_dim) # [batch_size*length, channel_dim] # output_reshaped = self.predict(k_reshaped, params) # [batch_size*length, channel_dim] # z = F.linear(k_reshaped, params['weight_linear_tmp'], params['bias_linear_tmp']) # mu = z.mean(dim=-1, keepdim=True) # var = z.var(dim=-1, keepdim=True, unbiased=False) z = torch.einsum('bsi,bsio->bso', k, weight_linear_tmp) + bias_linear_tmp mu = z.mean(dim=-1, keepdim=True) var = z.var(dim=-1, keepdim=True, unbiased=False) # Normalization eps = 1e-6 std = torch.sqrt(var + eps) z_hat = (z - mu) / std # output_reshaped = params['weight_ln_tmp'] * z_hat + params['bias_ln_tmp'] + k output_reshaped = weight_ln_tmp * z_hat + bias_ln_tmp + k # # 计算误差 # v_reshaped = v.reshape(-1, channel_dim) # error_reshaped = output_reshaped - v_reshaped # [batch_size*length, channel_dim] error_reshaped = output_reshaped - v # 计算层归一化梯度 # 层归一化参数更新 # ln_rate = learning_rate * 0.1 # 降低LN学习率 grad_weight_ln_temp = error_reshaped * z_hat # grad_weight_ln = grad_weight_ln_temp.mean(dim=0) # # weight_ln_tmp = weight_ln_tmp - ln_rate * grad_weight_ln # sequence length, channel_dim grad_weight_ln = grad_weight_ln_temp # batch_size, sequence length, logit_dim params0 = params['weight_ln_tmp'] - lr_ln * torch.einsum('ol,bso->bsl', self.weight_ln, grad_weight_ln) # bias_update = ln_rate * error_reshaped # .mean(dim=0) # bias_ln_tmp = bias_ln_tmp - bias_update # batch_size, sequence length, concept_dim grad_bias_ln = error_reshaped params1 = params['bias_ln_tmp'] - lr_ln * torch.einsum('ol,bso->bsl', self.bias_ln, grad_bias_ln) # 线性层权重梯度: [out_dim, in_dim] # grad_linear_temp = error_reshaped - error_reshaped.mean(dim=-1, keepdim=True) - z_hat * grad_weight_ln_temp.mean(dim=-1, keepdim=True) grad_linear = weight_ln_tmp * error_reshaped / std # batch_size, sequence length, concept_dim # grad_weight_linear = grad_linear.t() @ k # [channel_dim, channel_dim] grad_weight_linear = torch.einsum('bsi,bso->bsio', k, grad_linear) # 应用梯度 (避免使用原地操作 -=) # weight_linear_tmp = weight_linear_tmp - learning_rate * grad_weight_linear.mean(dim=0) params2 = params['weight_linear_tmp'] - lr_linear * torch.einsum('iol,bsio->bsl', self.weight_linear, grad_weight_linear) # 更新偏置(如果存在) (避免使用原地操作 -=) grad_b = grad_linear #.mean(dim=0) # [channel_dim] # bias_linear_tmp = bias_linear_tmp - learning_rate * grad_b params3 = params['bias_linear_tmp'] - lr_linear * torch.einsum('ol,bso->bsl', self.bias_linear, grad_b) params_new = { 'weight_linear_tmp': params2, 'weight_ln_tmp': params0, 'bias_linear_tmp': params3, 'bias_ln_tmp': params1 } return params_new def predict(self, q, params): weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp = self.get_weight_per_token(params) z = torch.einsum('bsi,bsio->bso', q, weight_linear_tmp) + bias_linear_tmp mu = z.mean(dim=-1, keepdim=True) var = z.var(dim=-1, keepdim=True, unbiased=False) # Normalization eps = 1e-6 std = torch.sqrt(var + eps) z_hat = (z - mu) / std # output_reshaped = params['weight_ln_tmp'] * z_hat + params['bias_ln_tmp'] + k output = weight_ln_tmp * z_hat + bias_ln_tmp + q return output