Dhruvil03's picture
Upload INT4 quantized Perception-LM-1B
5db39a7 verified
import torch, torch.nn as nn, torch.nn.functional as F
def _pack_int4(q):
q_u4=(q+8).to(torch.uint8); even=q_u4[::2]; odd=q_u4[1::2]; p=(even<<4)|odd
if q_u4.numel()%2==1: p=torch.cat([p,(q_u4[-1]<<4)])
return p.contiguous()
def _unpack_int4(packed,total_elems,device=None):
hi=(packed>>4)&0x0F; lo=packed&0x0F; u4=torch.stack([hi,lo],-1).flatten()[:total_elems]
q=(u4.to(torch.int16)-8).to(torch.int8); return q.to(device) if device is not None else q
def quantize_per_outchannel_int4(weight):
assert weight.dim()==2
w=weight.detach().to(torch.float32)
max_abs=w.abs().amax(1,keepdim=True).clamp(min=1e-8)
scale=(max_abs/7.0); q=torch.round(w/scale).clamp_(-8,7).to(torch.int8)
packed=_pack_int4(q.flatten()); return packed, scale.squeeze(1).to(torch.float32), w.shape
class Int4Linear(nn.Module):
def __init__(self,in_features,out_features,bias=True,device=None,dtype=None):
super().__init__()
self.in_features=in_features; self.out_features=out_features
self.bias=nn.Parameter(torch.zeros(out_features,device=device,dtype=dtype)) if bias else None
self.register_buffer("packed_weight",torch.empty(0,dtype=torch.uint8),persistent=True)
self.register_buffer("scales",torch.empty(out_features,dtype=torch.float32),persistent=True)
self.register_buffer("orig_in_features",torch.tensor(in_features,dtype=torch.int32),persistent=True)
@staticmethod
def from_linear(m: nn.Linear):
q=Int4Linear(m.in_features,m.out_features,bias=(m.bias is not None),device=m.weight.device,dtype=m.weight.dtype)
packed,scales,shape=quantize_per_outchannel_int4(m.weight)
q.packed_weight=packed; q.scales=scales; q.orig_in_features=torch.tensor(shape[1],dtype=torch.int32,device=m.weight.device)
if m.bias is not None: q.bias=nn.Parameter(m.bias.detach().to(m.weight.dtype))
return q
def forward(self,x):
total=int(self.out_features*int(self.orig_in_features.item()))
q=_unpack_int4(self.packed_weight,total,device=x.device)
w_q=q.to(torch.float32).view(self.out_features,-1)
w=(w_q*self.scales.to(w_q.dtype).unsqueeze(1)).to(x.dtype)
return F.linear(x,w,self.bias)
def quantize_model_to_int4(model,name_exclude_patterns=()):
def ex(n): return any(p in n for p in name_exclude_patterns)
rep=0
for name,mod in list(model.named_modules()):
for cn,ch in list(mod.named_children()):
full=f"{name}.{cn}" if name else cn
if isinstance(ch,nn.Linear) and not ex(full):
setattr(mod,cn,Int4Linear.from_linear(ch)); rep+=1
print(f"[INT4] Replaced {rep} Linear layers with Int4Linear."); return model