feat: fixed _from_config
Browse files- modeling_bert.py +4 -4
modeling_bert.py
CHANGED
|
@@ -335,6 +335,10 @@ class BertPreTrainedModel(nn.Module):
|
|
| 335 |
logger.info(load_return)
|
| 336 |
return model
|
| 337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
class BertModel(BertPreTrainedModel):
|
| 340 |
def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
|
|
@@ -523,10 +527,6 @@ class BertForPreTraining(BertPreTrainedModel):
|
|
| 523 |
seq_relationship_logits=seq_relationship_score,
|
| 524 |
)
|
| 525 |
|
| 526 |
-
@classmethod
|
| 527 |
-
def _from_config(cls, config, **kwargs):
|
| 528 |
-
pass
|
| 529 |
-
|
| 530 |
|
| 531 |
def remap_state_dict(state_dict, config: PretrainedConfig):
|
| 532 |
"""
|
|
|
|
| 335 |
logger.info(load_return)
|
| 336 |
return model
|
| 337 |
|
| 338 |
+
@classmethod
|
| 339 |
+
def _from_config(cls, config, **kwargs):
|
| 340 |
+
return cls(config, **kwargs)
|
| 341 |
+
|
| 342 |
|
| 343 |
class BertModel(BertPreTrainedModel):
|
| 344 |
def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
|
|
|
|
| 527 |
seq_relationship_logits=seq_relationship_score,
|
| 528 |
)
|
| 529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
def remap_state_dict(state_dict, config: PretrainedConfig):
|
| 532 |
"""
|