|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def scan(f, init, xs, out, checkpoint_group=0): |
|
|
""" |
|
|
模拟JAX中的lax.scan函数,用于序列化处理数据。 |
|
|
|
|
|
参数: |
|
|
f: 处理函数,接收(carry, x)作为输入,返回(new_carry, y) |
|
|
init: 初始状态值 |
|
|
xs: 输入序列,可以是字典或列表 |
|
|
out: 输出结果的存储张量 |
|
|
checkpoint_group: 梯度检查点分组数量,用于节省内存 |
|
|
|
|
|
返回: |
|
|
carry: 最终的状态值 |
|
|
out: 填充好的输出张量 |
|
|
""" |
|
|
|
|
|
carry = init |
|
|
|
|
|
|
|
|
if isinstance(xs, dict): |
|
|
|
|
|
num_items = len(next(iter(xs.values()))) |
|
|
else: |
|
|
|
|
|
num_items = len(xs[0]) |
|
|
|
|
|
def scan_fn(carry, i_start, i_end): |
|
|
"""内部扫描函数,处理从i_start到i_end的元素""" |
|
|
for i in range(i_start, i_end): |
|
|
|
|
|
if isinstance(xs, dict): |
|
|
|
|
|
x = {key: tensor[i] for key, tensor in xs.items()} |
|
|
else: |
|
|
|
|
|
x = [x[i] for x in xs] |
|
|
|
|
|
|
|
|
carry, y = f(carry, x) |
|
|
|
|
|
|
|
|
out[i] = y |
|
|
|
|
|
|
|
|
return carry |
|
|
|
|
|
|
|
|
if checkpoint_group > 0: |
|
|
|
|
|
ckpt_every_n = num_items // checkpoint_group |
|
|
|
|
|
|
|
|
for k in range(0, num_items, ckpt_every_n): |
|
|
|
|
|
carry = torch.utils.checkpoint.checkpoint( |
|
|
scan_fn, carry, k, min(k + ckpt_every_n, num_items), use_reentrant=False |
|
|
) |
|
|
else: |
|
|
|
|
|
carry = scan_fn(carry, 0, num_items) |
|
|
|
|
|
|
|
|
return carry, out |
|
|
|
|
|
def ln_fwd(x, gamma, beta, eps=1e-6): |
|
|
"Batch forward for LayerNorm." |
|
|
|
|
|
|
|
|
mu = x.mean(dim=-1, keepdim=True) |
|
|
var = x.var(dim=-1, keepdim=True, unbiased=False) |
|
|
|
|
|
|
|
|
std = torch.sqrt(var + eps) |
|
|
x_hat = (x - mu) / std |
|
|
|
|
|
|
|
|
y = gamma * x_hat + beta |
|
|
|
|
|
return y |
|
|
|
|
|
def ln_fused_l2_bwd(x, l2_target, gamma, beta, eps=1e-6): |
|
|
""" |
|
|
层归一化(LayerNorm)与L2损失融合的反向传播函数。 |
|
|
|
|
|
这个函数执行两个操作: |
|
|
1. 前向传播:对输入x进行层归一化,得到输出y |
|
|
2. 反向传播:计算L2损失(y - l2_target)对输入x的梯度 |
|
|
|
|
|
参数: |
|
|
x: 输入张量 |
|
|
l2_target: L2损失的目标值 |
|
|
gamma: 层归一化的缩放参数 |
|
|
beta: 层归一化的偏移参数 |
|
|
eps: 数值稳定性的小常数 |
|
|
|
|
|
返回: |
|
|
z: 损失对输入x的梯度 |
|
|
""" |
|
|
D = x.shape[-1] |
|
|
|
|
|
|
|
|
mu = x.mean(dim=-1, keepdim=True) |
|
|
var = x.var(dim=-1, keepdim=True, unbiased=False) |
|
|
|
|
|
|
|
|
std = torch.sqrt(var + eps) |
|
|
x_hat = (x - mu) / std |
|
|
|
|
|
|
|
|
y = gamma * x_hat + beta |
|
|
|
|
|
|
|
|
grad_output = y - l2_target |
|
|
grad_x_hat = grad_output * gamma |
|
|
|
|
|
|
|
|
z = ( |
|
|
(1.0 / D) |
|
|
* ( |
|
|
D * grad_x_hat |
|
|
- grad_x_hat.sum(dim=-1, keepdim=True) |
|
|
- x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) |
|
|
) |
|
|
/ std |
|
|
) |
|
|
|
|
|
return z |
|
|
|
|
|
from torch.autograd import Function |
|
|
class MyLinearFunction(Function): |
|
|
@staticmethod |
|
|
def forward(ctx, input, weight, bias): |
|
|
""" |
|
|
正向计算: y = x * W^T + b |
|
|
参数: |
|
|
ctx :上下文对象,用于保存反向传播时需要的信息。 |
|
|
input :输入 tensor, 尺寸为 (N, in_features) |
|
|
weight :权重 tensor, 尺寸为 (out_features, in_features) |
|
|
bias :偏置 tensor, 尺寸为 (out_features) |
|
|
返回: |
|
|
输出 tensor, 尺寸为 (N, out_features) |
|
|
""" |
|
|
|
|
|
ctx.save_for_backward(input, weight, bias) |
|
|
|
|
|
|
|
|
output = input.matmul(weight.t()) + bias |
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
""" |
|
|
反向传播:计算正向计算中各个输入的梯度。 |
|
|
参数: |
|
|
grad_output:从上层传回来的梯度,形状与 forward 的输出相同 (N, out_features) |
|
|
返回: |
|
|
grad_input :关于 input 的梯度,形状 (N, in_features) |
|
|
grad_weight :关于 weight 的梯度,形状 (out_features, in_features) |
|
|
grad_bias :关于 bias 的梯度,形状 (out_features) |
|
|
""" |
|
|
|
|
|
input, weight, bias = ctx.saved_tensors |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
grad_input = grad_output.matmul(weight) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
grad_weight = grad_output.t().matmul(input) |
|
|
|
|
|
|
|
|
|
|
|
grad_bias = grad_output.sum(dim=0) |
|
|
|
|
|
|
|
|
return grad_input, grad_weight, grad_bias |
|
|
|
|
|
class TTT_Cross_Layer(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.input_size = config.concept_dim |
|
|
self.concept_dim = config.concept_dim |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.logit_dim = config.logit_dim |
|
|
|
|
|
self.weight_linear = nn.Parameter(torch.empty(self.concept_dim, self.input_size, self.logit_dim)) |
|
|
self.weight_ln = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim)) |
|
|
self.bias_linear = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim)) |
|
|
self.bias_ln = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.config = config |
|
|
self.init_weights() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_params_as_logits(self, batch_size, sequence_length): |
|
|
weight_linear_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype) |
|
|
weight_ln_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype) |
|
|
bias_linear_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype) |
|
|
bias_ln_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype) |
|
|
|
|
|
params = { |
|
|
'weight_linear_tmp': weight_linear_tmp, |
|
|
'weight_ln_tmp': weight_ln_tmp, |
|
|
'bias_linear_tmp': bias_linear_tmp, |
|
|
'bias_ln_tmp': bias_ln_tmp |
|
|
} |
|
|
return params |
|
|
|
|
|
def init_weights(self): |
|
|
|
|
|
nn.init.normal_(self.weight_linear, mean=0.0, std=self.config.initializer_range) |
|
|
nn.init._no_grad_fill_(self.weight_ln, 1.0 / self.logit_dim) |
|
|
|
|
|
|
|
|
nn.init.normal_(self.bias_linear, mean=0.0, std=self.config.initializer_range / self.logit_dim) |
|
|
nn.init.normal_(self.bias_linear, mean=0.0, std=self.config.initializer_range / self.logit_dim) |
|
|
|
|
|
def get_weight_per_token(self, params): |
|
|
|
|
|
weight_linear_tmp = torch.einsum('iol,bsl->bsio', self.weight_linear, params['weight_linear_tmp']) |
|
|
weight_ln_tmp = torch.einsum('ol,bsl->bso', self.weight_ln, params['weight_ln_tmp']) |
|
|
bias_linear_tmp = torch.einsum('ol,bsl->bso', self.bias_linear, params['bias_linear_tmp']) |
|
|
bias_ln_tmp = torch.einsum('ol,bsl->bso', self.bias_ln, params['bias_ln_tmp']) |
|
|
|
|
|
return weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp |
|
|
|
|
|
def learn(self, k, v, params, lr_linear=1, lr_ln=1): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp = self.get_weight_per_token(params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
z = torch.einsum('bsi,bsio->bso', k, weight_linear_tmp) + bias_linear_tmp |
|
|
mu = z.mean(dim=-1, keepdim=True) |
|
|
var = z.var(dim=-1, keepdim=True, unbiased=False) |
|
|
|
|
|
|
|
|
eps = 1e-6 |
|
|
std = torch.sqrt(var + eps) |
|
|
z_hat = (z - mu) / std |
|
|
|
|
|
output_reshaped = weight_ln_tmp * z_hat + bias_ln_tmp + k |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
error_reshaped = output_reshaped - v |
|
|
|
|
|
|
|
|
|
|
|
grad_weight_ln_temp = error_reshaped * z_hat |
|
|
|
|
|
|
|
|
grad_weight_ln = grad_weight_ln_temp |
|
|
|
|
|
params0 = params['weight_ln_tmp'] - lr_ln * torch.einsum('ol,bso->bsl', self.weight_ln, grad_weight_ln) |
|
|
|
|
|
|
|
|
|
|
|
grad_bias_ln = error_reshaped |
|
|
params1 = params['bias_ln_tmp'] - lr_ln * torch.einsum('ol,bso->bsl', self.bias_ln, grad_bias_ln) |
|
|
|
|
|
|
|
|
|
|
|
grad_linear = weight_ln_tmp * error_reshaped / std |
|
|
|
|
|
grad_weight_linear = torch.einsum('bsi,bso->bsio', k, grad_linear) |
|
|
|
|
|
|
|
|
params2 = params['weight_linear_tmp'] - lr_linear * torch.einsum('iol,bsio->bsl', self.weight_linear, grad_weight_linear) |
|
|
|
|
|
grad_b = grad_linear |
|
|
|
|
|
params3 = params['bias_linear_tmp'] - lr_linear * torch.einsum('ol,bso->bsl', self.bias_linear, grad_b) |
|
|
|
|
|
params_new = { |
|
|
'weight_linear_tmp': params2, |
|
|
'weight_ln_tmp': params0, |
|
|
'bias_linear_tmp': params3, |
|
|
'bias_ln_tmp': params1 |
|
|
} |
|
|
|
|
|
return params_new |
|
|
|
|
|
def predict(self, q, params): |
|
|
weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp = self.get_weight_per_token(params) |
|
|
z = torch.einsum('bsi,bsio->bso', q, weight_linear_tmp) + bias_linear_tmp |
|
|
mu = z.mean(dim=-1, keepdim=True) |
|
|
var = z.var(dim=-1, keepdim=True, unbiased=False) |
|
|
|
|
|
|
|
|
eps = 1e-6 |
|
|
std = torch.sqrt(var + eps) |
|
|
z_hat = (z - mu) / std |
|
|
|
|
|
output = weight_ln_tmp * z_hat + bias_ln_tmp + q |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|