File size: 531 Bytes
8960e0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from transformers import PretrainedConfig
class FISHERConfig(PretrainedConfig):
model_type = "fisher"
def __init__(
self,
band_width=100,
embed_dim=192,
num_heads=3,
max_band_per_sample=64,
depth=12,
**kwargs,
):
super().__init__(**kwargs)
self.band_width = band_width
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.max_band_per_sample = max_band_per_sample
|