Try to subclass PretrainedModel
Browse files- modeling_bert.py +1 -2
modeling_bert.py
CHANGED
|
@@ -22,7 +22,7 @@ import torch
|
|
| 22 |
import torch.nn as nn
|
| 23 |
import torch.nn.functional as F
|
| 24 |
from einops import rearrange
|
| 25 |
-
from transformers import PretrainedModel
|
| 26 |
from .configuration_bert import JinaBertConfig
|
| 27 |
from transformers.models.bert.modeling_bert import (
|
| 28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
@@ -39,7 +39,6 @@ from flash_attn.modules.block import Block
|
|
| 39 |
from flash_attn.modules.embedding import BertEmbeddings
|
| 40 |
from flash_attn.modules.mha import MHA
|
| 41 |
from flash_attn.modules.mlp import FusedMLP, Mlp
|
| 42 |
-
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
| 43 |
|
| 44 |
try:
|
| 45 |
from flash_attn.ops.fused_dense import FusedDense
|
|
|
|
| 22 |
import torch.nn as nn
|
| 23 |
import torch.nn.functional as F
|
| 24 |
from einops import rearrange
|
| 25 |
+
from transformers.modeling_utils import PretrainedModel
|
| 26 |
from .configuration_bert import JinaBertConfig
|
| 27 |
from transformers.models.bert.modeling_bert import (
|
| 28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
|
|
| 39 |
from flash_attn.modules.embedding import BertEmbeddings
|
| 40 |
from flash_attn.modules.mha import MHA
|
| 41 |
from flash_attn.modules.mlp import FusedMLP, Mlp
|
|
|
|
| 42 |
|
| 43 |
try:
|
| 44 |
from flash_attn.ops.fused_dense import FusedDense
|