import torch from torch import nn import quiptools_cuda from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda def get_grid(): hintr = torch.arange(-8, 8) + 1 / 2 return hintr.unsqueeze(-1) _HI4B1C_CACHED = get_grid() _HI4B1C_NORM_CACHED = torch.diag(_HI4B1C_CACHED @ _HI4B1C_CACHED.T) class HI4B1C_codebook(nn.Module): def __init__(self, inference=False): super(HI4B1C_codebook, self).__init__() self.opt_scale = 2.97 self.codesz = 1 self.idx_dtype = torch.int32 self.packsz = 8 self.pack_out = False self.version = 0 self.register_buffer('grid', _HI4B1C_CACHED) if not inference: self.register_buffer('grid_norm', _HI4B1C_NORM_CACHED) ''' self.cuda() samples = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(1), torch.eye(1)).rsample([200000]).cuda() print(samples.shape) def fn_s(s): err = (self.quantize(samples*s, False)/s - samples).float().norm()**2 err = err.cpu() / torch.numel(samples) return err.cpu() import scipy print(scipy.optimize.minimize_scalar(fn_s, bounds=(0.1, 100))) exit() ''' def round(self, X, grid, grid_norm): assert X.shape[-1] == self.codesz Xqidx = (2 * X @ grid.T - grid_norm).argmax(-1) return grid[Xqidx], Xqidx def quantize(self, X, return_idx=True): vals, idx = self.round(X, self.grid, self.grid_norm) if not return_idx: return vals return vals, idx.to(self.idx_dtype) def maybe_pack_idxs(self, idxs): return \ (idxs[:, 0::self.packsz] << 4*7) + \ (idxs[:, 2::self.packsz] << 4*6) + \ (idxs[:, 4::self.packsz] << 4*5) + \ (idxs[:, 6::self.packsz] << 4*4) + \ (idxs[:, 1::self.packsz] << 4*3) + \ (idxs[:, 3::self.packsz] << 4*2) + \ (idxs[:, 5::self.packsz] << 4*1) + \ idxs[:, 7::self.packsz] def by_idxs(self, idxs, packed=False): if packed: idxs = idxs.repeat_interleave(self.packsz, dim=-1) idxs[:, 0::self.packsz] = (idxs[:, 0::self.packsz] >> 28) & 15 idxs[:, 2::self.packsz] = (idxs[:, 2::self.packsz] >> 24) & 15 idxs[:, 4::self.packsz] = (idxs[:, 4::self.packsz] >> 20) & 15 idxs[:, 6::self.packsz] = (idxs[:, 6::self.packsz] >> 16) & 15 idxs[:, 1::self.packsz] = (idxs[:, 1::self.packsz] >> 12) & 15 idxs[:, 3::self.packsz] = (idxs[:, 3::self.packsz] >> 8) & 15 idxs[:, 5::self.packsz] = (idxs[:, 5::self.packsz] >> 4) & 15 idxs[:, 7::self.packsz] = idxs[:, 7::self.packsz] & 15 return self.grid[idxs.int()] class QuantizedHI4B1CLinear(nn.Module): def __init__(self, device): super().__init__() self.codebook = HI4B1C_codebook(inference=True).to(torch.float16).to(device) def forward(self, input, Qidxs, SU, SV, Wscale, had_left, had_right, K_left, K_right, rank=-1, A=None, B=None, rescale_WH=False, scaleWH=None, packed=False): n, m = len(SU), len(SV) x = input.view(-1, n).to(torch.float32) if rescale_WH: x /= scaleWH x = x * SU x = matmul_hadUt_cuda(x, had_left, K_left) if rank > 0: Bx = x @ B.t().to(torch.float32) ABx = Bx @ A.t().to(torch.float32) num_scale = 1024 x = x / num_scale x = x.to(torch.float16) if packed: W_decompressed = torch.zeros(m, n, dtype=torch.float16, device=x.device) quiptools_cuda.decompress_hi4b1c_packed(Qidxs, self.codebook.grid, W_decompressed) else: W_decompressed = self.codebook.by_idxs(Qidxs, packed=False).reshape(-1, n) z = x @ W_decompressed.t() x = z.to(torch.float32) x = x * (Wscale * num_scale) if rank > 0: x = x + ABx.to(torch.float32) x = matmul_hadU_cuda(x, had_right, K_right) x = x * SV output = x.view(*input.shape[:-1], m) return output