import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple import numpy as np import math def pack_sign_bits(sign_tensor: torch.Tensor) -> torch.Tensor: sign_flat = sign_tensor.flatten() sign_uint8 = ((sign_flat == 1).to(torch.uint8)) remainder = sign_uint8.numel() % 8 if remainder != 0: padding = 8 - remainder sign_uint8 = torch.cat([ sign_uint8, torch.zeros(padding, dtype=torch.uint8, device=sign_uint8.device) ]) sign_uint8 = sign_uint8.reshape(-1, 8) shifts = torch.arange(7, -1, -1, device=sign_uint8.device, dtype=torch.uint8) packed = (sign_uint8 << shifts.unsqueeze(0)).sum(dim=1) return packed def unpack_sign_bits_ultra_fast(packed: torch.Tensor, original_shape: torch.Size) -> torch.Tensor: device = packed.device dtype = torch.float16 int8_tensor = packed.to(torch.int8) shifts = torch.arange(8, device=device).view(1, 8) expanded_int8 = int8_tensor.unsqueeze(-1) unpacked_bits = ((expanded_int8 >> shifts) & 1).to(dtype) unpacked_bits = unpacked_bits.view(int8_tensor.shape[0], -1) fp16_tensor = -2 * unpacked_bits + 1 if isinstance(original_shape, (tuple, list)): total_elements = 1 for dim in original_shape: total_elements *= dim original_shape = torch.Size(original_shape) else: total_elements = original_shape.numel() return fp16_tensor.flatten()[:total_elements].reshape(original_shape) def unpack_sign_bits(packed: torch.Tensor, original_shape: torch.Size) -> torch.Tensor: return unpack_sign_bits_ultra_fast(packed, original_shape) class OneBitLinear(nn.Module): def __init__(self, in_features: int, out_features: int, a_scale: torch.Tensor = None, b_scale: torch.Tensor = None, weight_packed: torch.Tensor = None, bias: Optional[torch.Tensor] = None, device=None, dtype=None): factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.in_features = in_features self.out_features = out_features if weight_packed is not None: expected_size = out_features * in_features // 8 if weight_packed.numel() == expected_size: weight_2d = weight_packed.view(out_features, in_features // 8).to(torch.int8) else: weight_2d = torch.zeros((out_features, in_features // 8), dtype=torch.int8, **factory_kwargs) self.register_buffer("weight", weight_2d, persistent=False) else: self.register_buffer("weight", torch.zeros((out_features, in_features // 8), dtype=torch.int8, **factory_kwargs), persistent=False) if a_scale is not None: self.register_buffer("input_factor", a_scale.to(torch.float16)) else: self.register_buffer("input_factor", torch.ones(in_features, dtype=torch.float16, **factory_kwargs)) if b_scale is not None: self.register_buffer("weight_scale", b_scale.to(torch.float16)) else: self.register_buffer("weight_scale", torch.ones(out_features, dtype=torch.float16, **factory_kwargs)) # Bias if bias is not None: self.register_buffer("bias", bias.to(torch.float16)) else: self.bias = None self.layernorm = nn.LayerNorm(out_features, elementwise_affine=False, **factory_kwargs) self._weight_cache = None def int8_to_fp16(self, int8_tensor): dtype = self.weight_scale.dtype shifts = torch.arange(8, device=int8_tensor.device).view(1, 1, 8) expanded_int8 = int8_tensor.unsqueeze(-1) unpacked_bits = ((expanded_int8 >> shifts) & 1).to(dtype) unpacked_bits = unpacked_bits.view(int8_tensor.shape[0], -1) fp16_tensor = -2 * unpacked_bits + 1 return fp16_tensor def forward(self, input): input_factor_shape = [1] * len(input.shape) input_factor_shape[-1] = self.in_features input = input * self.input_factor.view(*input_factor_shape) if self._weight_cache is not None: weight = self._weight_cache else: weight = self.int8_to_fp16(self.weight) self._weight_cache = weight output = F.linear(input, weight) weight_scale_shape = [1] * len(output.shape) weight_scale_shape[-1] = self.out_features output *= self.weight_scale.view(*weight_scale_shape) output = self.layernorm(output) if self.bias is not None: output += self.bias return output @classmethod def from_safetensors(cls, state_dict: dict, layer_idx: int, module_name: str): prefix = f"model.layers.{layer_idx}.{module_name}" input_factor_key = f"{prefix}.input_factor" weight_scale_key = f"{prefix}.weight_scale" weight_key = f"{prefix}.weight" bias_key = f"{prefix}.bias" input_factor = None if input_factor_key in state_dict: input_factor = state_dict[input_factor_key] elif f"{prefix}.a_scale" in state_dict: input_factor = state_dict[f"{prefix}.a_scale"] weight_scale = None if weight_scale_key in state_dict: weight_scale = state_dict[weight_scale_key] elif f"{prefix}.b_scale" in state_dict: weight_scale = state_dict[f"{prefix}.b_scale"] weight_packed = None if weight_key in state_dict: weight_packed = state_dict[weight_key] elif f"{prefix}.sign_packed" in state_dict: weight_packed = state_dict[f"{prefix}.sign_packed"] bias = state_dict.get(bias_key) if input_factor is None or weight_scale is None: return None in_features = input_factor.shape[0] out_features = weight_scale.shape[0] return cls( in_features=in_features, out_features=out_features, a_scale=input_factor, b_scale=weight_scale, weight_packed=weight_packed, bias=bias )