Update modeling_bitllama.py
Browse files- modeling_bitllama.py +5 -1
modeling_bitllama.py
CHANGED
|
@@ -253,9 +253,13 @@ def weight_quant(w):
|
|
| 253 |
|
| 254 |
|
| 255 |
class BitLinear(nn.Linear):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
def forward(self, x):
|
| 257 |
w = self.weight
|
| 258 |
-
x_norm =
|
| 259 |
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
|
| 260 |
w_quant = w + (weight_quant(w) - w).detach()
|
| 261 |
return F.linear(x_quant, w_quant)
|
|
|
|
| 253 |
|
| 254 |
|
| 255 |
class BitLinear(nn.Linear):
|
| 256 |
+
def __init__(self, in_features, out_features, bias=True):
|
| 257 |
+
super().__init__(in_features, out_features, bias=bias)
|
| 258 |
+
self.norm = LlamaRMSNorm(in_features)
|
| 259 |
+
|
| 260 |
def forward(self, x):
|
| 261 |
w = self.weight
|
| 262 |
+
x_norm = self.norm(x)
|
| 263 |
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
|
| 264 |
w_quant = w + (weight_quant(w) - w).detach()
|
| 265 |
return F.linear(x_quant, w_quant)
|