import torch import numpy as np import torch.nn as nn from functools import partial from einops import rearrange from typing import Callable, Optional from dataclasses import dataclass, field, is_dataclass from transformers import PreTrainedModel from .configuration_fisher import FISHERConfig from .base import ( D2vModalityConfig, ModalitySpecificEncoder, ) from .modules import AltBlock from .images import ( D2vImageConfig, ImageEncoder, ) @dataclass class D2vModalitiesConfig: image: D2vImageConfig = field(default_factory=lambda *args: D2vImageConfig()) @dataclass class Data2VecMultiConfig: depth: int = 12 # band split band_width: int = 100 # standard vision Transformer start_drop_path_rate: float = 0.0 end_drop_path_rate: float = 0.0 num_heads: int = 12 norm_eps: float = 1e-6 norm_affine: bool = True encoder_dropout: float = 0.0 post_mlp_drop: float = 0.0 attention_dropout: float = 0.0 activation_dropout: float = 0.0 dropout_input: float = 0.0 layerdrop: float = 0.0 embed_dim: int = 768 mlp_ratio: float = 4.0 layer_norm_first: bool = False end_of_block_targets: bool = False # clone batch for multi-mask strategy max_band_per_sample: int = 64 # normalization for teacher Transformer layer output layer_norm_target_layer: bool = False batch_norm_target_layer: bool = False instance_norm_target_layer: bool = True instance_norm_targets: bool = False layer_norm_targets: bool = True modalities: D2vModalitiesConfig = field(default_factory=lambda *args: D2vModalitiesConfig()) def update_dataclass(instance, data_dict): if not data_dict: return instance for field_name, field_value in data_dict.items(): if hasattr(instance, field_name): current_value = getattr(instance, field_name) if is_dataclass(current_value) and isinstance(field_value, dict): update_dataclass(current_value, field_value) else: setattr(instance, field_name, field_value) return instance class FISHER(nn.Module): def __init__(self, config: FISHERConfig): super().__init__() cfg = Data2VecMultiConfig() update_dataclass(cfg, config.to_dict()) cfg.modalities.image.embed_dim = cfg.embed_dim cfg.modalities.image.embed_dim = cfg.embed_dim self.cfg = cfg make_layer_norm = partial( nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine ) def make_block(drop_path, dim=None, heads=None): return AltBlock( cfg.embed_dim if dim is None else dim, cfg.num_heads if heads is None else heads, cfg.mlp_ratio, qkv_bias=True, drop=0.0, attn_drop=cfg.attention_dropout, mlp_drop=cfg.activation_dropout, post_mlp_drop=cfg.post_mlp_drop, drop_path=drop_path, norm_layer=make_layer_norm, layer_norm_first=cfg.layer_norm_first, ffn_targets=not cfg.end_of_block_targets, ) self.alibi_biases = {} self.modality_encoders = nn.ModuleDict() mod_cfg = getattr(cfg.modalities, 'image') enc = self.make_modality_encoder( mod_cfg, cfg.embed_dim, make_block, make_layer_norm, cfg.layer_norm_first, self.alibi_biases, ) self.modality_encoders['IMAGE'] = enc dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth) self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)]) self.norm = None if cfg.layer_norm_first: self.norm = make_layer_norm(cfg.embed_dim) # band split self.band_width = cfg.band_width self.patch_size = cfg.modalities.image.patch_size def make_modality_encoder( self, cfg: D2vModalityConfig, embed_dim: int, make_block: Callable[[float], nn.ModuleList], norm_layer: Callable[[int], nn.LayerNorm], layer_norm_first: bool, alibi_biases, task=None, ) -> ModalitySpecificEncoder: return ImageEncoder( cfg, embed_dim, make_block, norm_layer, layer_norm_first, alibi_biases, task, ) def forward( self, source: torch.Tensor, target=None, id=None, mode='IMAGE', padding_mask: Optional[torch.Tensor] = None, mask: bool = True, features_only: bool = False, force_remove_masked=False, precomputed_mask: Optional[torch.Tensor] = None, ): # band split num_band = source.shape[-1] // self.band_width source = torch.stack(source.split(self.band_width, dim=-1)[:num_band]) # drop residual source = rearrange(source, 'nb B c t f -> (B nb) c t f') clone_batch = self.cfg.max_band_per_sample // num_band feature_extractor = self.modality_encoders[mode] # models.images.ImageEncoder # extract (unmasked) features using CNN encoder extractor_out = feature_extractor( source, padding_mask, mask, remove_masked=not features_only or force_remove_masked, # train: True; infer: False clone_batch=clone_batch if not features_only else 1, mask_seeds=None, precomputed_mask=precomputed_mask, ) # x in shape (batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension)) x = extractor_out["x"] # encoder_mask is applied on sub-band level encoder_mask = extractor_out["encoder_mask"] # models.base.MaskInfo, ["x_unmasked", "mask", "ids_restore", "ids_keep"] masked_padding_mask = extractor_out["padding_mask"] masked_alibi_bias = extractor_out.get("alibi_bias", None) alibi_scale = extractor_out.get("alibi_scale", None) # standard Transformer (for student encoder) layer_results = [] for i, blk in enumerate(self.blocks): ab = masked_alibi_bias if ab is not None and alibi_scale is not None: scale = ( alibi_scale[i] if alibi_scale.size(0) > 1 else alibi_scale.squeeze(0) ) ab = ab * scale.type_as(ab) x, lr = blk( x, padding_mask=masked_padding_mask, alibi_bias=ab, ) if features_only: layer_results.append(lr) if self.norm is not None: x = self.norm(x) # extract features for fine-tuning if features_only: return { "x": x, "padding_mask": masked_padding_mask, "layer_results": layer_results, "mask": encoder_mask, } def extract_features( self, source, mode='IMAGE', padding_mask=None, mask=False ): num_band = source.shape[-1] // self.band_width res = self.forward( source, mode=mode, padding_mask=padding_mask, mask=mask, features_only=True, ) x = res['x'][:, 0] x = rearrange(x, '(B nb) D -> B (nb D)', nb=num_band) return x class FISHERModel(PreTrainedModel): config_class = FISHERConfig def __init__(self, cfg: FISHERConfig): super().__init__(cfg) self.cfg = cfg self.model = FISHER(cfg) def forward(self, *args, **kwargs): return self.model(*args, **kwargs) def extract_features(self, x): return self.model.extract_features(x)