fix: use proper initilization for embedding layer
Browse files- modeling_lora.py +28 -11
modeling_lora.py
CHANGED
|
@@ -11,20 +11,37 @@ from torch.nn import Parameter
|
|
| 11 |
from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
class LoRAParametrization(nn.Module):
|
| 15 |
-
def __init__(self, fan_in, fan_out,
|
| 16 |
super().__init__()
|
| 17 |
# if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
|
| 18 |
# otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
|
|
|
|
| 19 |
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
nn.
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
self.lora_alpha, self.rank = lora_alpha, rank
|
| 29 |
self.scaling = lora_alpha / rank
|
| 30 |
self.lora_dropout = nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
|
|
@@ -55,14 +72,14 @@ class LoRAParametrization(nn.Module):
|
|
| 55 |
def from_linear(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
| 56 |
fan_out, fan_in = layer.weight.shape
|
| 57 |
return cls(
|
| 58 |
-
fan_in, fan_out, num_adaptions=num_adaptions,
|
| 59 |
)
|
| 60 |
|
| 61 |
@classmethod
|
| 62 |
def from_embedding(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
| 63 |
fan_in, fan_out = layer.weight.shape
|
| 64 |
return cls(
|
| 65 |
-
fan_in, fan_out, num_adaptions=num_adaptions,
|
| 66 |
)
|
| 67 |
|
| 68 |
@classmethod
|
|
|
|
| 11 |
from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
|
| 12 |
|
| 13 |
|
| 14 |
+
def initialized_weights(shape, num_adaptions, init='kaiming'):
|
| 15 |
+
weight_data = []
|
| 16 |
+
for _ in range(num_adaptions):
|
| 17 |
+
new_adaption = torch.zeros(shape)
|
| 18 |
+
if init == 'kaiming':
|
| 19 |
+
nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
|
| 20 |
+
elif init == 'normal':
|
| 21 |
+
nn.init.normal_(new_adaption)
|
| 22 |
+
else:
|
| 23 |
+
raise NotImplementedError
|
| 24 |
+
weight_data.append(new_adaption)
|
| 25 |
+
return torch.stack(weight_data, dim=0)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
class LoRAParametrization(nn.Module):
|
| 29 |
+
def __init__(self, fan_in, fan_out, layer_type='linear', num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
| 30 |
super().__init__()
|
| 31 |
# if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
|
| 32 |
# otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
|
| 33 |
+
fan_in_fan_out = (layer_type == 'embedding')
|
| 34 |
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
|
| 35 |
+
|
| 36 |
+
if layer_type == 'linear':
|
| 37 |
+
self.lora_A = nn.Parameter(initialized_weights((rank, fan_in), num_adaptions, init='kaiming'))
|
| 38 |
+
self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
|
| 39 |
+
elif layer_type == 'embedding':
|
| 40 |
+
self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
|
| 41 |
+
self.lora_B = nn.Parameter(initialized_weights((rank, fan_out), num_adaptions=num_adaptions, init='normal'))
|
| 42 |
+
else:
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
|
| 45 |
self.lora_alpha, self.rank = lora_alpha, rank
|
| 46 |
self.scaling = lora_alpha / rank
|
| 47 |
self.lora_dropout = nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
|
|
|
|
| 72 |
def from_linear(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
| 73 |
fan_out, fan_in = layer.weight.shape
|
| 74 |
return cls(
|
| 75 |
+
fan_in, fan_out, num_adaptions=num_adaptions, layer_type='linear', rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
|
| 76 |
)
|
| 77 |
|
| 78 |
@classmethod
|
| 79 |
def from_embedding(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
| 80 |
fan_in, fan_out = layer.weight.shape
|
| 81 |
return cls(
|
| 82 |
+
fan_in, fan_out, num_adaptions=num_adaptions, layer_type='embedding', rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
|
| 83 |
)
|
| 84 |
|
| 85 |
@classmethod
|