|
|
import torch, torch.nn as nn, torch.nn.functional as F |
|
|
def _pack_int4(q): |
|
|
q_u4=(q+8).to(torch.uint8); even=q_u4[::2]; odd=q_u4[1::2]; p=(even<<4)|odd |
|
|
if q_u4.numel()%2==1: p=torch.cat([p,(q_u4[-1]<<4)]) |
|
|
return p.contiguous() |
|
|
def _unpack_int4(packed,total_elems,device=None): |
|
|
hi=(packed>>4)&0x0F; lo=packed&0x0F; u4=torch.stack([hi,lo],-1).flatten()[:total_elems] |
|
|
q=(u4.to(torch.int16)-8).to(torch.int8); return q.to(device) if device is not None else q |
|
|
def quantize_per_outchannel_int4(weight): |
|
|
assert weight.dim()==2 |
|
|
w=weight.detach().to(torch.float32) |
|
|
max_abs=w.abs().amax(1,keepdim=True).clamp(min=1e-8) |
|
|
scale=(max_abs/7.0); q=torch.round(w/scale).clamp_(-8,7).to(torch.int8) |
|
|
packed=_pack_int4(q.flatten()); return packed, scale.squeeze(1).to(torch.float32), w.shape |
|
|
class Int4Linear(nn.Module): |
|
|
def __init__(self,in_features,out_features,bias=True,device=None,dtype=None): |
|
|
super().__init__() |
|
|
self.in_features=in_features; self.out_features=out_features |
|
|
self.bias=nn.Parameter(torch.zeros(out_features,device=device,dtype=dtype)) if bias else None |
|
|
self.register_buffer("packed_weight",torch.empty(0,dtype=torch.uint8),persistent=True) |
|
|
self.register_buffer("scales",torch.empty(out_features,dtype=torch.float32),persistent=True) |
|
|
self.register_buffer("orig_in_features",torch.tensor(in_features,dtype=torch.int32),persistent=True) |
|
|
@staticmethod |
|
|
def from_linear(m: nn.Linear): |
|
|
q=Int4Linear(m.in_features,m.out_features,bias=(m.bias is not None),device=m.weight.device,dtype=m.weight.dtype) |
|
|
packed,scales,shape=quantize_per_outchannel_int4(m.weight) |
|
|
q.packed_weight=packed; q.scales=scales; q.orig_in_features=torch.tensor(shape[1],dtype=torch.int32,device=m.weight.device) |
|
|
if m.bias is not None: q.bias=nn.Parameter(m.bias.detach().to(m.weight.dtype)) |
|
|
return q |
|
|
def forward(self,x): |
|
|
total=int(self.out_features*int(self.orig_in_features.item())) |
|
|
q=_unpack_int4(self.packed_weight,total,device=x.device) |
|
|
w_q=q.to(torch.float32).view(self.out_features,-1) |
|
|
w=(w_q*self.scales.to(w_q.dtype).unsqueeze(1)).to(x.dtype) |
|
|
return F.linear(x,w,self.bias) |
|
|
def quantize_model_to_int4(model,name_exclude_patterns=()): |
|
|
def ex(n): return any(p in n for p in name_exclude_patterns) |
|
|
rep=0 |
|
|
for name,mod in list(model.named_modules()): |
|
|
for cn,ch in list(mod.named_children()): |
|
|
full=f"{name}.{cn}" if name else cn |
|
|
if isinstance(ch,nn.Linear) and not ex(full): |
|
|
setattr(mod,cn,Int4Linear.from_linear(ch)); rep+=1 |
|
|
print(f"[INT4] Replaced {rep} Linear layers with Int4Linear."); return model |
|
|
|