import torch, torch.nn as nn, torch.nn.functional as F def _pack_int4(q): q_u4=(q+8).to(torch.uint8); even=q_u4[::2]; odd=q_u4[1::2]; p=(even<<4)|odd if q_u4.numel()%2==1: p=torch.cat([p,(q_u4[-1]<<4)]) return p.contiguous() def _unpack_int4(packed,total_elems,device=None): hi=(packed>>4)&0x0F; lo=packed&0x0F; u4=torch.stack([hi,lo],-1).flatten()[:total_elems] q=(u4.to(torch.int16)-8).to(torch.int8); return q.to(device) if device is not None else q def quantize_per_outchannel_int4(weight): assert weight.dim()==2 w=weight.detach().to(torch.float32) max_abs=w.abs().amax(1,keepdim=True).clamp(min=1e-8) scale=(max_abs/7.0); q=torch.round(w/scale).clamp_(-8,7).to(torch.int8) packed=_pack_int4(q.flatten()); return packed, scale.squeeze(1).to(torch.float32), w.shape class Int4Linear(nn.Module): def __init__(self,in_features,out_features,bias=True,device=None,dtype=None): super().__init__() self.in_features=in_features; self.out_features=out_features self.bias=nn.Parameter(torch.zeros(out_features,device=device,dtype=dtype)) if bias else None self.register_buffer("packed_weight",torch.empty(0,dtype=torch.uint8),persistent=True) self.register_buffer("scales",torch.empty(out_features,dtype=torch.float32),persistent=True) self.register_buffer("orig_in_features",torch.tensor(in_features,dtype=torch.int32),persistent=True) @staticmethod def from_linear(m: nn.Linear): q=Int4Linear(m.in_features,m.out_features,bias=(m.bias is not None),device=m.weight.device,dtype=m.weight.dtype) packed,scales,shape=quantize_per_outchannel_int4(m.weight) q.packed_weight=packed; q.scales=scales; q.orig_in_features=torch.tensor(shape[1],dtype=torch.int32,device=m.weight.device) if m.bias is not None: q.bias=nn.Parameter(m.bias.detach().to(m.weight.dtype)) return q def forward(self,x): total=int(self.out_features*int(self.orig_in_features.item())) q=_unpack_int4(self.packed_weight,total,device=x.device) w_q=q.to(torch.float32).view(self.out_features,-1) w=(w_q*self.scales.to(w_q.dtype).unsqueeze(1)).to(x.dtype) return F.linear(x,w,self.bias) def quantize_model_to_int4(model,name_exclude_patterns=()): def ex(n): return any(p in n for p in name_exclude_patterns) rep=0 for name,mod in list(model.named_modules()): for cn,ch in list(mod.named_children()): full=f"{name}.{cn}" if name else cn if isinstance(ch,nn.Linear) and not ex(full): setattr(mod,cn,Int4Linear.from_linear(ch)); rep+=1 print(f"[INT4] Replaced {rep} Linear layers with Int4Linear."); return model