from transformers import PretrainedConfig class FISHERConfig(PretrainedConfig): model_type = "fisher" def __init__( self, band_width=100, embed_dim=192, num_heads=3, max_band_per_sample=64, depth=12, **kwargs, ): super().__init__(**kwargs) self.band_width = band_width self.embed_dim = embed_dim self.depth = depth self.num_heads = num_heads self.max_band_per_sample = max_band_per_sample