| from transformers import PretrainedConfig | |
| class FISHERConfig(PretrainedConfig): | |
| model_type = "fisher" | |
| def __init__( | |
| self, | |
| band_width=100, | |
| embed_dim=192, | |
| num_heads=3, | |
| max_band_per_sample=64, | |
| depth=12, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.band_width = band_width | |
| self.embed_dim = embed_dim | |
| self.depth = depth | |
| self.num_heads = num_heads | |
| self.max_band_per_sample = max_band_per_sample | |