File size: 531 Bytes
8960e0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from transformers import PretrainedConfig


class FISHERConfig(PretrainedConfig):
    model_type = "fisher"

    def __init__(

        self,

        band_width=100,

        embed_dim=192,

        num_heads=3,

        max_band_per_sample=64,

        depth=12,

        **kwargs,

    ):
        super().__init__(**kwargs)

        self.band_width = band_width
        self.embed_dim = embed_dim
        self.depth = depth
        self.num_heads = num_heads
        self.max_band_per_sample = max_band_per_sample