feat: added option for QK normalization
Browse files- configuration_bert.py +2 -0
- modeling_bert.py +5 -3
configuration_bert.py
CHANGED
|
@@ -83,6 +83,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
| 83 |
pad_vocab_size_multiple=1,
|
| 84 |
num_tasks=0,
|
| 85 |
use_flash_attn=True,
|
|
|
|
| 86 |
**kwargs,
|
| 87 |
):
|
| 88 |
assert 'position_embedding_type' not in kwargs
|
|
@@ -110,3 +111,4 @@ class JinaBertConfig(PretrainedConfig):
|
|
| 110 |
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
| 111 |
self.num_tasks = num_tasks
|
| 112 |
self.use_flash_attn = use_flash_attn
|
|
|
|
|
|
| 83 |
pad_vocab_size_multiple=1,
|
| 84 |
num_tasks=0,
|
| 85 |
use_flash_attn=True,
|
| 86 |
+
use_qk_norm=True,
|
| 87 |
**kwargs,
|
| 88 |
):
|
| 89 |
assert 'position_embedding_type' not in kwargs
|
|
|
|
| 111 |
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
| 112 |
self.num_tasks = num_tasks
|
| 113 |
self.use_flash_attn = use_flash_attn
|
| 114 |
+
self.use_qk_norm = use_qk_norm
|
modeling_bert.py
CHANGED
|
@@ -59,9 +59,10 @@ logger = logging.getLogger(__name__)
|
|
| 59 |
|
| 60 |
|
| 61 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
| 62 |
-
use_flash_attn =
|
| 63 |
-
|
| 64 |
-
|
|
|
|
| 65 |
mixer_cls = partial(
|
| 66 |
MHA,
|
| 67 |
num_heads=config.num_attention_heads,
|
|
@@ -73,6 +74,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
| 73 |
return_residual=return_residual,
|
| 74 |
use_alibi=True,
|
| 75 |
window_size=window_size,
|
|
|
|
| 76 |
)
|
| 77 |
return mixer_cls
|
| 78 |
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
| 62 |
+
use_flash_attn = config.use_flash_attn
|
| 63 |
+
use_qk_norm = config.use_qk_norm
|
| 64 |
+
fused_bias_fc = config.fused_bias_fc
|
| 65 |
+
window_size = config.window_size
|
| 66 |
mixer_cls = partial(
|
| 67 |
MHA,
|
| 68 |
num_heads=config.num_attention_heads,
|
|
|
|
| 74 |
return_residual=return_residual,
|
| 75 |
use_alibi=True,
|
| 76 |
window_size=window_size,
|
| 77 |
+
qk_norm=use_qk_norm
|
| 78 |
)
|
| 79 |
return mixer_cls
|
| 80 |
|