Try to subclass PretrainedModel
Browse files- modeling_bert.py +2 -29
modeling_bert.py
CHANGED
|
@@ -22,7 +22,7 @@ import torch
|
|
| 22 |
import torch.nn as nn
|
| 23 |
import torch.nn.functional as F
|
| 24 |
from einops import rearrange
|
| 25 |
-
from transformers import
|
| 26 |
from .configuration_bert import JinaBertConfig
|
| 27 |
from transformers.models.bert.modeling_bert import (
|
| 28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
@@ -295,7 +295,7 @@ class BertPreTrainingHeads(nn.Module):
|
|
| 295 |
return prediction_scores, seq_relationship_score
|
| 296 |
|
| 297 |
|
| 298 |
-
class BertPreTrainedModel(
|
| 299 |
"""An abstract class to handle weights initialization and
|
| 300 |
a simple interface for dowloading and loading pretrained models.
|
| 301 |
"""
|
|
@@ -310,33 +310,6 @@ class BertPreTrainedModel(nn.Module):
|
|
| 310 |
)
|
| 311 |
self.config = config
|
| 312 |
|
| 313 |
-
@classmethod
|
| 314 |
-
def from_pretrained(cls, model_name, config, *inputs, **kwargs):
|
| 315 |
-
"""
|
| 316 |
-
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
| 317 |
-
Download and cache the pre-trained model file if needed.
|
| 318 |
-
|
| 319 |
-
Params:
|
| 320 |
-
pretrained_model_name_or_path: either:
|
| 321 |
-
- a path or url to a pretrained model archive containing:
|
| 322 |
-
. `bert_config.json` a configuration file for the model
|
| 323 |
-
. `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
|
| 324 |
-
- a path or url to a pretrained model archive containing:
|
| 325 |
-
. `bert_config.json` a configuration file for the model
|
| 326 |
-
. `model.chkpt` a TensorFlow checkpoint
|
| 327 |
-
*inputs, **kwargs: additional input for the specific Bert class
|
| 328 |
-
(ex: num_labels for BertForSequenceClassification)
|
| 329 |
-
"""
|
| 330 |
-
# Instantiate model.
|
| 331 |
-
model = cls(config, *inputs, **kwargs)
|
| 332 |
-
load_return = model.load_state_dict(state_dict_from_pretrained(model_name), strict=True)
|
| 333 |
-
logger.info(load_return)
|
| 334 |
-
return model
|
| 335 |
-
|
| 336 |
-
@classmethod
|
| 337 |
-
def _from_config(cls, config, **kwargs):
|
| 338 |
-
return cls(config, **kwargs)
|
| 339 |
-
|
| 340 |
|
| 341 |
class BertModel(BertPreTrainedModel):
|
| 342 |
def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
|
|
|
|
| 22 |
import torch.nn as nn
|
| 23 |
import torch.nn.functional as F
|
| 24 |
from einops import rearrange
|
| 25 |
+
from transformers import PretrainedModel
|
| 26 |
from .configuration_bert import JinaBertConfig
|
| 27 |
from transformers.models.bert.modeling_bert import (
|
| 28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
|
|
| 295 |
return prediction_scores, seq_relationship_score
|
| 296 |
|
| 297 |
|
| 298 |
+
class BertPreTrainedModel(PretrainedModel):
|
| 299 |
"""An abstract class to handle weights initialization and
|
| 300 |
a simple interface for dowloading and loading pretrained models.
|
| 301 |
"""
|
|
|
|
| 310 |
)
|
| 311 |
self.config = config
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
class BertModel(BertPreTrainedModel):
|
| 315 |
def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
|