feat: choose flash attention heuristically if not set explicitly
Browse files- modeling_bert.py +2 -2
modeling_bert.py
CHANGED
|
@@ -66,7 +66,7 @@ logger = logging.getLogger(__name__)
|
|
| 66 |
|
| 67 |
|
| 68 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
| 69 |
-
use_flash_attn = config.use_flash_attn
|
| 70 |
use_qk_norm = config.use_qk_norm
|
| 71 |
fused_bias_fc = config.fused_bias_fc
|
| 72 |
window_size = config.window_size
|
|
@@ -161,7 +161,7 @@ def _init_weights(module, initializer_range=0.02):
|
|
| 161 |
class BertEncoder(nn.Module):
|
| 162 |
def __init__(self, config: JinaBertConfig):
|
| 163 |
super().__init__()
|
| 164 |
-
self.use_flash_attn =
|
| 165 |
self.layers = nn.ModuleList(
|
| 166 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 167 |
)
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
| 69 |
+
use_flash_attn = config.use_flash_attn if config.use_flash_attn is not None else torch.cuda.is_available()
|
| 70 |
use_qk_norm = config.use_qk_norm
|
| 71 |
fused_bias_fc = config.fused_bias_fc
|
| 72 |
window_size = config.window_size
|
|
|
|
| 161 |
class BertEncoder(nn.Module):
|
| 162 |
def __init__(self, config: JinaBertConfig):
|
| 163 |
super().__init__()
|
| 164 |
+
self.use_flash_attn = config.use_flash_attn if config.use_flash_attn is not None else torch.cuda.is_available()
|
| 165 |
self.layers = nn.ModuleList(
|
| 166 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 167 |
)
|