feat: added top-level docstring, made it compatible with AutoModel
Browse files- modeling_bert.py +10 -5
modeling_bert.py
CHANGED
|
@@ -1,3 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Copyright (c) 2022, Tri Dao.
|
| 2 |
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
| 3 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
|
@@ -297,12 +304,10 @@ class BertPreTrainedModel(nn.Module):
|
|
| 297 |
|
| 298 |
def __init__(self, config, *inputs, **kwargs):
|
| 299 |
super().__init__()
|
| 300 |
-
if not
|
| 301 |
raise ValueError(
|
| 302 |
-
"Parameter config in `{}(config)` should be an instance of class `
|
| 303 |
-
|
| 304 |
-
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
| 305 |
-
self.__class__.__name__, self.__class__.__name__
|
| 306 |
)
|
| 307 |
)
|
| 308 |
self.config = config
|
|
|
|
| 1 |
+
""" Implementation of BERT, using ALiBi and Flash Attention
|
| 2 |
+
|
| 3 |
+
The implementation was adopted from
|
| 4 |
+
https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0/flash_attn/models/bert.py
|
| 5 |
+
and made modifications to use ALiBi.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
# Copyright (c) 2022, Tri Dao.
|
| 9 |
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
| 10 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
|
|
|
| 304 |
|
| 305 |
def __init__(self, config, *inputs, **kwargs):
|
| 306 |
super().__init__()
|
| 307 |
+
if not config.__class__.__name__ == 'JinaBertConfig':
|
| 308 |
raise ValueError(
|
| 309 |
+
"Parameter config in `{}(config)` should be an instance of class `JinaBertConfig`.".format(
|
| 310 |
+
self.__class__.__name__,
|
|
|
|
|
|
|
| 311 |
)
|
| 312 |
)
|
| 313 |
self.config = config
|