Upload folder using huggingface_hub
Browse files- README.md +219 -0
- added_tokens.json +7 -0
- config.json +51 -0
- config_molmo.py +60 -0
- generation_config.json +4 -0
- image_preprocessing_molmo.py +546 -0
- merges.txt +0 -0
- model-00001-of-00002.safetensors +3 -0
- model-00002-of-00002.safetensors +3 -0
- model.safetensors.index.json +928 -0
- modeling_molmo.py +2372 -0
- molmo_logo.png +0 -0
- preprocessing_molmo.py +192 -0
- preprocessor_config.json +32 -0
- quantization_config.json +19 -0
- special_tokens_map.json +37 -0
- tokenizer.json +0 -0
- tokenizer_config.json +240 -0
- vocab.json +0 -0
README.md
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
base_model:
|
| 6 |
+
- openai/clip-vit-large-patch14-336
|
| 7 |
+
- allenai/OLMo-7B-1124
|
| 8 |
+
pipeline_tag: image-text-to-text
|
| 9 |
+
tags:
|
| 10 |
+
- multimodal
|
| 11 |
+
- olmo
|
| 12 |
+
- molmo
|
| 13 |
+
- pixmo
|
| 14 |
+
library_name: transformers
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
<img src="molmo_logo.png" alt="Logo for the Molmo Project" style="width: auto; height: 50px;">
|
| 18 |
+
|
| 19 |
+
# Molmo 7B-O
|
| 20 |
+
|
| 21 |
+
Molmo is a family of open vision-language models developed by the Allen Institute for AI.
|
| 22 |
+
Molmo models are trained on PixMo, a dataset of 1 million, highly-curated image-text pairs.
|
| 23 |
+
It has state-of-the-art performance among multimodal models with a similar size while being fully open-source.
|
| 24 |
+
You can find all models in the Molmo family [here](https://huggingface.co/collections/allenai/molmo-66f379e6fe3b8ef090a8ca19).
|
| 25 |
+
**Learn more** about the Molmo family [in our announcement blog post](https://molmo.allenai.org/blog) or the [paper](https://huggingface.co/papers/2409.17146).
|
| 26 |
+
|
| 27 |
+
Molmo 7B-O is based on [OLMo-7B-1024](https://huggingface.co/allenai/OLMo-7B-1024-preview) (a **preview** of next generation of OLMo models)
|
| 28 |
+
and uses [OpenAI CLIP](https://huggingface.co/openai/clip-vit-large-patch14-336) as vision backbone.
|
| 29 |
+
It performs comfortably between GPT-4V and GPT-4o on both academic benchmarks and human evaluation.
|
| 30 |
+
|
| 31 |
+
This checkpoint is a **preview** of the Molmo release. All artifacts used in creating Molmo (PixMo dataset, training code, evaluations, intermediate checkpoints) will be made available at a later date, furthering our commitment to open-source AI development and reproducibility.
|
| 32 |
+
|
| 33 |
+
[**Sign up here**](https://docs.google.com/forms/d/e/1FAIpQLSdML1MhNNBDsCHpgWG65Oydg2SjZzVasyqlP08nBrWjZp_c7A/viewform) to be the first to know when artifacts are released.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
Quick links:
|
| 37 |
+
- 💬 [Demo](https://molmo.allenai.org/)
|
| 38 |
+
- 📂 [All Models](https://huggingface.co/collections/allenai/molmo-66f379e6fe3b8ef090a8ca19)
|
| 39 |
+
- 📃 [Paper](https://molmo.allenai.org/paper.pdf)
|
| 40 |
+
- 🎥 [Blog with Videos](https://molmo.allenai.org/blog)
|
| 41 |
+
|
| 42 |
+
## Quick Start
|
| 43 |
+
|
| 44 |
+
To run Molmo, first install dependencies:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
pip install einops torchvision
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Then, follow these steps:
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
| 54 |
+
from PIL import Image
|
| 55 |
+
import requests
|
| 56 |
+
|
| 57 |
+
# load the processor
|
| 58 |
+
processor = AutoProcessor.from_pretrained(
|
| 59 |
+
'allenai/Molmo-7B-O-0924',
|
| 60 |
+
trust_remote_code=True,
|
| 61 |
+
torch_dtype='auto',
|
| 62 |
+
device_map='auto'
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# load the model
|
| 66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
+
'allenai/Molmo-7B-O-0924',
|
| 68 |
+
trust_remote_code=True,
|
| 69 |
+
torch_dtype='auto',
|
| 70 |
+
device_map='auto'
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# process the image and text
|
| 74 |
+
inputs = processor.process(
|
| 75 |
+
images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)],
|
| 76 |
+
text="Describe this image."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# move inputs to the correct device and make a batch of size 1
|
| 80 |
+
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
|
| 81 |
+
|
| 82 |
+
# generate output; maximum 200 new tokens; stop generation when <|endoftext|> is generated
|
| 83 |
+
output = model.generate_from_batch(
|
| 84 |
+
inputs,
|
| 85 |
+
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
|
| 86 |
+
tokenizer=processor.tokenizer
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# only get generated tokens; decode them to text
|
| 90 |
+
generated_tokens = output[0,inputs['input_ids'].size(1):]
|
| 91 |
+
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 92 |
+
|
| 93 |
+
# print the generated text
|
| 94 |
+
print(generated_text)
|
| 95 |
+
|
| 96 |
+
# >>> This photograph captures an adorable black Labrador puppy sitting on a weathered
|
| 97 |
+
# wooden deck. The deck's planks, which are a mix of light and dark brown with ...
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
To make inference more efficient, run with autocast:
|
| 101 |
+
|
| 102 |
+
```python
|
| 103 |
+
with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
|
| 104 |
+
output = model.generate_from_batch(
|
| 105 |
+
inputs,
|
| 106 |
+
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
|
| 107 |
+
tokenizer=processor.tokenizer
|
| 108 |
+
)
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
We did most of our evaluations in this setting (autocast on, but float32 weights)
|
| 112 |
+
|
| 113 |
+
To even further reduce the memory requirements, the model can be run with bfloat16 weights:
|
| 114 |
+
|
| 115 |
+
```python
|
| 116 |
+
model.to(dtype=torch.bfloat16)
|
| 117 |
+
inputs["images"] = inputs["images"].to(torch.bfloat16)
|
| 118 |
+
output = model.generate_from_batch(
|
| 119 |
+
inputs,
|
| 120 |
+
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
|
| 121 |
+
tokenizer=processor.tokenizer
|
| 122 |
+
)
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
Note that this can sometimes change the output of the model compared to running with float32 weights.
|
| 126 |
+
|
| 127 |
+
## vLLM
|
| 128 |
+
Molmo is supported in vLLM, however please use version <=0.7.2 until a [prepreprocessing bug](https://github.com/vllm-project/vllm/issues/26451) is fixed.
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
## Evaluations
|
| 132 |
+
|
| 133 |
+
| Model | Average Score on 11 Academic Benchmarks | Human Preference Elo Rating |
|
| 134 |
+
|-----------------------------|-----------------------------------------|-----------------------------|
|
| 135 |
+
| Molmo 72B | 81.2 | 1077 |
|
| 136 |
+
| Molmo 7B-D | 77.3 | 1056 |
|
| 137 |
+
| **Molmo 7B-O (this model)** | **74.6** | **1051** |
|
| 138 |
+
| MolmoE 1B | 68.6 | 1032 |
|
| 139 |
+
| GPT-4o | 78.5 | 1079 |
|
| 140 |
+
| GPT-4V | 71.1 | 1041 |
|
| 141 |
+
| Gemini 1.5 Pro | 78.3 | 1074 |
|
| 142 |
+
| Gemini 1.5 Flash | 75.1 | 1054 |
|
| 143 |
+
| Claude 3.5 Sonnet | 76.7 | 1069 |
|
| 144 |
+
| Claude 3 Opus | 66.4 | 971 |
|
| 145 |
+
| Claude 3 Haiku | 65.3 | 999 |
|
| 146 |
+
| Qwen VL2 72B | 79.4 | 1037 |
|
| 147 |
+
| Qwen VL2 7B | 73.7 | 1025 |
|
| 148 |
+
| Intern VL2 LLAMA 76B | 77.1 | 1018 |
|
| 149 |
+
| Intern VL2 8B | 69.4 | 953 |
|
| 150 |
+
| Pixtral 12B | 69.5 | 1016 |
|
| 151 |
+
| Phi3.5-Vision 4B | 59.7 | 982 |
|
| 152 |
+
| PaliGemma 3B | 50.0 | 937 |
|
| 153 |
+
| LLAVA OneVision 72B | 76.6 | 1051 |
|
| 154 |
+
| LLAVA OneVision 7B | 72.0 | 1024 |
|
| 155 |
+
| Cambrian-1 34B | 66.8 | 953 |
|
| 156 |
+
| Cambrian-1 8B | 63.4 | 952 |
|
| 157 |
+
| xGen - MM - Interleave 4B | 59.5 | 979 |
|
| 158 |
+
| LLAVA-1.5 13B | 43.9 | 960 |
|
| 159 |
+
| LLAVA-1.5 7B | 40.7 | 951 |
|
| 160 |
+
|
| 161 |
+
*Benchmarks: AI2D test, ChartQA test, VQA v2.0 test, DocQA test, InfographicVQA test, TextVQA val, RealWorldQA, MMMU val, MathVista testmini, CountBenchQA, Flickr Count (we collected this new dataset that is significantly harder than CountBenchQA).*
|
| 162 |
+
|
| 163 |
+
## FAQs
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
### I'm getting an error a broadcast error when processing images!
|
| 167 |
+
|
| 168 |
+
Your image might not be in RGB format. You can convert it using the following code snippet:
|
| 169 |
+
|
| 170 |
+
```python
|
| 171 |
+
from PIL import Image
|
| 172 |
+
|
| 173 |
+
image = Image.open(...)
|
| 174 |
+
|
| 175 |
+
if image.mode != "RGB":
|
| 176 |
+
image = image.convert("RGB")
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
### Molmo doesn't work great with transparent images!
|
| 180 |
+
|
| 181 |
+
We received reports that Molmo models might struggle with transparent images.
|
| 182 |
+
For the time being, we recommend adding a white or dark background to your images before passing them to the model. The code snippet below shows how to do this using the Python Imaging Library (PIL):
|
| 183 |
+
|
| 184 |
+
```python
|
| 185 |
+
|
| 186 |
+
# Load the image
|
| 187 |
+
url = "..."
|
| 188 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 189 |
+
|
| 190 |
+
# Convert the image to grayscale to calculate brightness
|
| 191 |
+
gray_image = image.convert('L') # Convert to grayscale
|
| 192 |
+
|
| 193 |
+
# Calculate the average brightness
|
| 194 |
+
stat = ImageStat.Stat(gray_image)
|
| 195 |
+
average_brightness = stat.mean[0] # Get the average value
|
| 196 |
+
|
| 197 |
+
# Define background color based on brightness (threshold can be adjusted)
|
| 198 |
+
bg_color = (0, 0, 0) if average_brightness > 127 else (255, 255, 255)
|
| 199 |
+
|
| 200 |
+
# Create a new image with the same size as the original, filled with the background color
|
| 201 |
+
new_image = Image.new('RGB', image.size, bg_color)
|
| 202 |
+
|
| 203 |
+
# Paste the original image on top of the background (use image as a mask if needed)
|
| 204 |
+
new_image.paste(image, (0, 0), image if image.mode == 'RGBA' else None)
|
| 205 |
+
|
| 206 |
+
# Now you can pass the new_image to Molmo
|
| 207 |
+
processor = AutoProcessor.from_pretrained(
|
| 208 |
+
'allenai/Molmo-7B-D-0924',
|
| 209 |
+
trust_remote_code=True,
|
| 210 |
+
torch_dtype='auto',
|
| 211 |
+
device_map='auto'
|
| 212 |
+
)
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
## License and Use
|
| 217 |
+
|
| 218 |
+
This model is licensed under Apache 2.0. It is intended for research and educational use.
|
| 219 |
+
For more information, please see our [Responsible Use Guidelines](https://allenai.org/responsible-use).
|
added_tokens.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<im_col>": 100281,
|
| 3 |
+
"<im_end>": 100279,
|
| 4 |
+
"<im_patch>": 100280,
|
| 5 |
+
"<im_start>": 100278,
|
| 6 |
+
"<|image|>": 100282
|
| 7 |
+
}
|
config.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MolmoForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_layer_norm": true,
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "config_molmo.MolmoConfig",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_molmo.MolmoForCausalLM"
|
| 9 |
+
},
|
| 10 |
+
"clip_qkv": null,
|
| 11 |
+
"dtype": "float32",
|
| 12 |
+
"embedding_size": 100352,
|
| 13 |
+
"hidden_size": 4096,
|
| 14 |
+
"initializer_range": 0.02,
|
| 15 |
+
"intermediate_size": 22016,
|
| 16 |
+
"layer_norm_eps": 1e-06,
|
| 17 |
+
"layer_norm_type": "rms",
|
| 18 |
+
"max_position_embeddings": 4096,
|
| 19 |
+
"model_type": "molmo",
|
| 20 |
+
"norm_after": true,
|
| 21 |
+
"num_attention_heads": 32,
|
| 22 |
+
"num_hidden_layers": 32,
|
| 23 |
+
"num_key_value_heads": null,
|
| 24 |
+
"qkv_bias": false,
|
| 25 |
+
"quantization_config": {
|
| 26 |
+
"add_skip_keys": false,
|
| 27 |
+
"dequantize_fp32": false,
|
| 28 |
+
"group_size": 0,
|
| 29 |
+
"is_integer": true,
|
| 30 |
+
"modules_dtype_dict": {},
|
| 31 |
+
"modules_to_not_convert": [],
|
| 32 |
+
"non_blocking": false,
|
| 33 |
+
"quant_conv": false,
|
| 34 |
+
"quant_method": "sdnq",
|
| 35 |
+
"quantization_device": null,
|
| 36 |
+
"return_device": null,
|
| 37 |
+
"svd_rank": 32,
|
| 38 |
+
"svd_steps": 8,
|
| 39 |
+
"use_quantized_matmul": false,
|
| 40 |
+
"use_quantized_matmul_conv": false,
|
| 41 |
+
"use_svd": false,
|
| 42 |
+
"weights_dtype": "int8"
|
| 43 |
+
},
|
| 44 |
+
"rope_theta": 500000.0,
|
| 45 |
+
"tie_word_embeddings": false,
|
| 46 |
+
"transformers_version": "4.57.1",
|
| 47 |
+
"use_cache": true,
|
| 48 |
+
"use_position_ids": true,
|
| 49 |
+
"vocab_size": 100278,
|
| 50 |
+
"weight_tying": false
|
| 51 |
+
}
|
config_molmo.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from transformers import PretrainedConfig, AutoTokenizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MolmoConfig(PretrainedConfig):
|
| 7 |
+
model_type = "molmo"
|
| 8 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
vocab_size=50304,
|
| 13 |
+
embedding_size=50304,
|
| 14 |
+
hidden_size=4096,
|
| 15 |
+
intermediate_size=11008,
|
| 16 |
+
num_hidden_layers=32,
|
| 17 |
+
num_attention_heads=32,
|
| 18 |
+
num_key_value_heads=None,
|
| 19 |
+
max_position_embeddings=2048,
|
| 20 |
+
initializer_range=0.02,
|
| 21 |
+
use_cache=True,
|
| 22 |
+
layer_norm_eps: float = 1e-5,
|
| 23 |
+
rope_theta=10000.0,
|
| 24 |
+
clip_qkv=None,
|
| 25 |
+
qkv_bias: bool = False,
|
| 26 |
+
weight_tying: bool = False,
|
| 27 |
+
use_position_ids: bool=True,
|
| 28 |
+
tie_word_embeddings: bool=True,
|
| 29 |
+
attention_layer_norm: bool=False,
|
| 30 |
+
norm_after: bool = False,
|
| 31 |
+
layer_norm_type: str="rms",
|
| 32 |
+
**kwargs,
|
| 33 |
+
):
|
| 34 |
+
self.vocab_size = vocab_size
|
| 35 |
+
self.embedding_size = embedding_size
|
| 36 |
+
self.max_position_embeddings = max_position_embeddings
|
| 37 |
+
self.hidden_size = hidden_size
|
| 38 |
+
self.intermediate_size = intermediate_size
|
| 39 |
+
self.num_hidden_layers = num_hidden_layers
|
| 40 |
+
self.num_attention_heads = num_attention_heads
|
| 41 |
+
self.layer_norm_eps = layer_norm_eps
|
| 42 |
+
self.weight_tying = weight_tying
|
| 43 |
+
self.use_position_ids = use_position_ids
|
| 44 |
+
self.attention_layer_norm = attention_layer_norm
|
| 45 |
+
self.num_key_value_heads = num_key_value_heads
|
| 46 |
+
self.initializer_range = initializer_range
|
| 47 |
+
self.use_cache = use_cache
|
| 48 |
+
self.rope_theta = rope_theta
|
| 49 |
+
self.clip_qkv = clip_qkv
|
| 50 |
+
self.qkv_bias = qkv_bias
|
| 51 |
+
self.norm_after = norm_after
|
| 52 |
+
self.tie_word_embeddings = tie_word_embeddings
|
| 53 |
+
self.layer_norm_type = layer_norm_type
|
| 54 |
+
|
| 55 |
+
super().__init__(
|
| 56 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 57 |
+
**kwargs,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
MolmoConfig.register_for_auto_class()
|
generation_config.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"transformers_version": "4.57.1"
|
| 4 |
+
}
|
image_preprocessing_molmo.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image processor class for Molmo"""
|
| 2 |
+
from typing import List, Optional, Union, Mapping
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import einops
|
| 6 |
+
import torch
|
| 7 |
+
import torchvision.transforms
|
| 8 |
+
from torchvision.transforms import InterpolationMode
|
| 9 |
+
from torchvision.transforms.functional import convert_image_dtype
|
| 10 |
+
|
| 11 |
+
from transformers.image_utils import (
|
| 12 |
+
OPENAI_CLIP_MEAN,
|
| 13 |
+
OPENAI_CLIP_STD,
|
| 14 |
+
ImageInput,
|
| 15 |
+
is_valid_image,
|
| 16 |
+
)
|
| 17 |
+
from transformers.processing_utils import ImagesKwargs
|
| 18 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
| 19 |
+
from transformers.utils import logging
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def pad_to_bounding_box(
|
| 26 |
+
image, offset_height, offset_width, target_height,
|
| 27 |
+
target_width, value=0
|
| 28 |
+
):
|
| 29 |
+
height, width = image.shape[:2]
|
| 30 |
+
after_padding_width = target_width - offset_width - width
|
| 31 |
+
after_padding_height = target_height - offset_height - height
|
| 32 |
+
return np.pad(image, [
|
| 33 |
+
[offset_height, after_padding_height],
|
| 34 |
+
[offset_width, after_padding_width],
|
| 35 |
+
[0, 0]
|
| 36 |
+
], constant_values=value)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def normalize_image(image, offset, scale):
|
| 40 |
+
image -= np.array(offset, dtype=np.float32)[None, None, :]
|
| 41 |
+
image /= np.array(scale, dtype=np.float32)[None, None, :]
|
| 42 |
+
return image
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def resize_and_pad(
|
| 46 |
+
image,
|
| 47 |
+
desired_output_size,
|
| 48 |
+
resize_method="torch-bilinear",
|
| 49 |
+
pad_value=0,
|
| 50 |
+
normalize=True,
|
| 51 |
+
image_mean=OPENAI_CLIP_MEAN,
|
| 52 |
+
image_std=OPENAI_CLIP_STD,
|
| 53 |
+
):
|
| 54 |
+
desired_height, desired_width = desired_output_size
|
| 55 |
+
height, width = image.shape[:2]
|
| 56 |
+
|
| 57 |
+
# Cast into float32 since the training code did this in float32 and it (very rarely) effects
|
| 58 |
+
# the results after rounding.
|
| 59 |
+
image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32)
|
| 60 |
+
image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32)
|
| 61 |
+
image_scale = min(image_scale_x, image_scale_y)
|
| 62 |
+
scaled_height = int(np.array(height, np.float32) * image_scale)
|
| 63 |
+
scaled_width = int(np.array(width, np.float32) * image_scale)
|
| 64 |
+
|
| 65 |
+
if resize_method == "tensorflow":
|
| 66 |
+
# This how the original training code did resizing, it can produce slightly different
|
| 67 |
+
# results then using torch resize so we keep it just in case
|
| 68 |
+
import tensorflow as tf
|
| 69 |
+
image = tf.image.convert_image_dtype(tf.constant(image), dtype=tf.float32)
|
| 70 |
+
image = tf.image.resize(
|
| 71 |
+
image,
|
| 72 |
+
[scaled_height, scaled_width],
|
| 73 |
+
method=tf.image.ResizeMethod.BILINEAR,
|
| 74 |
+
antialias=True,
|
| 75 |
+
)
|
| 76 |
+
image = tf.clip_by_value(image, 0.0, 1.0)
|
| 77 |
+
image = image.numpy()
|
| 78 |
+
elif resize_method == "torch-bilinear":
|
| 79 |
+
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
| 80 |
+
image = convert_image_dtype(image) # resize in float32 to match the training code
|
| 81 |
+
image = torchvision.transforms.Resize(
|
| 82 |
+
[scaled_height, scaled_width], InterpolationMode.BILINEAR, antialias=True
|
| 83 |
+
)(image)
|
| 84 |
+
image = torch.clip(image, 0.0, 1.0)
|
| 85 |
+
image = torch.permute(image, [1, 2, 0]).numpy()
|
| 86 |
+
else:
|
| 87 |
+
raise NotImplementedError(resize_method)
|
| 88 |
+
|
| 89 |
+
top_pad = (desired_height - scaled_height) // 2
|
| 90 |
+
left_pad = (desired_width - scaled_width) // 2
|
| 91 |
+
padding = [
|
| 92 |
+
[top_pad, desired_height - scaled_height - top_pad],
|
| 93 |
+
[left_pad, desired_width - scaled_width - left_pad],
|
| 94 |
+
[0, 0]
|
| 95 |
+
]
|
| 96 |
+
image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2])
|
| 97 |
+
image = np.pad(image, padding, constant_values=pad_value)
|
| 98 |
+
if normalize:
|
| 99 |
+
image = normalize_image(image, offset=image_mean, scale=image_std)
|
| 100 |
+
return image, image_mask
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def select_tiling(h, w, patch_size, max_num_patches):
|
| 104 |
+
"""Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
| 105 |
+
original_size = np.stack([h, w]) # [1, 2]
|
| 106 |
+
original_res = h * w
|
| 107 |
+
tilings = []
|
| 108 |
+
for i in range(1, max_num_patches+1):
|
| 109 |
+
for j in range(1, max_num_patches+1):
|
| 110 |
+
if i*j <= max_num_patches:
|
| 111 |
+
tilings.append((i, j))
|
| 112 |
+
# sort so argmin and argmax favour smaller tilings in the event of a tie
|
| 113 |
+
tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
|
| 114 |
+
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
|
| 115 |
+
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
|
| 116 |
+
|
| 117 |
+
# How much we would need to scale the image to fit exactly in each tiling
|
| 118 |
+
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
|
| 119 |
+
required_scale_d = candidate_resolutions.astype(np.float32) / original_size
|
| 120 |
+
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
|
| 121 |
+
if np.all(required_scale < 1):
|
| 122 |
+
# We are forced to downscale, so try to minimize the amount of downscaling
|
| 123 |
+
ix = np.argmax(required_scale)
|
| 124 |
+
else:
|
| 125 |
+
# Pick the resolution that required the least upscaling so that it most closely fits the image
|
| 126 |
+
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
|
| 127 |
+
ix = np.argmin(required_scale)
|
| 128 |
+
return candidate_tilings[ix]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class MolmoImagesKwargs(ImagesKwargs, total=False):
|
| 132 |
+
max_crops: Optional[int]
|
| 133 |
+
overlap_margins: Optional[List[int]]
|
| 134 |
+
base_image_input_size: Optional[List[int]]
|
| 135 |
+
image_token_length_w: Optional[int]
|
| 136 |
+
image_token_length_h: Optional[int]
|
| 137 |
+
image_patch_size: Optional[int]
|
| 138 |
+
image_padding_mask: Optional[bool]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class MolmoImageProcessor(BaseImageProcessor):
|
| 142 |
+
"""Preprocess images and multi-model inputs"""
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
max_crops: int = 12,
|
| 147 |
+
overlap_margins: List[int] = (4, 4),
|
| 148 |
+
base_image_input_size: List[int] = (336, 336),
|
| 149 |
+
image_token_length_w: int = 12,
|
| 150 |
+
image_token_length_h: int = 12,
|
| 151 |
+
image_patch_size: int = 14,
|
| 152 |
+
image_padding_mask: bool = True,
|
| 153 |
+
do_normalize: bool = True,
|
| 154 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 155 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 156 |
+
**kwargs,
|
| 157 |
+
):
|
| 158 |
+
super().__init__(**kwargs)
|
| 159 |
+
self.max_crops = max_crops
|
| 160 |
+
self.overlap_margins = overlap_margins
|
| 161 |
+
self.base_image_input_size = base_image_input_size
|
| 162 |
+
self.image_token_length_w = image_token_length_w
|
| 163 |
+
self.image_token_length_h = image_token_length_h
|
| 164 |
+
self.image_patch_size = image_patch_size
|
| 165 |
+
self.image_padding_mask = image_padding_mask
|
| 166 |
+
self.do_normalize = do_normalize
|
| 167 |
+
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
| 168 |
+
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
| 169 |
+
|
| 170 |
+
def image_to_patches_and_tokens(
|
| 171 |
+
self,
|
| 172 |
+
image: ImageInput,
|
| 173 |
+
image_patch_token_id: int,
|
| 174 |
+
image_col_token_id: int,
|
| 175 |
+
image_start_token_id: int,
|
| 176 |
+
image_end_token_id: int,
|
| 177 |
+
max_crops: Optional[int] = None,
|
| 178 |
+
overlap_margins: Optional[List[int]] = None,
|
| 179 |
+
base_image_input_size: Optional[Union[int, List[int]]] = None,
|
| 180 |
+
image_token_length_w: Optional[int] = None,
|
| 181 |
+
image_token_length_h: Optional[int] = None,
|
| 182 |
+
image_patch_size: Optional[int] = None,
|
| 183 |
+
):
|
| 184 |
+
if isinstance(base_image_input_size, int):
|
| 185 |
+
base_image_input_size = (base_image_input_size, base_image_input_size)
|
| 186 |
+
|
| 187 |
+
base_image_input_d = image_patch_size
|
| 188 |
+
tokens_per_image = image_token_length_w * image_token_length_h
|
| 189 |
+
image_base_patch_w = base_image_input_size[1] // base_image_input_d
|
| 190 |
+
image_base_patch_h = base_image_input_size[0] // base_image_input_d
|
| 191 |
+
|
| 192 |
+
original_image_h, original_image_w = image.shape[:2]
|
| 193 |
+
crop_size = base_image_input_size[0]
|
| 194 |
+
|
| 195 |
+
# Discard this many patches from the (left/top, right/bottom) of crops
|
| 196 |
+
left_margin, right_margin = overlap_margins
|
| 197 |
+
# left_margin, right_margin = 2, 2
|
| 198 |
+
assert left_margin % 2 == 0 # Required for compatibility with 2x2 pooling
|
| 199 |
+
total_margin_pixels = base_image_input_d*(right_margin + left_margin) # pixels removed per dim
|
| 200 |
+
crop_patches = base_image_input_size[0] // base_image_input_d # patches per crop dim
|
| 201 |
+
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
|
| 202 |
+
crop_window_size = crop_window_patches * base_image_input_d
|
| 203 |
+
tiling = select_tiling(
|
| 204 |
+
original_image_h - total_margin_pixels,
|
| 205 |
+
original_image_w - total_margin_pixels,
|
| 206 |
+
crop_window_size,
|
| 207 |
+
max_crops
|
| 208 |
+
)
|
| 209 |
+
src, img_mask = resize_and_pad(
|
| 210 |
+
image,
|
| 211 |
+
[tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels]
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Now we have to split the image into crops, while keeping track of how each patch in the
|
| 215 |
+
# each crop should be ordered in the global image, this require a lot of tricky booking
|
| 216 |
+
n_crops = tiling[0] * tiling[1]
|
| 217 |
+
patches_arr = []
|
| 218 |
+
mask_arr = []
|
| 219 |
+
patch_ordering_arr = []
|
| 220 |
+
|
| 221 |
+
# We assume 2x2 pooling, but can allow padding the right/bottom with extra
|
| 222 |
+
# patches if the number of patches per side is not even
|
| 223 |
+
assert (crop_patches+1)//2 == image_token_length_h
|
| 224 |
+
assert (crop_patches+1)//2 == image_token_length_w
|
| 225 |
+
on = 0
|
| 226 |
+
on_patch = 0
|
| 227 |
+
for i in range(tiling[0]):
|
| 228 |
+
y0 = i*crop_window_size
|
| 229 |
+
if i == 0:
|
| 230 |
+
crop_y0 = 0
|
| 231 |
+
else:
|
| 232 |
+
crop_y0 = left_margin // 2
|
| 233 |
+
|
| 234 |
+
crop_h = image_base_patch_h - (right_margin + left_margin)
|
| 235 |
+
if i == 0:
|
| 236 |
+
crop_h += left_margin
|
| 237 |
+
if i == (tiling[0]-1):
|
| 238 |
+
crop_h += right_margin
|
| 239 |
+
for j in range(tiling[1]):
|
| 240 |
+
x0 = j*crop_window_size
|
| 241 |
+
if j == 0:
|
| 242 |
+
crop_x0 = 0
|
| 243 |
+
else:
|
| 244 |
+
crop_x0 = left_margin // 2
|
| 245 |
+
|
| 246 |
+
crop_w = image_base_patch_w - (right_margin + left_margin)
|
| 247 |
+
if j == 0:
|
| 248 |
+
crop_w += left_margin
|
| 249 |
+
if j == (tiling[1]-1):
|
| 250 |
+
crop_w += right_margin
|
| 251 |
+
|
| 252 |
+
pooled_w = (crop_w + 1) // 2
|
| 253 |
+
pooled_h = (crop_h + 1) // 2
|
| 254 |
+
patch_ordering_arr.append(
|
| 255 |
+
pad_to_bounding_box(
|
| 256 |
+
np.reshape(np.arange(on, on+pooled_h*pooled_w, dtype=np.int32), (pooled_h, pooled_w, 1)),
|
| 257 |
+
crop_y0, crop_x0, image_token_length_h, image_token_length_w, value=-1
|
| 258 |
+
)[:, :, 0]
|
| 259 |
+
)
|
| 260 |
+
patches_arr.append(src[y0:y0+crop_size, x0:x0+crop_size])
|
| 261 |
+
mask_arr.append(img_mask[y0:y0+crop_size, x0:x0+crop_size])
|
| 262 |
+
|
| 263 |
+
on += pooled_h*pooled_w
|
| 264 |
+
on_patch += 1
|
| 265 |
+
patches = np.stack(patches_arr)
|
| 266 |
+
patch_ordering = np.stack(patch_ordering_arr)
|
| 267 |
+
img_mask = np.stack(mask_arr)
|
| 268 |
+
|
| 269 |
+
# Switch to [n_crops, n_patches, pixels_per_patch] format
|
| 270 |
+
image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1]
|
| 271 |
+
patches = einops.rearrange(
|
| 272 |
+
patches, 'p (h dh) (w dw) c -> p (h w) (dh dw c)',
|
| 273 |
+
dh=base_image_input_d,
|
| 274 |
+
dw=base_image_input_d,
|
| 275 |
+
h=image_base_patch_h,
|
| 276 |
+
w=image_base_patch_w
|
| 277 |
+
)
|
| 278 |
+
img_mask = einops.rearrange(
|
| 279 |
+
img_mask, 'p (h dh) (w dw) -> p (h w) (dh dw)',
|
| 280 |
+
dh=base_image_input_d,
|
| 281 |
+
dw=base_image_input_d,
|
| 282 |
+
h=image_base_patch_h,
|
| 283 |
+
w=image_base_patch_w
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
img_mask = img_mask.astype(np.float32).mean(axis=-1)
|
| 287 |
+
patch_ordering = np.reshape(patch_ordering, [-1])
|
| 288 |
+
valid = patch_ordering >= 0
|
| 289 |
+
|
| 290 |
+
# Transpose order, to get left-to-right order instead of crop-by-crop order
|
| 291 |
+
patch_ordering_rh = np.reshape(
|
| 292 |
+
patch_ordering,
|
| 293 |
+
[tiling[0], tiling[1], image_token_length_h, image_token_length_w]
|
| 294 |
+
)
|
| 295 |
+
patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3])
|
| 296 |
+
patch_ordering_rh = np.reshape(patch_ordering_rh, [-1])
|
| 297 |
+
|
| 298 |
+
# The transpose will screw up which patches are masked, project the
|
| 299 |
+
# new order into sparse structure of `patch_ordering` to fix this
|
| 300 |
+
patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
|
| 301 |
+
|
| 302 |
+
# Now build the output tokens
|
| 303 |
+
h = tiling[0] * crop_window_patches + (right_margin+left_margin)
|
| 304 |
+
w = tiling[1] * crop_window_patches + (right_margin+left_margin)
|
| 305 |
+
per_row = np.full(
|
| 306 |
+
((w+1)//2,),
|
| 307 |
+
image_patch_token_id,
|
| 308 |
+
)
|
| 309 |
+
per_row = np.concatenate([per_row, [image_col_token_id]], 0)
|
| 310 |
+
|
| 311 |
+
joint = np.tile(per_row, [(h+1)//2])
|
| 312 |
+
joint = [
|
| 313 |
+
[image_start_token_id],
|
| 314 |
+
joint,
|
| 315 |
+
[image_end_token_id]
|
| 316 |
+
]
|
| 317 |
+
|
| 318 |
+
# Finally do the same for the global image
|
| 319 |
+
resized, _ = resize_and_pad(image, base_image_input_size)
|
| 320 |
+
resized = einops.rearrange(
|
| 321 |
+
resized, '(h dh) (w dw) c -> (h w) (dh dw c)',
|
| 322 |
+
dh=base_image_input_d,
|
| 323 |
+
dw=base_image_input_d,
|
| 324 |
+
h=image_base_patch_h,
|
| 325 |
+
w=image_base_patch_w
|
| 326 |
+
)
|
| 327 |
+
patches = np.concatenate([np.expand_dims(resized, 0), patches], 0)
|
| 328 |
+
|
| 329 |
+
# Global image goes first, so the order of patches in previous crops gets increased
|
| 330 |
+
patch_ordering = np.where(
|
| 331 |
+
patch_ordering >= 0,
|
| 332 |
+
patch_ordering + tokens_per_image,
|
| 333 |
+
-1
|
| 334 |
+
)
|
| 335 |
+
patch_ordering = np.concatenate([np.arange(0, tokens_per_image), patch_ordering], 0)
|
| 336 |
+
per_row = np.full(
|
| 337 |
+
(image_token_length_w,),
|
| 338 |
+
image_patch_token_id,
|
| 339 |
+
)
|
| 340 |
+
per_row = np.concatenate([per_row, [image_col_token_id]], 0)
|
| 341 |
+
extra_tokens = np.tile(per_row, [image_token_length_h])
|
| 342 |
+
joint = [
|
| 343 |
+
[image_start_token_id],
|
| 344 |
+
extra_tokens,
|
| 345 |
+
[image_end_token_id],
|
| 346 |
+
] + joint
|
| 347 |
+
|
| 348 |
+
joint = np.concatenate(joint, 0)
|
| 349 |
+
img_mask = np.pad(img_mask, [[0, 1], [0, 0]], constant_values=-1)
|
| 350 |
+
return patches, joint, patch_ordering, img_mask
|
| 351 |
+
|
| 352 |
+
def build_image_input_idx(
|
| 353 |
+
self,
|
| 354 |
+
image_tokens: np.ndarray,
|
| 355 |
+
patch_order: np.ndarray,
|
| 356 |
+
image_patch_token_id: int,
|
| 357 |
+
no_image: Optional[bool] = None,
|
| 358 |
+
image_token_length_w: Optional[int] = None,
|
| 359 |
+
image_token_length_h: Optional[int] = None,
|
| 360 |
+
):
|
| 361 |
+
"""Converts `patch_order` into a mapping of token_id -> patch_id"""
|
| 362 |
+
|
| 363 |
+
tokens_per_image = image_token_length_w * image_token_length_h
|
| 364 |
+
if no_image is not None and no_image:
|
| 365 |
+
return np.zeros((0, tokens_per_image), np.int32)
|
| 366 |
+
|
| 367 |
+
# Indices to insert the patches
|
| 368 |
+
image_input_idx = image_tokens == image_patch_token_id
|
| 369 |
+
image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32)
|
| 370 |
+
|
| 371 |
+
if patch_order is not None:
|
| 372 |
+
n_tokens = image_input_idx.shape[0]
|
| 373 |
+
patch_order = np.reshape(patch_order, [-1])
|
| 374 |
+
n_patches = patch_order.shape[0]
|
| 375 |
+
|
| 376 |
+
valid = patch_order >= 0
|
| 377 |
+
n_valid_patches = valid.sum()
|
| 378 |
+
assert len(image_input_idx) == n_valid_patches
|
| 379 |
+
|
| 380 |
+
sorted_patch_ixs = np.zeros([n_tokens], np.int32)
|
| 381 |
+
sorted_patch_ixs[patch_order[valid]] = np.arange(n_valid_patches, dtype=np.int32)
|
| 382 |
+
|
| 383 |
+
# Project the inverted mapping into same sparse structure
|
| 384 |
+
sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1)
|
| 385 |
+
sorted_patch_ixs_ex[valid] = sorted_patch_ixs
|
| 386 |
+
|
| 387 |
+
# Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs`
|
| 388 |
+
valid = (sorted_patch_ixs_ex >= 0).astype(np.int32)
|
| 389 |
+
image_input_idx = image_input_idx[sorted_patch_ixs_ex*valid]
|
| 390 |
+
image_input_idx = image_input_idx*valid - 100*(1 - valid)
|
| 391 |
+
image_input_idx = np.reshape(image_input_idx, [-1, tokens_per_image])
|
| 392 |
+
return image_input_idx
|
| 393 |
+
|
| 394 |
+
def preprocess(
|
| 395 |
+
self,
|
| 396 |
+
image: np.ndarray,
|
| 397 |
+
image_patch_token_id: int,
|
| 398 |
+
image_col_token_id: int,
|
| 399 |
+
image_start_token_id: int,
|
| 400 |
+
image_end_token_id: int,
|
| 401 |
+
max_crops: Optional[int] = None,
|
| 402 |
+
overlap_margins: Optional[List[int]] = None,
|
| 403 |
+
base_image_input_size: Optional[Union[int, List[int]]] = None,
|
| 404 |
+
image_token_length_w: Optional[int] = None,
|
| 405 |
+
image_token_length_h: Optional[int] = None,
|
| 406 |
+
image_patch_size: Optional[int] = None,
|
| 407 |
+
**kwargs,
|
| 408 |
+
):
|
| 409 |
+
"""Preprocesses an image
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might
|
| 413 |
+
change between images but the other dimension are fixed
|
| 414 |
+
tokens: (n_tokens,) int32 tokens, pad tokens indicate where to insert the
|
| 415 |
+
patch features, might include other special tokens as well
|
| 416 |
+
image_idx: (n_crops, n_patches) index in `tokens` to put the patch features from the
|
| 417 |
+
crops after pooling, negative values indicates patches features to exclude
|
| 418 |
+
padding_mask: (n_crops, n_patches) what percent of each crop is padding, can be None
|
| 419 |
+
if the image mask is not being used.
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
max_crops = max_crops or self.max_crops
|
| 423 |
+
overlap_margins = overlap_margins or self.overlap_margins
|
| 424 |
+
base_image_input_size = base_image_input_size or self.base_image_input_size
|
| 425 |
+
image_token_length_w = image_token_length_w or self.image_token_length_w
|
| 426 |
+
image_token_length_h = image_token_length_h or self.image_token_length_h
|
| 427 |
+
image_patch_size = image_patch_size or self.image_patch_size
|
| 428 |
+
|
| 429 |
+
crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens(
|
| 430 |
+
image,
|
| 431 |
+
image_patch_token_id,
|
| 432 |
+
image_col_token_id,
|
| 433 |
+
image_start_token_id,
|
| 434 |
+
image_end_token_id,
|
| 435 |
+
max_crops,
|
| 436 |
+
overlap_margins,
|
| 437 |
+
base_image_input_size,
|
| 438 |
+
image_token_length_w,
|
| 439 |
+
image_token_length_h,
|
| 440 |
+
image_patch_size,
|
| 441 |
+
)
|
| 442 |
+
patch_idx = self.build_image_input_idx(
|
| 443 |
+
image_tokens,
|
| 444 |
+
patch_ordering,
|
| 445 |
+
image_patch_token_id,
|
| 446 |
+
image_token_length_w=image_token_length_w,
|
| 447 |
+
image_token_length_h=image_token_length_h,
|
| 448 |
+
)
|
| 449 |
+
return crops, image_tokens, patch_idx, img_mask
|
| 450 |
+
|
| 451 |
+
def multimodal_preprocess(
|
| 452 |
+
self,
|
| 453 |
+
images: np.ndarray,
|
| 454 |
+
tokens: List[int],
|
| 455 |
+
image_idx: np.ndarray,
|
| 456 |
+
sequence_length: int,
|
| 457 |
+
image_patch_token_id: int,
|
| 458 |
+
image_col_token_id: int,
|
| 459 |
+
image_start_token_id: int,
|
| 460 |
+
image_end_token_id: int,
|
| 461 |
+
**kwargs,
|
| 462 |
+
):
|
| 463 |
+
"""Merge images and text tokens into multi-modal features for the model
|
| 464 |
+
|
| 465 |
+
:param images: images to use as input
|
| 466 |
+
:param tokens: input text tokens
|
| 467 |
+
:param image_idx: where to insert the images into `tokens`
|
| 468 |
+
:params image_patch_token_id: id to use of tokens that will contain image features
|
| 469 |
+
:params image_col_token_id: token id for image column special tokens
|
| 470 |
+
:params image_start_token_id: token id for image start special tokens
|
| 471 |
+
:params image_end_token_id: token id for image end special tokens
|
| 472 |
+
:params kwargs: override preprocessor default args
|
| 473 |
+
"""
|
| 474 |
+
max_total_crops = kwargs.get("max_crops") or self.max_crops
|
| 475 |
+
image_token_length_w = kwargs.get("image_token_length_w") or self.image_token_length_w
|
| 476 |
+
image_token_length_h = kwargs.get("image_token_length_h") or self.image_token_length_h
|
| 477 |
+
image_patch_size = kwargs.get("image_patch_size") or self.image_patch_size
|
| 478 |
+
base_image_input_size = kwargs.get("base_image_input_size") or self.base_image_input_size
|
| 479 |
+
image_num_patch = (
|
| 480 |
+
base_image_input_size[0] // image_patch_size,
|
| 481 |
+
base_image_input_size[1] // image_patch_size,
|
| 482 |
+
)
|
| 483 |
+
image_padding_mask = kwargs.get("image_padding_mask") or self.image_padding_mask
|
| 484 |
+
|
| 485 |
+
tokens_per_image = image_token_length_w * image_token_length_h
|
| 486 |
+
n_pixels = image_patch_size * image_patch_size * 3
|
| 487 |
+
n_patches = image_num_patch[0] * image_num_patch[1]
|
| 488 |
+
|
| 489 |
+
if images is None:
|
| 490 |
+
return {
|
| 491 |
+
"input_ids": tokens,
|
| 492 |
+
}
|
| 493 |
+
else:
|
| 494 |
+
n = len(images)
|
| 495 |
+
all_crops = []
|
| 496 |
+
all_image_idx = []
|
| 497 |
+
out_tokens = []
|
| 498 |
+
all_crop_masks = []
|
| 499 |
+
|
| 500 |
+
for ix in range(n):
|
| 501 |
+
token_ix = image_idx[ix]
|
| 502 |
+
crops, image_tokens, patch_idx, img_mask = self.preprocess(
|
| 503 |
+
images[ix],
|
| 504 |
+
image_patch_token_id,
|
| 505 |
+
image_col_token_id,
|
| 506 |
+
image_start_token_id,
|
| 507 |
+
image_end_token_id,
|
| 508 |
+
**kwargs,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
if token_ix == -1: # -1 is an image inserted at the very start
|
| 512 |
+
start = 0
|
| 513 |
+
token_ix = 0
|
| 514 |
+
end = 0
|
| 515 |
+
else:
|
| 516 |
+
start = 0 if ix == 0 else image_idx[ix-1] + 1
|
| 517 |
+
end = token_ix + 1
|
| 518 |
+
|
| 519 |
+
all_image_idx.append(patch_idx + token_ix)
|
| 520 |
+
all_crops.append(crops)
|
| 521 |
+
out_tokens.append(tokens[start:token_ix])
|
| 522 |
+
out_tokens.append(image_tokens)
|
| 523 |
+
if ix == (n - 1):
|
| 524 |
+
out_tokens.append(tokens[end:])
|
| 525 |
+
if image_padding_mask:
|
| 526 |
+
all_crop_masks.append(img_mask)
|
| 527 |
+
|
| 528 |
+
input_ids = np.concatenate(out_tokens, 0)
|
| 529 |
+
images = np.concatenate(all_crops, 0)
|
| 530 |
+
image_input_idx = np.concatenate(all_image_idx, 0)
|
| 531 |
+
if image_padding_mask:
|
| 532 |
+
image_masks = np.concatenate(all_crop_masks, 0)
|
| 533 |
+
else:
|
| 534 |
+
image_masks = None
|
| 535 |
+
|
| 536 |
+
out = {
|
| 537 |
+
"input_ids": input_ids,
|
| 538 |
+
"images": images,
|
| 539 |
+
"image_input_idx": image_input_idx
|
| 540 |
+
}
|
| 541 |
+
if image_masks is not None:
|
| 542 |
+
out["image_masks"] = image_masks
|
| 543 |
+
return out
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
MolmoImageProcessor.register_for_auto_class()
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model-00001-of-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cfcf2c9002bc69a9a8f15ebba206ff75daa8858ba3a7924bd4a709ec72f637ad
|
| 3 |
+
size 4953349016
|
model-00002-of-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fddc23c86e93926b88a0fa8508a80d5ad18b057411860365c72f70ab0e40f119
|
| 3 |
+
size 3964234784
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,928 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_parameters": 7671742976,
|
| 4 |
+
"total_size": 8917462016
|
| 5 |
+
},
|
| 6 |
+
"weight_map": {
|
| 7 |
+
"model.transformer.blocks.0.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 8 |
+
"model.transformer.blocks.0.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 9 |
+
"model.transformer.blocks.0.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 10 |
+
"model.transformer.blocks.0.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 11 |
+
"model.transformer.blocks.0.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 12 |
+
"model.transformer.blocks.0.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 13 |
+
"model.transformer.blocks.0.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 14 |
+
"model.transformer.blocks.0.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 15 |
+
"model.transformer.blocks.0.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 16 |
+
"model.transformer.blocks.0.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 17 |
+
"model.transformer.blocks.0.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 18 |
+
"model.transformer.blocks.0.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 19 |
+
"model.transformer.blocks.1.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 20 |
+
"model.transformer.blocks.1.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 21 |
+
"model.transformer.blocks.1.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 22 |
+
"model.transformer.blocks.1.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 23 |
+
"model.transformer.blocks.1.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 24 |
+
"model.transformer.blocks.1.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 25 |
+
"model.transformer.blocks.1.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 26 |
+
"model.transformer.blocks.1.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 27 |
+
"model.transformer.blocks.1.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 28 |
+
"model.transformer.blocks.1.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 29 |
+
"model.transformer.blocks.1.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 30 |
+
"model.transformer.blocks.1.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 31 |
+
"model.transformer.blocks.10.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 32 |
+
"model.transformer.blocks.10.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 33 |
+
"model.transformer.blocks.10.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 34 |
+
"model.transformer.blocks.10.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 35 |
+
"model.transformer.blocks.10.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 36 |
+
"model.transformer.blocks.10.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 37 |
+
"model.transformer.blocks.10.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 38 |
+
"model.transformer.blocks.10.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 39 |
+
"model.transformer.blocks.10.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 40 |
+
"model.transformer.blocks.10.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 41 |
+
"model.transformer.blocks.10.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 42 |
+
"model.transformer.blocks.10.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 43 |
+
"model.transformer.blocks.11.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 44 |
+
"model.transformer.blocks.11.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 45 |
+
"model.transformer.blocks.11.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 46 |
+
"model.transformer.blocks.11.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 47 |
+
"model.transformer.blocks.11.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 48 |
+
"model.transformer.blocks.11.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 49 |
+
"model.transformer.blocks.11.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 50 |
+
"model.transformer.blocks.11.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 51 |
+
"model.transformer.blocks.11.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 52 |
+
"model.transformer.blocks.11.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 53 |
+
"model.transformer.blocks.11.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 54 |
+
"model.transformer.blocks.11.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 55 |
+
"model.transformer.blocks.12.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 56 |
+
"model.transformer.blocks.12.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 57 |
+
"model.transformer.blocks.12.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 58 |
+
"model.transformer.blocks.12.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 59 |
+
"model.transformer.blocks.12.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 60 |
+
"model.transformer.blocks.12.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 61 |
+
"model.transformer.blocks.12.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 62 |
+
"model.transformer.blocks.12.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 63 |
+
"model.transformer.blocks.12.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 64 |
+
"model.transformer.blocks.12.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 65 |
+
"model.transformer.blocks.12.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 66 |
+
"model.transformer.blocks.12.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 67 |
+
"model.transformer.blocks.13.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 68 |
+
"model.transformer.blocks.13.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 69 |
+
"model.transformer.blocks.13.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 70 |
+
"model.transformer.blocks.13.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 71 |
+
"model.transformer.blocks.13.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 72 |
+
"model.transformer.blocks.13.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 73 |
+
"model.transformer.blocks.13.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 74 |
+
"model.transformer.blocks.13.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 75 |
+
"model.transformer.blocks.13.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 76 |
+
"model.transformer.blocks.13.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 77 |
+
"model.transformer.blocks.13.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 78 |
+
"model.transformer.blocks.13.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 79 |
+
"model.transformer.blocks.14.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 80 |
+
"model.transformer.blocks.14.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 81 |
+
"model.transformer.blocks.14.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 82 |
+
"model.transformer.blocks.14.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 83 |
+
"model.transformer.blocks.14.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 84 |
+
"model.transformer.blocks.14.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 85 |
+
"model.transformer.blocks.14.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 86 |
+
"model.transformer.blocks.14.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 87 |
+
"model.transformer.blocks.14.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 88 |
+
"model.transformer.blocks.14.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 89 |
+
"model.transformer.blocks.14.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 90 |
+
"model.transformer.blocks.14.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 91 |
+
"model.transformer.blocks.15.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 92 |
+
"model.transformer.blocks.15.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 93 |
+
"model.transformer.blocks.15.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 94 |
+
"model.transformer.blocks.15.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 95 |
+
"model.transformer.blocks.15.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 96 |
+
"model.transformer.blocks.15.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 97 |
+
"model.transformer.blocks.15.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 98 |
+
"model.transformer.blocks.15.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 99 |
+
"model.transformer.blocks.15.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 100 |
+
"model.transformer.blocks.15.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 101 |
+
"model.transformer.blocks.15.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 102 |
+
"model.transformer.blocks.15.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 103 |
+
"model.transformer.blocks.16.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 104 |
+
"model.transformer.blocks.16.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 105 |
+
"model.transformer.blocks.16.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 106 |
+
"model.transformer.blocks.16.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 107 |
+
"model.transformer.blocks.16.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 108 |
+
"model.transformer.blocks.16.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 109 |
+
"model.transformer.blocks.16.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 110 |
+
"model.transformer.blocks.16.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 111 |
+
"model.transformer.blocks.16.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 112 |
+
"model.transformer.blocks.16.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 113 |
+
"model.transformer.blocks.16.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 114 |
+
"model.transformer.blocks.16.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 115 |
+
"model.transformer.blocks.17.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 116 |
+
"model.transformer.blocks.17.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 117 |
+
"model.transformer.blocks.17.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 118 |
+
"model.transformer.blocks.17.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 119 |
+
"model.transformer.blocks.17.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 120 |
+
"model.transformer.blocks.17.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 121 |
+
"model.transformer.blocks.17.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 122 |
+
"model.transformer.blocks.17.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 123 |
+
"model.transformer.blocks.17.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 124 |
+
"model.transformer.blocks.17.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 125 |
+
"model.transformer.blocks.17.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 126 |
+
"model.transformer.blocks.17.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 127 |
+
"model.transformer.blocks.18.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 128 |
+
"model.transformer.blocks.18.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 129 |
+
"model.transformer.blocks.18.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 130 |
+
"model.transformer.blocks.18.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 131 |
+
"model.transformer.blocks.18.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 132 |
+
"model.transformer.blocks.18.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 133 |
+
"model.transformer.blocks.18.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 134 |
+
"model.transformer.blocks.18.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 135 |
+
"model.transformer.blocks.18.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 136 |
+
"model.transformer.blocks.18.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 137 |
+
"model.transformer.blocks.18.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 138 |
+
"model.transformer.blocks.18.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 139 |
+
"model.transformer.blocks.19.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 140 |
+
"model.transformer.blocks.19.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 141 |
+
"model.transformer.blocks.19.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 142 |
+
"model.transformer.blocks.19.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 143 |
+
"model.transformer.blocks.19.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 144 |
+
"model.transformer.blocks.19.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 145 |
+
"model.transformer.blocks.19.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 146 |
+
"model.transformer.blocks.19.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 147 |
+
"model.transformer.blocks.19.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 148 |
+
"model.transformer.blocks.19.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 149 |
+
"model.transformer.blocks.19.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 150 |
+
"model.transformer.blocks.19.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 151 |
+
"model.transformer.blocks.2.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 152 |
+
"model.transformer.blocks.2.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 153 |
+
"model.transformer.blocks.2.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 154 |
+
"model.transformer.blocks.2.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 155 |
+
"model.transformer.blocks.2.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 156 |
+
"model.transformer.blocks.2.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 157 |
+
"model.transformer.blocks.2.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 158 |
+
"model.transformer.blocks.2.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 159 |
+
"model.transformer.blocks.2.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 160 |
+
"model.transformer.blocks.2.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 161 |
+
"model.transformer.blocks.2.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 162 |
+
"model.transformer.blocks.2.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 163 |
+
"model.transformer.blocks.20.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 164 |
+
"model.transformer.blocks.20.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 165 |
+
"model.transformer.blocks.20.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 166 |
+
"model.transformer.blocks.20.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 167 |
+
"model.transformer.blocks.20.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 168 |
+
"model.transformer.blocks.20.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 169 |
+
"model.transformer.blocks.20.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 170 |
+
"model.transformer.blocks.20.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 171 |
+
"model.transformer.blocks.20.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 172 |
+
"model.transformer.blocks.20.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 173 |
+
"model.transformer.blocks.20.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 174 |
+
"model.transformer.blocks.20.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 175 |
+
"model.transformer.blocks.21.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 176 |
+
"model.transformer.blocks.21.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 177 |
+
"model.transformer.blocks.21.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 178 |
+
"model.transformer.blocks.21.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 179 |
+
"model.transformer.blocks.21.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 180 |
+
"model.transformer.blocks.21.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 181 |
+
"model.transformer.blocks.21.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 182 |
+
"model.transformer.blocks.21.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 183 |
+
"model.transformer.blocks.21.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 184 |
+
"model.transformer.blocks.21.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 185 |
+
"model.transformer.blocks.21.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 186 |
+
"model.transformer.blocks.21.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 187 |
+
"model.transformer.blocks.22.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 188 |
+
"model.transformer.blocks.22.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 189 |
+
"model.transformer.blocks.22.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 190 |
+
"model.transformer.blocks.22.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 191 |
+
"model.transformer.blocks.22.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 192 |
+
"model.transformer.blocks.22.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 193 |
+
"model.transformer.blocks.22.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 194 |
+
"model.transformer.blocks.22.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 195 |
+
"model.transformer.blocks.22.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 196 |
+
"model.transformer.blocks.22.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 197 |
+
"model.transformer.blocks.22.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 198 |
+
"model.transformer.blocks.22.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 199 |
+
"model.transformer.blocks.23.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 200 |
+
"model.transformer.blocks.23.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 201 |
+
"model.transformer.blocks.23.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 202 |
+
"model.transformer.blocks.23.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 203 |
+
"model.transformer.blocks.23.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 204 |
+
"model.transformer.blocks.23.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 205 |
+
"model.transformer.blocks.23.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 206 |
+
"model.transformer.blocks.23.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 207 |
+
"model.transformer.blocks.23.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 208 |
+
"model.transformer.blocks.23.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 209 |
+
"model.transformer.blocks.23.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 210 |
+
"model.transformer.blocks.23.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 211 |
+
"model.transformer.blocks.24.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 212 |
+
"model.transformer.blocks.24.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 213 |
+
"model.transformer.blocks.24.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 214 |
+
"model.transformer.blocks.24.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 215 |
+
"model.transformer.blocks.24.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 216 |
+
"model.transformer.blocks.24.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 217 |
+
"model.transformer.blocks.24.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 218 |
+
"model.transformer.blocks.24.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 219 |
+
"model.transformer.blocks.24.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 220 |
+
"model.transformer.blocks.24.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 221 |
+
"model.transformer.blocks.24.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 222 |
+
"model.transformer.blocks.24.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 223 |
+
"model.transformer.blocks.25.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 224 |
+
"model.transformer.blocks.25.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 225 |
+
"model.transformer.blocks.25.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 226 |
+
"model.transformer.blocks.25.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 227 |
+
"model.transformer.blocks.25.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 228 |
+
"model.transformer.blocks.25.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 229 |
+
"model.transformer.blocks.25.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 230 |
+
"model.transformer.blocks.25.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 231 |
+
"model.transformer.blocks.25.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 232 |
+
"model.transformer.blocks.25.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 233 |
+
"model.transformer.blocks.25.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 234 |
+
"model.transformer.blocks.25.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 235 |
+
"model.transformer.blocks.26.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 236 |
+
"model.transformer.blocks.26.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 237 |
+
"model.transformer.blocks.26.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 238 |
+
"model.transformer.blocks.26.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 239 |
+
"model.transformer.blocks.26.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 240 |
+
"model.transformer.blocks.26.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 241 |
+
"model.transformer.blocks.26.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 242 |
+
"model.transformer.blocks.26.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 243 |
+
"model.transformer.blocks.26.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 244 |
+
"model.transformer.blocks.26.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 245 |
+
"model.transformer.blocks.26.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 246 |
+
"model.transformer.blocks.26.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 247 |
+
"model.transformer.blocks.27.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 248 |
+
"model.transformer.blocks.27.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 249 |
+
"model.transformer.blocks.27.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 250 |
+
"model.transformer.blocks.27.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 251 |
+
"model.transformer.blocks.27.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 252 |
+
"model.transformer.blocks.27.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 253 |
+
"model.transformer.blocks.27.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 254 |
+
"model.transformer.blocks.27.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 255 |
+
"model.transformer.blocks.27.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 256 |
+
"model.transformer.blocks.27.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 257 |
+
"model.transformer.blocks.27.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 258 |
+
"model.transformer.blocks.27.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 259 |
+
"model.transformer.blocks.28.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 260 |
+
"model.transformer.blocks.28.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 261 |
+
"model.transformer.blocks.28.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 262 |
+
"model.transformer.blocks.28.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 263 |
+
"model.transformer.blocks.28.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 264 |
+
"model.transformer.blocks.28.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 265 |
+
"model.transformer.blocks.28.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 266 |
+
"model.transformer.blocks.28.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 267 |
+
"model.transformer.blocks.28.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 268 |
+
"model.transformer.blocks.28.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 269 |
+
"model.transformer.blocks.28.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 270 |
+
"model.transformer.blocks.28.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 271 |
+
"model.transformer.blocks.29.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 272 |
+
"model.transformer.blocks.29.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 273 |
+
"model.transformer.blocks.29.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 274 |
+
"model.transformer.blocks.29.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 275 |
+
"model.transformer.blocks.29.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 276 |
+
"model.transformer.blocks.29.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 277 |
+
"model.transformer.blocks.29.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 278 |
+
"model.transformer.blocks.29.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 279 |
+
"model.transformer.blocks.29.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 280 |
+
"model.transformer.blocks.29.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 281 |
+
"model.transformer.blocks.29.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 282 |
+
"model.transformer.blocks.29.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 283 |
+
"model.transformer.blocks.3.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 284 |
+
"model.transformer.blocks.3.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 285 |
+
"model.transformer.blocks.3.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 286 |
+
"model.transformer.blocks.3.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 287 |
+
"model.transformer.blocks.3.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 288 |
+
"model.transformer.blocks.3.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 289 |
+
"model.transformer.blocks.3.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 290 |
+
"model.transformer.blocks.3.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 291 |
+
"model.transformer.blocks.3.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 292 |
+
"model.transformer.blocks.3.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 293 |
+
"model.transformer.blocks.3.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 294 |
+
"model.transformer.blocks.3.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 295 |
+
"model.transformer.blocks.30.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 296 |
+
"model.transformer.blocks.30.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 297 |
+
"model.transformer.blocks.30.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 298 |
+
"model.transformer.blocks.30.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 299 |
+
"model.transformer.blocks.30.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 300 |
+
"model.transformer.blocks.30.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 301 |
+
"model.transformer.blocks.30.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 302 |
+
"model.transformer.blocks.30.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 303 |
+
"model.transformer.blocks.30.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 304 |
+
"model.transformer.blocks.30.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 305 |
+
"model.transformer.blocks.30.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 306 |
+
"model.transformer.blocks.30.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 307 |
+
"model.transformer.blocks.31.att_proj.scale": "model-00002-of-00002.safetensors",
|
| 308 |
+
"model.transformer.blocks.31.att_proj.weight": "model-00002-of-00002.safetensors",
|
| 309 |
+
"model.transformer.blocks.31.attn_norm.weight": "model-00002-of-00002.safetensors",
|
| 310 |
+
"model.transformer.blocks.31.attn_out.scale": "model-00002-of-00002.safetensors",
|
| 311 |
+
"model.transformer.blocks.31.attn_out.weight": "model-00002-of-00002.safetensors",
|
| 312 |
+
"model.transformer.blocks.31.ff_norm.weight": "model-00002-of-00002.safetensors",
|
| 313 |
+
"model.transformer.blocks.31.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 314 |
+
"model.transformer.blocks.31.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 315 |
+
"model.transformer.blocks.31.ff_proj.scale": "model-00002-of-00002.safetensors",
|
| 316 |
+
"model.transformer.blocks.31.ff_proj.weight": "model-00002-of-00002.safetensors",
|
| 317 |
+
"model.transformer.blocks.31.k_norm.weight": "model-00002-of-00002.safetensors",
|
| 318 |
+
"model.transformer.blocks.31.q_norm.weight": "model-00002-of-00002.safetensors",
|
| 319 |
+
"model.transformer.blocks.4.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 320 |
+
"model.transformer.blocks.4.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 321 |
+
"model.transformer.blocks.4.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 322 |
+
"model.transformer.blocks.4.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 323 |
+
"model.transformer.blocks.4.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 324 |
+
"model.transformer.blocks.4.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 325 |
+
"model.transformer.blocks.4.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 326 |
+
"model.transformer.blocks.4.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 327 |
+
"model.transformer.blocks.4.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 328 |
+
"model.transformer.blocks.4.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 329 |
+
"model.transformer.blocks.4.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 330 |
+
"model.transformer.blocks.4.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 331 |
+
"model.transformer.blocks.5.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 332 |
+
"model.transformer.blocks.5.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 333 |
+
"model.transformer.blocks.5.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 334 |
+
"model.transformer.blocks.5.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 335 |
+
"model.transformer.blocks.5.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 336 |
+
"model.transformer.blocks.5.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 337 |
+
"model.transformer.blocks.5.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 338 |
+
"model.transformer.blocks.5.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 339 |
+
"model.transformer.blocks.5.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 340 |
+
"model.transformer.blocks.5.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 341 |
+
"model.transformer.blocks.5.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 342 |
+
"model.transformer.blocks.5.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 343 |
+
"model.transformer.blocks.6.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 344 |
+
"model.transformer.blocks.6.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 345 |
+
"model.transformer.blocks.6.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 346 |
+
"model.transformer.blocks.6.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 347 |
+
"model.transformer.blocks.6.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 348 |
+
"model.transformer.blocks.6.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 349 |
+
"model.transformer.blocks.6.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 350 |
+
"model.transformer.blocks.6.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 351 |
+
"model.transformer.blocks.6.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 352 |
+
"model.transformer.blocks.6.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 353 |
+
"model.transformer.blocks.6.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 354 |
+
"model.transformer.blocks.6.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 355 |
+
"model.transformer.blocks.7.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 356 |
+
"model.transformer.blocks.7.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 357 |
+
"model.transformer.blocks.7.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 358 |
+
"model.transformer.blocks.7.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 359 |
+
"model.transformer.blocks.7.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 360 |
+
"model.transformer.blocks.7.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 361 |
+
"model.transformer.blocks.7.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 362 |
+
"model.transformer.blocks.7.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 363 |
+
"model.transformer.blocks.7.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 364 |
+
"model.transformer.blocks.7.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 365 |
+
"model.transformer.blocks.7.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 366 |
+
"model.transformer.blocks.7.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 367 |
+
"model.transformer.blocks.8.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 368 |
+
"model.transformer.blocks.8.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 369 |
+
"model.transformer.blocks.8.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 370 |
+
"model.transformer.blocks.8.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 371 |
+
"model.transformer.blocks.8.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 372 |
+
"model.transformer.blocks.8.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 373 |
+
"model.transformer.blocks.8.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 374 |
+
"model.transformer.blocks.8.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 375 |
+
"model.transformer.blocks.8.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 376 |
+
"model.transformer.blocks.8.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 377 |
+
"model.transformer.blocks.8.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 378 |
+
"model.transformer.blocks.8.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 379 |
+
"model.transformer.blocks.9.att_proj.scale": "model-00001-of-00002.safetensors",
|
| 380 |
+
"model.transformer.blocks.9.att_proj.weight": "model-00001-of-00002.safetensors",
|
| 381 |
+
"model.transformer.blocks.9.attn_norm.weight": "model-00001-of-00002.safetensors",
|
| 382 |
+
"model.transformer.blocks.9.attn_out.scale": "model-00001-of-00002.safetensors",
|
| 383 |
+
"model.transformer.blocks.9.attn_out.weight": "model-00001-of-00002.safetensors",
|
| 384 |
+
"model.transformer.blocks.9.ff_norm.weight": "model-00001-of-00002.safetensors",
|
| 385 |
+
"model.transformer.blocks.9.ff_out.scale": "model-00001-of-00002.safetensors",
|
| 386 |
+
"model.transformer.blocks.9.ff_out.weight": "model-00001-of-00002.safetensors",
|
| 387 |
+
"model.transformer.blocks.9.ff_proj.scale": "model-00001-of-00002.safetensors",
|
| 388 |
+
"model.transformer.blocks.9.ff_proj.weight": "model-00001-of-00002.safetensors",
|
| 389 |
+
"model.transformer.blocks.9.k_norm.weight": "model-00001-of-00002.safetensors",
|
| 390 |
+
"model.transformer.blocks.9.q_norm.weight": "model-00001-of-00002.safetensors",
|
| 391 |
+
"model.transformer.ff_out.scale": "model-00002-of-00002.safetensors",
|
| 392 |
+
"model.transformer.ff_out.weight": "model-00002-of-00002.safetensors",
|
| 393 |
+
"model.transformer.ln_f.weight": "model-00001-of-00002.safetensors",
|
| 394 |
+
"model.transformer.wte.embedding": "model-00001-of-00002.safetensors",
|
| 395 |
+
"model.transformer.wte.new_embedding": "model-00001-of-00002.safetensors",
|
| 396 |
+
"model.vision_backbone.image_pooling_2d.wk.bias": "model-00002-of-00002.safetensors",
|
| 397 |
+
"model.vision_backbone.image_pooling_2d.wk.scale": "model-00002-of-00002.safetensors",
|
| 398 |
+
"model.vision_backbone.image_pooling_2d.wk.weight": "model-00002-of-00002.safetensors",
|
| 399 |
+
"model.vision_backbone.image_pooling_2d.wo.bias": "model-00002-of-00002.safetensors",
|
| 400 |
+
"model.vision_backbone.image_pooling_2d.wo.scale": "model-00002-of-00002.safetensors",
|
| 401 |
+
"model.vision_backbone.image_pooling_2d.wo.weight": "model-00002-of-00002.safetensors",
|
| 402 |
+
"model.vision_backbone.image_pooling_2d.wq.bias": "model-00002-of-00002.safetensors",
|
| 403 |
+
"model.vision_backbone.image_pooling_2d.wq.scale": "model-00002-of-00002.safetensors",
|
| 404 |
+
"model.vision_backbone.image_pooling_2d.wq.weight": "model-00002-of-00002.safetensors",
|
| 405 |
+
"model.vision_backbone.image_pooling_2d.wv.bias": "model-00002-of-00002.safetensors",
|
| 406 |
+
"model.vision_backbone.image_pooling_2d.wv.scale": "model-00002-of-00002.safetensors",
|
| 407 |
+
"model.vision_backbone.image_pooling_2d.wv.weight": "model-00002-of-00002.safetensors",
|
| 408 |
+
"model.vision_backbone.image_projector.w1.scale": "model-00002-of-00002.safetensors",
|
| 409 |
+
"model.vision_backbone.image_projector.w1.weight": "model-00002-of-00002.safetensors",
|
| 410 |
+
"model.vision_backbone.image_projector.w2.scale": "model-00002-of-00002.safetensors",
|
| 411 |
+
"model.vision_backbone.image_projector.w2.weight": "model-00002-of-00002.safetensors",
|
| 412 |
+
"model.vision_backbone.image_projector.w3.scale": "model-00002-of-00002.safetensors",
|
| 413 |
+
"model.vision_backbone.image_projector.w3.weight": "model-00002-of-00002.safetensors",
|
| 414 |
+
"model.vision_backbone.image_vit.class_embedding": "model-00002-of-00002.safetensors",
|
| 415 |
+
"model.vision_backbone.image_vit.patch_embedding.scale": "model-00002-of-00002.safetensors",
|
| 416 |
+
"model.vision_backbone.image_vit.patch_embedding.weight": "model-00002-of-00002.safetensors",
|
| 417 |
+
"model.vision_backbone.image_vit.positional_embedding": "model-00002-of-00002.safetensors",
|
| 418 |
+
"model.vision_backbone.image_vit.pre_ln.bias": "model-00002-of-00002.safetensors",
|
| 419 |
+
"model.vision_backbone.image_vit.pre_ln.weight": "model-00002-of-00002.safetensors",
|
| 420 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 421 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 422 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 423 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 424 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 425 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 426 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 427 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 428 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 429 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 430 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 431 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 432 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 433 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 434 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 435 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 436 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 437 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 438 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 439 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 440 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 441 |
+
"model.vision_backbone.image_vit.transformer.resblocks.0.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 442 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 443 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 444 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 445 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 446 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 447 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 448 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 449 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 450 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 451 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 452 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 453 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 454 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 455 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 456 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 457 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 458 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 459 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 460 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 461 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 462 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 463 |
+
"model.vision_backbone.image_vit.transformer.resblocks.1.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 464 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 465 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 466 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 467 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 468 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 469 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 470 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 471 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 472 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 473 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 474 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 475 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 476 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 477 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 478 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 479 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 480 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 481 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 482 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 483 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 484 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 485 |
+
"model.vision_backbone.image_vit.transformer.resblocks.10.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 486 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 487 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 488 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 489 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 490 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 491 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 492 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 493 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 494 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 495 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 496 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 497 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 498 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 499 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 500 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 501 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 502 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 503 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 504 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 505 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 506 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 507 |
+
"model.vision_backbone.image_vit.transformer.resblocks.11.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 508 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 509 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 510 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 511 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 512 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 513 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 514 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 515 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 516 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 517 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 518 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 519 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 520 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 521 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 522 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 523 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 524 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 525 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 526 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 527 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 528 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 529 |
+
"model.vision_backbone.image_vit.transformer.resblocks.12.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 530 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 531 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 532 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 533 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 534 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 535 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 536 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 537 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 538 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 539 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 540 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 541 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 542 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 543 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 544 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 545 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 546 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 547 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 548 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 549 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 550 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 551 |
+
"model.vision_backbone.image_vit.transformer.resblocks.13.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 552 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 553 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 554 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 555 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 556 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 557 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 558 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 559 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 560 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 561 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 562 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 563 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 564 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 565 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 566 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 567 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 568 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 569 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 570 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 571 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 572 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 573 |
+
"model.vision_backbone.image_vit.transformer.resblocks.14.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 574 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 575 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 576 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 577 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 578 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 579 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 580 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 581 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 582 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 583 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 584 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 585 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 586 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 587 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 588 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 589 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 590 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 591 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 592 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 593 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 594 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 595 |
+
"model.vision_backbone.image_vit.transformer.resblocks.15.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 596 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 597 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 598 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 599 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 600 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 601 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 602 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 603 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 604 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 605 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 606 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 607 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 608 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 609 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 610 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 611 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 612 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 613 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 614 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 615 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 616 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 617 |
+
"model.vision_backbone.image_vit.transformer.resblocks.16.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 618 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 619 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 620 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 621 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 622 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 623 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 624 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 625 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 626 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 627 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 628 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 629 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 630 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 631 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 632 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 633 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 634 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 635 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 636 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 637 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 638 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 639 |
+
"model.vision_backbone.image_vit.transformer.resblocks.17.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 640 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 641 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 642 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 643 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 644 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 645 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 646 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 647 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 648 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 649 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 650 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 651 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 652 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 653 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 654 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 655 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 656 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 657 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 658 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 659 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 660 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 661 |
+
"model.vision_backbone.image_vit.transformer.resblocks.18.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 662 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 663 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 664 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 665 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 666 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 667 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 668 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 669 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 670 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 671 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 672 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 673 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 674 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 675 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 676 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 677 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 678 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 679 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 680 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 681 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 682 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 683 |
+
"model.vision_backbone.image_vit.transformer.resblocks.19.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 684 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 685 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 686 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 687 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 688 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 689 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 690 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 691 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 692 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 693 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 694 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 695 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 696 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 697 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 698 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 699 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 700 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 701 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 702 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 703 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 704 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 705 |
+
"model.vision_backbone.image_vit.transformer.resblocks.2.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 706 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 707 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 708 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 709 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 710 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 711 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 712 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 713 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 714 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 715 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 716 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 717 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 718 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 719 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 720 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 721 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 722 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 723 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 724 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 725 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 726 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 727 |
+
"model.vision_backbone.image_vit.transformer.resblocks.20.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 728 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 729 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 730 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 731 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 732 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 733 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 734 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 735 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 736 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 737 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 738 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 739 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 740 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 741 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 742 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 743 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 744 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 745 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 746 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 747 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 748 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 749 |
+
"model.vision_backbone.image_vit.transformer.resblocks.21.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 750 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 751 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 752 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 753 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 754 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 755 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 756 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 757 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 758 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 759 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 760 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 761 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 762 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 763 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 764 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 765 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 766 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 767 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 768 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 769 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 770 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 771 |
+
"model.vision_backbone.image_vit.transformer.resblocks.22.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 772 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 773 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 774 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 775 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 776 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 777 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 778 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 779 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 780 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 781 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 782 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 783 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 784 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 785 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 786 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 787 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 788 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 789 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 790 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 791 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 792 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 793 |
+
"model.vision_backbone.image_vit.transformer.resblocks.3.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 794 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 795 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 796 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 797 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 798 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 799 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 800 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 801 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 802 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 803 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 804 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 805 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 806 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 807 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 808 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 809 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 810 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 811 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 812 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 813 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 814 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 815 |
+
"model.vision_backbone.image_vit.transformer.resblocks.4.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 816 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 817 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 818 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 819 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 820 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 821 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 822 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 823 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 824 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 825 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 826 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 827 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 828 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 829 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 830 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 831 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 832 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 833 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 834 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 835 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 836 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 837 |
+
"model.vision_backbone.image_vit.transformer.resblocks.5.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 838 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 839 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 840 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 841 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 842 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 843 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 844 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 845 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 846 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 847 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 848 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 849 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 850 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 851 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 852 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 853 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 854 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 855 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 856 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 857 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 858 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 859 |
+
"model.vision_backbone.image_vit.transformer.resblocks.6.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 860 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 861 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 862 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 863 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 864 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 865 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 866 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 867 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 868 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 869 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 870 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 871 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 872 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 873 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 874 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 875 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 876 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 877 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 878 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 879 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 880 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 881 |
+
"model.vision_backbone.image_vit.transformer.resblocks.7.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 882 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 883 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 884 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 885 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 886 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 887 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 888 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 889 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 890 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 891 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 892 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 893 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 894 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 895 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 896 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 897 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 898 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 899 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 900 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 901 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 902 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 903 |
+
"model.vision_backbone.image_vit.transformer.resblocks.8.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 904 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention.wk.bias": "model-00002-of-00002.safetensors",
|
| 905 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention.wk.scale": "model-00002-of-00002.safetensors",
|
| 906 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention.wk.weight": "model-00002-of-00002.safetensors",
|
| 907 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention.wo.bias": "model-00002-of-00002.safetensors",
|
| 908 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention.wo.scale": "model-00002-of-00002.safetensors",
|
| 909 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention.wo.weight": "model-00002-of-00002.safetensors",
|
| 910 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention.wq.bias": "model-00002-of-00002.safetensors",
|
| 911 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention.wq.scale": "model-00002-of-00002.safetensors",
|
| 912 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention.wq.weight": "model-00002-of-00002.safetensors",
|
| 913 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention.wv.bias": "model-00002-of-00002.safetensors",
|
| 914 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention.wv.scale": "model-00002-of-00002.safetensors",
|
| 915 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention.wv.weight": "model-00002-of-00002.safetensors",
|
| 916 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention_norm.bias": "model-00002-of-00002.safetensors",
|
| 917 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.attention_norm.weight": "model-00002-of-00002.safetensors",
|
| 918 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.feed_forward.w1.bias": "model-00002-of-00002.safetensors",
|
| 919 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.feed_forward.w1.scale": "model-00002-of-00002.safetensors",
|
| 920 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.feed_forward.w1.weight": "model-00002-of-00002.safetensors",
|
| 921 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.feed_forward.w2.bias": "model-00002-of-00002.safetensors",
|
| 922 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.feed_forward.w2.scale": "model-00002-of-00002.safetensors",
|
| 923 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.feed_forward.w2.weight": "model-00002-of-00002.safetensors",
|
| 924 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.ffn_norm.bias": "model-00002-of-00002.safetensors",
|
| 925 |
+
"model.vision_backbone.image_vit.transformer.resblocks.9.ffn_norm.weight": "model-00002-of-00002.safetensors",
|
| 926 |
+
"model.vision_backbone.pad_embed": "model-00002-of-00002.safetensors"
|
| 927 |
+
}
|
| 928 |
+
}
|
modeling_molmo.py
ADDED
|
@@ -0,0 +1,2372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from dataclasses import fields, dataclass, replace
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import List, Optional, Tuple, Union, Dict, Any, Sequence, Callable, cast, MutableMapping
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from einops import einsum, einops
|
| 10 |
+
from transformers import PreTrainedModel, GenerationConfig
|
| 11 |
+
from transformers.cache_utils import Cache
|
| 12 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
|
| 13 |
+
from transformers.models.auto import AutoModelForCausalLM
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
from .config_molmo import MolmoConfig
|
| 17 |
+
from torch.nn import functional as F
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
|
| 24 |
+
"""
|
| 25 |
+
Cache for attention biases and other things that would normally be stored as buffers.
|
| 26 |
+
We avoid using buffers because we've run into various issues doing so with FSDP.
|
| 27 |
+
In general it appears the way FSDP handles buffers is not well-defined.
|
| 28 |
+
It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
|
| 29 |
+
since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
|
| 30 |
+
NaNs when they're synchronized due to casting or some other issue.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class StrEnum(str, Enum):
|
| 35 |
+
def __str__(self) -> str:
|
| 36 |
+
return self.value
|
| 37 |
+
|
| 38 |
+
def __repr__(self) -> str:
|
| 39 |
+
return f"'{str(self)}'"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ImageProjectType(StrEnum):
|
| 43 |
+
mlp = "mlp"
|
| 44 |
+
mlpx2 = "2mlp"
|
| 45 |
+
linear = "linear"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ImagePooling2DType(StrEnum):
|
| 49 |
+
attention = "attention"
|
| 50 |
+
attention_meanq = "attention-meanq"
|
| 51 |
+
attention_2wide = "attention_2wide"
|
| 52 |
+
attention_v2 = "attention-v2"
|
| 53 |
+
none = "none"
|
| 54 |
+
stack = "stack"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ActivationType(StrEnum):
|
| 58 |
+
quick_gelu = "quick_gelu"
|
| 59 |
+
gelu = "gelu"
|
| 60 |
+
gelu_tanh = "gelu_tanh"
|
| 61 |
+
relu = "relu"
|
| 62 |
+
silu = "silu"
|
| 63 |
+
llama_geglu = "llama_geglu"
|
| 64 |
+
llama_geglu_tanh = "llama_geglu_tanh"
|
| 65 |
+
llama_swiglu = "llama_swiglu"
|
| 66 |
+
swiglu = "swiglu"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
|
| 70 |
+
"""
|
| 71 |
+
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
|
| 72 |
+
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
|
| 73 |
+
"""
|
| 74 |
+
if check_neg_inf:
|
| 75 |
+
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
|
| 76 |
+
if check_pos_inf:
|
| 77 |
+
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class MolmoConfigurationError(Exception):
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _non_meta_init_device(config) -> torch.device:
|
| 85 |
+
if config.init_device is not None and config.init_device != "meta":
|
| 86 |
+
return torch.device(config.init_device)
|
| 87 |
+
else:
|
| 88 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class RotaryEmbedding(nn.Module):
|
| 92 |
+
"""
|
| 93 |
+
[Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(self, config: MolmoConfig, cache: BufferCache):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.config = config
|
| 99 |
+
self.__cache = cache
|
| 100 |
+
# Warm up cache.
|
| 101 |
+
self.get_rotary_embedding(
|
| 102 |
+
config.max_position_embeddings or config.max_sequence_length,
|
| 103 |
+
_non_meta_init_device(config)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 107 |
+
if (
|
| 108 |
+
(pos_sin := self.__cache.get("rope_pos_sin")) is not None
|
| 109 |
+
and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
|
| 110 |
+
and pos_sin.shape[-2] >= seq_len
|
| 111 |
+
and pos_cos.shape[-2] >= seq_len
|
| 112 |
+
):
|
| 113 |
+
if pos_sin.device != device:
|
| 114 |
+
pos_sin = pos_sin.to(device)
|
| 115 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
| 116 |
+
if pos_cos.device != device:
|
| 117 |
+
pos_cos = pos_cos.to(device)
|
| 118 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
| 119 |
+
return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
|
| 120 |
+
|
| 121 |
+
with torch.autocast(device.type, enabled=False):
|
| 122 |
+
dim = self.config.d_model // self.config.n_heads
|
| 123 |
+
inv_freq = 1.0 / (self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
|
| 124 |
+
seq = torch.arange(seq_len, device=device, dtype=torch.float)
|
| 125 |
+
freqs = torch.einsum("i , j -> i j", seq, inv_freq)
|
| 126 |
+
if self.config.rope_impl == "interleave":
|
| 127 |
+
positions = freqs.repeat_interleave(2, dim=-1)
|
| 128 |
+
else:
|
| 129 |
+
positions = torch.cat((freqs, freqs), dim=-1)
|
| 130 |
+
pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
|
| 131 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
| 132 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
| 133 |
+
return pos_sin, pos_cos
|
| 134 |
+
|
| 135 |
+
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
| 136 |
+
B, nh, T, hs = x.size()
|
| 137 |
+
x = x.view(B, nh, T, 2, hs // 2)
|
| 138 |
+
x1, x2 = x.unbind(dim=-2)
|
| 139 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 140 |
+
|
| 141 |
+
def rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
|
| 142 |
+
B, nh, T, hs = x.size()
|
| 143 |
+
x = x.view(B, nh, T, hs // 2, 2)
|
| 144 |
+
x1, x2 = x.unbind(dim=-1)
|
| 145 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 146 |
+
return x.view(B, nh, T, hs)
|
| 147 |
+
|
| 148 |
+
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 149 |
+
if self.config.rope_impl == "interleave":
|
| 150 |
+
return ((t * pos_cos) + (self.rotate_every_two(t) * pos_sin)).to(t.dtype)
|
| 151 |
+
else:
|
| 152 |
+
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
|
| 153 |
+
|
| 154 |
+
def forward(
|
| 155 |
+
self,
|
| 156 |
+
q: torch.Tensor,
|
| 157 |
+
k: torch.Tensor,
|
| 158 |
+
position_ids: Optional[torch.Tensor] = None
|
| 159 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 160 |
+
if self.config.rope_full_precision:
|
| 161 |
+
q_, k_ = q.float(), k.float()
|
| 162 |
+
else:
|
| 163 |
+
q_, k_ = q, k
|
| 164 |
+
|
| 165 |
+
with torch.autocast(q.device.type, enabled=False):
|
| 166 |
+
batch_size = q_.shape[0]
|
| 167 |
+
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
|
| 168 |
+
if position_ids is not None:
|
| 169 |
+
freqs_cis_len = (self.config.max_position_embeddings or self.config.max_sequence_length)
|
| 170 |
+
else:
|
| 171 |
+
freqs_cis_len = key_len
|
| 172 |
+
pos_sin, pos_cos = self.get_rotary_embedding(freqs_cis_len, q_.device)
|
| 173 |
+
pos_sin = pos_sin.type_as(q_)
|
| 174 |
+
pos_cos = pos_cos.type_as(q_)
|
| 175 |
+
if position_ids is not None:
|
| 176 |
+
assert query_len == key_len, "Query and key lengths must be equal when using position IDs."
|
| 177 |
+
pos_sin = pos_sin[0, 0][position_ids].view(
|
| 178 |
+
(batch_size, 1, key_len, pos_sin.shape[-1])
|
| 179 |
+
)
|
| 180 |
+
pos_cos = pos_cos[0, 0][position_ids].view(
|
| 181 |
+
(batch_size, 1, key_len, pos_cos.shape[-1])
|
| 182 |
+
)
|
| 183 |
+
q_ = self.apply_rotary_pos_emb(
|
| 184 |
+
pos_sin[:, :, key_len - query_len : key_len, :],
|
| 185 |
+
pos_cos[:, :, key_len - query_len : key_len, :],
|
| 186 |
+
q_,
|
| 187 |
+
)
|
| 188 |
+
k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
|
| 189 |
+
return q_.type_as(q), k_.type_as(k)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class MolmoBlock(nn.Module):
|
| 193 |
+
"""
|
| 194 |
+
A base class for transformer block implementations.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __init__(self, layer_id: int, config: MolmoConfig, cache: BufferCache):
|
| 198 |
+
super().__init__()
|
| 199 |
+
self.layer_id = layer_id
|
| 200 |
+
self.config = config
|
| 201 |
+
self.hidden_size = (
|
| 202 |
+
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
|
| 203 |
+
)
|
| 204 |
+
self.__cache = cache
|
| 205 |
+
self._activation_checkpoint_fn = None
|
| 206 |
+
|
| 207 |
+
# Dropout.
|
| 208 |
+
self.dropout = Dropout(config.residual_dropout)
|
| 209 |
+
|
| 210 |
+
# Layer norms.
|
| 211 |
+
self.k_norm: Optional[LayerNormBase] = None
|
| 212 |
+
self.q_norm: Optional[LayerNormBase] = None
|
| 213 |
+
if config.attention_layer_norm:
|
| 214 |
+
assert config.effective_n_kv_heads is not None
|
| 215 |
+
self.k_norm = LayerNormBase.build(
|
| 216 |
+
config,
|
| 217 |
+
size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
|
| 218 |
+
elementwise_affine=config.attention_layer_norm_with_affine,
|
| 219 |
+
)
|
| 220 |
+
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
|
| 221 |
+
|
| 222 |
+
# Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
|
| 223 |
+
if config.clip_qkv is not None:
|
| 224 |
+
assert config.clip_qkv > 0
|
| 225 |
+
|
| 226 |
+
# Activation function.
|
| 227 |
+
self.act = Activation.build(config)
|
| 228 |
+
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
|
| 229 |
+
|
| 230 |
+
# Attention output projection.
|
| 231 |
+
input_dim = config.d_model
|
| 232 |
+
self.attn_out = nn.Linear(
|
| 233 |
+
input_dim, config.d_model,
|
| 234 |
+
bias=config.include_bias,
|
| 235 |
+
device=config.init_device
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Feed-forward output projection.
|
| 239 |
+
self.ff_out = nn.Linear(
|
| 240 |
+
int(self.act.output_multiplier * self.hidden_size),
|
| 241 |
+
config.d_model,
|
| 242 |
+
bias=config.include_bias,
|
| 243 |
+
device=config.init_device,
|
| 244 |
+
)
|
| 245 |
+
self.ff_out._is_residual = True # type: ignore
|
| 246 |
+
|
| 247 |
+
# Rotary embeddings.
|
| 248 |
+
if self.config.rope:
|
| 249 |
+
self.rotary_emb = RotaryEmbedding(config, self.__cache)
|
| 250 |
+
|
| 251 |
+
self.flash_attn_func = None
|
| 252 |
+
if config.attention_type == "flash":
|
| 253 |
+
try:
|
| 254 |
+
from flash_attn import flash_attn_func # type: ignore
|
| 255 |
+
|
| 256 |
+
self.flash_attn_func = flash_attn_func
|
| 257 |
+
except ModuleNotFoundError:
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
+
def reset_parameters(self):
|
| 261 |
+
if self.k_norm is not None:
|
| 262 |
+
self.k_norm.reset_parameters()
|
| 263 |
+
if self.q_norm is not None:
|
| 264 |
+
self.q_norm.reset_parameters()
|
| 265 |
+
init_weights(
|
| 266 |
+
self.config,
|
| 267 |
+
self.attn_out,
|
| 268 |
+
d=self.config.d_model,
|
| 269 |
+
layer_id=self.layer_id,
|
| 270 |
+
type_of_module=ModuleType.out_module,
|
| 271 |
+
)
|
| 272 |
+
init_weights(
|
| 273 |
+
self.config,
|
| 274 |
+
self.ff_out,
|
| 275 |
+
d=self.ff_out.in_features,
|
| 276 |
+
layer_id=self.layer_id,
|
| 277 |
+
type_of_module=ModuleType.out_module,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
@classmethod
|
| 281 |
+
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
|
| 282 |
+
target_dtype = input_dtype
|
| 283 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
| 284 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
| 285 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
| 286 |
+
if bias.device.type == "cuda" and torch.is_autocast_enabled():
|
| 287 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 288 |
+
elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
| 289 |
+
target_dtype = torch.get_autocast_cpu_dtype()
|
| 290 |
+
if bias.dtype != target_dtype:
|
| 291 |
+
bias = bias.to(target_dtype)
|
| 292 |
+
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
|
| 293 |
+
return bias
|
| 294 |
+
|
| 295 |
+
def _scaled_dot_product_attention(
|
| 296 |
+
self,
|
| 297 |
+
q: torch.Tensor,
|
| 298 |
+
k: torch.Tensor,
|
| 299 |
+
v: torch.Tensor,
|
| 300 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 301 |
+
dropout_p: float = 0.0,
|
| 302 |
+
response_dropout_p: float = 0.0,
|
| 303 |
+
is_causal: bool = False,
|
| 304 |
+
) -> torch.Tensor:
|
| 305 |
+
"""
|
| 306 |
+
Computes scaled dot product attention on query, key and value tensors, using an optional
|
| 307 |
+
attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
|
| 308 |
+
"""
|
| 309 |
+
if attn_mask is not None:
|
| 310 |
+
attn_mask = attn_mask.to(q.device)
|
| 311 |
+
|
| 312 |
+
if self.flash_attn_func is not None and attn_mask is None:
|
| 313 |
+
r = self.flash_attn_func(
|
| 314 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal
|
| 315 |
+
)
|
| 316 |
+
return r.transpose(1, 2)
|
| 317 |
+
else:
|
| 318 |
+
# torch's sdpa doesn't support GQA, so we're doing this
|
| 319 |
+
assert k.size(1) == v.size(1)
|
| 320 |
+
num_kv_heads = k.size(1)
|
| 321 |
+
num_q_heads = q.size(1)
|
| 322 |
+
if num_q_heads != num_kv_heads:
|
| 323 |
+
assert num_q_heads % num_kv_heads == 0
|
| 324 |
+
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 325 |
+
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 326 |
+
|
| 327 |
+
return F.scaled_dot_product_attention(
|
| 328 |
+
q,
|
| 329 |
+
k,
|
| 330 |
+
v,
|
| 331 |
+
attn_mask=attn_mask,
|
| 332 |
+
dropout_p=dropout_p,
|
| 333 |
+
is_causal=is_causal,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
def attention(
|
| 337 |
+
self,
|
| 338 |
+
q: torch.Tensor,
|
| 339 |
+
k: torch.Tensor,
|
| 340 |
+
v: torch.Tensor,
|
| 341 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 342 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 343 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 344 |
+
use_cache: bool = False,
|
| 345 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 346 |
+
B, T, C = q.size() # batch size, sequence length, d_model
|
| 347 |
+
dtype = k.dtype
|
| 348 |
+
|
| 349 |
+
# Optionally apply layer norm to keys and queries.
|
| 350 |
+
if self.q_norm is not None and self.k_norm is not None:
|
| 351 |
+
q = self.q_norm(q).to(dtype=dtype)
|
| 352 |
+
k = self.k_norm(k).to(dtype=dtype)
|
| 353 |
+
|
| 354 |
+
# Move head forward to be next to the batch dim.
|
| 355 |
+
# shape: (B, nh, T, hs)
|
| 356 |
+
q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
|
| 357 |
+
# shape: (B, n_kv_h, T, hs)
|
| 358 |
+
k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
| 359 |
+
# shape: (B, n_kv_h, T, hs)
|
| 360 |
+
v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
| 361 |
+
|
| 362 |
+
if self.config.use_position_ids and self.config.rope:
|
| 363 |
+
# Apply rotary embeddings
|
| 364 |
+
q, k = self.rotary_emb(q, k, position_ids=position_ids)
|
| 365 |
+
|
| 366 |
+
if layer_past is not None:
|
| 367 |
+
past_key, past_value = layer_past
|
| 368 |
+
k = torch.cat((past_key.to(k.device), k), dim=-2)
|
| 369 |
+
v = torch.cat((past_value.to(v.device), v), dim=-2)
|
| 370 |
+
|
| 371 |
+
present = (k, v) if use_cache else None
|
| 372 |
+
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
|
| 373 |
+
|
| 374 |
+
if not self.config.use_position_ids and self.config.rope:
|
| 375 |
+
# Apply rotary embeddings
|
| 376 |
+
q, k = self.rotary_emb(q, k)
|
| 377 |
+
|
| 378 |
+
if attention_bias is not None:
|
| 379 |
+
# Resize and cast attention bias.
|
| 380 |
+
# The current dtype of the attention bias might not match the dtype that the SDP attn function will
|
| 381 |
+
# run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
|
| 382 |
+
# as down-casting the attention bias to the autocast precision will result in -infs, which will
|
| 383 |
+
# cause the SDP attn function to produce NaNs.
|
| 384 |
+
attention_bias = self._cast_attn_bias(
|
| 385 |
+
attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# Get the attention scores.
|
| 389 |
+
# shape: (B, nh, T, hs)
|
| 390 |
+
att = self._scaled_dot_product_attention(
|
| 391 |
+
q,
|
| 392 |
+
k,
|
| 393 |
+
v,
|
| 394 |
+
attn_mask=attention_bias,
|
| 395 |
+
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
| 396 |
+
response_dropout_p=0.0 if not self.training else self.config.response_attention_dropout,
|
| 397 |
+
is_causal=attention_bias is None,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# Re-assemble all head outputs side-by-side.
|
| 401 |
+
att = att.transpose(1, 2).contiguous().view(B, T, C)
|
| 402 |
+
|
| 403 |
+
# Apply output projection.
|
| 404 |
+
return self.attn_out(att), present
|
| 405 |
+
|
| 406 |
+
def forward(
|
| 407 |
+
self,
|
| 408 |
+
x: torch.Tensor,
|
| 409 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
| 410 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 411 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 412 |
+
use_cache: bool = False,
|
| 413 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 414 |
+
raise NotImplementedError
|
| 415 |
+
|
| 416 |
+
@classmethod
|
| 417 |
+
def build(cls, layer_id: int, config: MolmoConfig, cache: BufferCache):
|
| 418 |
+
return MolmoSequentialBlock(layer_id, config, cache)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class MolmoSequentialBlock(MolmoBlock):
|
| 422 |
+
"""
|
| 423 |
+
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
| 424 |
+
(plus another skip connection).
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
def __init__(self, layer_id: int, config: MolmoConfig, cache: BufferCache):
|
| 428 |
+
super().__init__(layer_id, config, cache)
|
| 429 |
+
# Layer norms.
|
| 430 |
+
self.attn_norm = LayerNorm.build(config)
|
| 431 |
+
self.ff_norm = LayerNorm.build(config)
|
| 432 |
+
# Attention input projection. Projects x -> (q, k, v)
|
| 433 |
+
|
| 434 |
+
head_dim = config.d_model // config.n_heads
|
| 435 |
+
self.fused_dims = (
|
| 436 |
+
config.d_model,
|
| 437 |
+
config.effective_n_kv_heads * head_dim,
|
| 438 |
+
config.effective_n_kv_heads * head_dim,
|
| 439 |
+
)
|
| 440 |
+
self.att_proj = nn.Linear(
|
| 441 |
+
config.d_model, sum(self.fused_dims),
|
| 442 |
+
bias=config.include_bias or config.qkv_bias,
|
| 443 |
+
device=config.init_device
|
| 444 |
+
)
|
| 445 |
+
# Feed-forward input projection.
|
| 446 |
+
self.ff_proj = nn.Linear(
|
| 447 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
def reset_parameters(self):
|
| 451 |
+
super().reset_parameters()
|
| 452 |
+
self.attn_norm.reset_parameters()
|
| 453 |
+
self.ff_norm.reset_parameters()
|
| 454 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
| 455 |
+
init_weights(
|
| 456 |
+
self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
|
| 457 |
+
)
|
| 458 |
+
init_weights(
|
| 459 |
+
self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
def forward(
|
| 463 |
+
self,
|
| 464 |
+
x: torch.Tensor,
|
| 465 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 466 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 467 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 468 |
+
use_cache: bool = False,
|
| 469 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 470 |
+
# Get query, key, value projections.
|
| 471 |
+
# shape:
|
| 472 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
| 473 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
| 474 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
| 475 |
+
# - for group query attn q: (batch_size, seq_len, d_model)
|
| 476 |
+
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
|
| 477 |
+
|
| 478 |
+
if not self.config.norm_after:
|
| 479 |
+
if self._activation_checkpoint_fn is not None:
|
| 480 |
+
atten_in = self._activation_checkpoint_fn(self.attn_norm, x)
|
| 481 |
+
else:
|
| 482 |
+
atten_in = self.attn_norm(x)
|
| 483 |
+
else:
|
| 484 |
+
atten_in = x
|
| 485 |
+
qkv = self.att_proj(atten_in)
|
| 486 |
+
|
| 487 |
+
if self.config.clip_qkv is not None:
|
| 488 |
+
qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 489 |
+
|
| 490 |
+
q, k, v = qkv.split(self.fused_dims, dim=-1)
|
| 491 |
+
|
| 492 |
+
# Get attention scores.
|
| 493 |
+
if self._activation_checkpoint_fn is not None:
|
| 494 |
+
att, cache = self._activation_checkpoint_fn( # type: ignore
|
| 495 |
+
self.attention, q, k, v, attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
|
| 496 |
+
)
|
| 497 |
+
else:
|
| 498 |
+
att, cache = self.attention(q, k, v, attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
|
| 499 |
+
|
| 500 |
+
if self.config.norm_after:
|
| 501 |
+
if self._activation_checkpoint_fn is not None:
|
| 502 |
+
att = self._activation_checkpoint_fn(self.attn_norm, att)
|
| 503 |
+
else:
|
| 504 |
+
att = self.attn_norm(att)
|
| 505 |
+
|
| 506 |
+
# Add attention scores.
|
| 507 |
+
# shape: (B, T, C)
|
| 508 |
+
x = x + self.dropout(att)
|
| 509 |
+
|
| 510 |
+
# Add feed-forward projection.
|
| 511 |
+
# shape: (batch_size, seq_len, d_model)
|
| 512 |
+
og_x = x
|
| 513 |
+
|
| 514 |
+
if not self.config.norm_after:
|
| 515 |
+
if self._activation_checkpoint_fn is not None:
|
| 516 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
| 517 |
+
else:
|
| 518 |
+
x = self.ff_norm(x)
|
| 519 |
+
|
| 520 |
+
x = self.ff_proj(x)
|
| 521 |
+
if self._activation_checkpoint_fn is not None:
|
| 522 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
| 523 |
+
else:
|
| 524 |
+
x = self.act(x)
|
| 525 |
+
x = self.ff_out(x)
|
| 526 |
+
|
| 527 |
+
if self.config.norm_after:
|
| 528 |
+
if self._activation_checkpoint_fn is not None:
|
| 529 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
| 530 |
+
else:
|
| 531 |
+
x = self.ff_norm(x)
|
| 532 |
+
|
| 533 |
+
x = self.dropout(x)
|
| 534 |
+
x = og_x + x
|
| 535 |
+
|
| 536 |
+
return x, cache
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
class Embedding(nn.Module):
|
| 540 |
+
def __init__(
|
| 541 |
+
self,
|
| 542 |
+
num_embeddings: int,
|
| 543 |
+
num_new_embeddings: int,
|
| 544 |
+
features: int,
|
| 545 |
+
device: Union[str, torch.device],
|
| 546 |
+
initializer_range: float = 0.02,
|
| 547 |
+
new_embed_initializer_range: float = 0.02,
|
| 548 |
+
):
|
| 549 |
+
super().__init__()
|
| 550 |
+
self.initializer_range = initializer_range
|
| 551 |
+
self.new_embed_initializer_range = new_embed_initializer_range
|
| 552 |
+
self.embedding = nn.Parameter(
|
| 553 |
+
torch.zeros(num_embeddings, features, device=device),
|
| 554 |
+
)
|
| 555 |
+
self.new_embedding = nn.Parameter(
|
| 556 |
+
torch.zeros(num_new_embeddings, features, device=device),
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
def reset_parameters(self):
|
| 560 |
+
nn.init.normal_(self.embedding, std=self.initializer_range)
|
| 561 |
+
nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range)
|
| 562 |
+
|
| 563 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 564 |
+
return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
class Dropout(nn.Dropout):
|
| 568 |
+
def __init__(
|
| 569 |
+
self,
|
| 570 |
+
p: float = 0.5,
|
| 571 |
+
inplace: bool = False,
|
| 572 |
+
mask_p: float = 0,
|
| 573 |
+
broadcast_dims: Sequence[int] = (),
|
| 574 |
+
):
|
| 575 |
+
super().__init__(p, inplace)
|
| 576 |
+
self.mask_p = mask_p
|
| 577 |
+
self.broadcast_dims = broadcast_dims
|
| 578 |
+
|
| 579 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 580 |
+
"""
|
| 581 |
+
:param input: A tensor of shape `(batch_size, seq_len, embed_dim)`
|
| 582 |
+
"""
|
| 583 |
+
if self.p == 0.0 and (self.mask_p is None or self.mask_p == 0.0):
|
| 584 |
+
return input
|
| 585 |
+
else:
|
| 586 |
+
if self.p > 0. and len(self.broadcast_dims) > 0 and self.training:
|
| 587 |
+
keep_prob = 1.0 - self.p
|
| 588 |
+
dropout_shape = list(input.shape)
|
| 589 |
+
for dim in self.broadcast_dims:
|
| 590 |
+
dropout_shape[dim] = 1
|
| 591 |
+
keep = input.new_empty(dropout_shape).bernoulli_(keep_prob)
|
| 592 |
+
multiplier = keep.broadcast_to(input.shape)
|
| 593 |
+
multiplier.div_(keep_prob)
|
| 594 |
+
input = input * multiplier
|
| 595 |
+
else:
|
| 596 |
+
return F.dropout(input, self.p, self.training, self.inplace)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
@dataclass
|
| 600 |
+
class VisionBackboneConfig:
|
| 601 |
+
image_default_input_size: Tuple[int, int] = (336, 336)
|
| 602 |
+
image_patch_size: int = 14
|
| 603 |
+
image_pos_patch_size: int = 14
|
| 604 |
+
image_emb_dim: int = 1024
|
| 605 |
+
image_num_heads: int = 16
|
| 606 |
+
image_num_key_value_heads: int = 16
|
| 607 |
+
image_num_layers: int = 24
|
| 608 |
+
image_head_dim: int = 64
|
| 609 |
+
image_mlp_dim: int = 4096
|
| 610 |
+
image_mlp_activations: str = "gelu"
|
| 611 |
+
image_dropout_rate: float = 0.0
|
| 612 |
+
image_num_pos: int = 577
|
| 613 |
+
image_norm_eps: float = 1e-5
|
| 614 |
+
attention_dropout: float = 0.0
|
| 615 |
+
residual_dropout: float = 0.0
|
| 616 |
+
initializer_range: float = 0.02
|
| 617 |
+
fsdp_wrap: bool = False
|
| 618 |
+
resize_mode: str = "default"
|
| 619 |
+
|
| 620 |
+
def __post_init__(self):
|
| 621 |
+
self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment]
|
| 622 |
+
|
| 623 |
+
@property
|
| 624 |
+
def image_num_patch(self):
|
| 625 |
+
h, w = self.image_default_input_size
|
| 626 |
+
return h // self.image_patch_size, w // self.image_patch_size
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
@dataclass
|
| 630 |
+
class FullMolmoConfig:
|
| 631 |
+
d_model: int = 768
|
| 632 |
+
n_heads: int = 12
|
| 633 |
+
n_kv_heads: Optional[int] = None
|
| 634 |
+
qkv_bias: bool = False
|
| 635 |
+
clip_qkv: Optional[float] = None
|
| 636 |
+
n_layers: int = 12
|
| 637 |
+
mlp_ratio: int = 4
|
| 638 |
+
mlp_hidden_size: Optional[int] = None
|
| 639 |
+
activation_type: str = "swiglu"
|
| 640 |
+
block_group_size: int = 1
|
| 641 |
+
rope: bool = True
|
| 642 |
+
rope_full_precision: bool = True
|
| 643 |
+
rope_theta: float = 10000.
|
| 644 |
+
rope_impl: str = "interleave"
|
| 645 |
+
vision_backbone: Optional[VisionBackboneConfig] = None
|
| 646 |
+
attention_type: str = "sdpa"
|
| 647 |
+
float32_attention: bool = True
|
| 648 |
+
attention_dropout: float = 0.1
|
| 649 |
+
response_attention_dropout: float = 0.0
|
| 650 |
+
multi_query_attention: Optional[bool] = None
|
| 651 |
+
attention_layer_norm: bool = False
|
| 652 |
+
residual_dropout: float = 0.1
|
| 653 |
+
embedding_dropout: float = 0.1
|
| 654 |
+
layer_norm_type: str = "default"
|
| 655 |
+
layer_norm_with_affine: bool = True
|
| 656 |
+
layer_norm_eps: Optional[float] = None
|
| 657 |
+
attention_layer_norm_with_affine: bool = True
|
| 658 |
+
max_sequence_length: int = 1024
|
| 659 |
+
max_position_embeddings: Optional[int] = None
|
| 660 |
+
include_bias: bool = True
|
| 661 |
+
bias_for_layer_norm: Optional[bool] = None
|
| 662 |
+
scale_logits: bool = False
|
| 663 |
+
vocab_size: int = 50257
|
| 664 |
+
embedding_size: Optional[int] = 50304
|
| 665 |
+
additional_vocab_size: Optional[int] = None
|
| 666 |
+
new_embedding_init_range: float = 0.02
|
| 667 |
+
weight_tying: bool = True
|
| 668 |
+
pad_token_id: int = -1
|
| 669 |
+
init_device: Optional[str] = None
|
| 670 |
+
init_std: float = 0.02
|
| 671 |
+
init_cutoff_factor: Optional[float] = None
|
| 672 |
+
norm_after: bool = False
|
| 673 |
+
precision: Optional[str] = None
|
| 674 |
+
image_padding_embed: Optional[str] = None
|
| 675 |
+
vit_layers: Tuple = (-1,)
|
| 676 |
+
image_pooling_h: int = 2
|
| 677 |
+
image_pooling_w: int = 2
|
| 678 |
+
image_pooling_2d: str = "attention"
|
| 679 |
+
image_projector: str = "mlp"
|
| 680 |
+
image_feature_dropout: float = 0.0
|
| 681 |
+
initializer_range: float = 0.02
|
| 682 |
+
normalize_input_embeds: bool = False
|
| 683 |
+
use_position_ids: bool = True
|
| 684 |
+
|
| 685 |
+
@property
|
| 686 |
+
def effective_n_kv_heads(self) -> int:
|
| 687 |
+
if self.n_kv_heads is None:
|
| 688 |
+
if self.multi_query_attention is True:
|
| 689 |
+
return 1
|
| 690 |
+
else:
|
| 691 |
+
return self.n_heads
|
| 692 |
+
else:
|
| 693 |
+
if self.multi_query_attention is None:
|
| 694 |
+
return self.n_kv_heads
|
| 695 |
+
if self.multi_query_attention:
|
| 696 |
+
n_kv_heads_should_be = 1
|
| 697 |
+
else:
|
| 698 |
+
n_kv_heads_should_be = self.n_heads
|
| 699 |
+
if self.n_kv_heads == n_kv_heads_should_be:
|
| 700 |
+
return n_kv_heads_should_be
|
| 701 |
+
else:
|
| 702 |
+
raise MolmoConfigurationError(
|
| 703 |
+
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
@property
|
| 707 |
+
def image_num_patch(self):
|
| 708 |
+
assert self.vision_backbone is not None
|
| 709 |
+
return self.vision_backbone.image_num_patch
|
| 710 |
+
|
| 711 |
+
@property
|
| 712 |
+
def image_patch_size(self):
|
| 713 |
+
assert self.vision_backbone is not None
|
| 714 |
+
return self.visoin_backbone.image_patch_size
|
| 715 |
+
|
| 716 |
+
def llm_patches_per_crop(self):
|
| 717 |
+
h, w = self.image_num_patch
|
| 718 |
+
# Round up in case we need to pad the image features for pooling
|
| 719 |
+
h = (h + self.image_pooling_h - 1) // self.image_pooling_h
|
| 720 |
+
w = (w + self.image_pooling_w - 1) // self.image_pooling_w
|
| 721 |
+
return h, w
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def _expand_token(token, batch_size: int):
|
| 725 |
+
return token.view(1, 1, -1).expand(batch_size, -1, -1)
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
class ViTMLP(nn.Module):
|
| 729 |
+
def __init__(self, config: FullMolmoConfig):
|
| 730 |
+
super().__init__()
|
| 731 |
+
self.config = config
|
| 732 |
+
v_cfg = config.vision_backbone
|
| 733 |
+
|
| 734 |
+
self.w1 = nn.Linear(
|
| 735 |
+
v_cfg.image_emb_dim,
|
| 736 |
+
v_cfg.image_mlp_dim,
|
| 737 |
+
bias=True,
|
| 738 |
+
device=config.init_device,
|
| 739 |
+
)
|
| 740 |
+
# Activation function.
|
| 741 |
+
cfg = deepcopy(config)
|
| 742 |
+
cfg.activation_type = v_cfg.image_mlp_activations
|
| 743 |
+
self.act = Activation.build(cfg)
|
| 744 |
+
self.w2 = nn.Linear(
|
| 745 |
+
v_cfg.image_mlp_dim,
|
| 746 |
+
v_cfg.image_emb_dim,
|
| 747 |
+
bias=True,
|
| 748 |
+
device=config.init_device,
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
def reset_parameters(self):
|
| 752 |
+
v_cfg = self.config.vision_backbone
|
| 753 |
+
nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0)
|
| 754 |
+
nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0)
|
| 755 |
+
nn.init.zeros_(self.w1.bias)
|
| 756 |
+
nn.init.zeros_(self.w2.bias)
|
| 757 |
+
|
| 758 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 759 |
+
x = self.w1(x)
|
| 760 |
+
x = self.act(x)
|
| 761 |
+
x = self.w2(x)
|
| 762 |
+
return x
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
class ResidualAttentionBlock(nn.Module):
|
| 766 |
+
|
| 767 |
+
def __init__(self, config: FullMolmoConfig):
|
| 768 |
+
super().__init__()
|
| 769 |
+
self.config = config
|
| 770 |
+
|
| 771 |
+
v_cfg = config.vision_backbone
|
| 772 |
+
self.attention = MultiHeadDotProductAttention(config)
|
| 773 |
+
self.feed_forward = ViTMLP(config)
|
| 774 |
+
self.attention_norm = nn.LayerNorm(
|
| 775 |
+
v_cfg.image_emb_dim,
|
| 776 |
+
eps=v_cfg.image_norm_eps,
|
| 777 |
+
device=config.init_device,
|
| 778 |
+
)
|
| 779 |
+
self.ffn_norm = nn.LayerNorm(
|
| 780 |
+
v_cfg.image_emb_dim,
|
| 781 |
+
eps=v_cfg.image_norm_eps,
|
| 782 |
+
device=config.init_device,
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
def reset_parameters(self):
|
| 786 |
+
self.attention.reset_parameters()
|
| 787 |
+
self.feed_forward.reset_parameters()
|
| 788 |
+
self.attention_norm.reset_parameters()
|
| 789 |
+
self.ffn_norm.reset_parameters()
|
| 790 |
+
|
| 791 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 792 |
+
x = x + self.attention(self.attention_norm(x))
|
| 793 |
+
x = x + self.feed_forward(self.ffn_norm(x))
|
| 794 |
+
return x
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
class BlockCollection(nn.Module):
|
| 798 |
+
|
| 799 |
+
def __init__(self, config: FullMolmoConfig):
|
| 800 |
+
super().__init__()
|
| 801 |
+
self.config = config
|
| 802 |
+
self.grad_checkpointing: bool = False
|
| 803 |
+
|
| 804 |
+
v_cfg = config.vision_backbone
|
| 805 |
+
self.resblocks = nn.ModuleList([
|
| 806 |
+
ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)
|
| 807 |
+
])
|
| 808 |
+
|
| 809 |
+
def reset_parameters(self):
|
| 810 |
+
for r in self.resblocks:
|
| 811 |
+
r.reset_parameters()
|
| 812 |
+
|
| 813 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 814 |
+
hidden_states = []
|
| 815 |
+
for r in self.resblocks:
|
| 816 |
+
x = r(x)
|
| 817 |
+
hidden_states.append(x)
|
| 818 |
+
return hidden_states
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
class LayerNormFp32(nn.LayerNorm):
|
| 822 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 823 |
+
orig_type = x.dtype
|
| 824 |
+
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight.to(torch.float32),
|
| 825 |
+
self.bias.to(torch.float32), self.eps)
|
| 826 |
+
return x.to(orig_type)
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
class VisionTransformer(nn.Module):
|
| 830 |
+
|
| 831 |
+
def __init__(self, config: FullMolmoConfig):
|
| 832 |
+
super().__init__()
|
| 833 |
+
self.config = config
|
| 834 |
+
|
| 835 |
+
v_cfg = config.vision_backbone
|
| 836 |
+
# class embeddings and positional embeddings
|
| 837 |
+
self.scale = v_cfg.image_emb_dim ** -0.5
|
| 838 |
+
self.class_embedding = nn.Parameter(
|
| 839 |
+
torch.zeros(v_cfg.image_emb_dim, device=config.init_device),
|
| 840 |
+
)
|
| 841 |
+
self.num_prefix_tokens: int = 1
|
| 842 |
+
self.positional_embedding = nn.Parameter(
|
| 843 |
+
torch.zeros(v_cfg.image_num_pos, v_cfg.image_emb_dim, device=config.init_device),
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
image_patch_size = v_cfg.image_patch_size
|
| 847 |
+
self.patch_embedding = nn.Linear(
|
| 848 |
+
image_patch_size * image_patch_size * 3,
|
| 849 |
+
v_cfg.image_emb_dim,
|
| 850 |
+
bias=False,
|
| 851 |
+
device=config.init_device,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
self.pre_ln = LayerNormFp32(
|
| 855 |
+
v_cfg.image_emb_dim,
|
| 856 |
+
eps=v_cfg.image_norm_eps,
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
self.transformer = BlockCollection(config)
|
| 860 |
+
|
| 861 |
+
@torch.jit.ignore
|
| 862 |
+
def set_grad_checkpointing(self, enable=True):
|
| 863 |
+
self.transformer.grad_checkpointing = enable
|
| 864 |
+
|
| 865 |
+
def reset_parameters(self):
|
| 866 |
+
nn.init.normal_(self.class_embedding, std=self.scale)
|
| 867 |
+
nn.init.normal_(self.positional_embedding, std=self.scale)
|
| 868 |
+
nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
| 869 |
+
self.pre_ln.reset_parameters()
|
| 870 |
+
self.transformer.reset_parameters()
|
| 871 |
+
|
| 872 |
+
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
|
| 873 |
+
cls_emb = self.positional_embedding[0:1]
|
| 874 |
+
pos_emb = self.positional_embedding[1:]
|
| 875 |
+
|
| 876 |
+
pos_emb = pos_emb.reshape(
|
| 877 |
+
(int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
(patch_num_0, patch_num_1) = patch_num
|
| 881 |
+
|
| 882 |
+
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
|
| 883 |
+
# Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 884 |
+
# antialias: default True in jax.image.resize
|
| 885 |
+
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
|
| 886 |
+
pos_emb = F.interpolate(
|
| 887 |
+
pos_emb, size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True,
|
| 888 |
+
)
|
| 889 |
+
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
|
| 890 |
+
|
| 891 |
+
pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
|
| 892 |
+
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
|
| 893 |
+
return x
|
| 894 |
+
|
| 895 |
+
def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]:
|
| 896 |
+
"""
|
| 897 |
+
: param x: (batch_size, num_patch, n_pixels)
|
| 898 |
+
"""
|
| 899 |
+
if patch_num is None:
|
| 900 |
+
patch_num = self.config.vision_backbone.image_num_patch
|
| 901 |
+
B, N, D = x.shape
|
| 902 |
+
|
| 903 |
+
x = self.patch_embedding(x)
|
| 904 |
+
|
| 905 |
+
# class embeddings and positional embeddings
|
| 906 |
+
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
|
| 907 |
+
x = self.add_pos_emb(x, patch_num)
|
| 908 |
+
|
| 909 |
+
x = self.pre_ln(x)
|
| 910 |
+
|
| 911 |
+
hidden_states = self.transformer(x)
|
| 912 |
+
return hidden_states
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
class MultiHeadDotProductAttention(nn.Module):
|
| 916 |
+
def __init__(self, config: FullMolmoConfig, use_bias: bool = True, is_vit_layer: Optional[bool] = True):
|
| 917 |
+
super().__init__()
|
| 918 |
+
self.config = config
|
| 919 |
+
self.use_bias = use_bias
|
| 920 |
+
|
| 921 |
+
v_cfg = config.vision_backbone
|
| 922 |
+
self.embed_dim = v_cfg.image_emb_dim
|
| 923 |
+
self.num_heads = v_cfg.image_num_heads
|
| 924 |
+
self.head_dim = v_cfg.image_head_dim
|
| 925 |
+
self.num_key_value_heads = v_cfg.image_num_key_value_heads
|
| 926 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 927 |
+
self.initializer_range = v_cfg.initializer_range
|
| 928 |
+
self.is_vit_layer = is_vit_layer
|
| 929 |
+
|
| 930 |
+
nlayers = 1 if (is_vit_layer or config.vit_layers is None) else len(config.vit_layers)
|
| 931 |
+
|
| 932 |
+
self.wq = nn.Linear(
|
| 933 |
+
nlayers * self.embed_dim,
|
| 934 |
+
self.num_heads * self.head_dim,
|
| 935 |
+
bias=use_bias,
|
| 936 |
+
device=config.init_device,
|
| 937 |
+
)
|
| 938 |
+
self.wk = nn.Linear(
|
| 939 |
+
nlayers * self.embed_dim,
|
| 940 |
+
self.num_key_value_heads * self.head_dim,
|
| 941 |
+
bias=use_bias,
|
| 942 |
+
device=config.init_device,
|
| 943 |
+
)
|
| 944 |
+
self.wv = nn.Linear(
|
| 945 |
+
nlayers * self.embed_dim,
|
| 946 |
+
self.num_key_value_heads * self.head_dim,
|
| 947 |
+
bias=use_bias,
|
| 948 |
+
device=config.init_device,
|
| 949 |
+
)
|
| 950 |
+
self.wo = nn.Linear(
|
| 951 |
+
self.num_heads * self.head_dim,
|
| 952 |
+
self.embed_dim,
|
| 953 |
+
bias=use_bias,
|
| 954 |
+
device=config.init_device,
|
| 955 |
+
)
|
| 956 |
+
self.attention_dropout: Optional[Dropout] = None
|
| 957 |
+
if v_cfg.attention_dropout > 0:
|
| 958 |
+
self.attention_dropout = Dropout(v_cfg.attention_dropout, broadcast_dims=(0, 1))
|
| 959 |
+
self.residual_dropout = Dropout(v_cfg.residual_dropout)
|
| 960 |
+
|
| 961 |
+
def reset_parameters(self):
|
| 962 |
+
nn.init.normal_(self.wq.weight, std=self.initializer_range)
|
| 963 |
+
nn.init.normal_(self.wk.weight, std=self.initializer_range)
|
| 964 |
+
nn.init.normal_(self.wv.weight, std=self.initializer_range)
|
| 965 |
+
nn.init.normal_(self.wo.weight, std=self.initializer_range)
|
| 966 |
+
if self.use_bias:
|
| 967 |
+
nn.init.constant_(self.wq.bias, 0)
|
| 968 |
+
nn.init.constant_(self.wk.bias, 0)
|
| 969 |
+
nn.init.constant_(self.wv.bias, 0)
|
| 970 |
+
nn.init.constant_(self.wo.bias, 0)
|
| 971 |
+
|
| 972 |
+
def _split_heads(self, hidden_states, num_heads) -> torch.Tensor:
|
| 973 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
|
| 974 |
+
|
| 975 |
+
def _merge_heads(self, hidden_states) -> torch.Tensor:
|
| 976 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
| 977 |
+
|
| 978 |
+
def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 979 |
+
|
| 980 |
+
if inputs_kv is not None:
|
| 981 |
+
inputs_k = inputs_kv
|
| 982 |
+
inputs_v = inputs_kv
|
| 983 |
+
else:
|
| 984 |
+
inputs_k = inputs_q
|
| 985 |
+
inputs_v = inputs_q
|
| 986 |
+
|
| 987 |
+
xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v)
|
| 988 |
+
|
| 989 |
+
xq = self._split_heads(xq, self.num_heads)
|
| 990 |
+
xk = self._split_heads(xk, self.num_key_value_heads)
|
| 991 |
+
xv = self._split_heads(xv, self.num_key_value_heads)
|
| 992 |
+
|
| 993 |
+
if self.num_heads != self.num_key_value_heads:
|
| 994 |
+
xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
|
| 995 |
+
xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
|
| 996 |
+
|
| 997 |
+
og_dtype = xq.dtype
|
| 998 |
+
|
| 999 |
+
if self.config.float32_attention:
|
| 1000 |
+
xq = xq.to(torch.float)
|
| 1001 |
+
xk = xk.to(torch.float)
|
| 1002 |
+
|
| 1003 |
+
if self.config.attention_type == "direct":
|
| 1004 |
+
attn_weights = torch.einsum("...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk)
|
| 1005 |
+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype)
|
| 1006 |
+
if self.attention_dropout is not None:
|
| 1007 |
+
attn_weights = self.attention_dropout(attn_weights)
|
| 1008 |
+
attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv)
|
| 1009 |
+
|
| 1010 |
+
elif self.config.attention_type == "sdpa":
|
| 1011 |
+
if self.config.float32_attention and not torch.is_autocast_enabled():
|
| 1012 |
+
xv = xv.to(torch.float32)
|
| 1013 |
+
attn_output = F.scaled_dot_product_attention(
|
| 1014 |
+
xq.transpose(1, 2).contiguous(),
|
| 1015 |
+
xk.transpose(1, 2).contiguous(),
|
| 1016 |
+
xv.transpose(1, 2).contiguous(),
|
| 1017 |
+
is_causal=False,
|
| 1018 |
+
dropout_p=self.config.vision_backbone.attention_dropout
|
| 1019 |
+
).transpose(1, 2)
|
| 1020 |
+
else:
|
| 1021 |
+
raise NotImplementedError(self.config.attention_type)
|
| 1022 |
+
attn_output = attn_output.to(og_dtype)
|
| 1023 |
+
attn_output = self._merge_heads(attn_output)
|
| 1024 |
+
attn_output = self.wo(attn_output)
|
| 1025 |
+
attn_output = self.residual_dropout(attn_output)
|
| 1026 |
+
|
| 1027 |
+
return attn_output
|
| 1028 |
+
|
| 1029 |
+
|
| 1030 |
+
class MultiHeadAttentionPool(nn.Module):
|
| 1031 |
+
def __init__(
|
| 1032 |
+
self,
|
| 1033 |
+
config: FullMolmoConfig,
|
| 1034 |
+
factor: int = 1,
|
| 1035 |
+
use_bias: bool = True,
|
| 1036 |
+
dropout: bool = True,
|
| 1037 |
+
output_layer: bool = True,
|
| 1038 |
+
mean_residual: bool = False,
|
| 1039 |
+
query: str = "mean",
|
| 1040 |
+
is_vit_layer: Optional[bool] = True
|
| 1041 |
+
):
|
| 1042 |
+
super().__init__()
|
| 1043 |
+
self.config = config
|
| 1044 |
+
self.factor = factor
|
| 1045 |
+
self.use_bias = use_bias
|
| 1046 |
+
self.dropout = dropout
|
| 1047 |
+
self.output_layer = output_layer
|
| 1048 |
+
self.mean_residual = mean_residual
|
| 1049 |
+
self.query = query
|
| 1050 |
+
|
| 1051 |
+
v_cfg = config.vision_backbone
|
| 1052 |
+
input_dim = v_cfg.image_emb_dim
|
| 1053 |
+
self.embed_dim = v_cfg.image_emb_dim * factor
|
| 1054 |
+
self.num_heads = v_cfg.image_num_heads
|
| 1055 |
+
self.head_dim = v_cfg.image_head_dim * factor
|
| 1056 |
+
self.num_key_value_heads = v_cfg.image_num_key_value_heads
|
| 1057 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 1058 |
+
self.initializer_range = v_cfg.initializer_range
|
| 1059 |
+
|
| 1060 |
+
nlayers = 1 if (is_vit_layer or config.vit_layers is None) else len(config.vit_layers)
|
| 1061 |
+
|
| 1062 |
+
if query != "vector":
|
| 1063 |
+
self.wq = nn.Linear(
|
| 1064 |
+
nlayers * input_dim,
|
| 1065 |
+
self.num_heads * self.head_dim,
|
| 1066 |
+
bias=use_bias,
|
| 1067 |
+
device=config.init_device,
|
| 1068 |
+
)
|
| 1069 |
+
self.wk = nn.Linear(
|
| 1070 |
+
nlayers * input_dim,
|
| 1071 |
+
self.num_key_value_heads * self.head_dim,
|
| 1072 |
+
bias=use_bias,
|
| 1073 |
+
device=config.init_device,
|
| 1074 |
+
)
|
| 1075 |
+
self.wv = nn.Linear(
|
| 1076 |
+
nlayers * input_dim,
|
| 1077 |
+
self.num_key_value_heads * self.head_dim,
|
| 1078 |
+
bias=use_bias,
|
| 1079 |
+
device=config.init_device,
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
if query == "vector":
|
| 1083 |
+
self.attention_query = nn.Parameter(
|
| 1084 |
+
torch.zeros(
|
| 1085 |
+
1, self.num_key_value_heads * self.head_dim, device=config.init_device,
|
| 1086 |
+
),
|
| 1087 |
+
)
|
| 1088 |
+
|
| 1089 |
+
if output_layer:
|
| 1090 |
+
self.wo = nn.Linear(
|
| 1091 |
+
self.num_heads * self.head_dim,
|
| 1092 |
+
self.embed_dim,
|
| 1093 |
+
bias=use_bias,
|
| 1094 |
+
device=config.init_device,
|
| 1095 |
+
)
|
| 1096 |
+
self.attention_dropout = Dropout(v_cfg.attention_dropout, broadcast_dims=(0, 1))
|
| 1097 |
+
if dropout:
|
| 1098 |
+
self.residual_dropout = Dropout(v_cfg.residual_dropout)
|
| 1099 |
+
|
| 1100 |
+
def reset_parameters(self):
|
| 1101 |
+
if self.query != "vector":
|
| 1102 |
+
nn.init.normal_(self.wq.weight, std=self.initializer_range)
|
| 1103 |
+
nn.init.normal_(self.wk.weight, std=self.initializer_range)
|
| 1104 |
+
nn.init.normal_(self.wv.weight, std=self.initializer_range)
|
| 1105 |
+
if self.output_layer:
|
| 1106 |
+
nn.init.normal_(self.wo.weight, std=self.initializer_range)
|
| 1107 |
+
if self.use_bias:
|
| 1108 |
+
if self.query != "vector":
|
| 1109 |
+
nn.init.constant_(self.wq.bias, 0)
|
| 1110 |
+
nn.init.constant_(self.wk.bias, 0)
|
| 1111 |
+
nn.init.constant_(self.wv.bias, 0)
|
| 1112 |
+
if self.output_layer:
|
| 1113 |
+
nn.init.constant_(self.wo.bias, 0)
|
| 1114 |
+
if self.query == "vector":
|
| 1115 |
+
nn.init.normal_(self.attention_query, std=self.initializer_range)
|
| 1116 |
+
|
| 1117 |
+
def _split_heads(self, hidden_states, num_heads):
|
| 1118 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
|
| 1119 |
+
|
| 1120 |
+
def _merge_heads(self, hidden_states):
|
| 1121 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
| 1122 |
+
|
| 1123 |
+
def forward(self, inputs_kv: torch.Tensor) -> torch.Tensor:
|
| 1124 |
+
|
| 1125 |
+
xk, xv = self.wk(inputs_kv), self.wv(inputs_kv)
|
| 1126 |
+
|
| 1127 |
+
if self.query == "mean":
|
| 1128 |
+
inputs_q = inputs_kv.mean(dim=1, keepdim=True)
|
| 1129 |
+
xq = self.wq(inputs_q)
|
| 1130 |
+
elif self.query == "first":
|
| 1131 |
+
inputs_q = inputs_kv[:, :1]
|
| 1132 |
+
xq = self.wq(inputs_q)
|
| 1133 |
+
elif self.query == "vector":
|
| 1134 |
+
xq = self.attention_query.expand(inputs_kv.size(0), -1, -1)
|
| 1135 |
+
elif self.query == "constant":
|
| 1136 |
+
inputs_q = torch.ones_like(inputs_kv[:, :1]) / math.sqrt(inputs_kv.shape[-1])
|
| 1137 |
+
xq = self.wq(inputs_q)
|
| 1138 |
+
else:
|
| 1139 |
+
raise ValueError(f"Unknown query type: {self.query}")
|
| 1140 |
+
|
| 1141 |
+
xq = self._split_heads(xq, self.num_heads)
|
| 1142 |
+
xk = self._split_heads(xk, self.num_key_value_heads)
|
| 1143 |
+
xv = self._split_heads(xv, self.num_key_value_heads)
|
| 1144 |
+
|
| 1145 |
+
if self.num_heads != self.num_key_value_heads:
|
| 1146 |
+
xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
|
| 1147 |
+
xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
|
| 1148 |
+
|
| 1149 |
+
xq = xq.to(torch.float)
|
| 1150 |
+
xk = xk.to(torch.float)
|
| 1151 |
+
|
| 1152 |
+
xq = xq / math.sqrt(xq.size(-1))
|
| 1153 |
+
attn_weights = torch.einsum("...qhd,...khd->...hqk", xq, xk)
|
| 1154 |
+
|
| 1155 |
+
attn_weights = F.softmax(attn_weights, dim=-1).to(xq.dtype)
|
| 1156 |
+
|
| 1157 |
+
attn_weights = self.attention_dropout(attn_weights).to(xv.dtype)
|
| 1158 |
+
|
| 1159 |
+
attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights, xv)
|
| 1160 |
+
attn_output = self._merge_heads(attn_output)
|
| 1161 |
+
if self.output_layer:
|
| 1162 |
+
attn_output = self.wo(attn_output)
|
| 1163 |
+
if self.dropout:
|
| 1164 |
+
attn_output = self.residual_dropout(attn_output)
|
| 1165 |
+
if self.mean_residual:
|
| 1166 |
+
attn_output += inputs_kv.mean(dim=1, keepdim=True)
|
| 1167 |
+
|
| 1168 |
+
return attn_output
|
| 1169 |
+
|
| 1170 |
+
|
| 1171 |
+
class MLP(nn.Module):
|
| 1172 |
+
def __init__(self, config: FullMolmoConfig, input_dim: int, dropout: float = 0.0):
|
| 1173 |
+
super().__init__()
|
| 1174 |
+
self.config = config
|
| 1175 |
+
self.hidden_size = (
|
| 1176 |
+
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
|
| 1177 |
+
)
|
| 1178 |
+
self.initializer_range = config.initializer_range
|
| 1179 |
+
|
| 1180 |
+
self.w1 = nn.Linear(
|
| 1181 |
+
input_dim,
|
| 1182 |
+
self.hidden_size // 2,
|
| 1183 |
+
bias=False,
|
| 1184 |
+
device=config.init_device,
|
| 1185 |
+
)
|
| 1186 |
+
self.w2 = nn.Linear(
|
| 1187 |
+
self.hidden_size // 2,
|
| 1188 |
+
config.d_model,
|
| 1189 |
+
bias=False,
|
| 1190 |
+
device=config.init_device,
|
| 1191 |
+
)
|
| 1192 |
+
self.w3 = nn.Linear(
|
| 1193 |
+
input_dim,
|
| 1194 |
+
self.hidden_size // 2,
|
| 1195 |
+
bias=False,
|
| 1196 |
+
device=config.init_device,
|
| 1197 |
+
)
|
| 1198 |
+
# Activation function.
|
| 1199 |
+
self.act = Activation.build(config)
|
| 1200 |
+
self.dropout = Dropout(dropout)
|
| 1201 |
+
|
| 1202 |
+
def reset_parameters(self):
|
| 1203 |
+
nn.init.normal_(self.w1.weight, std=self.initializer_range)
|
| 1204 |
+
nn.init.normal_(self.w2.weight, std=self.initializer_range)
|
| 1205 |
+
nn.init.normal_(self.w3.weight, std=self.initializer_range)
|
| 1206 |
+
|
| 1207 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1208 |
+
x = self.w2(self.act(self.w1(x), self.w3(x)))
|
| 1209 |
+
x = self.dropout(x)
|
| 1210 |
+
return x
|
| 1211 |
+
|
| 1212 |
+
|
| 1213 |
+
class Residual(nn.Module):
|
| 1214 |
+
def __init__(self, submodule: nn.Module):
|
| 1215 |
+
super().__init__()
|
| 1216 |
+
self.submodule = submodule
|
| 1217 |
+
|
| 1218 |
+
def reset_parameters(self):
|
| 1219 |
+
self.submodule.reset_parameters()
|
| 1220 |
+
|
| 1221 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1222 |
+
return x + self.submodule(x)
|
| 1223 |
+
|
| 1224 |
+
|
| 1225 |
+
class OLMoVisionBackbone(nn.Module):
|
| 1226 |
+
def __init__(self, config: FullMolmoConfig):
|
| 1227 |
+
super().__init__()
|
| 1228 |
+
self.config = config
|
| 1229 |
+
self.image_vit = VisionTransformer(config)
|
| 1230 |
+
|
| 1231 |
+
input_dim: int = None
|
| 1232 |
+
self.image_pooling_2d: nn.Module = None
|
| 1233 |
+
if config.image_pooling_2d in {ImagePooling2DType.attention, ImagePooling2DType.attention_meanq}:
|
| 1234 |
+
self.image_pooling_2d = MultiHeadDotProductAttention(config, is_vit_layer=False)
|
| 1235 |
+
input_dim = config.vision_backbone.image_emb_dim
|
| 1236 |
+
elif config.image_pooling_2d == ImagePooling2DType.attention_2wide:
|
| 1237 |
+
cfg = deepcopy(config)
|
| 1238 |
+
cfg.vision_backbone.image_emb_dim *= 2
|
| 1239 |
+
cfg.vision_backbone.image_head_dim *= 2
|
| 1240 |
+
self.image_pooling_2d = MultiHeadDotProductAttention(cfg, is_vit_layer=False)
|
| 1241 |
+
input_dim = cfg.vision_backbone.image_emb_dim
|
| 1242 |
+
elif config.image_pooling_2d == ImagePooling2DType.attention_v2:
|
| 1243 |
+
assert config.vit_layers is not None
|
| 1244 |
+
use_bias = True
|
| 1245 |
+
dropout = True
|
| 1246 |
+
output_layer = True
|
| 1247 |
+
query = "mean"
|
| 1248 |
+
mean_residual = False
|
| 1249 |
+
factor = len(config.vit_layers)
|
| 1250 |
+
self.image_pooling_2d = MultiHeadAttentionPool(
|
| 1251 |
+
config,
|
| 1252 |
+
factor=factor,
|
| 1253 |
+
use_bias=use_bias,
|
| 1254 |
+
dropout=dropout,
|
| 1255 |
+
output_layer=output_layer,
|
| 1256 |
+
mean_residual=mean_residual,
|
| 1257 |
+
query=query,
|
| 1258 |
+
is_vit_layer=False,
|
| 1259 |
+
)
|
| 1260 |
+
input_dim = config.vision_backbone.image_emb_dim * factor
|
| 1261 |
+
elif config.image_pooling_2d in [ImagePooling2DType.none, ImagePooling2DType.stack]:
|
| 1262 |
+
self.image_pooling_2d = None
|
| 1263 |
+
nlayers = 1 if config.vit_layers is None else len(config.vit_layers)
|
| 1264 |
+
input_dim = nlayers * config.vision_backbone.image_emb_dim
|
| 1265 |
+
else:
|
| 1266 |
+
raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}")
|
| 1267 |
+
|
| 1268 |
+
self.input_dim = input_dim
|
| 1269 |
+
|
| 1270 |
+
# `MLP` assume the activation takes two inputs, so it must be a 'llama' version
|
| 1271 |
+
if config.activation_type == ActivationType.swiglu:
|
| 1272 |
+
mlp_config = replace(config, activation_type=ActivationType.llama_swiglu)
|
| 1273 |
+
elif config.activation_type == ActivationType.gelu:
|
| 1274 |
+
mlp_config = replace(config, activation_type=ActivationType.llama_geglu)
|
| 1275 |
+
else:
|
| 1276 |
+
mlp_config = config
|
| 1277 |
+
if config.image_projector == ImageProjectType.mlpx2:
|
| 1278 |
+
self.image_projector = nn.ModuleList(
|
| 1279 |
+
[MLP(mlp_config, input_dim), Residual(MLP(config, input_dim))]
|
| 1280 |
+
)
|
| 1281 |
+
elif config.image_projector == ImageProjectType.mlp:
|
| 1282 |
+
self.image_projector = MLP(mlp_config, input_dim)
|
| 1283 |
+
elif config.image_projector == ImageProjectType.linear:
|
| 1284 |
+
self.image_projector = nn.Linear(
|
| 1285 |
+
input_dim,
|
| 1286 |
+
config.d_model,
|
| 1287 |
+
bias=False,
|
| 1288 |
+
device=config.init_device,
|
| 1289 |
+
)
|
| 1290 |
+
else:
|
| 1291 |
+
raise NotImplementedError(f"Unknown image projector: {config.image_projector}")
|
| 1292 |
+
|
| 1293 |
+
self.image_feature_dropout = Dropout(config.image_feature_dropout)
|
| 1294 |
+
|
| 1295 |
+
def reset_parameters(self):
|
| 1296 |
+
if self.image_pooling_2d is not None:
|
| 1297 |
+
self.image_pooling_2d.reset_parameters()
|
| 1298 |
+
if self.config.image_projector == "2mlp":
|
| 1299 |
+
for module in self.image_projector:
|
| 1300 |
+
module.reset_parameters()
|
| 1301 |
+
elif self.config.image_projector == "linear":
|
| 1302 |
+
nn.init.xavier_uniform_(self.image_projector.weight)
|
| 1303 |
+
else:
|
| 1304 |
+
self.image_projector.reset_parameters()
|
| 1305 |
+
|
| 1306 |
+
def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 1307 |
+
raise NotImplementedError
|
| 1308 |
+
|
| 1309 |
+
|
| 1310 |
+
class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
|
| 1311 |
+
def __init__(self, config: FullMolmoConfig):
|
| 1312 |
+
super().__init__(config)
|
| 1313 |
+
v_cfg = self.config.vision_backbone
|
| 1314 |
+
self.grad_checkpointing = False
|
| 1315 |
+
|
| 1316 |
+
self.num_prefix_tokens = self.image_vit.num_prefix_tokens
|
| 1317 |
+
assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported"
|
| 1318 |
+
|
| 1319 |
+
self.pad_embed = None
|
| 1320 |
+
if config.image_padding_embed:
|
| 1321 |
+
image_dim = v_cfg.image_emb_dim*len(self.config.vit_layers)
|
| 1322 |
+
if config.image_padding_embed in ["pad_embed", "regress"]:
|
| 1323 |
+
self.pad_embed = nn.Parameter(
|
| 1324 |
+
torch.zeros((image_dim,), device=config.init_device))
|
| 1325 |
+
elif config.image_padding_embed == "pad_and_partial_pad":
|
| 1326 |
+
self.pad_embed = nn.Parameter(
|
| 1327 |
+
torch.zeros((2, image_dim), device=config.init_device))
|
| 1328 |
+
else:
|
| 1329 |
+
raise ValueError(config.image_padding_embed)
|
| 1330 |
+
|
| 1331 |
+
def reset_parameters(self):
|
| 1332 |
+
super().reset_parameters()
|
| 1333 |
+
self.image_vit.reset_parameters()
|
| 1334 |
+
|
| 1335 |
+
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
| 1336 |
+
"""
|
| 1337 |
+
: param images: (batch_size, num_crops, num_patch, n_pixels)
|
| 1338 |
+
"""
|
| 1339 |
+
cfg = self.config
|
| 1340 |
+
v_cfg = self.config.vision_backbone
|
| 1341 |
+
B, T, N, D = images.shape
|
| 1342 |
+
|
| 1343 |
+
mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
|
| 1344 |
+
|
| 1345 |
+
# Output all hidden states
|
| 1346 |
+
# n_layers x (batch_num_crops, (1+)n_tokens, image_emb_dim)
|
| 1347 |
+
images = images.view(B * T, N, D)
|
| 1348 |
+
image_features = self.image_vit(images)
|
| 1349 |
+
|
| 1350 |
+
if cfg.vit_layers is not None:
|
| 1351 |
+
features = []
|
| 1352 |
+
for layer in cfg.vit_layers:
|
| 1353 |
+
features.append(image_features[layer])
|
| 1354 |
+
image_features = torch.cat(features, dim=-1)
|
| 1355 |
+
else:
|
| 1356 |
+
image_features = image_features[-1]
|
| 1357 |
+
|
| 1358 |
+
cls_embed: torch.Tensor = None
|
| 1359 |
+
if self.num_prefix_tokens > 0:
|
| 1360 |
+
cls_embed = image_features[:, 0]
|
| 1361 |
+
image_features = image_features[:, 1:]
|
| 1362 |
+
|
| 1363 |
+
image_features = image_features * mask
|
| 1364 |
+
image_features = image_features.view(B, T, N, -1)
|
| 1365 |
+
|
| 1366 |
+
cls_embed = cls_embed.view(B, T, -1) if cls_embed is not None else None
|
| 1367 |
+
|
| 1368 |
+
return image_features, cls_embed
|
| 1369 |
+
|
| 1370 |
+
def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 1371 |
+
cfg = self.config
|
| 1372 |
+
|
| 1373 |
+
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
|
| 1374 |
+
batch_size, num_image = images.shape[:2]
|
| 1375 |
+
image_features, cls_embed = self.encode_image(images)
|
| 1376 |
+
|
| 1377 |
+
if cfg.image_padding_embed:
|
| 1378 |
+
assert image_masks is not None
|
| 1379 |
+
if cfg.image_padding_embed == "pad_embed":
|
| 1380 |
+
all_pad = (image_masks == 0).to(dtype=torch.float32)
|
| 1381 |
+
pad_embed = self.pad_embed[None, None, None, :]
|
| 1382 |
+
image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1)
|
| 1383 |
+
elif cfg.image_padding_embed == "regress":
|
| 1384 |
+
pad_embed = self.pad_embed[None, None, None, :]
|
| 1385 |
+
image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1)
|
| 1386 |
+
elif cfg.image_padding_embed == "pad_and_partial_pad":
|
| 1387 |
+
pad_embed = self.pad_embed[:, None, None, None, :]
|
| 1388 |
+
all_pad = image_masks == 0
|
| 1389 |
+
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype)
|
| 1390 |
+
all_pad = all_pad.to(dtype=image_features.dtype)
|
| 1391 |
+
image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
|
| 1392 |
+
image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1)
|
| 1393 |
+
else:
|
| 1394 |
+
raise ValueError(cfg.image_padding_embed)
|
| 1395 |
+
|
| 1396 |
+
image_features = self.image_feature_dropout(image_features)
|
| 1397 |
+
if cls_embed is not None:
|
| 1398 |
+
cls_embed = self.image_feature_dropout(cls_embed)
|
| 1399 |
+
|
| 1400 |
+
image_features = image_features.reshape(
|
| 1401 |
+
(batch_size, num_image) + cfg.image_num_patch + (-1,),
|
| 1402 |
+
)
|
| 1403 |
+
|
| 1404 |
+
if cfg.image_num_patch[0] % cfg.image_pooling_h == 1:
|
| 1405 |
+
# Pad so we can still pool 2x2 patches
|
| 1406 |
+
image_features = F.pad(
|
| 1407 |
+
image_features,
|
| 1408 |
+
(0, 0, 0, 1, 0, 1, 0, 0, 0, 0),
|
| 1409 |
+
)
|
| 1410 |
+
|
| 1411 |
+
# image pooling
|
| 1412 |
+
image_features = einops.rearrange(
|
| 1413 |
+
image_features,
|
| 1414 |
+
'b n (h dh) (w dw) c -> (b n h w) (dh dw) c',
|
| 1415 |
+
dh=cfg.image_pooling_h,
|
| 1416 |
+
dw=cfg.image_pooling_w,
|
| 1417 |
+
)
|
| 1418 |
+
|
| 1419 |
+
if cfg.image_pooling_2d == ImagePooling2DType.attention_meanq:
|
| 1420 |
+
query = image_features.mean(-2, keepdim=True)
|
| 1421 |
+
image_features = self.image_pooling_2d(query, image_features)
|
| 1422 |
+
elif cfg.image_pooling_2d not in {ImagePooling2DType.none, ImagePooling2DType.stack}:
|
| 1423 |
+
if self.grad_checkpointing:
|
| 1424 |
+
from torch.utils.checkpoint import checkpoint
|
| 1425 |
+
image_features = checkpoint(self.image_pooling_2d, image_features[:, :1, :], image_features, use_reentrant=False)
|
| 1426 |
+
else:
|
| 1427 |
+
image_features = self.image_pooling_2d(image_features[:, :1, :], image_features)
|
| 1428 |
+
|
| 1429 |
+
h, w = cfg.llm_patches_per_crop()
|
| 1430 |
+
image_features = image_features.reshape(batch_size, num_image, h * w, -1)
|
| 1431 |
+
|
| 1432 |
+
# MLP layer to map the feature.
|
| 1433 |
+
if self.grad_checkpointing:
|
| 1434 |
+
from torch.utils.checkpoint import checkpoint
|
| 1435 |
+
image_features = checkpoint(self.image_projector, image_features, use_reentrant=False)
|
| 1436 |
+
else:
|
| 1437 |
+
image_features = self.image_projector(image_features)
|
| 1438 |
+
|
| 1439 |
+
# image_features: (batch_size, num_image, num_patch, d_model)
|
| 1440 |
+
# cls_embed: (batch_size, num_image, d_model)
|
| 1441 |
+
return image_features, cls_embed
|
| 1442 |
+
|
| 1443 |
+
|
| 1444 |
+
class ModuleType(str, Enum):
|
| 1445 |
+
in_module = "in"
|
| 1446 |
+
out_module = "out"
|
| 1447 |
+
emb = "emb"
|
| 1448 |
+
final_out = "final_out"
|
| 1449 |
+
|
| 1450 |
+
|
| 1451 |
+
def init_weights(
|
| 1452 |
+
config: FullMolmoConfig,
|
| 1453 |
+
module: Union[nn.Linear, nn.Embedding],
|
| 1454 |
+
d: Optional[int] = None,
|
| 1455 |
+
layer_id: Optional[int] = None,
|
| 1456 |
+
std_factor: float = 1.0,
|
| 1457 |
+
type_of_module: Optional[ModuleType] = None,
|
| 1458 |
+
) -> None:
|
| 1459 |
+
d = d if d is not None else config.d_model
|
| 1460 |
+
std = config.init_std * std_factor
|
| 1461 |
+
if config.init_cutoff_factor is not None:
|
| 1462 |
+
cutoff_value = config.init_cutoff_factor * std
|
| 1463 |
+
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
|
| 1464 |
+
else:
|
| 1465 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 1466 |
+
|
| 1467 |
+
|
| 1468 |
+
class LlamaSwiGLU(nn.Module):
|
| 1469 |
+
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 1470 |
+
return F.silu(x1) * x2
|
| 1471 |
+
|
| 1472 |
+
@property
|
| 1473 |
+
def output_multiplier(self) -> float:
|
| 1474 |
+
return 0.5
|
| 1475 |
+
|
| 1476 |
+
|
| 1477 |
+
class SwiGLU(nn.Module):
|
| 1478 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1479 |
+
x, gate = x.chunk(2, dim=-1)
|
| 1480 |
+
return F.silu(gate) * x
|
| 1481 |
+
|
| 1482 |
+
@property
|
| 1483 |
+
def output_multiplier(self) -> float:
|
| 1484 |
+
return 0.5
|
| 1485 |
+
|
| 1486 |
+
|
| 1487 |
+
class Activation(nn.Module):
|
| 1488 |
+
def __init__(self, config: FullMolmoConfig):
|
| 1489 |
+
super().__init__()
|
| 1490 |
+
self.config = config
|
| 1491 |
+
|
| 1492 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1493 |
+
raise NotImplementedError
|
| 1494 |
+
|
| 1495 |
+
@property
|
| 1496 |
+
def output_multiplier(self) -> float:
|
| 1497 |
+
raise NotImplementedError
|
| 1498 |
+
|
| 1499 |
+
@classmethod
|
| 1500 |
+
def build(cls, config: FullMolmoConfig) -> 'Activation':
|
| 1501 |
+
if config.activation_type == "quick_gelu":
|
| 1502 |
+
return QuickGELU(config)
|
| 1503 |
+
elif config.activation_type == "gelu":
|
| 1504 |
+
return cast(Activation, GELU(approximate="none"))
|
| 1505 |
+
elif config.activation_type == "gelu_tanh":
|
| 1506 |
+
return cast(Activation, GELU(approximate="tanh"))
|
| 1507 |
+
elif config.activation_type == "relu":
|
| 1508 |
+
return cast(Activation, ReLU(inplace=False))
|
| 1509 |
+
elif config.activation_type == "silu":
|
| 1510 |
+
return cast(Activation, SiLU(inplace=False))
|
| 1511 |
+
# elif config.activation_type == "llama_geglu":
|
| 1512 |
+
# return LlamaGEGLU(config)
|
| 1513 |
+
# elif config.activation_type == "llama_geglu_tanh":
|
| 1514 |
+
# return LlamaGEGLUTanh(config)
|
| 1515 |
+
elif config.activation_type == "llama_swiglu":
|
| 1516 |
+
return LlamaSwiGLU()
|
| 1517 |
+
elif config.activation_type == "swiglu":
|
| 1518 |
+
return SwiGLU()
|
| 1519 |
+
else:
|
| 1520 |
+
raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
|
| 1521 |
+
|
| 1522 |
+
|
| 1523 |
+
class QuickGELU(Activation):
|
| 1524 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1525 |
+
return x * torch.sigmoid(1.702 * x)
|
| 1526 |
+
|
| 1527 |
+
@property
|
| 1528 |
+
def output_multiplier(self) -> float:
|
| 1529 |
+
return 1.0
|
| 1530 |
+
|
| 1531 |
+
|
| 1532 |
+
class GELU(nn.GELU):
|
| 1533 |
+
@property
|
| 1534 |
+
def output_multiplier(self) -> float:
|
| 1535 |
+
return 1.0
|
| 1536 |
+
|
| 1537 |
+
|
| 1538 |
+
class ReLU(nn.ReLU):
|
| 1539 |
+
@property
|
| 1540 |
+
def output_multiplier(self) -> float:
|
| 1541 |
+
return 1.0
|
| 1542 |
+
|
| 1543 |
+
|
| 1544 |
+
class SiLU(nn.SiLU):
|
| 1545 |
+
@property
|
| 1546 |
+
def output_multiplier(self) -> float:
|
| 1547 |
+
return 1.0
|
| 1548 |
+
|
| 1549 |
+
|
| 1550 |
+
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
|
| 1551 |
+
att_bias = torch.triu(
|
| 1552 |
+
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
|
| 1553 |
+
diagonal=1,
|
| 1554 |
+
)
|
| 1555 |
+
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
|
| 1556 |
+
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
|
| 1557 |
+
|
| 1558 |
+
|
| 1559 |
+
def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 1560 |
+
if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
|
| 1561 |
+
if causal_bias.device != device:
|
| 1562 |
+
causal_bias = causal_bias.to(device)
|
| 1563 |
+
cache["causal_attention_bias"] = causal_bias
|
| 1564 |
+
return causal_bias
|
| 1565 |
+
with torch.autocast(device.type, enabled=False):
|
| 1566 |
+
causal_bias = causal_attention_bias(seq_len, device)
|
| 1567 |
+
cache["causal_attention_bias"] = causal_bias
|
| 1568 |
+
return causal_bias
|
| 1569 |
+
|
| 1570 |
+
|
| 1571 |
+
class LayerNormBase(nn.Module):
|
| 1572 |
+
def __init__(
|
| 1573 |
+
self,
|
| 1574 |
+
config: MolmoConfig,
|
| 1575 |
+
*,
|
| 1576 |
+
size: Optional[int] = None,
|
| 1577 |
+
elementwise_affine: Optional[bool] = True,
|
| 1578 |
+
eps: float = 1e-05,
|
| 1579 |
+
weight_initializer: Optional[Callable] = torch.ones,
|
| 1580 |
+
bias_initializer: Optional[Callable] = torch.zeros,
|
| 1581 |
+
):
|
| 1582 |
+
super().__init__()
|
| 1583 |
+
self.config = config
|
| 1584 |
+
self.eps = self.config.layer_norm_eps or eps
|
| 1585 |
+
self.normalized_shape = (size or config.d_model,)
|
| 1586 |
+
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
|
| 1587 |
+
self.weight = nn.Parameter(weight_initializer(self.normalized_shape, device=config.init_device))
|
| 1588 |
+
use_bias = self.config.bias_for_layer_norm
|
| 1589 |
+
if use_bias is None:
|
| 1590 |
+
use_bias = self.config.include_bias
|
| 1591 |
+
if use_bias:
|
| 1592 |
+
self.bias = nn.Parameter(bias_initializer(self.normalized_shape, device=config.init_device))
|
| 1593 |
+
else:
|
| 1594 |
+
self.register_parameter("bias", None)
|
| 1595 |
+
else:
|
| 1596 |
+
self.register_parameter("bias", None)
|
| 1597 |
+
self.register_parameter("weight", None)
|
| 1598 |
+
|
| 1599 |
+
@classmethod
|
| 1600 |
+
def build(cls, config: FullMolmoConfig, size: Optional[int] = None, **kwargs):
|
| 1601 |
+
if config.layer_norm_type == "default":
|
| 1602 |
+
return LayerNorm(config, size=size, low_precision=False, **kwargs)
|
| 1603 |
+
elif config.layer_norm_type == "low_precision":
|
| 1604 |
+
return LayerNorm(config, size=size, low_precision=True, **kwargs)
|
| 1605 |
+
elif config.layer_norm_type == "rms":
|
| 1606 |
+
return RMSLayerNorm(config, size=size, **kwargs)
|
| 1607 |
+
else:
|
| 1608 |
+
raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
|
| 1609 |
+
|
| 1610 |
+
|
| 1611 |
+
class RMSLayerNorm(LayerNormBase):
|
| 1612 |
+
"""
|
| 1613 |
+
RMS layer norm, a simplified :class:`LayerNorm` implementation
|
| 1614 |
+
"""
|
| 1615 |
+
|
| 1616 |
+
def __init__(
|
| 1617 |
+
self,
|
| 1618 |
+
config: FullMolmoConfig,
|
| 1619 |
+
size: Optional[int] = None,
|
| 1620 |
+
elementwise_affine: Optional[bool] = None,
|
| 1621 |
+
eps: float = 1e-5,
|
| 1622 |
+
):
|
| 1623 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
|
| 1624 |
+
|
| 1625 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1626 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
| 1627 |
+
og_dtype = x.dtype
|
| 1628 |
+
x = x.to(torch.float32)
|
| 1629 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 1630 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 1631 |
+
x = x.to(og_dtype)
|
| 1632 |
+
|
| 1633 |
+
if self.weight is not None:
|
| 1634 |
+
if self.bias is not None:
|
| 1635 |
+
return self.weight * x + self.bias
|
| 1636 |
+
else:
|
| 1637 |
+
return self.weight * x
|
| 1638 |
+
else:
|
| 1639 |
+
return x
|
| 1640 |
+
|
| 1641 |
+
|
| 1642 |
+
class LayerNorm(LayerNormBase):
|
| 1643 |
+
"""
|
| 1644 |
+
The default :class:`LayerNorm` implementation which can optionally run in low precision.
|
| 1645 |
+
"""
|
| 1646 |
+
|
| 1647 |
+
def __init__(
|
| 1648 |
+
self,
|
| 1649 |
+
config: FullMolmoConfig,
|
| 1650 |
+
size: Optional[int] = None,
|
| 1651 |
+
low_precision: bool = False,
|
| 1652 |
+
elementwise_affine: Optional[bool] = None,
|
| 1653 |
+
eps: float = 1e-05,
|
| 1654 |
+
):
|
| 1655 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
|
| 1656 |
+
self.low_precision = low_precision
|
| 1657 |
+
|
| 1658 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 1659 |
+
if self.low_precision:
|
| 1660 |
+
module_device = x.device
|
| 1661 |
+
downcast_x = self._cast_if_autocast_enabled(x)
|
| 1662 |
+
downcast_weight = (
|
| 1663 |
+
self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
| 1664 |
+
)
|
| 1665 |
+
downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
| 1666 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
| 1667 |
+
return F.layer_norm(
|
| 1668 |
+
downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
|
| 1669 |
+
)
|
| 1670 |
+
else:
|
| 1671 |
+
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
|
| 1672 |
+
|
| 1673 |
+
|
| 1674 |
+
class Molmo(nn.Module):
|
| 1675 |
+
def __init__(self, config: FullMolmoConfig, init_params: bool = True):
|
| 1676 |
+
super().__init__()
|
| 1677 |
+
self.config = config
|
| 1678 |
+
self.__cache = BufferCache()
|
| 1679 |
+
|
| 1680 |
+
# Validate config.
|
| 1681 |
+
if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
|
| 1682 |
+
if self.config.embedding_size < self.config.vocab_size:
|
| 1683 |
+
raise MolmoConfigurationError("embedding size should be at least as big as vocab size")
|
| 1684 |
+
elif self.config.embedding_size % 128 != 0:
|
| 1685 |
+
import warnings
|
| 1686 |
+
|
| 1687 |
+
warnings.warn(
|
| 1688 |
+
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
|
| 1689 |
+
)
|
| 1690 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 1691 |
+
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
|
| 1692 |
+
|
| 1693 |
+
wte = None
|
| 1694 |
+
if self.config.additional_vocab_size is not None:
|
| 1695 |
+
wte = Embedding(
|
| 1696 |
+
config.embedding_size or config.vocab_size,
|
| 1697 |
+
config.additional_vocab_size,
|
| 1698 |
+
config.d_model,
|
| 1699 |
+
device=config.init_device,
|
| 1700 |
+
initializer_range=config.initializer_range,
|
| 1701 |
+
new_embed_initializer_range=config.new_embedding_init_range
|
| 1702 |
+
)
|
| 1703 |
+
else:
|
| 1704 |
+
wte=nn.Embedding(
|
| 1705 |
+
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
|
| 1706 |
+
)
|
| 1707 |
+
|
| 1708 |
+
self.transformer = nn.ModuleDict(
|
| 1709 |
+
dict(
|
| 1710 |
+
wte=wte,
|
| 1711 |
+
emb_drop=Dropout(config.embedding_dropout),
|
| 1712 |
+
ln_f=LayerNorm.build(config),
|
| 1713 |
+
)
|
| 1714 |
+
)
|
| 1715 |
+
|
| 1716 |
+
blocks = [MolmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]
|
| 1717 |
+
if self.config.block_group_size > 1:
|
| 1718 |
+
raise NotImplementedError()
|
| 1719 |
+
else:
|
| 1720 |
+
self.transformer.update({"blocks": nn.ModuleList(blocks)})
|
| 1721 |
+
|
| 1722 |
+
if not self.config.rope:
|
| 1723 |
+
self.transformer.update(
|
| 1724 |
+
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
|
| 1725 |
+
)
|
| 1726 |
+
if not config.weight_tying:
|
| 1727 |
+
self.transformer.update(
|
| 1728 |
+
{
|
| 1729 |
+
"ff_out": nn.Linear(
|
| 1730 |
+
config.d_model,
|
| 1731 |
+
config.embedding_size or config.vocab_size,
|
| 1732 |
+
bias=config.include_bias,
|
| 1733 |
+
device=config.init_device,
|
| 1734 |
+
)
|
| 1735 |
+
}
|
| 1736 |
+
)
|
| 1737 |
+
|
| 1738 |
+
self.vision_backbone: Optional[OLMoVisionBackbone] = None
|
| 1739 |
+
if config.vision_backbone is not None:
|
| 1740 |
+
self.vision_backbone = OLMoPretrainedVisionBackbone(config)
|
| 1741 |
+
|
| 1742 |
+
self.__num_fwd_flops: Optional[int] = None
|
| 1743 |
+
|
| 1744 |
+
def reset_parameters(self):
|
| 1745 |
+
if self.vision_backbone is not None:
|
| 1746 |
+
self.vision_backbone.reset_parameters()
|
| 1747 |
+
self.reset_non_vision_parameters()
|
| 1748 |
+
|
| 1749 |
+
def reset_non_vision_parameters(self):
|
| 1750 |
+
self.transformer.wte.reset_parameters()
|
| 1751 |
+
if hasattr(self.transformer.wte, "new_embedding"):
|
| 1752 |
+
nn.init.normal_(self.transformer.wte.new_embedding, std=self.config.new_embedding_init_range)
|
| 1753 |
+
|
| 1754 |
+
if hasattr(self.transformer, "wpe"):
|
| 1755 |
+
nn.init.normal_(self.transformer.wpe, mean=0.0, std=1.0)
|
| 1756 |
+
|
| 1757 |
+
self.transformer.ln_f.reset_parameters() # type: ignore
|
| 1758 |
+
|
| 1759 |
+
if hasattr(self.transformer, "ff_out"):
|
| 1760 |
+
nn.init.normal_(self.transformer.ff_out, mean=0.0, std=0.02)
|
| 1761 |
+
|
| 1762 |
+
if self.config.block_group_size == 1:
|
| 1763 |
+
for block in self.transformer.blocks:
|
| 1764 |
+
block.reset_parameters()
|
| 1765 |
+
else:
|
| 1766 |
+
for block_group in self.transformer.block_groups:
|
| 1767 |
+
block_group.reset_parameters()
|
| 1768 |
+
|
| 1769 |
+
|
| 1770 |
+
def forward(
|
| 1771 |
+
self,
|
| 1772 |
+
input_ids: torch.LongTensor,
|
| 1773 |
+
input_embeddings: Optional[torch.FloatTensor] = None,
|
| 1774 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1775 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 1776 |
+
response_mask: Optional[torch.Tensor] = None,
|
| 1777 |
+
images: Optional[torch.Tensor] = None,
|
| 1778 |
+
image_masks: Optional[torch.Tensor] = None,
|
| 1779 |
+
image_input_idx: Optional[torch.Tensor] = None,
|
| 1780 |
+
subsegment_ids: Optional[torch.Tensor] = None,
|
| 1781 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1782 |
+
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 1783 |
+
use_cache: bool = False,
|
| 1784 |
+
last_logits_only: bool = False,
|
| 1785 |
+
output_hidden_states: Optional[bool] = None,
|
| 1786 |
+
append_last_valid_logits: Optional[torch.Tensor] = None,
|
| 1787 |
+
) -> ModelOutput:
|
| 1788 |
+
"""
|
| 1789 |
+
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
| 1790 |
+
:param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
|
| 1791 |
+
embeddings. When provided, it is treated as the output of the input embedding layer.
|
| 1792 |
+
:param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
|
| 1793 |
+
which input IDs are masked. A `1` value in the mask means that
|
| 1794 |
+
the corresponding input ID should *not* be ignored. A `0` means
|
| 1795 |
+
that the corresponding input ID is masked.
|
| 1796 |
+
|
| 1797 |
+
This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
|
| 1798 |
+
library.
|
| 1799 |
+
:param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
|
| 1800 |
+
`(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
|
| 1801 |
+
to introduce causal or other biases.
|
| 1802 |
+
|
| 1803 |
+
If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
|
| 1804 |
+
indicates that the i-th element in the sequence is allowed to attend to the j-th
|
| 1805 |
+
element in the sequence.
|
| 1806 |
+
|
| 1807 |
+
If the tensor is a float tensor, it will just be added to the attention
|
| 1808 |
+
scores before the softmax.
|
| 1809 |
+
|
| 1810 |
+
The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
|
| 1811 |
+
:param response_mask: A tensor of shape `(batch_size, seq_len)` that indicates
|
| 1812 |
+
the response mask. A `1` value in the mask means that the corresponding token
|
| 1813 |
+
is a response token. A `0` means that the corresponding token is not
|
| 1814 |
+
a response token.
|
| 1815 |
+
:param past_key_values: Pre-computed keys and values for each attention block.
|
| 1816 |
+
Can be used to speed up sequential decoding. The `input_ids` which have
|
| 1817 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
| 1818 |
+
:param use_cache: If `True`, return key and value tensors for each block.
|
| 1819 |
+
:param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
|
| 1820 |
+
This can speed up decoding when you only care about the next token.
|
| 1821 |
+
"""
|
| 1822 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
| 1823 |
+
|
| 1824 |
+
if past_key_values:
|
| 1825 |
+
assert len(past_key_values) == self.config.n_layers
|
| 1826 |
+
|
| 1827 |
+
has_image = images is not None
|
| 1828 |
+
|
| 1829 |
+
assert not (has_image and input_embeddings is not None), "Cannot provide both images and input embeddings."
|
| 1830 |
+
assert not (has_image and past_key_values is not None), "Cached key and values should not be used with images."
|
| 1831 |
+
|
| 1832 |
+
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
|
| 1833 |
+
if past_key_values is None:
|
| 1834 |
+
past_length = 0
|
| 1835 |
+
else:
|
| 1836 |
+
past_length = past_key_values[0][0].size(-2)
|
| 1837 |
+
|
| 1838 |
+
if self.config.use_position_ids and attention_mask is None:
|
| 1839 |
+
attention_mask = input_ids != -1
|
| 1840 |
+
|
| 1841 |
+
if subsegment_ids is not None:
|
| 1842 |
+
assert not use_cache, "Subsegment_ids cannot be used with cache."
|
| 1843 |
+
subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
|
| 1844 |
+
attention_mask = (
|
| 1845 |
+
subsegment_mask.to(attention_mask.dtype) *
|
| 1846 |
+
attention_mask.unsqueeze(2) *
|
| 1847 |
+
attention_mask.unsqueeze(1))
|
| 1848 |
+
if position_ids is None:
|
| 1849 |
+
raise ValueError(f"Positioned ids must be given if using subsegment_ids")
|
| 1850 |
+
else:
|
| 1851 |
+
if self.config.use_position_ids and position_ids is None:
|
| 1852 |
+
position_ids = torch.clamp(
|
| 1853 |
+
torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1,
|
| 1854 |
+
min=0,
|
| 1855 |
+
).broadcast_to((batch_size, attention_mask.shape[-1]))
|
| 1856 |
+
|
| 1857 |
+
# Get embeddings of input.
|
| 1858 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1859 |
+
if input_ids is not None:
|
| 1860 |
+
input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
|
| 1861 |
+
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
|
| 1862 |
+
|
| 1863 |
+
num_image: Optional[int] = None
|
| 1864 |
+
if images is not None:
|
| 1865 |
+
# shape: (batch_size, num_image, num_patch, d_model)
|
| 1866 |
+
# cls_embed: (batch_size, num_image, d_model)
|
| 1867 |
+
image_features, cls_embed = self.vision_backbone(images, image_masks)
|
| 1868 |
+
num_image, num_patch = image_features.shape[1:3]
|
| 1869 |
+
assert image_input_idx.shape == (batch_size, num_image, num_patch)
|
| 1870 |
+
|
| 1871 |
+
# inster the image feature into the embedding.
|
| 1872 |
+
image_features = image_features.view(batch_size, num_image * num_patch, -1)
|
| 1873 |
+
image_input_idx = image_input_idx.view(batch_size, num_image * num_patch)
|
| 1874 |
+
|
| 1875 |
+
valid = image_input_idx >= 0
|
| 1876 |
+
batch_idx = torch.arange(batch_size, device=x.device)
|
| 1877 |
+
batch_idx = torch.tile(batch_idx[:, None], [1, image_features.shape[1]])
|
| 1878 |
+
|
| 1879 |
+
# For hf demo/endpoint
|
| 1880 |
+
image_features = image_features.to(x.device)
|
| 1881 |
+
|
| 1882 |
+
x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
|
| 1883 |
+
|
| 1884 |
+
if not self.config.rope:
|
| 1885 |
+
# Get positional embeddings.
|
| 1886 |
+
# shape: (1, seq_len)
|
| 1887 |
+
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
|
| 1888 |
+
# shape: (1, seq_len, d_model)
|
| 1889 |
+
pos_emb = self.transformer.wpe(pos) # type: ignore
|
| 1890 |
+
x = pos_emb + x
|
| 1891 |
+
|
| 1892 |
+
# Add input + positional embeddings and apply dropout.
|
| 1893 |
+
# shape: (batch_size, seq_len, d_model)
|
| 1894 |
+
x = self.transformer.emb_drop(x) # type: ignore
|
| 1895 |
+
|
| 1896 |
+
# normalized
|
| 1897 |
+
if self.config.normalize_input_embeds:
|
| 1898 |
+
x = x * (self.config.d_model ** 0.5)
|
| 1899 |
+
|
| 1900 |
+
# Transform the attention mask into what the blocks expect.
|
| 1901 |
+
if attention_mask is not None:
|
| 1902 |
+
# shape: (batch_size, 1, 1, seq_len)
|
| 1903 |
+
if len(attention_mask.shape) == 2:
|
| 1904 |
+
attention_mask = attention_mask[:, :past_length + seq_len]
|
| 1905 |
+
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
|
| 1906 |
+
else:
|
| 1907 |
+
attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float)
|
| 1908 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
| 1909 |
+
|
| 1910 |
+
# Merge attention mask with attention bias.
|
| 1911 |
+
if (
|
| 1912 |
+
attention_bias is not None
|
| 1913 |
+
or attention_mask is not None
|
| 1914 |
+
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
|
| 1915 |
+
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
|
| 1916 |
+
# scores correctly.
|
| 1917 |
+
or past_key_values is not None
|
| 1918 |
+
):
|
| 1919 |
+
if attention_bias is None:
|
| 1920 |
+
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
|
| 1921 |
+
elif attention_bias.dtype in (torch.int8, torch.bool):
|
| 1922 |
+
attention_bias = attention_bias.to(dtype=torch.float)
|
| 1923 |
+
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
|
| 1924 |
+
|
| 1925 |
+
# Transform to the right shape and data type.
|
| 1926 |
+
mask_len = seq_len
|
| 1927 |
+
if attention_mask is not None:
|
| 1928 |
+
mask_len = attention_mask.shape[-1]
|
| 1929 |
+
elif past_key_values is not None:
|
| 1930 |
+
mask_len = past_key_values[0][0].shape[-2] + seq_len
|
| 1931 |
+
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
|
| 1932 |
+
|
| 1933 |
+
# Add in the masking bias.
|
| 1934 |
+
if attention_mask is not None:
|
| 1935 |
+
attention_bias = attention_bias + attention_mask
|
| 1936 |
+
# Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
|
| 1937 |
+
# `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
|
| 1938 |
+
# it can produce NaNs.
|
| 1939 |
+
ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
|
| 1940 |
+
|
| 1941 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
| 1942 |
+
|
| 1943 |
+
# decoder layers
|
| 1944 |
+
all_hidden_states = []
|
| 1945 |
+
|
| 1946 |
+
# Apply blocks one-by-one.
|
| 1947 |
+
if self.config.block_group_size == 1:
|
| 1948 |
+
for block_idx, block in enumerate(self.transformer.blocks):
|
| 1949 |
+
if output_hidden_states:
|
| 1950 |
+
# add hidden states
|
| 1951 |
+
all_hidden_states.append(x)
|
| 1952 |
+
|
| 1953 |
+
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
| 1954 |
+
x, cache = block(x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
|
| 1955 |
+
|
| 1956 |
+
if attn_key_values is not None:
|
| 1957 |
+
assert cache is not None
|
| 1958 |
+
attn_key_values.append(cache)
|
| 1959 |
+
else:
|
| 1960 |
+
for group_idx, block_group in enumerate(self.transformer.block_groups):
|
| 1961 |
+
if output_hidden_states:
|
| 1962 |
+
# add hidden states
|
| 1963 |
+
all_hidden_states.append(x)
|
| 1964 |
+
|
| 1965 |
+
layers_past = (
|
| 1966 |
+
None
|
| 1967 |
+
if past_key_values is None
|
| 1968 |
+
else past_key_values[
|
| 1969 |
+
group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
|
| 1970 |
+
]
|
| 1971 |
+
)
|
| 1972 |
+
x, cache = block_group(
|
| 1973 |
+
x, attention_bias=attention_bias, position_ids=position_ids, layers_past=layers_past, use_cache=use_cache
|
| 1974 |
+
)
|
| 1975 |
+
if attn_key_values is not None:
|
| 1976 |
+
assert cache is not None
|
| 1977 |
+
attn_key_values.extend(cache)
|
| 1978 |
+
|
| 1979 |
+
if last_logits_only:
|
| 1980 |
+
# shape: (batch_size, 1, d_model)
|
| 1981 |
+
if append_last_valid_logits is not None:
|
| 1982 |
+
last_valid_output = x[
|
| 1983 |
+
torch.arange(x.shape[0], device=x.device), append_last_valid_logits.to(x.device)]
|
| 1984 |
+
x = last_valid_output.unsqueeze(1)
|
| 1985 |
+
else:
|
| 1986 |
+
x = x[:, -1, :].unsqueeze(1)
|
| 1987 |
+
|
| 1988 |
+
# Apply final layer norm.
|
| 1989 |
+
# shape: (batch_size, seq_len or 1, d_model)
|
| 1990 |
+
x = self.transformer.ln_f(x) # type: ignore
|
| 1991 |
+
if output_hidden_states:
|
| 1992 |
+
# add final hidden state post-final-layernorm, following HuggingFace's convention
|
| 1993 |
+
all_hidden_states.append(x)
|
| 1994 |
+
|
| 1995 |
+
# Get logits.
|
| 1996 |
+
# shape: (batch_size, seq_len or 1, vocab_size)
|
| 1997 |
+
if self.config.weight_tying:
|
| 1998 |
+
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
|
| 1999 |
+
else:
|
| 2000 |
+
logits = self.transformer.ff_out(x) # type: ignore
|
| 2001 |
+
if self.config.scale_logits:
|
| 2002 |
+
logits.mul_(1 / math.sqrt(self.config.d_model))
|
| 2003 |
+
|
| 2004 |
+
if not last_logits_only and append_last_valid_logits is not None:
|
| 2005 |
+
last_valid_logit = logits[
|
| 2006 |
+
torch.arange(logits.shape[0], device=logits.device), append_last_valid_logits]
|
| 2007 |
+
logits = torch.cat([logits[:, :-1], last_valid_logit[:, None]], dim=1)
|
| 2008 |
+
|
| 2009 |
+
return ModelOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
|
| 2010 |
+
|
| 2011 |
+
|
| 2012 |
+
class MolmoForCausalLM(PreTrainedModel):
|
| 2013 |
+
config_class = MolmoConfig
|
| 2014 |
+
base_model_prefix = "model"
|
| 2015 |
+
_no_split_modules = ["MolmoBlock"]
|
| 2016 |
+
|
| 2017 |
+
def __init__(self, config: MolmoConfig, model: Optional[Molmo] = None, init_params: bool = False):
|
| 2018 |
+
super().__init__(config)
|
| 2019 |
+
|
| 2020 |
+
if not model:
|
| 2021 |
+
full_config = FullMolmoConfig(
|
| 2022 |
+
image_padding_embed="pad_and_partial_pad",
|
| 2023 |
+
image_pooling_2d="attention-meanq",
|
| 2024 |
+
attention_layer_norm=config.attention_layer_norm,
|
| 2025 |
+
rope_impl="llama",
|
| 2026 |
+
vocab_size=config.vocab_size,
|
| 2027 |
+
max_sequence_length=config.max_position_embeddings,
|
| 2028 |
+
qkv_bias=config.qkv_bias,
|
| 2029 |
+
norm_after=config.norm_after,
|
| 2030 |
+
embedding_size=config.embedding_size,
|
| 2031 |
+
attention_type="sdpa",
|
| 2032 |
+
embedding_dropout=0,
|
| 2033 |
+
attention_dropout=0,
|
| 2034 |
+
residual_dropout=0,
|
| 2035 |
+
rope=True,
|
| 2036 |
+
weight_tying=False,
|
| 2037 |
+
include_bias=False,
|
| 2038 |
+
d_model=config.hidden_size,
|
| 2039 |
+
mlp_hidden_size=config.intermediate_size,
|
| 2040 |
+
n_layers=config.num_hidden_layers,
|
| 2041 |
+
additional_vocab_size=128,
|
| 2042 |
+
n_heads=config.num_attention_heads,
|
| 2043 |
+
n_kv_heads=config.num_key_value_heads,
|
| 2044 |
+
rope_theta=config.rope_theta,
|
| 2045 |
+
layer_norm_eps=config.layer_norm_eps,
|
| 2046 |
+
layer_norm_type=config.layer_norm_type,
|
| 2047 |
+
vit_layers=[-2, -9],
|
| 2048 |
+
vision_backbone=VisionBackboneConfig(
|
| 2049 |
+
image_default_input_size=(336, 336),
|
| 2050 |
+
image_patch_size=14,
|
| 2051 |
+
image_pos_patch_size=14,
|
| 2052 |
+
image_emb_dim=1024,
|
| 2053 |
+
image_num_heads=16,
|
| 2054 |
+
image_num_key_value_heads=16,
|
| 2055 |
+
image_num_layers=23,
|
| 2056 |
+
image_head_dim=64,
|
| 2057 |
+
image_mlp_dim=4096,
|
| 2058 |
+
image_mlp_activations="quick_gelu",
|
| 2059 |
+
image_dropout_rate=0.0,
|
| 2060 |
+
image_num_pos=577,
|
| 2061 |
+
image_norm_eps=1e-5,
|
| 2062 |
+
attention_dropout=0.0,
|
| 2063 |
+
residual_dropout=0.0,
|
| 2064 |
+
initializer_range=0.02,
|
| 2065 |
+
)
|
| 2066 |
+
)
|
| 2067 |
+
self.model = Molmo(full_config, init_params=init_params)
|
| 2068 |
+
else:
|
| 2069 |
+
self.model = model
|
| 2070 |
+
|
| 2071 |
+
|
| 2072 |
+
def forward(
|
| 2073 |
+
self,
|
| 2074 |
+
input_ids: torch.LongTensor = None,
|
| 2075 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 2076 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 2077 |
+
attention_bias: Optional[torch.Tensor] = None,
|
| 2078 |
+
response_mask: Optional[torch.Tensor] = None,
|
| 2079 |
+
images: Optional[torch.Tensor] = None,
|
| 2080 |
+
image_masks: Optional[torch.Tensor] = None,
|
| 2081 |
+
image_input_idx: Optional[torch.Tensor] = None,
|
| 2082 |
+
subsegment_ids: Optional[torch.Tensor] = None,
|
| 2083 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 2084 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 2085 |
+
labels: Optional[torch.LongTensor] = None,
|
| 2086 |
+
loss_masks: Optional[torch.Tensor] = None,
|
| 2087 |
+
use_cache: Optional[bool] = None,
|
| 2088 |
+
last_logits_only: Optional[bool] = None,
|
| 2089 |
+
output_attentions: Optional[bool] = None,
|
| 2090 |
+
output_hidden_states: Optional[bool] = None,
|
| 2091 |
+
append_last_valid_logits: Optional[torch.Tensor] = None,
|
| 2092 |
+
return_dict: Optional[bool] = None,
|
| 2093 |
+
cache_position: Optional[
|
| 2094 |
+
Cache
|
| 2095 |
+
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
|
| 2096 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 2097 |
+
if use_cache is None:
|
| 2098 |
+
use_cache = self.config.use_cache
|
| 2099 |
+
|
| 2100 |
+
if output_attentions:
|
| 2101 |
+
raise ValueError("output_attentions is not yet supported in Molmo")
|
| 2102 |
+
|
| 2103 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2104 |
+
|
| 2105 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 2106 |
+
outputs = self.model.forward(
|
| 2107 |
+
input_ids=input_ids,
|
| 2108 |
+
input_embeddings=inputs_embeds,
|
| 2109 |
+
attention_mask=attention_mask,
|
| 2110 |
+
attention_bias=attention_bias,
|
| 2111 |
+
response_mask=response_mask,
|
| 2112 |
+
images=images,
|
| 2113 |
+
image_masks=image_masks,
|
| 2114 |
+
image_input_idx=image_input_idx,
|
| 2115 |
+
subsegment_ids=subsegment_ids,
|
| 2116 |
+
position_ids=position_ids,
|
| 2117 |
+
past_key_values=past_key_values,
|
| 2118 |
+
use_cache=use_cache,
|
| 2119 |
+
last_logits_only=last_logits_only,
|
| 2120 |
+
output_hidden_states=output_hidden_states,
|
| 2121 |
+
append_last_valid_logits=append_last_valid_logits,
|
| 2122 |
+
)
|
| 2123 |
+
|
| 2124 |
+
logits = outputs.logits
|
| 2125 |
+
hidden_states = outputs.hidden_states
|
| 2126 |
+
|
| 2127 |
+
loss = None
|
| 2128 |
+
if labels is not None:
|
| 2129 |
+
if loss_masks is not None:
|
| 2130 |
+
loss_masks = loss_masks * (loss_masks > 0)
|
| 2131 |
+
batch_size_in_tokens = max(loss_masks.sum().item(), 1)
|
| 2132 |
+
labels = labels.long()
|
| 2133 |
+
labels.masked_fill_(~(loss_masks > 0), -100)
|
| 2134 |
+
labels = labels.view(-1)
|
| 2135 |
+
logits_for_loss = logits.to(torch.float32).view(-1, logits.size(-1))
|
| 2136 |
+
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
|
| 2137 |
+
loss = loss_fct(logits_for_loss, labels)
|
| 2138 |
+
loss = loss.view(input_ids.shape[0], -1)
|
| 2139 |
+
loss = loss * loss_masks
|
| 2140 |
+
loss = loss.sum() / batch_size_in_tokens
|
| 2141 |
+
use_zloss = getattr(self.config, "softmax_auxiliary_loss", False)
|
| 2142 |
+
if use_zloss:
|
| 2143 |
+
z_squared = logits_for_loss.logsumexp(-1).pow(2)
|
| 2144 |
+
z_loss = self.config.softmax_auxiliary_loss_scale * z_squared
|
| 2145 |
+
z_loss = z_loss.view(input_ids.shape[0], -1)
|
| 2146 |
+
z_loss = z_loss * loss_masks
|
| 2147 |
+
z_loss = z_loss.sum() / batch_size_in_tokens
|
| 2148 |
+
loss += z_loss
|
| 2149 |
+
else:
|
| 2150 |
+
# Shift so that tokens < n predict n
|
| 2151 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 2152 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 2153 |
+
# Flatten the tokens
|
| 2154 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
| 2155 |
+
shift_logits = shift_logits.view(-1, self.config.embedding_size)
|
| 2156 |
+
shift_labels = shift_labels.view(-1)
|
| 2157 |
+
# Enable model parallelism
|
| 2158 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 2159 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 2160 |
+
|
| 2161 |
+
if not return_dict:
|
| 2162 |
+
output = (logits,) + outputs[1:]
|
| 2163 |
+
return (loss,) + output if loss is not None else output
|
| 2164 |
+
|
| 2165 |
+
return CausalLMOutputWithPast(
|
| 2166 |
+
loss=loss,
|
| 2167 |
+
logits=logits,
|
| 2168 |
+
past_key_values=outputs.attn_key_values,
|
| 2169 |
+
hidden_states=hidden_states,
|
| 2170 |
+
)
|
| 2171 |
+
|
| 2172 |
+
def can_generate(self) -> bool:
|
| 2173 |
+
return True
|
| 2174 |
+
|
| 2175 |
+
@torch.no_grad()
|
| 2176 |
+
def generate_from_batch(
|
| 2177 |
+
self,
|
| 2178 |
+
batch: Dict[str, Any],
|
| 2179 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 2180 |
+
**kwargs,
|
| 2181 |
+
):
|
| 2182 |
+
if generation_config is not None:
|
| 2183 |
+
assert generation_config.use_cache
|
| 2184 |
+
|
| 2185 |
+
images = batch.get("images")
|
| 2186 |
+
image_masks = batch.get("image_masks")
|
| 2187 |
+
image_input_idx = batch.get("image_input_idx")
|
| 2188 |
+
|
| 2189 |
+
# Validate inputs.
|
| 2190 |
+
input_ids = batch["input_ids"]
|
| 2191 |
+
batch_size, seq_len = input_ids.shape
|
| 2192 |
+
attention_mask = batch.get("attention_mask", None)
|
| 2193 |
+
max_new_tokens = generation_config.max_new_tokens
|
| 2194 |
+
assert max_new_tokens is not None
|
| 2195 |
+
mask_len = seq_len + max_new_tokens if self.config.use_position_ids else seq_len
|
| 2196 |
+
position_ids: Optional[torch.Tensor] = None
|
| 2197 |
+
append_last_valid_logits: Optional[torch.Tensor] = None
|
| 2198 |
+
if self.config.use_position_ids and attention_mask is None:
|
| 2199 |
+
attention_mask = input_ids != -1
|
| 2200 |
+
position_ids = torch.clamp(
|
| 2201 |
+
torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1,
|
| 2202 |
+
min=0
|
| 2203 |
+
)
|
| 2204 |
+
append_last_valid_logits = attention_mask.long().sum(dim=-1) - 1
|
| 2205 |
+
attention_mask = torch.cat(
|
| 2206 |
+
[attention_mask, attention_mask.new_ones((batch_size, max_new_tokens))],
|
| 2207 |
+
dim=1,
|
| 2208 |
+
)
|
| 2209 |
+
if attention_mask is not None:
|
| 2210 |
+
assert attention_mask.shape == (batch_size, mask_len)
|
| 2211 |
+
|
| 2212 |
+
out = super().generate(
|
| 2213 |
+
batch["input_ids"],
|
| 2214 |
+
generation_config,
|
| 2215 |
+
attention_mask=attention_mask,
|
| 2216 |
+
images=images,
|
| 2217 |
+
image_masks=image_masks,
|
| 2218 |
+
image_input_idx=image_input_idx,
|
| 2219 |
+
position_ids=position_ids,
|
| 2220 |
+
append_last_valid_logits=append_last_valid_logits,
|
| 2221 |
+
**kwargs,
|
| 2222 |
+
)
|
| 2223 |
+
|
| 2224 |
+
return out
|
| 2225 |
+
|
| 2226 |
+
def prepare_inputs_for_generation(
|
| 2227 |
+
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
| 2228 |
+
):
|
| 2229 |
+
if past_key_values:
|
| 2230 |
+
# This is because we want the model to only process the last generated token.
|
| 2231 |
+
input_ids = input_ids[:, -1:]
|
| 2232 |
+
|
| 2233 |
+
if self.config.use_position_ids:
|
| 2234 |
+
attention_mask = kwargs.get("attention_mask")
|
| 2235 |
+
images = kwargs.get("images")
|
| 2236 |
+
image_masks = kwargs.get("image_masks")
|
| 2237 |
+
image_input_idx = kwargs.get("image_input_idx")
|
| 2238 |
+
position_ids = kwargs.get("position_ids")
|
| 2239 |
+
append_last_valid_logits = kwargs.get("append_last_valid_logits")
|
| 2240 |
+
model_inputs = {
|
| 2241 |
+
"input_ids": input_ids,
|
| 2242 |
+
"attention_mask": attention_mask,
|
| 2243 |
+
"position_ids": position_ids,
|
| 2244 |
+
"past_key_values": past_key_values,
|
| 2245 |
+
"use_cache": True,
|
| 2246 |
+
"last_logits_only": True,
|
| 2247 |
+
}
|
| 2248 |
+
if past_key_values is None:
|
| 2249 |
+
model_inputs["images"] = images
|
| 2250 |
+
model_inputs["image_masks"] = image_masks
|
| 2251 |
+
model_inputs["image_input_idx"] = image_input_idx
|
| 2252 |
+
model_inputs["append_last_valid_logits"] = append_last_valid_logits
|
| 2253 |
+
else:
|
| 2254 |
+
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
| 2255 |
+
|
| 2256 |
+
model_inputs.update(kwargs)
|
| 2257 |
+
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
|
| 2258 |
+
return model_inputs
|
| 2259 |
+
|
| 2260 |
+
def _update_model_kwargs_for_generation(
|
| 2261 |
+
self,
|
| 2262 |
+
outputs: ModelOutput,
|
| 2263 |
+
model_kwargs: Dict[str, Any],
|
| 2264 |
+
is_encoder_decoder: bool = False,
|
| 2265 |
+
num_new_tokens: int = 1,
|
| 2266 |
+
) -> Dict[str, Any]:
|
| 2267 |
+
if self.config.use_position_ids:
|
| 2268 |
+
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
| 2269 |
+
if "append_last_valid_logits" in model_kwargs:
|
| 2270 |
+
del model_kwargs["append_last_valid_logits"]
|
| 2271 |
+
if "images" in model_kwargs:
|
| 2272 |
+
del model_kwargs["images"]
|
| 2273 |
+
del model_kwargs["image_masks"]
|
| 2274 |
+
del model_kwargs["image_input_idx"]
|
| 2275 |
+
cache_name, cache = super()._extract_past_from_model_output(outputs)
|
| 2276 |
+
try:
|
| 2277 |
+
cache_name, cache = super()._extract_past_from_model_output(outputs)
|
| 2278 |
+
except AttributeError:
|
| 2279 |
+
past_key_values = outputs.past_key_values if "past_key_values" in outputs else None
|
| 2280 |
+
cache_name, cache = "past_key_values", past_key_values
|
| 2281 |
+
model_kwargs[cache_name] = cache
|
| 2282 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
| 2283 |
+
return model_kwargs
|
| 2284 |
+
|
| 2285 |
+
def get_input_embeddings(self) -> torch.nn.Module:
|
| 2286 |
+
return self.model.transformer.wte
|
| 2287 |
+
|
| 2288 |
+
def set_input_embeddings(self, value: torch.nn.Module):
|
| 2289 |
+
self.model.transformer.wte = value
|
| 2290 |
+
|
| 2291 |
+
def get_output_embeddings(self):
|
| 2292 |
+
if self.config.weight_tying:
|
| 2293 |
+
return self.model.transformer.wte
|
| 2294 |
+
else:
|
| 2295 |
+
return self.model.transformer.ff_out
|
| 2296 |
+
|
| 2297 |
+
def set_output_embeddings(self, value: torch.nn.Module):
|
| 2298 |
+
if self.config.weight_tying:
|
| 2299 |
+
self.model.transformer.wte = value
|
| 2300 |
+
else:
|
| 2301 |
+
self.model.transformer.ff_out = value
|
| 2302 |
+
|
| 2303 |
+
def tie_weights(self):
|
| 2304 |
+
"""
|
| 2305 |
+
This function is intentionally left as a no-op.
|
| 2306 |
+
|
| 2307 |
+
Weight tying is handled as follows:
|
| 2308 |
+
- When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration.
|
| 2309 |
+
See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`.
|
| 2310 |
+
- When computing logits, the `wte` weights are used directly if `weight_tying` is enabled.
|
| 2311 |
+
See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method.
|
| 2312 |
+
|
| 2313 |
+
Therefore, there is no need to explicitly tie the weights in this function.
|
| 2314 |
+
"""
|
| 2315 |
+
pass
|
| 2316 |
+
|
| 2317 |
+
def resize_token_embeddings(
|
| 2318 |
+
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
|
| 2319 |
+
) -> torch.nn.Embedding:
|
| 2320 |
+
"""
|
| 2321 |
+
Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`.
|
| 2322 |
+
|
| 2323 |
+
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
| 2324 |
+
|
| 2325 |
+
Arguments:
|
| 2326 |
+
new_num_tokens (`int`, *optional*):
|
| 2327 |
+
The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
|
| 2328 |
+
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
|
| 2329 |
+
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
|
| 2330 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 2331 |
+
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
|
| 2332 |
+
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
|
| 2333 |
+
|
| 2334 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
| 2335 |
+
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
|
| 2336 |
+
details about this, or help on choosing the correct value for resizing, refer to this guide:
|
| 2337 |
+
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
| 2338 |
+
|
| 2339 |
+
Return:
|
| 2340 |
+
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
| 2341 |
+
|
| 2342 |
+
Note:
|
| 2343 |
+
This method differs from the base class implementation by resizing the `embedding_size` attribute of the
|
| 2344 |
+
model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size`
|
| 2345 |
+
is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token
|
| 2346 |
+
embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary.
|
| 2347 |
+
"""
|
| 2348 |
+
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
| 2349 |
+
if new_num_tokens is None and pad_to_multiple_of is None:
|
| 2350 |
+
return model_embeds
|
| 2351 |
+
|
| 2352 |
+
# Update base model and current model config
|
| 2353 |
+
self.config.embedding_size = model_embeds.weight.shape[0]
|
| 2354 |
+
self.model.config.embedding_size = model_embeds.weight.shape[0]
|
| 2355 |
+
|
| 2356 |
+
# Check if the embedding size is less than the vocab size
|
| 2357 |
+
if self.config.embedding_size < self.config.vocab_size:
|
| 2358 |
+
warning_message = (
|
| 2359 |
+
f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size "
|
| 2360 |
+
f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary "
|
| 2361 |
+
"size is less than or equal to the new token embedding size."
|
| 2362 |
+
)
|
| 2363 |
+
log.warning(warning_message)
|
| 2364 |
+
|
| 2365 |
+
# Tie weights again if needed
|
| 2366 |
+
self.tie_weights()
|
| 2367 |
+
|
| 2368 |
+
return model_embeds
|
| 2369 |
+
|
| 2370 |
+
|
| 2371 |
+
# Always register for multi-modal features
|
| 2372 |
+
AutoModelForCausalLM.register(MolmoConfig, MolmoForCausalLM)
|
molmo_logo.png
ADDED
|
preprocessing_molmo.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor class for Molmo.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import PIL
|
| 8 |
+
from PIL import ImageOps
|
| 9 |
+
from PIL.Image import Image
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from typing import Unpack
|
| 13 |
+
except ImportError:
|
| 14 |
+
from typing_extensions import Unpack
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from transformers.image_utils import ImageInput
|
| 20 |
+
from transformers.processing_utils import (
|
| 21 |
+
TextKwargs,
|
| 22 |
+
ProcessingKwargs,
|
| 23 |
+
ProcessorMixin,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
|
| 27 |
+
from transformers.utils import logging
|
| 28 |
+
|
| 29 |
+
from transformers import AutoTokenizer
|
| 30 |
+
from .image_preprocessing_molmo import MolmoImagesKwargs, MolmoImageProcessor
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
DEFAULT_IMAGE_PATCH_TOKEN = f"<im_patch>"
|
| 37 |
+
DEFAULT_IM_START_TOKEN = f"<im_start>"
|
| 38 |
+
DEFAULT_IM_END_TOKEN = f"<im_end>"
|
| 39 |
+
DEFAULT_IM_COL_TOKEN = f"<im_col>"
|
| 40 |
+
IMAGE_PROMPT = "<|image|>"
|
| 41 |
+
|
| 42 |
+
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_special_token_ids(tokenizer):
|
| 46 |
+
ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
|
| 47 |
+
assert len(ids) == len(EXTRA_TOKENS)
|
| 48 |
+
return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class MolmoTextKwargs(TextKwargs, total=False):
|
| 52 |
+
style: Optional[str]
|
| 53 |
+
system_prompt: Optional[str]
|
| 54 |
+
message_format: Optional[str]
|
| 55 |
+
always_start_with_space: Optional[bool]
|
| 56 |
+
sequence_length: Optional[int]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MolmoProcessorKwargs(ProcessingKwargs, total=False):
|
| 60 |
+
text_kwargs: MolmoTextKwargs
|
| 61 |
+
images_kwargs: MolmoImagesKwargs
|
| 62 |
+
_defaults = {
|
| 63 |
+
"images_kwargs": {
|
| 64 |
+
"max_crops": 12,
|
| 65 |
+
"overlap_margins": [4, 4],
|
| 66 |
+
"base_image_input_size": [336, 336],
|
| 67 |
+
"image_token_length_w": 12,
|
| 68 |
+
"image_token_length_h": 12,
|
| 69 |
+
"image_patch_size": 14,
|
| 70 |
+
"image_padding_mask": True,
|
| 71 |
+
},
|
| 72 |
+
"text_kwargs": {
|
| 73 |
+
"style": "long_caption",
|
| 74 |
+
"system_prompt": "none",
|
| 75 |
+
"message_format": "role",
|
| 76 |
+
"always_start_with_space": True,
|
| 77 |
+
"sequence_length": 1536,
|
| 78 |
+
"padding": False,
|
| 79 |
+
},
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class MolmoProcessor(ProcessorMixin):
|
| 84 |
+
attributes = ["image_processor", "tokenizer"]
|
| 85 |
+
image_processor_class = "AutoImageProcessor"
|
| 86 |
+
tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast")
|
| 87 |
+
|
| 88 |
+
def __init__(self, image_processor: MolmoImageProcessor = None, tokenizer : AutoTokenizer = None, **kwargs):
|
| 89 |
+
# self.image_processor = image_processor
|
| 90 |
+
# self.tokenizer = tokenizer
|
| 91 |
+
super().__init__(image_processor, tokenizer)
|
| 92 |
+
self._special_tokens = None
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def special_token_ids(self):
|
| 96 |
+
if self._special_tokens is None:
|
| 97 |
+
self._special_tokens = get_special_token_ids(self.tokenizer)
|
| 98 |
+
return self._special_tokens
|
| 99 |
+
|
| 100 |
+
def get_tokens_input(self, prompt, message_format, always_start_with_space):
|
| 101 |
+
if message_format == "none" or message_format is None:
|
| 102 |
+
pass
|
| 103 |
+
elif message_format == "role":
|
| 104 |
+
prompt = "User: " + prompt + " Assistant:"
|
| 105 |
+
else:
|
| 106 |
+
raise NotImplementedError(f"Message format {message_format} not implemented")
|
| 107 |
+
|
| 108 |
+
if always_start_with_space:
|
| 109 |
+
prompt = " " + prompt
|
| 110 |
+
|
| 111 |
+
tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
|
| 112 |
+
|
| 113 |
+
return tokens
|
| 114 |
+
|
| 115 |
+
def process(
|
| 116 |
+
self,
|
| 117 |
+
text: TextInput = None,
|
| 118 |
+
images: ImageInput = None,
|
| 119 |
+
*,
|
| 120 |
+
tokens: Optional[PreTokenizedInput] = None,
|
| 121 |
+
**kwargs: Unpack[MolmoProcessorKwargs],
|
| 122 |
+
):
|
| 123 |
+
output_kwargs = self._merge_kwargs(
|
| 124 |
+
MolmoProcessorKwargs,
|
| 125 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 126 |
+
**kwargs,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if tokens is None:
|
| 130 |
+
tokens = self.get_tokens_input(
|
| 131 |
+
text,
|
| 132 |
+
output_kwargs["text_kwargs"]["message_format"],
|
| 133 |
+
output_kwargs["text_kwargs"]["always_start_with_space"],
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
image_token_id = self.special_token_ids[IMAGE_PROMPT]
|
| 137 |
+
|
| 138 |
+
if images is not None:
|
| 139 |
+
if not isinstance(images, (list, tuple)):
|
| 140 |
+
images = [images]
|
| 141 |
+
image_arrays = []
|
| 142 |
+
for image in images:
|
| 143 |
+
if isinstance(image, Image):
|
| 144 |
+
image = image.convert("RGB")
|
| 145 |
+
# Handle images with EXIF orientation tags, which PIL will ignore by default
|
| 146 |
+
# https://github.com/python-pillow/Pillow/issues/4703
|
| 147 |
+
img = ImageOps.exif_transpose(image)
|
| 148 |
+
image_arrays.append(np.array(image))
|
| 149 |
+
else:
|
| 150 |
+
assert len(image.shape) == 3 and image.shape[-1] == 3
|
| 151 |
+
image_arrays.append(image.astype(np.uint8))
|
| 152 |
+
images = image_arrays
|
| 153 |
+
# For now only support inserting images at the start
|
| 154 |
+
image_idx = [-1]*len(images)
|
| 155 |
+
else:
|
| 156 |
+
image_idx = None
|
| 157 |
+
|
| 158 |
+
sequence_length = output_kwargs["text_kwargs"]["sequence_length"]
|
| 159 |
+
|
| 160 |
+
image_patch_token_id = self.special_token_ids[DEFAULT_IMAGE_PATCH_TOKEN]
|
| 161 |
+
image_col_token_id = self.special_token_ids[DEFAULT_IM_COL_TOKEN]
|
| 162 |
+
image_start_token_id = self.special_token_ids[DEFAULT_IM_START_TOKEN]
|
| 163 |
+
image_end_token_id = self.special_token_ids[DEFAULT_IM_END_TOKEN]
|
| 164 |
+
out = self.image_processor.multimodal_preprocess(
|
| 165 |
+
images=images,
|
| 166 |
+
image_idx=image_idx,
|
| 167 |
+
tokens=np.asarray(tokens).astype(np.int32),
|
| 168 |
+
sequence_length=sequence_length,
|
| 169 |
+
image_patch_token_id=image_patch_token_id,
|
| 170 |
+
image_col_token_id=image_col_token_id,
|
| 171 |
+
image_start_token_id=image_start_token_id,
|
| 172 |
+
image_end_token_id=image_end_token_id,
|
| 173 |
+
**output_kwargs["images_kwargs"]
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Prepend BOS
|
| 177 |
+
# qwen2 and olmo do not have a BOS, and instead use EOS as a generic seperator token.
|
| 178 |
+
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
| 179 |
+
decoder_input_tokens = np.pad(out["input_ids"], [[1, 0]], constant_values=bos)
|
| 180 |
+
out["input_ids"] = decoder_input_tokens
|
| 181 |
+
if "image_input_idx" in out:
|
| 182 |
+
# Shift patch mapping up by one since we added BOS
|
| 183 |
+
image_input_idx = out["image_input_idx"]
|
| 184 |
+
out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
|
| 185 |
+
|
| 186 |
+
for k, v in out.items():
|
| 187 |
+
out[k] = torch.from_numpy(v)
|
| 188 |
+
|
| 189 |
+
return out
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
MolmoProcessor.register_for_auto_class()
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoImageProcessor": "image_preprocessing_molmo.MolmoImageProcessor",
|
| 4 |
+
"AutoProcessor": "preprocessing_molmo.MolmoProcessor"
|
| 5 |
+
},
|
| 6 |
+
"base_image_input_size": [
|
| 7 |
+
336,
|
| 8 |
+
336
|
| 9 |
+
],
|
| 10 |
+
"do_normalize": true,
|
| 11 |
+
"image_mean": [
|
| 12 |
+
0.48145466,
|
| 13 |
+
0.4578275,
|
| 14 |
+
0.40821073
|
| 15 |
+
],
|
| 16 |
+
"image_padding_mask": true,
|
| 17 |
+
"image_patch_size": 14,
|
| 18 |
+
"image_processor_type": "MolmoImageProcessor",
|
| 19 |
+
"image_std": [
|
| 20 |
+
0.26862954,
|
| 21 |
+
0.26130258,
|
| 22 |
+
0.27577711
|
| 23 |
+
],
|
| 24 |
+
"image_token_length_h": 12,
|
| 25 |
+
"image_token_length_w": 12,
|
| 26 |
+
"max_crops": 12,
|
| 27 |
+
"overlap_margins": [
|
| 28 |
+
4,
|
| 29 |
+
4
|
| 30 |
+
],
|
| 31 |
+
"processor_class": "MolmoProcessor"
|
| 32 |
+
}
|
quantization_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_skip_keys": false,
|
| 3 |
+
"dequantize_fp32": false,
|
| 4 |
+
"group_size": 0,
|
| 5 |
+
"is_integer": true,
|
| 6 |
+
"modules_dtype_dict": {},
|
| 7 |
+
"modules_to_not_convert": [],
|
| 8 |
+
"non_blocking": false,
|
| 9 |
+
"quant_conv": false,
|
| 10 |
+
"quant_method": "sdnq",
|
| 11 |
+
"quantization_device": null,
|
| 12 |
+
"return_device": null,
|
| 13 |
+
"svd_rank": 32,
|
| 14 |
+
"svd_steps": 8,
|
| 15 |
+
"use_quantized_matmul": false,
|
| 16 |
+
"use_quantized_matmul_conv": false,
|
| 17 |
+
"use_svd": false,
|
| 18 |
+
"weights_dtype": "int8"
|
| 19 |
+
}
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<im_start>",
|
| 4 |
+
"<im_end>",
|
| 5 |
+
"<im_patch>",
|
| 6 |
+
"<im_col>",
|
| 7 |
+
"<|image|>"
|
| 8 |
+
],
|
| 9 |
+
"bos_token": {
|
| 10 |
+
"content": "<|endoftext|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"eos_token": {
|
| 17 |
+
"content": "<|endoftext|>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"pad_token": {
|
| 24 |
+
"content": "<|pad|>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"unk_token": {
|
| 31 |
+
"content": "<|endoftext|>",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
}
|
| 37 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"100256": {
|
| 5 |
+
"content": "<|extra_id_0|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": false,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": false
|
| 11 |
+
},
|
| 12 |
+
"100257": {
|
| 13 |
+
"content": "<|endoftext|>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": false,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
},
|
| 20 |
+
"100258": {
|
| 21 |
+
"content": "<|fim_prefix|>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false,
|
| 26 |
+
"special": true
|
| 27 |
+
},
|
| 28 |
+
"100259": {
|
| 29 |
+
"content": "<|fim_middle|>",
|
| 30 |
+
"lstrip": false,
|
| 31 |
+
"normalized": false,
|
| 32 |
+
"rstrip": false,
|
| 33 |
+
"single_word": false,
|
| 34 |
+
"special": true
|
| 35 |
+
},
|
| 36 |
+
"100260": {
|
| 37 |
+
"content": "<|fim_suffix|>",
|
| 38 |
+
"lstrip": false,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"rstrip": false,
|
| 41 |
+
"single_word": false,
|
| 42 |
+
"special": true
|
| 43 |
+
},
|
| 44 |
+
"100261": {
|
| 45 |
+
"content": "|||PHONE_NUMBER|||",
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"normalized": false,
|
| 48 |
+
"rstrip": false,
|
| 49 |
+
"single_word": false,
|
| 50 |
+
"special": false
|
| 51 |
+
},
|
| 52 |
+
"100262": {
|
| 53 |
+
"content": "|||EMAIL_ADDRESS|||",
|
| 54 |
+
"lstrip": false,
|
| 55 |
+
"normalized": false,
|
| 56 |
+
"rstrip": false,
|
| 57 |
+
"single_word": false,
|
| 58 |
+
"special": false
|
| 59 |
+
},
|
| 60 |
+
"100263": {
|
| 61 |
+
"content": "|||IP_ADDRESS|||",
|
| 62 |
+
"lstrip": false,
|
| 63 |
+
"normalized": false,
|
| 64 |
+
"rstrip": false,
|
| 65 |
+
"single_word": false,
|
| 66 |
+
"special": false
|
| 67 |
+
},
|
| 68 |
+
"100264": {
|
| 69 |
+
"content": "<|im_start|>",
|
| 70 |
+
"lstrip": false,
|
| 71 |
+
"normalized": false,
|
| 72 |
+
"rstrip": false,
|
| 73 |
+
"single_word": false,
|
| 74 |
+
"special": true
|
| 75 |
+
},
|
| 76 |
+
"100265": {
|
| 77 |
+
"content": "<|im_end|>",
|
| 78 |
+
"lstrip": false,
|
| 79 |
+
"normalized": false,
|
| 80 |
+
"rstrip": false,
|
| 81 |
+
"single_word": false,
|
| 82 |
+
"special": true
|
| 83 |
+
},
|
| 84 |
+
"100266": {
|
| 85 |
+
"content": "<|extra_id_1|>",
|
| 86 |
+
"lstrip": false,
|
| 87 |
+
"normalized": false,
|
| 88 |
+
"rstrip": false,
|
| 89 |
+
"single_word": false,
|
| 90 |
+
"special": false
|
| 91 |
+
},
|
| 92 |
+
"100267": {
|
| 93 |
+
"content": "<|extra_id_2|>",
|
| 94 |
+
"lstrip": false,
|
| 95 |
+
"normalized": false,
|
| 96 |
+
"rstrip": false,
|
| 97 |
+
"single_word": false,
|
| 98 |
+
"special": false
|
| 99 |
+
},
|
| 100 |
+
"100268": {
|
| 101 |
+
"content": "<|extra_id_3|>",
|
| 102 |
+
"lstrip": false,
|
| 103 |
+
"normalized": false,
|
| 104 |
+
"rstrip": false,
|
| 105 |
+
"single_word": false,
|
| 106 |
+
"special": false
|
| 107 |
+
},
|
| 108 |
+
"100269": {
|
| 109 |
+
"content": "<|extra_id_4|>",
|
| 110 |
+
"lstrip": false,
|
| 111 |
+
"normalized": false,
|
| 112 |
+
"rstrip": false,
|
| 113 |
+
"single_word": false,
|
| 114 |
+
"special": false
|
| 115 |
+
},
|
| 116 |
+
"100270": {
|
| 117 |
+
"content": "<|extra_id_5|>",
|
| 118 |
+
"lstrip": false,
|
| 119 |
+
"normalized": false,
|
| 120 |
+
"rstrip": false,
|
| 121 |
+
"single_word": false,
|
| 122 |
+
"special": false
|
| 123 |
+
},
|
| 124 |
+
"100271": {
|
| 125 |
+
"content": "<|extra_id_6|>",
|
| 126 |
+
"lstrip": false,
|
| 127 |
+
"normalized": false,
|
| 128 |
+
"rstrip": false,
|
| 129 |
+
"single_word": false,
|
| 130 |
+
"special": false
|
| 131 |
+
},
|
| 132 |
+
"100272": {
|
| 133 |
+
"content": "<|extra_id_7|>",
|
| 134 |
+
"lstrip": false,
|
| 135 |
+
"normalized": false,
|
| 136 |
+
"rstrip": false,
|
| 137 |
+
"single_word": false,
|
| 138 |
+
"special": false
|
| 139 |
+
},
|
| 140 |
+
"100273": {
|
| 141 |
+
"content": "<|extra_id_8|>",
|
| 142 |
+
"lstrip": false,
|
| 143 |
+
"normalized": false,
|
| 144 |
+
"rstrip": false,
|
| 145 |
+
"single_word": false,
|
| 146 |
+
"special": false
|
| 147 |
+
},
|
| 148 |
+
"100274": {
|
| 149 |
+
"content": "<|extra_id_9|>",
|
| 150 |
+
"lstrip": false,
|
| 151 |
+
"normalized": false,
|
| 152 |
+
"rstrip": false,
|
| 153 |
+
"single_word": false,
|
| 154 |
+
"special": false
|
| 155 |
+
},
|
| 156 |
+
"100275": {
|
| 157 |
+
"content": "<|extra_id_10|>",
|
| 158 |
+
"lstrip": false,
|
| 159 |
+
"normalized": false,
|
| 160 |
+
"rstrip": false,
|
| 161 |
+
"single_word": false,
|
| 162 |
+
"special": false
|
| 163 |
+
},
|
| 164 |
+
"100276": {
|
| 165 |
+
"content": "<|endofprompt|>",
|
| 166 |
+
"lstrip": false,
|
| 167 |
+
"normalized": false,
|
| 168 |
+
"rstrip": false,
|
| 169 |
+
"single_word": false,
|
| 170 |
+
"special": true
|
| 171 |
+
},
|
| 172 |
+
"100277": {
|
| 173 |
+
"content": "<|pad|>",
|
| 174 |
+
"lstrip": false,
|
| 175 |
+
"normalized": false,
|
| 176 |
+
"rstrip": false,
|
| 177 |
+
"single_word": false,
|
| 178 |
+
"special": true
|
| 179 |
+
},
|
| 180 |
+
"100278": {
|
| 181 |
+
"content": "<im_start>",
|
| 182 |
+
"lstrip": false,
|
| 183 |
+
"normalized": false,
|
| 184 |
+
"rstrip": false,
|
| 185 |
+
"single_word": false,
|
| 186 |
+
"special": true
|
| 187 |
+
},
|
| 188 |
+
"100279": {
|
| 189 |
+
"content": "<im_end>",
|
| 190 |
+
"lstrip": false,
|
| 191 |
+
"normalized": false,
|
| 192 |
+
"rstrip": false,
|
| 193 |
+
"single_word": false,
|
| 194 |
+
"special": true
|
| 195 |
+
},
|
| 196 |
+
"100280": {
|
| 197 |
+
"content": "<im_patch>",
|
| 198 |
+
"lstrip": false,
|
| 199 |
+
"normalized": false,
|
| 200 |
+
"rstrip": false,
|
| 201 |
+
"single_word": false,
|
| 202 |
+
"special": true
|
| 203 |
+
},
|
| 204 |
+
"100281": {
|
| 205 |
+
"content": "<im_col>",
|
| 206 |
+
"lstrip": false,
|
| 207 |
+
"normalized": false,
|
| 208 |
+
"rstrip": false,
|
| 209 |
+
"single_word": false,
|
| 210 |
+
"special": true
|
| 211 |
+
},
|
| 212 |
+
"100282": {
|
| 213 |
+
"content": "<|image|>",
|
| 214 |
+
"lstrip": false,
|
| 215 |
+
"normalized": false,
|
| 216 |
+
"rstrip": false,
|
| 217 |
+
"single_word": false,
|
| 218 |
+
"special": true
|
| 219 |
+
}
|
| 220 |
+
},
|
| 221 |
+
"additional_special_tokens": [
|
| 222 |
+
"<im_start>",
|
| 223 |
+
"<im_end>",
|
| 224 |
+
"<im_patch>",
|
| 225 |
+
"<im_col>",
|
| 226 |
+
"<|image|>"
|
| 227 |
+
],
|
| 228 |
+
"auto_map": {
|
| 229 |
+
"AutoProcessor": "preprocessing_molmo.MolmoProcessor"
|
| 230 |
+
},
|
| 231 |
+
"bos_token": "<|endoftext|>",
|
| 232 |
+
"chat_template": "{% for message in messages -%}\n {%- if (loop.index % 2 == 1 and message['role'] != 'user') or \n (loop.index % 2 == 0 and message['role'].lower() != 'assistant') -%}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif -%}\n {{ message['role'].capitalize() + ': ' + message['content'] }}\n {%- if not loop.last -%}\n {{ ' ' }}\n {%- endif %}\n {%- endfor -%}\n {%- if add_generation_prompt -%}\n {{ ' Assistant:' }}\n {%- endif %}",
|
| 233 |
+
"clean_up_tokenization_spaces": false,
|
| 234 |
+
"eos_token": "<|endoftext|>",
|
| 235 |
+
"model_max_length": 8192,
|
| 236 |
+
"pad_token": "<|pad|>",
|
| 237 |
+
"processor_class": "MolmoProcessor",
|
| 238 |
+
"tokenizer_class": "GPT2Tokenizer",
|
| 239 |
+
"unk_token": "<|endoftext|>"
|
| 240 |
+
}
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|