| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from transformers.configuration_utils import PretrainedConfig |
|
|
| from .moe_lm import AriaMoELMConfig |
| from .vision_encoder import AriaVisionConfig |
|
|
|
|
| |
| class AriaConfig(PretrainedConfig): |
| """ |
| Configuration class for Aria model. |
| |
| This class handles the configuration for both vision and text components of the Aria model, |
| as well as additional parameters for image token handling and projector mapping. |
| |
| Args: |
| vision_config (AriaVisionConfig or dict): Configuration for the vision component. |
| text_config (AriaMoELMConfig or dict): Configuration for the text component. |
| projector_patch_to_query_dict (dict): Mapping of patch sizes to query dimensions. |
| ignore_index (int): Index to ignore in loss calculation. |
| image_token_index (int): Index used to represent image tokens. |
| **kwargs: Additional keyword arguments passed to the parent class. |
| |
| Attributes: |
| model_type (str): Type of the model, set to "aria". |
| is_composition (bool): Whether the model is a composition of multiple components. |
| ignore_index (int): Index to ignore in loss calculation. |
| image_token_index (int): Index used to represent image tokens. |
| projector_patch_to_query_dict (dict): Mapping of patch sizes to query dimensions. |
| vision_config (AriaVisionConfig): Configuration for the vision component. |
| text_config (AriaMoELMConfig): Configuration for the text component. |
| """ |
|
|
| model_type = "aria" |
| is_composition = False |
|
|
| def __init__( |
| self, |
| vision_config=AriaVisionConfig(), |
| text_config=AriaMoELMConfig(), |
| projector_patch_to_query_dict={ |
| 1225: 128, |
| 4900: 256, |
| }, |
| ignore_index=-100, |
| image_token_index=32000, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.ignore_index = ignore_index |
| self.image_token_index = image_token_index |
|
|
| |
| |
| self.projector_patch_to_query_dict = { |
| int(k): int(v) for k, v in projector_patch_to_query_dict.items() |
| } |
|
|
| if isinstance(vision_config, dict) and "model_type" in vision_config: |
| vision_config = AriaVisionConfig(**vision_config) |
|
|
| self.vision_config = vision_config |
|
|
| if isinstance(text_config, dict) and "model_type" in text_config: |
| text_config = AriaMoELMConfig(**text_config) |
|
|
| self.text_config = text_config |
|
|