guoxy25 commited on
Commit
2abe772
·
verified ·
1 Parent(s): bce5a8e

Upload 56 files

Browse files
added_tokens.json CHANGED
@@ -7,32 +7,32 @@
7
  "<B_USYS>": 151666,
8
  "<C_A>": 151668,
9
  "<C_Q>": 151667,
10
- "<audio_delim_baichuan>": 151693,
11
- "<audio_end_baichuan>": 151677,
12
- "<audio_pad_baichuan>": 151678,
13
- "<audio_start_baichuan>": 151676,
14
- "<baichuan_pad_token>": 151691,
15
- "<box_delim_baichuan>": 151685,
16
- "<box_end_baichuan>": 151684,
17
- "<box_start_baichuan>": 151683,
18
  "<calc_end>": 151674,
19
  "<calc_start>": 151673,
20
  "<function_calling>": 151672,
21
- "<img_delim_baichuan>": 151688,
22
- "<img_end_baichuan>": 151680,
23
- "<img_newline_baichuan>": 151682,
24
- "<img_pad_baichuan>": 151681,
25
- "<img_start_baichuan>": 151679,
26
  "<inner_think>": 151675,
27
- "<polygon_end_baichuan>": 151690,
28
- "<polygon_start_baichuan>": 151689,
29
- "<ref_end_baichuan>": 151687,
30
- "<ref_start_baichuan>": 151686,
31
  "<reserved_113>": 151692,
32
  "<tool_call>": 151657,
33
- "<video_end_baichuan>": 151696,
34
- "<video_palce_baichuan>": 151694,
35
- "<video_start_baichuan>": 151695,
36
  "<|box_end|>": 151649,
37
  "<|box_start|>": 151648,
38
  "<|endoftext|>": 151643,
 
7
  "<B_USYS>": 151666,
8
  "<C_A>": 151668,
9
  "<C_Q>": 151667,
10
+ "<audio_delim_ocean>": 151693,
11
+ "<audio_end_ocean>": 151677,
12
+ "<audio_pad_ocean>": 151678,
13
+ "<audio_start_ocean>": 151676,
14
+ "<ocean_pad_token>": 151691,
15
+ "<box_delim_ocean>": 151685,
16
+ "<box_end_ocean>": 151684,
17
+ "<box_start_ocean>": 151683,
18
  "<calc_end>": 151674,
19
  "<calc_start>": 151673,
20
  "<function_calling>": 151672,
21
+ "<img_delim_ocean>": 151688,
22
+ "<img_end_ocean>": 151680,
23
+ "<img_newline_ocean>": 151682,
24
+ "<img_pad_ocean>": 151681,
25
+ "<img_start_ocean>": 151679,
26
  "<inner_think>": 151675,
27
+ "<polygon_end_ocean>": 151690,
28
+ "<polygon_start_ocean>": 151689,
29
+ "<ref_end_ocean>": 151687,
30
+ "<ref_start_ocean>": 151686,
31
  "<reserved_113>": 151692,
32
  "<tool_call>": 151657,
33
+ "<video_end_ocean>": 151696,
34
+ "<video_palce_ocean>": 151694,
35
+ "<video_start_ocean>": 151695,
36
  "<|box_end|>": 151649,
37
  "<|box_start|>": 151648,
38
  "<|endoftext|>": 151643,
audio_modeling_ocean.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, random, fire
2
+ from transformers.models.whisper import WhisperConfig
3
+ from torch.nn import functional as F
4
+ from flash_attn import flash_attn_varlen_func
5
+ from torch import nn
6
+ import numpy as np
7
+ from transformers.activations import ACT2FN
8
+ import math
9
+
10
+ def sinusoids(length, channels, max_timescale=10000):
11
+ """Returns sinusoids for positional embedding"""
12
+ assert channels % 2 == 0
13
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
14
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
15
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
16
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
17
+
18
+ class OceanWhisperAttention(nn.Module):
19
+ def __init__(self, embed_dim, num_heads):
20
+ super().__init__()
21
+ self.embed_dim = embed_dim
22
+ self.num_heads = num_heads
23
+ self.head_dim = embed_dim // num_heads
24
+
25
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
26
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
27
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
28
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
29
+
30
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
31
+ bsz, _ = hidden_states.size()
32
+
33
+ query_states = self.q_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
34
+ key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
35
+ value_states = self.v_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
36
+
37
+ cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
38
+ max_seqlen = torch.max(seq_len).to(torch.int32).detach()
39
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen, max_seqlen, causal=False) # (bsz * qlen, nheads, headdim)
40
+ attn_output = attn_output.reshape(bsz, self.embed_dim)
41
+ attn_output = self.out_proj(attn_output)
42
+ return attn_output
43
+
44
+ class OceanWhisperEncoderLayer(nn.Module):
45
+ def __init__(self, config):
46
+ super().__init__()
47
+ self.embed_dim = config.d_model
48
+ self.self_attn = OceanWhisperAttention(self.embed_dim, config.encoder_attention_heads)
49
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
50
+ self.activation_fn = ACT2FN[config.activation_function]
51
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
52
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
53
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
54
+
55
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
56
+ residual = hidden_states
57
+ hidden_states = self.self_attn_layer_norm(hidden_states)
58
+ hidden_states = self.self_attn(hidden_states, seq_len)
59
+ hidden_states = residual + hidden_states
60
+ residual = hidden_states
61
+ hidden_states = self.final_layer_norm(hidden_states)
62
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
63
+ hidden_states = self.fc2(hidden_states)
64
+ hidden_states = residual + hidden_states
65
+
66
+ if (hidden_states.dtype == torch.float16 or hidden_states.dtype == torch.bfloat16) and (
67
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
68
+ ):
69
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
70
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
71
+ return hidden_states
72
+
73
+ class OceanAudioEncoder(nn.Module):
74
+ def __init__(self, config):
75
+ super().__init__()
76
+ config._attn_implementation = 'flash_attention_2' #
77
+ self.config = config
78
+ self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
79
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
80
+
81
+ # 需要在LLM的初始化中注册注册
82
+ self.conv1 = nn.Conv1d(config.num_mel_bins, config.d_model, kernel_size=config.kernel_size, padding=1)
83
+ self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size, stride=config.stride_size, padding=1)
84
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, config.d_model)) # 1500 * d
85
+
86
+ self.layers = nn.ModuleList([OceanWhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
87
+ self.layer_norm = nn.LayerNorm(config.d_model)
88
+
89
+ self.gradient_checkpointing = True
90
+
91
+ @torch.no_grad()
92
+ def fake_input(self, device):
93
+ input_features = torch.rand([2, self.config.num_mel_bins, 10], dtype=torch.float32, device=device)
94
+ encoder_length = torch.ones([2], dtype=torch.int32, device=device) * 3
95
+ bridge_length = torch.ones([2], dtype=torch.int32, device=device)
96
+ return input_features, encoder_length, bridge_length
97
+
98
+ def forward(
99
+ self,
100
+ input_features,
101
+ output_length, # MAKESURE 输入的必须是两次conv计算后的hidden state长度
102
+ ):
103
+ input_features = input_features.to(self.conv1.weight.dtype)
104
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (bs, channels, frames)
105
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (bs, channels, frames // 2)
106
+ inputs_embeds = inputs_embeds.permute(0, 2, 1) # (bs, frams, channels)
107
+ bsz, tgt_len, _ = inputs_embeds.size() # 当前batch最大长度
108
+ if tgt_len < self.positional_embedding.shape[0]:
109
+ current_positional_embedding = self.positional_embedding[:tgt_len]
110
+ else:
111
+ current_positional_embedding = self.positional_embedding
112
+ hidden_states = (inputs_embeds.to(torch.float32) + current_positional_embedding).to(inputs_embeds.dtype)
113
+
114
+ # packing hidden states
115
+ attention_mask = torch.arange(0, tgt_len).to(hidden_states.device)
116
+ attention_mask = torch.lt(attention_mask, output_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
117
+ unpacking_index = torch.cumsum(attention_mask.to(torch.int32).view(-1), dim=0) - 1 # 转成下标
118
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length), self.config.d_model)
119
+
120
+ for idx, encoder_layer in enumerate(self.layers):
121
+ if self.gradient_checkpointing and self.training:
122
+ def create_custom_forward(module):
123
+ def custom_forward(*inputs):
124
+ return module(*inputs)
125
+ return custom_forward
126
+ hidden_states = torch.utils.checkpoint.checkpoint(
127
+ create_custom_forward(encoder_layer),
128
+ hidden_states,
129
+ output_length
130
+ )
131
+ else:
132
+ hidden_states = encoder_layer(hidden_states, output_length)
133
+ hidden_states = self.layer_norm(hidden_states)
134
+ # unpacking
135
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
136
+ hidden_states = torch.where(attention_mask, hidden_states, 0)
137
+ return hidden_states
138
+
139
+ class OceanAudioBridge(nn.Module):
140
+ def __init__(self, config):
141
+ super().__init__()
142
+ self.config = config.audio_config
143
+ if self.config.avg_pooler > 1:
144
+ self.avg_pooler = nn.AvgPool1d(self.config.avg_pooler, stride=2)
145
+ else:
146
+ self.avg_pooler = None
147
+ self.proj1 = nn.Linear(self.config.d_model, config.hidden_size)
148
+ self.proj2 = nn.Linear(config.hidden_size, config.hidden_size)
149
+
150
+ def forward(self, x, output_length):
151
+ if self.avg_pooler is not None:
152
+ x = x.permute(0, 2, 1)
153
+ x = self.avg_pooler(x)
154
+ x = x.permute(0, 2, 1)
155
+
156
+ batch_size, sl, _ = x.shape
157
+ output_length = output_length.to(x.device)
158
+ valid_mask = torch.arange(0, sl).to(x.device)
159
+ valid_mask = torch.lt(valid_mask, output_length.reshape(batch_size, 1)).reshape(batch_size, sl, 1)
160
+ x = torch.masked_select(x, valid_mask).reshape(-1, self.config.d_model) # (sum(valid_sequence_length), d)
161
+ x = ACT2FN[self.config.activation_function](self.proj1(x))
162
+ x = self.proj2(x)
163
+ return x
164
+
165
+
166
+ def test_audio():
167
+ from transformers import AutoConfig
168
+ from processor_ocean import OceanAudioProcessor
169
+ # from ..configuration_ocean import OceanConfig
170
+ config = AutoConfig.from_pretrained("./", trust_remote_code=True)
171
+
172
+ config.audio_config.d_model = 24
173
+ config.audio_config.encoder_layers = 2
174
+ config.audio_config.encoder_attention_heads = 4
175
+ config.audio_config.encoder_ffn_dim = 48
176
+
177
+ ae = OceanAudioEncoder(config.audio_config).cuda().to(torch.bfloat16)
178
+ bg = OceanAudioBridge(config).cuda().to(torch.bfloat16)
179
+ l = random.randint(10, 30)
180
+ bs = 3
181
+ input_length = torch.tensor([random.randint(1, l) for _ in range(bs)])
182
+ encoder_length, bridge_length = OceanAudioProcessor.inference_output_length(config.audio_config, input_length)
183
+ print("l={}, input_valid_length={},\nencoder_valid_length={}, bridge_valid_length={}".format(l, input_length, encoder_length, bridge_length))
184
+ wave_features = torch.rand((bs, config.audio_config.num_mel_bins, l))
185
+ a = ae(wave_features.to('cuda'), encoder_length.to('cuda'))
186
+ b = bg(a, bridge_length.to('cuda'))
187
+ print('encoder output={}, bridge output={}'.format(a.shape, b.shape))
188
+ print(a)
189
+ print(b)
190
+
191
+
192
+ if __name__ == '__main__':
193
+ fire.Fire()
194
+
config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_name_or_path": "/cpfs/29f69eb5e2e60f26/code/mllm/zhangtao02/workspace/3_bc_mllm/models/mm_pretrain/cs_ocr3b/cs_ocr3b_ift_1204_ckpt_epoch_1_12030256",
3
  "architectures": [
4
- "BaichuanForCausalLM"
5
  ],
6
  "attention_qkv_bias": true,
7
  "attention_qkv_pack": false,
@@ -117,10 +117,10 @@
117
  "vocab_size": 51865
118
  },
119
  "auto_map": {
120
- "AutoConfig": "configuration_baichuan.BaichuanConfig",
121
- "AutoModelForCausalLM": "modeling_baichuan.BaichuanForCausalLM"
122
  },
123
- "baichuan_tokenizer_type": "auto",
124
  "bos_token_id": 1,
125
  "eos_token_id": 2,
126
  "head_dim": 128,
@@ -130,7 +130,7 @@
130
  "intermediate_size": 11008,
131
  "max_position_embeddings": 8192,
132
  "max_window_layers": 36,
133
- "model_type": "baichuan",
134
  "moe": false,
135
  "multimodal": [
136
  "image"
 
1
  {
2
  "_name_or_path": "/cpfs/29f69eb5e2e60f26/code/mllm/zhangtao02/workspace/3_bc_mllm/models/mm_pretrain/cs_ocr3b/cs_ocr3b_ift_1204_ckpt_epoch_1_12030256",
3
  "architectures": [
4
+ "OceanForCausalLM"
5
  ],
6
  "attention_qkv_bias": true,
7
  "attention_qkv_pack": false,
 
117
  "vocab_size": 51865
118
  },
119
  "auto_map": {
120
+ "AutoConfig": "configuration_ocean.OceanConfig",
121
+ "AutoModelForCausalLM": "modeling_ocean.OceanForCausalLM"
122
  },
123
+ "ocean_tokenizer_type": "auto",
124
  "bos_token_id": 1,
125
  "eos_token_id": 2,
126
  "head_dim": 128,
 
130
  "intermediate_size": 11008,
131
  "max_position_embeddings": 8192,
132
  "max_window_layers": 36,
133
+ "model_type": "ocean",
134
  "moe": false,
135
  "multimodal": [
136
  "image"
configuration_ocean.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Ocean Inc. All Rights Reserved.
2
+
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+ from transformers import WhisperConfig
25
+ from transformers import CLIPVisionConfig
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class OceanConfig(PretrainedConfig):
31
+ model_type = "ocean"
32
+ keys_to_ignore_at_inference = ["past_key_values"]
33
+
34
+ def __init__(
35
+ self,
36
+ vocab_size=125696,
37
+ hidden_size=4096,
38
+ intermediate_size=11008,
39
+ num_hidden_layers=32,
40
+ num_attention_heads=32,
41
+ num_key_value_heads=None,
42
+ sparse_attention_heads=None,
43
+ sparse_attention_layers=[],
44
+ head_dim=None,
45
+ attention_qkv_pack=True,
46
+ attention_qkv_bias=False,
47
+ use_norm_head=True,
48
+ hidden_act="silu",
49
+ max_position_embeddings=4096,
50
+ position_embedding_type="rope",
51
+ initializer_range=0.02,
52
+ rms_norm_eps=1e-6,
53
+ use_cache=True,
54
+ pad_token_id=0,
55
+ bos_token_id=1,
56
+ eos_token_id=2,
57
+ tie_word_embeddings=False,
58
+ audio_config=None,
59
+ visual_config=None,
60
+ video_config=None,
61
+ **kwargs,
62
+ ):
63
+ self.vocab_size = vocab_size
64
+ self.max_position_embeddings = max_position_embeddings
65
+ self.hidden_size = hidden_size
66
+ self.intermediate_size = intermediate_size
67
+ self.num_hidden_layers = num_hidden_layers
68
+ self.num_attention_heads = num_attention_heads
69
+ self.num_key_value_heads = num_key_value_heads or self.num_attention_heads
70
+ self.sparse_attention_heads = sparse_attention_heads
71
+ self.sparse_attention_layers = sparse_attention_layers
72
+ self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
73
+ self.attention_qkv_pack = attention_qkv_pack
74
+ self.attention_qkv_bias = attention_qkv_bias
75
+ self.use_norm_head = use_norm_head
76
+ self.hidden_act = hidden_act
77
+ self.position_embedding_type = position_embedding_type
78
+ self.initializer_range = initializer_range
79
+ self.rms_norm_eps = rms_norm_eps
80
+ self.use_cache = use_cache
81
+ assert self.position_embedding_type.lower() in ("rope", "alibi")
82
+ super().__init__(
83
+ pad_token_id=pad_token_id,
84
+ bos_token_id=bos_token_id,
85
+ eos_token_id=eos_token_id,
86
+ tie_word_embeddings=tie_word_embeddings,
87
+ **kwargs,
88
+ )
89
+ if audio_config is not None:
90
+ self.audio_config = WhisperConfig(**audio_config)
91
+ if visual_config is not None:
92
+ self.visual_config = CLIPVisionConfig(**visual_config)
93
+ if video_config is not None:
94
+ self.video_config = CLIPVisionConfig(**video_config)
95
+
96
+
97
+ def to_diff_dict(self):
98
+ data = super().to_diff_dict()
99
+ data["model_type"] = self.model_type
100
+ return data
101
+
102
+ def get_rotary_base(self):
103
+ if hasattr(self, "rotary_emb_base"):
104
+ return self.rotary_emb_base
105
+ else:
106
+ return self.rope_theta
107
+
108
+ if __name__ == '__main__':
109
+ from transformers import AutoConfig
110
+ config = AutoConfig.from_pretrained("./", trust_remote_code=True)
111
+ print(config)
modeling_ocean.py ADDED
@@ -0,0 +1,1001 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Ocean Inc. All Rights Reserved.
2
+ #
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """ PyTorch Ocean model."""
22
+ import os
23
+ import json
24
+ import math
25
+ from typing import List, Optional, Tuple, Union
26
+ from threading import Thread
27
+ from easydict import EasyDict
28
+ import numpy as np
29
+
30
+ import torch
31
+ import torch.utils.checkpoint
32
+ from torch import nn
33
+ from torch.nn import CrossEntropyLoss
34
+ from torch.nn import functional as F
35
+
36
+ from transformers import PreTrainedModel
37
+ from transformers.activations import ACT2FN
38
+ from dataclasses import dataclass
39
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
40
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
41
+ from transformers.generation.utils import GenerationConfig
42
+ from transformers.utils import logging
43
+
44
+ from .configuration_ocean import OceanConfig
45
+ from .audio_modeling_ocean import OceanAudioEncoder, OceanAudioBridge
46
+ from .visual_modeling_ocean import OceanVisualEncoder, OceanVisualBridge
47
+ from .processor_ocean import OceanMMProcessor
48
+ from .moe import moe_matmul
49
+
50
+ # support model path contain point(.)
51
+ try:
52
+ # step1: copy relative imports to transformers_modules
53
+ from .generation_utils import build_chat_input, TextIterStreamer
54
+ from .sequence_parallel_utils import (
55
+ create_attention_layer,
56
+ get_sequence_parallel_size,
57
+ get_sequence_parallel_chunk,
58
+ )
59
+ except ModuleNotFoundError:
60
+ # step2: direct import from transformers_modules
61
+ try: # bypass check_imports failure
62
+ import sys
63
+ sys.path.append(os.path.dirname(__file__))
64
+ from generation_utils import build_chat_input, TextIterStreamer
65
+ from sequence_parallel_utils import (
66
+ create_attention_layer,
67
+ get_sequence_parallel_size,
68
+ get_sequence_parallel_chunk,
69
+ )
70
+ except Exception:
71
+ raise
72
+
73
+
74
+ logger = logging.get_logger(__name__)
75
+
76
+ def get_slopes(n):
77
+ def get_slopes_power_of_2(n):
78
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
79
+ ratio = start
80
+ return [start * ratio ** i for i in range(n)]
81
+
82
+ if math.log2(n).is_integer():
83
+ return get_slopes_power_of_2(
84
+ n) # In the paper, we only train models that have 2^a heads for some a. This function has
85
+ else: # some good properties that only occur when the input is a power of 2. To maintain that even
86
+ closest_power_of_2 = 2 ** math.floor(
87
+ math.log2(n)) # when the number of heads is not a power of 2, we use this workaround.
88
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
89
+ :n - closest_power_of_2]
90
+
91
+
92
+ class RMSNorm(nn.Module):
93
+ def __init__(self, hidden_size, eps=1e-6):
94
+ """
95
+ RMSNorm is equivalent to T5LayerNorm
96
+ """
97
+ super().__init__()
98
+ self.weight = nn.Parameter(torch.ones(hidden_size))
99
+ self.variance_epsilon = eps
100
+
101
+ def forward(self, hidden_states):
102
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
103
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
104
+
105
+ # convert into half-precision if necessary
106
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
107
+ hidden_states = hidden_states.to(self.weight.dtype)
108
+
109
+ return self.weight * hidden_states
110
+
111
+
112
+ class RotaryEmbedding(torch.nn.Module):
113
+ def __init__(self, dim, max_position_embeddings=2048, base=5e6, device=None):
114
+ super().__init__()
115
+ # 修复RePE初始化精度问题 https://zhuanlan.zhihu.com/p/678963442
116
+ # DeepSpeed 会 Hack torch.arange 强制在 GPU 上运行,这里使用原生的 torch.arange
117
+ try:
118
+ import deepspeed
119
+ self.arange = deepspeed.runtime.zero.partition_parameters._orig_torch_arange
120
+ except:
121
+ self.arange = torch.arange
122
+
123
+ self.inv_freq = 1.0 / (base ** (self.arange(0, dim, 2).float().to(device) / dim))
124
+ self.max_seq_len_cached = max_position_embeddings
125
+ t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
126
+ freqs = torch.outer(t, self.inv_freq)
127
+ emb = torch.cat((freqs, freqs), dim=-1)
128
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
129
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
130
+
131
+ def forward(self, x, seq_len=None):
132
+ # x: [bs, num_attention_heads, seq_len, head_size]
133
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
134
+ if seq_len > self.max_seq_len_cached:
135
+ self.max_seq_len_cached = seq_len
136
+ t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
137
+ freqs = torch.outer(t, self.inv_freq)
138
+ emb = torch.cat((freqs, freqs), dim=-1)
139
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
140
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
141
+ return (
142
+ self.cos_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
143
+ self.sin_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
144
+ )
145
+
146
+
147
+ def rotate_half(x):
148
+ """Rotates half the hidden dims of the input."""
149
+ x1 = x[..., : x.shape[-1] // 2]
150
+ x2 = x[..., x.shape[-1] // 2:]
151
+ return torch.cat((-x2, x1), dim=-1)
152
+
153
+
154
+ def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
155
+ cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
156
+ sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
157
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
158
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
159
+ q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
160
+ k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
161
+ return q_embed.to(q.dtype), k_embed.to(k.dtype)
162
+
163
+
164
+ class MLP(nn.Module):
165
+ def __init__(
166
+ self,
167
+ hidden_size: int,
168
+ intermediate_size: int,
169
+ hidden_act: str,
170
+ ):
171
+ super().__init__()
172
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
173
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
174
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
175
+ self.act_fn = ACT2FN[hidden_act]
176
+
177
+ def forward(self, x):
178
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
179
+
180
+
181
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
182
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
183
+ """
184
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
185
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
186
+ """
187
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
188
+ if n_rep == 1:
189
+ return hidden_states
190
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
191
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
192
+
193
+
194
+ class Attention(nn.Module):
195
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
196
+ def __init__(self, config: OceanConfig, is_sparse=False):
197
+ super().__init__()
198
+ self.config = config
199
+ self.position_embedding_type = config.position_embedding_type.lower()
200
+ self.num_kv_heads = config.num_key_value_heads
201
+ self.head_dim = config.head_dim
202
+ self.hidden_size = config.num_attention_heads * self.head_dim
203
+ self.hidden_kv_size = self.num_kv_heads * self.head_dim
204
+
205
+ if is_sparse:
206
+ self.num_heads = config.sparse_attention_heads
207
+ assert self.num_kv_heads == config.num_attention_heads
208
+ self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=config.attention_qkv_bias)
209
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
210
+ else:
211
+ self.num_heads = config.num_attention_heads
212
+ if self.config.attention_qkv_pack:
213
+ self.W_pack = nn.Linear(config.hidden_size, self.hidden_size + self.hidden_kv_size * 2, bias=config.attention_qkv_bias)
214
+ if config.moe:
215
+ self.moe_W_pack = nn.Linear(config.hidden_size, self.hidden_size + self.hidden_kv_size * 2, bias=False)
216
+ else:
217
+ self.q_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=config.attention_qkv_bias)
218
+ self.k_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
219
+ self.v_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
220
+
221
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
222
+ if config.moe:
223
+ self.moe_o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
224
+
225
+ if self.position_embedding_type == 'rope':
226
+ self.rotary_emb = RotaryEmbedding(
227
+ dim=self.head_dim,
228
+ max_position_embeddings=config.max_position_embeddings,
229
+ base=config.get_rotary_base()
230
+ )
231
+ elif self.position_embedding_type == 'alibi':
232
+ self.alibi_slopes = get_slopes(self.num_heads)
233
+ self.attention = create_attention_layer(self.hidden_size, self.num_heads, self.head_dim)
234
+
235
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
236
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
237
+
238
+ def _repeat_kv(self, hidden_states: torch.Tensor, num_heads: int) -> torch.Tensor:
239
+ assert hidden_states.size(1) <= num_heads and num_heads % hidden_states.size(1) == 0
240
+ return repeat_kv(hidden_states, num_heads // hidden_states.size(1))
241
+
242
+ def forward(
243
+ self,
244
+ hidden_states: torch.Tensor,
245
+ attention_mask: Optional[torch.Tensor] = None,
246
+ position_ids: Optional[torch.LongTensor] = None,
247
+ seqlens: Optional[torch.IntTensor] = None,
248
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
249
+ output_attentions: bool = False,
250
+ use_cache: bool = False,
251
+ group_index=None,
252
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
253
+ bsz, q_len = hidden_states.shape[:2]
254
+
255
+ if self.config.attention_qkv_pack:
256
+ if self.config.moe and group_index is not None:
257
+ proj = moe_matmul(hidden_states, [self.W_pack.weight, self.moe_W_pack.weight], group_index, lambda x, y: torch.einsum('bd,ld->bl', x, y))
258
+ if self.config.attention_qkv_bias:
259
+ proj += self.W_pack.bias
260
+ else:
261
+ proj = self.W_pack(hidden_states)
262
+ query_states, key_states, value_states = proj.split([self.hidden_size, self.hidden_kv_size, self.hidden_kv_size], dim=-1)
263
+ else:
264
+ query_states = self.q_proj(hidden_states)
265
+ key_states = self.k_proj(hidden_states)
266
+ value_states = self.v_proj(hidden_states)
267
+
268
+ # (B, S, hidden_size) -> (B, num_heads, S, head_size)
269
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
270
+ # (B, S, hidden_size) -> (B, num_kv_heads, S, head_size)
271
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
272
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
273
+
274
+ kv_seq_len = key_states.shape[-2]
275
+ if past_key_value is not None:
276
+ kv_seq_len += past_key_value[0].shape[-2]
277
+ if self.position_embedding_type == 'rope':
278
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len * get_sequence_parallel_size())
279
+ query_states, key_states = apply_rotary_pos_emb(
280
+ query_states, key_states, cos, sin,
281
+ get_sequence_parallel_chunk(position_ids)
282
+ )
283
+
284
+ if past_key_value is not None:
285
+ # reuse k, v, self_attention
286
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
287
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
288
+ past_key_value = (key_states, value_states) if use_cache else None
289
+
290
+ # repeat k/v heads if n_kv_heads < n_heads
291
+ key_states = self._repeat_kv(key_states, query_states.size(1))
292
+ value_states = self._repeat_kv(value_states, query_states.size(1))
293
+
294
+ if seqlens is not None:
295
+ seqlens = seqlens.to(dtype=torch.int32)
296
+ max_seqlen = (seqlens[1:] - seqlens[:-1]).max().item()
297
+ if self.position_embedding_type == 'alibi':
298
+ alibi_slopes = torch.tensor(self.alibi_slopes, dtype=torch.float32).to(query_states.device)
299
+ else:
300
+ alibi_slopes = None
301
+ attn_output = self.attention(
302
+ query_states, key_states, value_states, seqlens, seqlens,
303
+ max_seqlen, max_seqlen, causal=True, alibi_slopes=alibi_slopes, use_flash=True)
304
+ else:
305
+ attn_output = self.attention(
306
+ query_states, key_states, value_states, attn_mask=attention_mask, use_flash=False)
307
+
308
+ attn_output = attn_output.reshape(bsz, q_len, -1)
309
+ if not self.config.moe or group_index is None:
310
+ attn_output = self.o_proj(attn_output)
311
+ else:
312
+ attn_output = moe_matmul(attn_output, [self.o_proj.weight, self.moe_o_proj.weight], group_index, lambda x, y: torch.einsum('bd,ld->bl', x, y))
313
+
314
+ return attn_output, None, past_key_value
315
+
316
+
317
+ class DecoderLayer(nn.Module):
318
+ def __init__(self, config: OceanConfig, is_sparse=False):
319
+ super().__init__()
320
+ self.hidden_size = config.hidden_size
321
+ self.self_attn = Attention(config=config, is_sparse=is_sparse)
322
+ self.mlp = MLP(
323
+ hidden_size=self.hidden_size,
324
+ intermediate_size=config.intermediate_size,
325
+ hidden_act=config.hidden_act,
326
+ )
327
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
328
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
329
+
330
+ def forward(
331
+ self,
332
+ hidden_states: torch.Tensor,
333
+ attention_mask: Optional[torch.Tensor] = None,
334
+ position_ids: Optional[torch.LongTensor] = None,
335
+ seqlens: Optional[torch.IntTensor] = None,
336
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
337
+ output_attentions: Optional[bool] = False,
338
+ use_cache: Optional[bool] = False,
339
+ group_index=None,
340
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
341
+
342
+ residual = hidden_states
343
+
344
+ hidden_states = self.input_layernorm(hidden_states)
345
+
346
+ # Self Attention
347
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
348
+ hidden_states=hidden_states,
349
+ attention_mask=attention_mask,
350
+ position_ids=position_ids,
351
+ seqlens=seqlens,
352
+ past_key_value=past_key_value,
353
+ output_attentions=output_attentions,
354
+ use_cache=use_cache,
355
+ group_index=group_index,
356
+ )
357
+ hidden_states = residual + hidden_states
358
+
359
+ # Fully Connected
360
+ residual = hidden_states
361
+ hidden_states = self.post_attention_layernorm(hidden_states)
362
+ hidden_states = self.mlp(hidden_states)
363
+ hidden_states = residual + hidden_states
364
+
365
+ outputs = (hidden_states,)
366
+
367
+ if output_attentions:
368
+ outputs += (self_attn_weights,)
369
+
370
+ if use_cache:
371
+ outputs += (present_key_value,)
372
+
373
+ return outputs
374
+
375
+
376
+ class OceanPreTrainedModel(PreTrainedModel):
377
+ config_class = OceanConfig
378
+ base_model_prefix = "model"
379
+ supports_gradient_checkpointing = True
380
+ _no_split_modules = ["DecoderLayer"]
381
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
382
+
383
+ def _init_weights(self, module):
384
+ std = self.config.initializer_range
385
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d):
386
+ module.weight.data.normal_(mean=0.0, std=std)
387
+ if module.bias is not None:
388
+ module.bias.data.zero_()
389
+ elif isinstance(module, nn.Embedding):
390
+ module.weight.data.normal_(mean=0.0, std=std)
391
+ if module.padding_idx is not None:
392
+ module.weight.data[module.padding_idx].zero_()
393
+ elif isinstance(module, nn.LayerNorm):
394
+ module.weight.data.fill_(1.0)
395
+ module.bias.data.zero_()
396
+
397
+ def _set_gradient_checkpointing(self, module, value=False):
398
+ if isinstance(module, OceanModel):
399
+ module.gradient_checkpointing = value
400
+
401
+
402
+ class OceanModel(OceanPreTrainedModel):
403
+ def __init__(self, config: OceanConfig):
404
+ super().__init__(config)
405
+ self.padding_idx = config.pad_token_id
406
+ self.vocab_size = config.vocab_size
407
+ self.merge_size = 1
408
+ if config.audio_config.enable:
409
+ self.audio_model = OceanAudioEncoder(config.audio_config)
410
+ self.audio_bridge_model = OceanAudioBridge(config)
411
+ if config.visual_config.enable:
412
+ self.visual_model = OceanVisualEncoder(config.visual_config)
413
+ self.visual_bridge_model = OceanVisualBridge(config.visual_config)
414
+ self.merge_size = max(config.visual_config.merge_size, self.merge_size)
415
+ if config.video_config.enable: # in case 没有visual_config而只有video_config
416
+ if not config.visual_config.enable:
417
+ self.visual_model = OceanVisualEncoder(config.video_config)
418
+ self.video_bridge_model = OceanVisualBridge(config.video_config)
419
+ self.merge_size = max(config.video_config.merge_size, self.merge_size)
420
+
421
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
422
+ self.layers = nn.ModuleList([
423
+ DecoderLayer(config, is_sparse=layer_idx in config.sparse_attention_layers)
424
+ for layer_idx in range(config.num_hidden_layers)
425
+ ])
426
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
427
+
428
+ self.gradient_checkpointing = True
429
+ # Initialize weights and apply final processing
430
+ self.post_init()
431
+
432
+ def get_input_embeddings(self):
433
+ return self.embed_tokens
434
+
435
+ def set_input_embeddings(self, value):
436
+ self.embed_tokens = value
437
+
438
+ def get_multimodal_mask(self, input_ids, pad_token_id, special_token_list):
439
+ '''
440
+ 获取任意模态的特殊mask,包含以下
441
+ 1. pad mask 表示文本中图像/语音/视频模态提前留出的token位置
442
+ 2. special token mask 特殊token 例如对理解模型<start> <end> 不需要next token prediction
443
+ 3. embedding mask / lm_head mask 标记出特殊token在embedding中的mask
444
+ '''
445
+
446
+ pad_mask = torch.eq(input_ids, pad_token_id)
447
+ sp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
448
+ lm_head_mask = torch.zeros([self.config.vocab_size, 1], dtype=torch.bool)
449
+ for sp_id in special_token_list:
450
+ sp_mask = torch.logical_or(sp_mask, torch.eq(input_ids, sp_id))
451
+ lm_head_mask[sp_id, 0] = True
452
+ return pad_mask, sp_mask, lm_head_mask
453
+
454
+ def get_audio_embed(
455
+ self,
456
+ input_ids,
457
+ text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
458
+ features, # list of tensors
459
+ encoder_length,
460
+ bridge_length,
461
+ group_index=None, # 某种模态的编号 for MoE
462
+ ):
463
+ pad_mask, sp_mask, _ = self.get_multimodal_mask(input_ids, self.config.audio_config.audio_pad_token_id, self.config.multimodal_special_token_list)
464
+ if features is None or len(features) <= 0 : # 空list or None 保证梯度回传
465
+ features, encoder_length, bridge_length = self.audio_model.fake_input(input_ids.device)
466
+ fake_input = True
467
+ else:
468
+ fake_input = False
469
+ audio_embed = self.audio_model(features, encoder_length)
470
+ audio_embed = self.audio_bridge_model(audio_embed, bridge_length) # (?, d)
471
+ if not self.training: # 推理支持auto map 把多模态模块输出和input_ids 统一到一个device
472
+ audio_embed = audio_embed.to(input_ids.device)
473
+ if not fake_input: # 检查多模态token 和 pad mask数量一致 (不正确的截断会导致该问题)
474
+ assert pad_mask.sum() == audio_embed.shape[0]
475
+ else:
476
+ assert pad_mask.sum() <= 0 # 0 vs 1
477
+
478
+ # 合并 当前模态embeddings 和text embeddings
479
+ input_ids = torch.where(pad_mask, torch.cumsum(pad_mask.view(-1).to(input_ids), dim=0).view(input_ids.shape)-1, input_ids)
480
+ if self.config.train_multimodal_special_tokens_only and self.training:
481
+ # 仅special token传梯度到embedding weight, 保证LLM部分不变
482
+ # 注意: 多种模态之间special token list应该共享,否则会有部分被stop gradient
483
+ sp_mask = sp_mask.unsqueeze(-1).to(text_embedding)
484
+ text_embedding = (1 - sp_mask) * text_embedding.detach() + sp_mask * text_embedding
485
+ text_embedding = (1 - pad_mask.to(text_embedding)).unsqueeze(-1) * text_embedding # pad token位置填0 (不传梯度)
486
+ multimodal_embedding = torch.embedding(audio_embed, input_ids * pad_mask) # 非 pad token 位置填idx=0位置结果
487
+ multimodal_embedding = pad_mask.to(multimodal_embedding).unsqueeze(-1) * multimodal_embedding # 非pad token 位置填0
488
+
489
+ final_embedding = multimodal_embedding.to(text_embedding) + text_embedding
490
+
491
+ if group_index is None:
492
+ group_index = pad_mask.to(torch.int32)
493
+ else:
494
+ current_index = torch.max(group_index) + 1
495
+ group_index += pad_mask.to(torch.int32) * current_index # 假设模态无重叠
496
+
497
+ return final_embedding, group_index # group_index 不传None 防止MoE部分参数无梯度
498
+
499
+ def get_visual_embed(
500
+ self,
501
+ input_ids,
502
+ text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
503
+ images,
504
+ group_index, # 某种模态的编号 for MoE
505
+ images_grid
506
+ ):
507
+ # TODO 与get_audio_embed合并重复功能 减少冗余代码
508
+ pad_mask, sp_mask, _ = self.get_multimodal_mask(input_ids, self.config.visual_config.image_pad_token_id, self.config.multimodal_special_token_list)
509
+ if images is None or len(images) <= 0 : # 空list or None 保证梯度回传
510
+ images = self.visual_model.fake_input(input_ids.device, self.merge_size)
511
+ images_grid = [(1, self.merge_size, self.merge_size)]
512
+ fake_input = True
513
+ else:
514
+ fake_input = False
515
+
516
+ images = torch.cat(images, dim=0)
517
+ images_grid = torch.tensor(np.array(images_grid))
518
+ visual_embed = self.visual_model(images, grid_thw=images_grid)
519
+ visual_embed = self.visual_bridge_model(visual_embed)
520
+
521
+ if not self.training: # 推理支持auto map 把多模态模块输出和input_ids 统一到一个device
522
+ visual_embed = visual_embed.to(input_ids.device)
523
+ if not fake_input: # 检查多模态token 和 pad mask数量一致 (不正确的截断会导致该问题)
524
+ assert pad_mask.sum() == visual_embed.shape[0], '{} != {}'.format(pad_mask.sum(), visual_embed.shape[0])
525
+ else:
526
+ assert pad_mask.sum() <= 0, '{} != {}'.format(pad_mask.sum(), visual_embed.shape[0])
527
+
528
+ # 合并 当前模态embeddings 和text embeddings
529
+ input_ids = torch.where(pad_mask, torch.cumsum(pad_mask.view(-1).to(input_ids), dim=0).view(input_ids.shape)-1, input_ids)
530
+ if self.config.train_multimodal_special_tokens_only and self.training:
531
+ # 仅special token传梯度到embedding weight, 保证LLM部分不变
532
+ # 注意: 多种模态之间special token list应���共享,否则会有部分被stop gradient
533
+ sp_mask = sp_mask.unsqueeze(-1).to(text_embedding)
534
+ text_embedding = (1 - sp_mask) * text_embedding.detach() + sp_mask * text_embedding
535
+ text_embedding = (1 - pad_mask.to(text_embedding)).unsqueeze(-1) * text_embedding # pad token位置填0 (不传梯度)
536
+ multimodal_embedding = torch.embedding(visual_embed, input_ids * pad_mask) # 非 pad token 位置填idx=0位置结果
537
+ multimodal_embedding = pad_mask.to(multimodal_embedding).unsqueeze(-1) * multimodal_embedding # 非pad token 位置填0
538
+
539
+ final_embedding = multimodal_embedding.to(text_embedding) + text_embedding
540
+
541
+ if group_index is None:
542
+ group_index = pad_mask.to(torch.int32)
543
+ else:
544
+ current_index = torch.max(group_index) + 1
545
+ group_index += pad_mask.to(torch.int32) * current_index # 假设模态无重叠
546
+
547
+ return final_embedding, group_index # group_index 不传None 防止MoE部分参数无梯度
548
+
549
+ def get_video_embed(
550
+ self,
551
+ input_ids,
552
+ text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
553
+ images,
554
+ group_index, # 某种模态的编号 for MoE
555
+ images_grid
556
+
557
+ ):
558
+ # TODO 与get_audio_embed合并重复功能 减少冗余代码
559
+ pad_mask, sp_mask, _ = self.get_multimodal_mask(input_ids, self.config.video_config.video_place_token_id, self.config.multimodal_special_token_list)
560
+ if images is None or len(images) <= 0 : # 空list or None 保证梯度回传
561
+ images = self.visual_model.fake_input(input_ids.device, self.merge_size)
562
+ images_grid = [(1, self.merge_size, self.merge_size)]
563
+ fake_input = True
564
+ else:
565
+ fake_input = False
566
+
567
+ images = torch.cat(images, dim=0)
568
+ images_grid = torch.tensor(np.array(images_grid))
569
+ visual_embed = self.visual_model(images, grid_thw=images_grid)
570
+ visual_embed = self.video_bridge_model(visual_embed)
571
+
572
+ if not self.training: # 推理支持auto map 把多模态模块输出和input_ids 统一到一个device
573
+ visual_embed = visual_embed.to(input_ids.device)
574
+ if not fake_input: # 检查多模态token 和 pad mask数量一致 (不正确的截断会导致该问题)
575
+ assert pad_mask.sum() == visual_embed.shape[0], '{} != {}'.format(pad_mask.sum(), visual_embed.shape[0])
576
+ assert pad_mask.sum() == visual_embed.shape[0], '{} != {}'.format(pad_mask.sum(), visual_embed.shape[0])
577
+ else:
578
+ assert pad_mask.sum() <= 0, '{} != {}'.format(pad_mask.sum(), visual_embed.shape[0])
579
+
580
+ # 合并 当前模态embeddings 和text embeddings
581
+ input_ids = torch.where(pad_mask, torch.cumsum(pad_mask.view(-1).to(input_ids), dim=0).view(input_ids.shape)-1, input_ids)
582
+ if self.config.train_multimodal_special_tokens_only and self.training:
583
+ # 仅special token传梯度到embedding weight, 保证LLM部分不变
584
+ # 注意: 多种模态之间special token list应该共享,否则会有部分被stop gradient
585
+ sp_mask = sp_mask.unsqueeze(-1).to(text_embedding)
586
+ text_embedding = (1 - sp_mask) * text_embedding.detach() + sp_mask * text_embedding
587
+ text_embedding = (1 - pad_mask.to(text_embedding)).unsqueeze(-1) * text_embedding # pad token位置填0 (不传梯度)
588
+ multimodal_embedding = torch.embedding(visual_embed, input_ids * pad_mask) # 非 pad token 位置填idx=0位置结果
589
+ multimodal_embedding = pad_mask.to(multimodal_embedding).unsqueeze(-1) * multimodal_embedding # 非pad token 位置填0
590
+
591
+ final_embedding = multimodal_embedding.to(text_embedding) + text_embedding
592
+
593
+ if group_index is None:
594
+ group_index = pad_mask.to(torch.int32)
595
+ else:
596
+ current_index = torch.max(group_index) + 1
597
+ group_index += pad_mask.to(torch.int32) * current_index # 假设模态无重叠
598
+
599
+ return final_embedding, group_index # group_index 不传None 防止MoE部分参数无梯度
600
+
601
+
602
+
603
+ def forward(
604
+ self,
605
+ input_ids: torch.LongTensor = None,
606
+ attention_mask: Optional[torch.Tensor] = None,
607
+ position_ids: Optional[torch.LongTensor] = None,
608
+ seqlens: Optional[torch.IntTensor] = None,
609
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
610
+ inputs_embeds: Optional[torch.FloatTensor] = None,
611
+ audios: Optional[List|torch.Tensor] = None,
612
+ encoder_length: Optional[torch.Tensor] = None,
613
+ bridge_length: Optional[torch.Tensor] = None,
614
+ images: Optional[List|torch.Tensor] = None,
615
+ images_grid: Optional[List|torch.Tensor] = None,
616
+ videos: Optional[List|torch.Tensor] = None,
617
+ videos_grid: Optional[List|torch.Tensor] = None,
618
+ use_cache: Optional[bool] = None,
619
+ output_attentions: Optional[bool] = None,
620
+ output_hidden_states: Optional[bool] = None,
621
+ return_dict: Optional[bool] = None,
622
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
623
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
624
+ output_hidden_states = (
625
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
626
+ )
627
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
628
+
629
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
630
+
631
+ # retrieve input_ids and inputs_embeds
632
+ if input_ids is not None and inputs_embeds is not None:
633
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
634
+ elif input_ids is not None:
635
+ batch_size, seq_length = input_ids.shape
636
+ elif inputs_embeds is not None:
637
+ batch_size, seq_length, _ = inputs_embeds.shape
638
+ else:
639
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
640
+
641
+ seq_length_with_past = seq_length
642
+ past_key_values_length = 0
643
+
644
+ if past_key_values is not None:
645
+ past_key_values_length = past_key_values[0][0].shape[2]
646
+ seq_length_with_past = seq_length_with_past + past_key_values_length
647
+
648
+ if position_ids is None:
649
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
650
+ position_ids = torch.arange(
651
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
652
+ )
653
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
654
+ else:
655
+ position_ids = position_ids.view(-1, seq_length).long()
656
+
657
+ group_index = None
658
+ if inputs_embeds is None:
659
+ sp_input_ids = get_sequence_parallel_chunk(input_ids)
660
+ inputs_embeds = self.embed_tokens(sp_input_ids)
661
+ if self.config.audio_config.enable:
662
+ inputs_embeds, group_index = self.get_audio_embed(sp_input_ids, inputs_embeds, audios, encoder_length, bridge_length)
663
+ if self.config.visual_config.enable:
664
+ inputs_embeds, group_index = self.get_visual_embed(sp_input_ids, inputs_embeds, images, group_index, images_grid) # 注意更新group index
665
+ if self.config.video_config.enable:
666
+ inputs_embeds, group_index = self.get_video_embed(sp_input_ids, inputs_embeds, videos, group_index, videos_grid) # 注意更新group index
667
+
668
+ if seqlens is not None and seqlens.ndim == 2:
669
+ # batch multi-pack 样本拉平
670
+ cu_seqlens = []
671
+ offset, seqlen = 0, seqlens.size(1)
672
+ for lens in seqlens:
673
+ cu_seqlens.append(offset)
674
+ cu_seqlens.extend((lens[(lens > 0) & (lens < seqlen)] + offset).tolist())
675
+ offset += seqlen
676
+ cu_seqlens.append(offset)
677
+ seqlens = torch.tensor(cu_seqlens, dtype=seqlens.dtype, device=seqlens.device)
678
+ elif seqlens is None and self.training:
679
+ # 兼容预训练场景, 此时 seqlens=None, 默认 maxlength
680
+ seqlens = torch.arange(
681
+ end=input_ids.size(0) + 1,
682
+ dtype=torch.int32,
683
+ device=input_ids.device
684
+ ) * input_ids.size(1)
685
+ if seqlens is not None:
686
+ attention_mask = None # unset attention_mask to save memory
687
+
688
+ if seqlens is None and attention_mask is None:
689
+ attention_mask = torch.ones(
690
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
691
+ )
692
+ if attention_mask is not None:
693
+ attention_mask = _prepare_4d_causal_attention_mask(
694
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
695
+ )
696
+
697
+ # embed positions
698
+ hidden_states = inputs_embeds
699
+
700
+ if self.gradient_checkpointing and self.training:
701
+ if use_cache:
702
+ logger.warning_once(
703
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
704
+ )
705
+ use_cache = False
706
+
707
+ # decoder layers
708
+ all_hidden_states = () if output_hidden_states else None
709
+ all_self_attns = () if output_attentions else None
710
+ next_decoder_cache = () if use_cache else None
711
+
712
+ for idx, decoder_layer in enumerate(self.layers):
713
+ if output_hidden_states:
714
+ all_hidden_states += (hidden_states,)
715
+
716
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
717
+
718
+ if self.gradient_checkpointing and self.training:
719
+
720
+ def create_custom_forward(module):
721
+ def custom_forward(*inputs):
722
+ # None for past_key_value
723
+ return module(*inputs, output_attentions, False, group_index)
724
+
725
+ return custom_forward
726
+
727
+ layer_outputs = torch.utils.checkpoint.checkpoint(
728
+ create_custom_forward(decoder_layer),
729
+ hidden_states,
730
+ attention_mask,
731
+ position_ids,
732
+ seqlens,
733
+ None,
734
+ )
735
+ else:
736
+ layer_outputs = decoder_layer(
737
+ hidden_states,
738
+ attention_mask=attention_mask,
739
+ position_ids=position_ids,
740
+ seqlens=seqlens,
741
+ past_key_value=past_key_value,
742
+ output_attentions=output_attentions,
743
+ use_cache=use_cache,
744
+ group_index=group_index,
745
+ )
746
+
747
+ hidden_states = layer_outputs[0]
748
+
749
+ if use_cache:
750
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
751
+
752
+ if output_attentions:
753
+ all_self_attns += (layer_outputs[1],)
754
+
755
+ hidden_states = self.norm(hidden_states)
756
+
757
+ # add hidden states from the last decoder layer
758
+ if output_hidden_states:
759
+ all_hidden_states += (hidden_states,)
760
+
761
+ next_cache = next_decoder_cache if use_cache else None
762
+ if not return_dict:
763
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
764
+ return BaseModelOutputWithPast(
765
+ last_hidden_state=hidden_states,
766
+ past_key_values=next_cache,
767
+ hidden_states=all_hidden_states,
768
+ attentions=all_self_attns,
769
+ )
770
+
771
+
772
+ class NormHead(nn.Module):
773
+ def __init__(self, hidden_size, vocab_size, bias=False):
774
+ super().__init__()
775
+ self.hidden_size = hidden_size
776
+ self.vocab_size = vocab_size
777
+ self.weight = nn.Parameter(torch.empty((self.vocab_size, self.hidden_size)))
778
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
779
+
780
+ def forward(self, hidden_states, mask=None):
781
+ norm_weight = nn.functional.normalize(self.weight)
782
+ if mask is not None:
783
+ mask = mask.to(norm_weight)
784
+ norm_weight = norm_weight * mask + (1 - mask) * norm_weight.detach()
785
+ return nn.functional.linear(hidden_states, norm_weight)
786
+
787
+
788
+ def extra_repr(self) -> str:
789
+ return f'in_features={self.hidden_size}, out_features={self.vocab_size}'
790
+
791
+ @dataclass
792
+ class OceanMMCausalLMOutputWithPast(ModelOutput):
793
+ loss: Optional[torch.FloatTensor] = None
794
+ logits: torch.FloatTensor = None
795
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
796
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
797
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
798
+ text_nt_loss: Optional[torch.FloatTensor] = None
799
+ flatten_loss: Optional[torch.FloatTensor] = None
800
+
801
+ class OceanForCausalLM(OceanPreTrainedModel):
802
+ def __init__(self, config):
803
+ super().__init__(config)
804
+ self.config = config
805
+ self.model = OceanModel(config)
806
+ if config.use_norm_head:
807
+ self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
808
+ else:
809
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
810
+ # Initialize weights and apply final processing
811
+ self.post_init()
812
+
813
+ def bind_processor(self, tokenizer, **kwargs):
814
+ self.processor = OceanMMProcessor(
815
+ tokenizer=tokenizer,
816
+ config=self.config,
817
+ **kwargs,
818
+ )
819
+ return self.processor
820
+
821
+ def get_input_embeddings(self):
822
+ return self.model.embed_tokens
823
+
824
+ def set_input_embeddings(self, value):
825
+ self.model.embed_tokens = value
826
+
827
+ def get_output_embeddings(self):
828
+ return self.lm_head
829
+
830
+ def set_output_embeddings(self, new_embeddings):
831
+ self.lm_head = new_embeddings
832
+
833
+ def set_decoder(self, decoder):
834
+ self.model = decoder
835
+
836
+ def get_decoder(self):
837
+ return self.model
838
+
839
+ def forward(
840
+ self,
841
+ input_ids: torch.LongTensor = None,
842
+ attention_mask: Optional[torch.Tensor] = None,
843
+ position_ids: Optional[torch.LongTensor] = None,
844
+ seqlens: Optional[torch.IntTensor] = None,
845
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
846
+ inputs_embeds: Optional[torch.FloatTensor] = None,
847
+ labels: Optional[torch.LongTensor] = None,
848
+ audios: Optional[List|torch.Tensor] = None,
849
+ encoder_length: Optional[torch.Tensor] = None,
850
+ bridge_length: Optional[torch.Tensor] = None,
851
+ images: Optional[torch.Tensor] = None,
852
+ images_grid: Optional[torch.Tensor] = None,
853
+ videos: Optional[torch.Tensor] = None,
854
+ videos_grid: Optional[torch.Tensor] = None,
855
+ use_cache: Optional[bool] = None,
856
+ output_attentions: Optional[bool] = None,
857
+ output_hidden_states: Optional[bool] = None,
858
+ return_dict: Optional[bool] = None,
859
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
860
+
861
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
862
+ output_hidden_states = (
863
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
864
+ )
865
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
866
+
867
+ _, sp_mask, _ = self.model.get_multimodal_mask(input_ids, self.config.audio_config.audio_pad_token_id, self.config.multimodal_special_token_list)
868
+ # TODO 放开部分可学习的special token lmhead参数
869
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
870
+ outputs = self.model(
871
+ input_ids=input_ids,
872
+ attention_mask=attention_mask,
873
+ position_ids=position_ids,
874
+ seqlens=seqlens,
875
+ past_key_values=past_key_values,
876
+ inputs_embeds=inputs_embeds,
877
+ audios=audios,
878
+ encoder_length=encoder_length,
879
+ bridge_length=bridge_length,
880
+ images=images,
881
+ images_grid=images_grid,
882
+ videos=videos,
883
+ videos_grid=videos_grid,
884
+ use_cache=use_cache,
885
+ output_attentions=output_attentions,
886
+ output_hidden_states=output_hidden_states,
887
+ return_dict=return_dict,
888
+ )
889
+
890
+ hidden_states = outputs[0]
891
+
892
+ # 部分可学习的special token放开lm head梯度
893
+ special_with_loss_list = list(set(self.config.multimodal_special_token_list) - set(self.config.multimodal_special_token_no_loss_list))
894
+ _, sp_with_loss_mask, lm_head_mask = self.model.get_multimodal_mask(input_ids, self.config.audio_config.audio_pad_token_id, special_with_loss_list)
895
+ if self.config.train_multimodal_special_tokens_only and self.training and len(special_with_loss_list) > 0:
896
+ if self.config.use_norm_head:
897
+ logits = self.lm_head(hidden_states, mask=lm_head_mask)
898
+ else:
899
+ lm_head_mask = lm_head_mask.to(self.lm_head.weight)
900
+ norm_weight = self.lm_head.weight * lm_head_mask + (1 - lm_head_mask) * self.lm_head.weight.detach()
901
+ logits = torch.einsum('bsd,ld->bsl', hidden_states, norm_weight)
902
+ else:
903
+ logits = self.lm_head(hidden_states)
904
+
905
+ loss = torch.tensor(0, device=hidden_states.device, dtype=hidden_states.dtype)
906
+ text_nt_loss = torch.tensor(0, device=hidden_states.device, dtype=hidden_states.dtype)
907
+ flatten_loss = None
908
+ if labels is not None:
909
+ # Shift so that tokens < n predict n
910
+ shift_logits = logits[..., :-1, :].contiguous()
911
+ shift_labels = labels[..., 1:].contiguous()
912
+
913
+ valid_mask = torch.gt(shift_labels, -1) # label < 0 视为pad位置
914
+ sp_mask = sp_mask[..., 1:].contiguous()
915
+ text_mask = torch.logical_and(valid_mask, torch.logical_not(sp_mask))
916
+
917
+ # Flatten the tokens
918
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
919
+ shift_labels = shift_labels.view(-1)
920
+ # Enable model parallelism
921
+ shift_labels = shift_labels.to(shift_logits.device)
922
+ flatten_loss = F.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction='none')
923
+
924
+ loss = torch.mean(torch.masked_select(flatten_loss, valid_mask.view(-1)))
925
+ text_nt_loss = torch.mean(torch.masked_select(flatten_loss, text_mask.view(-1))).detach()
926
+
927
+ if not return_dict:
928
+ output = (logits,) + outputs[1:]
929
+ return (loss,) + output if loss is not None else output
930
+
931
+ return OceanMMCausalLMOutputWithPast(
932
+ loss=loss,
933
+ logits=logits,
934
+ past_key_values=outputs.past_key_values,
935
+ hidden_states=outputs.hidden_states,
936
+ attentions=outputs.attentions,
937
+ text_nt_loss=text_nt_loss,
938
+ flatten_loss=flatten_loss
939
+ )
940
+
941
+
942
+ def prepare_inputs_for_generation(
943
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
944
+ ):
945
+ if past_key_values:
946
+ input_ids = input_ids[:, -1:]
947
+
948
+ position_ids = kwargs.get("position_ids", None)
949
+ if attention_mask is not None and position_ids is None:
950
+ # create position_ids on the fly for batch generation
951
+ position_ids = attention_mask.long().cumsum(-1) - 1
952
+ position_ids.masked_fill_(attention_mask == 0, 1)
953
+ if past_key_values:
954
+ position_ids = position_ids[:, -1].unsqueeze(-1)
955
+
956
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
957
+ if inputs_embeds is not None and past_key_values is None:
958
+ model_inputs = {"inputs_embeds": inputs_embeds}
959
+ elif past_key_values is not None:
960
+ model_inputs = {"input_ids": input_ids}
961
+ else:
962
+ model_inputs = {"input_ids": input_ids,
963
+ "audios": kwargs.get("audios", None), "encoder_length": kwargs.get("encoder_length", None), "bridge_length": kwargs.get("bridge_length", None),
964
+ "images": kwargs.get("images", None),
965
+ "videos": kwargs.get("videos", None)
966
+ }
967
+
968
+ model_inputs.update(
969
+ {
970
+ "position_ids": position_ids,
971
+ "past_key_values": past_key_values,
972
+ "use_cache": kwargs.get("use_cache"),
973
+ "attention_mask": attention_mask,
974
+ "images_grid": kwargs.get("images_grid"),
975
+ "videos_grid": kwargs.get("videos_grid"),
976
+ }
977
+ )
978
+ return model_inputs
979
+
980
+ @staticmethod
981
+ def _reorder_cache(past_key_values, beam_idx):
982
+ reordered_past = ()
983
+ for layer_past in past_key_values:
984
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
985
+ return reordered_past
986
+
987
+ def chat(self, tokenizer, messages: List[dict], stream=False,
988
+ generation_config: Optional[GenerationConfig]=None):
989
+ generation_config = generation_config or self.generation_config
990
+ input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
991
+ if stream:
992
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
993
+ Thread(target=self.generate, kwargs=dict(
994
+ inputs=input_ids, streamer=streamer,
995
+ generation_config=generation_config,
996
+ )).start()
997
+ return streamer
998
+ else:
999
+ outputs = self.generate(input_ids, generation_config=generation_config)
1000
+ response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
1001
+ return response
processor_ocean.py ADDED
@@ -0,0 +1,1154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import re, ujson, os, sys, fire, glob, random, time, json
3
+ import numpy as np
4
+ import io
5
+ import torch
6
+ from torch.utils.data import default_collate
7
+ import torchaudio
8
+ from typing import *
9
+ from dataclasses import dataclass, field
10
+ import transformers
11
+ from transformers.modeling_outputs import ModelOutput
12
+ from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
13
+ from functools import lru_cache
14
+ from io import BytesIO
15
+ from PIL import Image
16
+ from qcloud_cos import CosConfig
17
+ from qcloud_cos import CosS3Client
18
+ import tos
19
+ import concurrent.futures as cf
20
+ from transformers.image_transforms import resize, center_crop, get_resize_output_image_size
21
+ from transformers.image_utils import PILImageResampling
22
+ from PIL import Image, ImageOps
23
+ from PIL import ImageFile
24
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
25
+ import base64
26
+ from decord import VideoReader, cpu
27
+ import cv2
28
+ import av
29
+ import imagesize
30
+ import math
31
+
32
+
33
+ def smart_resize(
34
+ height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
35
+ ):
36
+ """Rescales the image so that the following conditions are met:
37
+
38
+ 1. Both dimensions (height and width) are divisible by 'factor'.
39
+
40
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
41
+
42
+ 3. The aspect ratio of the image is maintained as closely as possible.
43
+
44
+ """
45
+ # if height < factor or width < factor:
46
+ # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
47
+ if max(height, width) / min(height, width) > 200:
48
+ raise ValueError(
49
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
50
+ )
51
+ h_bar = round(height / factor) * factor if height > factor else factor
52
+ w_bar = round(width / factor) * factor if width > factor else factor
53
+ if h_bar * w_bar > max_pixels:
54
+ beta = math.sqrt((height * width) / max_pixels)
55
+ h_bar = math.floor(height / beta / factor) * factor
56
+ w_bar = math.floor(width / beta / factor) * factor
57
+ elif h_bar * w_bar < min_pixels:
58
+ beta = math.sqrt(min_pixels / (height * width))
59
+ h_bar = math.ceil(height * beta / factor) * factor
60
+ w_bar = math.ceil(width * beta / factor) * factor
61
+ return h_bar, w_bar
62
+
63
+
64
+ def select_best_resolution(image_size, candidate_resolutions):
65
+ '''找到最佳的resolution 对于原图进行放缩
66
+ image_size 通常为ori_size e.g. (8*336, 16*336)
67
+ candidate_resolutions 为备选分辨率 e.g. (1*336, 4*336)
68
+ '''
69
+ try:
70
+ original_width, original_height = image_size
71
+ except:
72
+ pass
73
+ best_fit = None
74
+ max_effective_resolution = 0
75
+ min_wasted_resolution = float("inf")
76
+ # 从candidate_resolutions 中遍历宽和高
77
+ for width, height in candidate_resolutions:
78
+ # width / original_width 和 height / original_height 中最小的那个作为scale
79
+ scale = min(width / original_width, height / original_height) # e.g. scale =min (1/8, 1/4) = 1/8
80
+ # 放缩 original_width 和 original_height
81
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) # e.g. 1*336, 2*336
82
+ # effective_resolution 为 放缩之后的分辨率 s^2 * w * h
83
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) # e.g. min(1*336 * 2*336, 8*336 * 16*336)
84
+ # wasted_resolution 为 放缩前后分辨率的差值
85
+ wasted_resolution = (width * height) - effective_resolution
86
+ # 若 (1) 放缩之后的分辨率 比当前的max_effective_resolution更大;
87
+ # (2)
88
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
89
+ max_effective_resolution = effective_resolution # 更新max_effective_resolution
90
+ min_wasted_resolution = wasted_resolution # min_wasted_resolution
91
+ best_fit = (width, height)
92
+
93
+ return best_fit
94
+
95
+
96
+ def read_video(image_path, max_frame_number, decode_way):
97
+ if decode_way=='1fps':
98
+ try:
99
+ vr = VideoReader(image_path, ctx=cpu(0))
100
+ total_frame_num = len(vr)
101
+ fps = round(vr.get_avg_fps())
102
+ frame_idx = [i for i in range(0, len(vr), fps)]
103
+ frames = vr.get_batch(frame_idx).asnumpy()
104
+ frames = [i for i in frames]
105
+ cnt = len(frames)
106
+ except Exception as e:
107
+ print(image_path)
108
+ print('error is', e)
109
+ return None
110
+ elif decode_way=='key':
111
+ try:
112
+ with av.open(image_path) as container:
113
+ stream = container.streams.video[0]
114
+ stream.codec_context.skip_frame = 'NONKEY'
115
+ frames = []
116
+ fps = int(stream.average_rate)
117
+ cnt = 0
118
+ for frame in container.decode(stream): # 关键帧存成image patch
119
+ image = frame.to_image()
120
+ frames.append(image)
121
+ cnt += 1
122
+ except Exception as e:
123
+ print('error is', e)
124
+ return None
125
+ if frames is None or len(frames)==0:
126
+ return None
127
+ if len(frames)>max_frame_number and max_frame_number>0:
128
+ # 生成均匀间隔的索引
129
+ indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int)
130
+ # 根据索引获取对应元素
131
+ sampled_elements = [frames[idx] for idx in indices]
132
+ frames = sampled_elements
133
+ return frames
134
+
135
+ class OceanImageProcessor:
136
+ def __init__(self, config, **kwargs):
137
+ self.config = config # visual_config
138
+ self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56
139
+ self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280
140
+ self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14
141
+ self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2
142
+ self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
143
+ self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2
144
+
145
+ def image_transform(self, strseq, return_mm_data = True):
146
+ image = None
147
+ if isinstance(strseq, str):
148
+ if return_mm_data:
149
+ image = Image.open(strseq).convert("RGB")
150
+ else:
151
+ image = Image.open(BytesIO(strseq)).convert("RGB")
152
+
153
+ image = np.array(image.convert("RGB")) # 这一步首先将图像转换为 RGB 格式,确保图像有三个通道(R、G、B)。然后使用 np.array() 将其转换为 NumPy 数组,方便后续处理。
154
+ image_org_size = image.shape[:2] # 这里保存了图像的原始大小(高度和宽度),image.shape 返回图像的形状 (高度, 宽度, 通道数),而 image.shape[:2] 提取了前两个值,即原始的高度和宽度。这个信息可以用于后续的对比或其他处理。
155
+
156
+ # resize, crop, scale, normalize
157
+ # 接受目标尺寸作为输入参数,通常是目标尺寸的短边或长边长度。例如,如果指定目标短边为 336 像素,函数会自动计算出对应的长边大小,以保持图像的宽高比。
158
+ # 输出一个新的尺寸,这个尺寸通常是 (宽度, 高度) 格式,用于后续的图像调整操作,如缩放或裁剪。
159
+ resized_height, resized_width = smart_resize(
160
+ image_org_size[0], image_org_size[1],
161
+ factor=self.patch_size * self.spatial_merge_size,
162
+ min_pixels=self.min_pixels,
163
+ max_pixels=self.max_pixels,
164
+ )
165
+ output_size = (resized_height, resized_width)
166
+
167
+ # output_size = get_resize_output_image_size(image, self.config.crop_size, False) # 短边resize到336
168
+ # 使用 resize 函数将图像调整到 output_size 大小。PILImageResampling.BICUBIC 指定使用双三次插值法来进行图像缩放,这种方法通常能够提供较好的图像质量。
169
+ # image: 输入的图像数据,可以是 NumPy 数组或 PIL 图像对象;output_size: 目标大小,通常是一个二元组 (宽度, 高度)。这个尺寸可以是图像的绝对大小,也可以是相对于原始图像的比例;
170
+ # resample: 可选的重采样方法,通常用于确定如何插值像素。例如,PILImageResampling.BICUBIC 表示使用双三次插值法,这是一种平滑的插值方法,常用于图像缩放。
171
+ image = resize(image, output_size, PILImageResampling.BICUBIC)
172
+ # 从图像中心裁剪出一个指定大小的区域,这里是一个正方形区域 self.config.crop_size x self.config.crop_size。center_crop 函数的参数 return_numpy=True 表示返回一个 NumPy 数组形式的裁剪图像。
173
+ # image = center_crop(image, (self.config.crop_size, self.config.crop_size), return_numpy=True)
174
+ img = image.transpose(2, 0, 1)
175
+ # 对图像进行归一化和标准化处理
176
+ image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis]
177
+ # 处理成patch
178
+ patches = image[np.newaxis, :]
179
+ if patches.shape[0] == 1:
180
+ patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
181
+ channel = patches.shape[1]
182
+ grid_t = patches.shape[0] // self.temporal_patch_size
183
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
184
+ patches = patches.reshape(
185
+ grid_t,
186
+ self.temporal_patch_size,
187
+ channel,
188
+ grid_h // self.spatial_merge_size,
189
+ self.spatial_merge_size,
190
+ self.patch_size,
191
+ grid_w // self.spatial_merge_size,
192
+ self.spatial_merge_size,
193
+ self.patch_size,
194
+ )
195
+ patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
196
+ flatten_patches = patches.reshape(
197
+ grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
198
+ )
199
+
200
+ return flatten_patches, image_org_size, (grid_t, grid_h, grid_w)
201
+
202
+
203
+ class OceanAudioProcessor:
204
+ # 包含基本的音频特征抽取模块 + 输入数据解析模块 + cos请求/缓存模块
205
+ def __init__(
206
+ self,
207
+ config, # audio processor config
208
+ **kwargs
209
+ ):
210
+ # make sure you have install 'conda install -c conda-forge 'ffmpeg<7'' for torchaudio
211
+ assert(len(torchaudio.list_audio_backends()) > 0)
212
+ self.config = config
213
+ self.mel_filters = mel_filter_bank(
214
+ num_frequency_bins=1 + self.config.n_fft // 2,
215
+ num_mel_filters=self.config.num_mel_bins,
216
+ min_frequency=0.0,
217
+ max_frequency=self.config.sampling_rate / 2.0,
218
+ sampling_rate=self.config.sampling_rate,
219
+ norm="slaney",
220
+ mel_scale="slaney",
221
+ )
222
+
223
+ @staticmethod
224
+ def zero_mean_unit_var_norm(x):
225
+ return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
226
+
227
+ def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False):
228
+ metadata = torchaudio.info(uri) # sample_rate, num_frames, num_channels, bits_per_sample, encoding=PCM_S
229
+ assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
230
+ waveform_tensor, _ = torchaudio.load(uri, normalize=True)
231
+ if self.config.sampling_rate != metadata.sample_rate:
232
+ waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate)
233
+
234
+ # downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
235
+ if metadata.num_channels > 1:
236
+ waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
237
+
238
+ # normalized to zero mean (Qwen Audio没有处理 但Whisper官方实现)
239
+ if do_normalize:
240
+ waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
241
+
242
+ if return_tensors: # (channels, samples)
243
+ return waveform_tensor
244
+ else:
245
+ return waveform_tensor.numpy()
246
+
247
+ def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
248
+ channels, wave_samples = waveform.shape
249
+ max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate
250
+ if wave_samples <= max_audio_samples or self.config.split_overlap < 0:
251
+ return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
252
+
253
+ split_waveform, start = [], 0
254
+ while start < wave_samples: # 20240724修改 统一按秒数对齐overlap 保证不同sampling rate/n_fft/hop length配置下采到的数据是一致的
255
+ if start > int(self.config.sampling_rate * self.config.split_overlap):
256
+ start -= int(self.config.sampling_rate * self.config.split_overlap) # 0表示没有overlap,>0 overlap对应秒数
257
+ end = min(start + max_audio_samples, wave_samples)
258
+ split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
259
+ start = end
260
+ return split_waveform
261
+
262
+ @classmethod
263
+ def inference_output_length(cls, config, input_length):
264
+ # for whisper + bridge
265
+ kernel_size = config.kernel_size
266
+ stride_size = config.stride_size
267
+ avg_pooler = config.avg_pooler
268
+ encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
269
+ encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
270
+ if avg_pooler > 1:
271
+ bridge_length = encoder_length // avg_pooler
272
+ return encoder_length, bridge_length
273
+
274
+ def extract_fbank_features(self, waveform):
275
+ # ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
276
+ channels, wave_samples = waveform.shape
277
+ assert(wave_samples >= self.config.n_fft)
278
+ valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1)
279
+ if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate:
280
+ waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0)
281
+ else:
282
+ waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate]
283
+
284
+ window = torch.hann_window(self.config.n_fft)
285
+ stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
286
+ magnitudes = stft[..., :-1].abs() ** 2
287
+
288
+ mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
289
+ mel_spec = mel_filters.T @ magnitudes
290
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
291
+
292
+ if waveform.dim() == 2:
293
+ max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
294
+ log_spec = torch.maximum(log_spec, max_val - 8.0)
295
+ else:
296
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
297
+ log_spec = (log_spec + 4.0) / 4.0
298
+
299
+ log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
300
+ log_spec[:, valid_frame_nums:] = 0.0 # pad0 在collect时取batch内最大长度
301
+
302
+ return log_spec, valid_frame_nums
303
+
304
+ def data_augment(self, feature: np.array, input_length, training=True):
305
+ # reference https://arxiv.org/pdf/1904.08779
306
+ # run only on cpu
307
+ def mask_start_indices(input_length, mask_length, min_masks, mask_prob):
308
+ # 计算总共需要mask的span数 之后随机筛选span开始下标
309
+ num_masked_span = int(mask_prob * input_length / mask_length + random.random())
310
+ num_masked_span = max(num_masked_span, min_masks)
311
+ start_indices = list(range(input_length - mask_length))
312
+ random.shuffle(start_indices)
313
+ start_indices = start_indices[:num_masked_span]
314
+ return start_indices
315
+
316
+ if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0):
317
+ return feature
318
+ if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1:
319
+ return feature
320
+ if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1:
321
+ return feature
322
+
323
+ if self.config.mask_time_prob > 0:
324
+ start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob)
325
+ for start_idx in start_indices:
326
+ feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0
327
+ if self.config.mask_feature_prob > 0:
328
+ start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob)
329
+ for start_idx in start_indices:
330
+ feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0
331
+
332
+ return feature
333
+
334
+ class CosClient():
335
+ def __init__(self, bucket_name='crawl-pic-1317568651',
336
+ max_retries=2):
337
+ self.config = CosConfig(
338
+ Endpoint="cos.ap-guangzhou.myqcloud.com",
339
+ # Region='ap-guangzhou',
340
+ SecretId='AKIDnRpxoOghgVs0tkU3Mfv20jAMI0SRDj02',
341
+ SecretKey='td9tRlqiPvEJ8i27wXwBIDiy5ye6JGyS',
342
+ Token=None, Scheme='https', Timeout=300)
343
+ self.client = CosS3Client(self.config)
344
+ self.max_retries = max_retries
345
+ self.bucket_name = bucket_name
346
+
347
+ def __call__(self, relative_path, bucket_name=None):
348
+ if bucket_name is None or len(bucket_name) <= 0:
349
+ bucket_name = self.bucket_name
350
+ multimodal_bytes = None
351
+ for _ in range(self.max_retries):
352
+ try:
353
+ response = self.client.get_object(Bucket=bucket_name, Key=relative_path)
354
+ fp = response['Body'].get_raw_stream()
355
+ multimodal_bytes = fp.read()
356
+ break
357
+ except Exception as e:
358
+ time.sleep(0.01)
359
+ continue
360
+ return multimodal_bytes
361
+
362
+ class TosClient(object):
363
+ def __init__(self):
364
+ ak = "AKLTYTM3MWY5MTFhNDgyNDk4YjhmYTE0ZTE3YTk5ZmU1MjU"
365
+ sk = "TVRRM1pUZGtaVEJqWTJJd05HSTNPR0ppWVdKa1lqYzVORFUwTlRobU1UVQ=="
366
+ endpoint = "tos-cn-beijing.ivolces.com" # "tos-cn-beijing.ivolces.com"
367
+ region = "cn-beijing"
368
+ self.bucket_name = "audio-dataset"
369
+ self.client = tos.TosClientV2(ak, sk, endpoint, region)
370
+
371
+ def __call__(self, path, bucket_name=None):
372
+ if bucket_name is None:
373
+ bucket_name = self.bucket_name
374
+ for _ in range(2):
375
+ try:
376
+ object_stream = self.client.get_object(bucket_name, path)
377
+ return object_stream.read()
378
+ except Exception as e:
379
+ time.sleep(0.01)
380
+ continue
381
+ return None
382
+
383
+ @dataclass
384
+ class OceanProcessorOutput(ModelOutput):
385
+ input_ids: Optional["List|torch.Tensor"] = None
386
+ labels: Optional["List|torch.Tensor"] = None
387
+ attention_mask: Optional["List|torch.Tensor"] = None
388
+ position_ids: Optional["List|torch.Tensor"] = None
389
+ seqlens: Optional["List|torch.Tensor"] = None # 需要配合Ocean Modeling使用
390
+ # audio fields
391
+ audios: Optional["List|torch.Tensor"] = None
392
+ encoder_length: Optional["List|torch.Tensor"] = None
393
+ bridge_length: Optional["List|torch.Tensor"] = None
394
+ # image fields
395
+ images: Optional["List|torch.Tensor"] = None
396
+ patch_nums: Optional["List|torch.Tensor"] = None
397
+ images_size: Optional["List|torch.Tensor"] = None
398
+ crop_size: Optional["List|torch.Tensor"] = None
399
+ images_grid: Optional["List|torch.Tensor"] = None
400
+ # video fields
401
+ videos: Optional["List|torch.Tensor"] = None
402
+ videos_patch_nums: Optional["List|torch.Tensor"] = None
403
+ videos_size: Optional["List|torch.Tensor"] = None
404
+ videos_crop_size: Optional["List|torch.Tensor"] = None
405
+ videos_grid: Optional["List|torch.Tensor"] = None
406
+ # processor fields
407
+ raw_text: Optional[str] = None
408
+ index: Optional[int] = None
409
+
410
+ def concatenate(self, other): # 仅限list使用
411
+ def concat_one(a, b):
412
+ if a is None and b is None:
413
+ return None
414
+ elif a is None and b is not None:
415
+ return b
416
+ elif a is not None and b is None:
417
+ return a
418
+ else:
419
+ return a + b
420
+ return OceanProcessorOutput(
421
+ input_ids=concat_one(self.input_ids, other.input_ids),
422
+ labels=concat_one(self.labels, other.labels),
423
+ audios=concat_one(self.audios, other.audios),
424
+ encoder_length=concat_one(self.encoder_length, other.encoder_length),
425
+ bridge_length=concat_one(self.bridge_length, other.bridge_length),
426
+ images=concat_one(self.images, other.images),
427
+ images_grid=concat_one(self.images_grid, other.images_grid),
428
+ patch_nums=concat_one(self.patch_nums, other.patch_nums),
429
+
430
+ videos=concat_one(self.videos, other.videos),
431
+ videos_grid=concat_one(self.videos_grid, other.videos_grid),
432
+ videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums),
433
+
434
+ position_ids=concat_one(self.position_ids, other.position_ids),
435
+ seqlens=concat_one(self.seqlens, other.seqlens),
436
+ images_size=concat_one(self.images_size, other.images_size)
437
+ )
438
+
439
+ class OceanMMProcessor(object):
440
+ def __init__(self,
441
+ tokenizer: transformers.PreTrainedTokenizer,
442
+ config,
443
+ training,
444
+ relative_path=None,
445
+ **kwargs,
446
+ ):
447
+ self.tokenizer = tokenizer
448
+ self.config = config
449
+ self.audio_processor = None
450
+ if hasattr(config, "audio_config"):
451
+ self.audio_processor = OceanAudioProcessor(config.audio_config)
452
+ self.visual_processor = None
453
+ if hasattr(config, "visual_config"):
454
+ self.visual_processor = OceanImageProcessor(config.visual_config)
455
+ self.video_processor = None
456
+ if hasattr(config, "video_config"):
457
+ self.video_processor = OceanImageProcessor(config.video_config)
458
+ self.training = training
459
+ self.relative_path = relative_path
460
+ self.cos_client = CosClient()
461
+ self.tos_client = TosClient()
462
+ # audio tag
463
+ self.audio_start_tag = None
464
+ self.audio_end_tag = None
465
+ self.audio_pad_tag = None
466
+ self.audio_delim_tag = None
467
+ if hasattr(self.config, "audio_config"):
468
+ self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id)
469
+ self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id)
470
+ self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id)
471
+ self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id)
472
+ # image tag
473
+ self.image_start_tag = None
474
+ self.image_end_tag = None
475
+ self.image_pad_tag = None
476
+ self.video_start_tag = None
477
+ self.video_end_tag = None
478
+ if hasattr(self.config, "visual_config"):
479
+ # special token for start_tag
480
+ self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id)
481
+ # special token for end_tag
482
+ self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id)
483
+ # special token for pad_tag
484
+ self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id)
485
+ self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id)
486
+ self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id)
487
+ if hasattr(self.config, "video_config"):
488
+ self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id)
489
+ self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id)
490
+ self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id)
491
+ self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id)
492
+ self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id)
493
+ self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id)
494
+
495
+
496
+ # @lru_cache(maxsize=1024)
497
+ def _get_audio(self, audio_info, return_mm_data = True):
498
+ try:
499
+ audio_info = ujson.loads(audio_info)
500
+ audio_uri = None
501
+ if 'path' in audio_info.keys():
502
+
503
+ if self.relative_path is not None: # 优先匹配本地路径
504
+ audio_uri = os.path.join(self.relative_path, audio_info['path'])
505
+ if not os.path.exists(audio_uri):
506
+ audio_uri = None
507
+ if audio_uri is None: # 本地没有尝试取cos/tos
508
+ if audio_info.get('server', 'cos') == 'tos':
509
+ audio_uri = self.tos_client(audio_info['path'], 'audio-dataset')
510
+ else:
511
+ audio_uri = self.cos_client(audio_info['path'], 'audio-data-1317568651')
512
+
513
+ elif 'local' in audio_info.keys():
514
+ audio_uri = audio_info['local']
515
+ if not os.path.exists(audio_uri):
516
+ audio_uri = None
517
+ return OceanProcessorOutput()
518
+ else:
519
+ raise ValueError("can not find path or local in audio_info")
520
+
521
+ waveforms = self.audio_processor.load_audio_waveform(audio_uri, True)
522
+ waveforms = self.audio_processor.split_with_overlap(waveforms) # 分割逻辑
523
+ ret = OceanProcessorOutput() # 默认初始化 audios字段为None
524
+ for waveform in waveforms:
525
+ audio, input_length = self.audio_processor.extract_fbank_features(waveform)
526
+ audio = self.audio_processor.data_augment(audio, input_length, self.training)
527
+ encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length)
528
+ if bridge_length <= 0: # 过滤极端短数据 1. 如果len(waveforms)==1 ret=None; 2. len(waveforms)>1 则说明最后一段太短被抛弃
529
+ continue
530
+ current_ret = OceanProcessorOutput(
531
+ audios=[audio],
532
+ encoder_length=[encoder_length],
533
+ bridge_length=[bridge_length])
534
+ if ret.audios is None:
535
+ ret = current_ret
536
+ else:
537
+ ret = ret.concatenate(current_ret) # 拼接多个切片
538
+ if not return_mm_data:
539
+ ret.audios = [None]
540
+ return ret
541
+ except Exception as e:
542
+ print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info)))
543
+ return OceanProcessorOutput()
544
+
545
+ # @lru_cache(maxsize=1024)
546
+ def _get_image(self, image_info, return_mm_data = True):
547
+ try:
548
+ try: # chensong
549
+ image_info = ujson.loads(image_info)
550
+ except:
551
+ #image_info = image_info.replace("'", '"')
552
+ image_info = re.sub(r"(?<!\\)'", '"', image_info)
553
+ image_info = ujson.loads(image_info)
554
+ if 'base64' in image_info.keys():
555
+ image_data = base64.b64decode(image_info['base64'])
556
+ image_feat, org_size, image_list = self.visual_processor.image_transform(image_data)
557
+ elif 'local' in image_info.keys():
558
+ image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local'],return_mm_data = return_mm_data)
559
+ elif 'path' in image_info.keys():
560
+ if "tos_bucket" in image_info.keys(): # tos上的每个item,一定要写明tos的桶以及tos_bucket这个key
561
+ tos_bucket = image_info['tos_bucket']
562
+ image_bytes = self.tos_client(image_info['path'], tos_bucket) # 从cos_client 获得 image
563
+ else:
564
+ cos_bucket = None
565
+ if "cos_bucket" in image_info.keys():
566
+ cos_bucket = image_info['cos_bucket']
567
+ if "bucket_name" in image_info.keys():
568
+ cos_bucket = image_info['bucket_name']
569
+ image_bytes = self.cos_client(image_info['path'], cos_bucket) # 从cos_client 获得 image
570
+ # 获得image_feat(image patches), org_size(image最初的size), image_list
571
+ image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes)
572
+ else:
573
+ raise ValueError("can not find any path in image_info")
574
+
575
+ merge_length = self.visual_processor.merge_size**2
576
+ patch_nums = np.array(image_list).prod() // merge_length
577
+
578
+ if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
579
+ return OceanProcessorOutput(
580
+ images=[image_feat],
581
+ patch_nums=[patch_nums],
582
+ crop_size=[image_list],
583
+ images_size= [org_size],
584
+ images_grid=[image_list]
585
+ )
586
+ else:
587
+ print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info)))
588
+ return OceanProcessorOutput()
589
+
590
+ except Exception as e:
591
+ print("**** get image error: {}, info: {} *****".format(str(e), str(image_info)))
592
+ return OceanProcessorOutput()
593
+
594
+ # @lru_cache(maxsize=1024)
595
+ def _get_video_frame(self, video_frame_info, return_mm_data = True):
596
+ try:
597
+ pattern = r'\{.*?\}'
598
+ matches = re.findall(pattern, video_frame_info)
599
+ ret = OceanProcessorOutput()
600
+ # 逐个解析
601
+ for match in matches:
602
+ video_frame_info = ujson.loads(match)
603
+ if 'local' in video_frame_info.keys():
604
+ image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'],return_mm_data = return_mm_data)
605
+ else:
606
+ raise ValueError("can not find any path in image_info")
607
+
608
+ merge_length = self.video_processor.merge_size**2
609
+ patch_nums = np.array(image_list).prod() // merge_length
610
+
611
+ if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
612
+ ret = ret.concatenate(
613
+ OceanProcessorOutput(
614
+ videos=[image_feat],
615
+ videos_patch_nums=[patch_nums],
616
+ videos_crop_size=[image_list],
617
+ videos_size= [org_size],
618
+ videos_grid=[image_list]
619
+ )
620
+ )
621
+ else:
622
+ print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info)))
623
+ return ret
624
+
625
+ except Exception as e:
626
+ print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info)))
627
+ return OceanProcessorOutput()
628
+
629
+ # 读取视频
630
+ def _get_video_obj_byte(self, source, path, video_obj_json):
631
+ video_obj_byte = None
632
+ if source == "cos":
633
+ start_time = time.time()
634
+ video_obj_byte = self.cos_client(path, bucket_name=video_obj_json.get("cos_bucket", None))
635
+ if (time.time() - start_time) > 1.0:
636
+ self.reflash_cos_client()
637
+ if source == "local":
638
+ if os.path.exists(path):
639
+ video_obj_byte = open(path, "rb").read()
640
+ else:
641
+ video_obj_byte = None
642
+ if source == "base64":
643
+ video_obj_byte = base64.b64decode(path)
644
+ if source == "url":
645
+ video_obj_byte = requests.get(url=path).content
646
+ return video_obj_byte
647
+
648
+ # 将视频切分为帧,保存至子目录中
649
+ def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"):
650
+ video_path = video_info['local']
651
+ # 帧保存本地路径
652
+ frame_path = video_path.split('.')[0] + '_frames'
653
+ if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0:
654
+ # 保存帧
655
+ os.makedirs(frame_path, exist_ok=True)
656
+ mm_obj_byte = self._get_video_obj_byte('local', video_path, video_info)
657
+ if mm_obj_byte is None: # 未读取到视频文件
658
+ return ""
659
+ frames = read_video(io.BytesIO(mm_obj_byte), max_frame_number=max_frame_number, decode_way=decode_way) #读取全部帧
660
+ for frame_idx, frame in enumerate(frames):
661
+ output_filename = os.path.join(frame_path, f"{frame_idx}.jpg")
662
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
663
+ cv2.imwrite(output_filename, frame)
664
+
665
+ # 选取帧
666
+ frame_number = len([filename for filename in os.listdir(frame_path) if filename.endswith('.jpg')])
667
+ if frame_number>max_frame_number:
668
+ indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int)
669
+ else:
670
+ indices = np.linspace(0, frame_number - 1, frame_number, dtype=int)
671
+ # 拼接模式
672
+ replace_str = ""
673
+ for idx in indices:
674
+ frame_str = f"{self.image_start_tag}{os.path.join(frame_path, f'{idx}.jpg')}{self.image_end_tag}"
675
+ replace_str += frame_str
676
+ return replace_str
677
+
678
+ def _get_video_frame_str(self, video_info, return_mm_data = True ):
679
+ try:
680
+ video_info = ujson.loads(video_info)
681
+ if 'local' in video_info.keys():
682
+ # 获取包含多帧图像路径的字符串,最大帧数量max_frame_number
683
+ frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way)
684
+ if frames_str != "":
685
+ parts = frames_str.split(self.image_end_tag)
686
+ result = []
687
+ for part in parts:
688
+ if self.image_start_tag in part:
689
+ before_path, path = part.split(self.image_start_tag)
690
+ new_path = f'{self.image_start_tag}{{"local": "{path}"}}{self.image_end_tag}'
691
+ result.append(before_path + new_path)
692
+ else:
693
+ result.append(part)
694
+ return ''.join(result)
695
+ else:
696
+ raise ValueError('can not find localpath in video_info')
697
+ except Exception as e:
698
+ print("**** get video error: {}, info: {} *****".format(str(e), str(video_info)))
699
+ return ""
700
+
701
+ # def _replace_audio(self, audio_text, return_mm_data = True):
702
+ # audio_info = re.sub(re.compile(self.audio_start_tag + "|" + self.audio_end_tag), '', audio_text)
703
+ # ret = self._get_audio(audio_info, return_mm_data) # 重复取结果 cached result
704
+ def _replace_audio(self, audio_text, mminfo_ret_dict):
705
+ audio_info = re.sub(re.compile(self.audio_start_tag + "|" + self.audio_end_tag), '', audio_text)
706
+ # ret = self._get_audio(audio_info) # 重复取结果 cached result
707
+ ret = mminfo_ret_dict.get(audio_info, OceanProcessorOutput()) # 直接从字典取
708
+ if ret.bridge_length is not None: # TODO 如果pad token很多 tokenizer效率会很低
709
+ replaced_text = [self.audio_pad_tag * l for l in ret.bridge_length]
710
+ replaced_text = self.audio_delim_tag.join(replaced_text)
711
+ return self.audio_start_tag + replaced_text + self.audio_end_tag
712
+ return ''
713
+
714
+ # def _replace_image(self, image_text, return_mm_data = True):
715
+ # image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
716
+ # ret = self._get_image(image_info, return_mm_data) # 重复取结果 cached result
717
+ def _replace_image(self, image_text, mminfo_ret_dict):
718
+ image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
719
+ # ret = self._get_image(image_info) # 重复取结果 cached result
720
+ ret = mminfo_ret_dict.get(image_info, OceanProcessorOutput()) # 直接从字典取
721
+ if ret.patch_nums is None:
722
+ return ''
723
+ return self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag
724
+ return ''
725
+
726
+ # def _replace_video_frame(self, video_frame_text, return_mm_data = True):
727
+ # video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text)
728
+ # ret = self._get_video_frame(video_frame_info, return_mm_data) # 重复取结果 cached result
729
+ def _replace_video_frame(self, video_frame_text, mminfo_ret_dict):
730
+ video_frame_info = re.sub(re.compile(self.video_start_tag + '|' + self.video_end_tag), '', video_frame_text)
731
+ video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_info)
732
+ # ret = self._get_video_frame(video_frame_info) # 重复取结果 cached result
733
+ ret = mminfo_ret_dict.get(video_frame_info, OceanProcessorOutput())
734
+ if ret.videos_patch_nums is None:
735
+ return ''
736
+ video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))]
737
+ return ''.join(video_frame_str)
738
+
739
+ def extract_replace_multimodal(self, text, mtype='audio', return_mm_data = True):
740
+ # 抽取text中的json格式音频/图像信息,读取并转化为特征,同时估计encoder token数,填入对应数量的pad token
741
+ if (self.audio_start_tag != None) and (mtype == 'audio'):
742
+ match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag)
743
+ drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag)
744
+ extract_func = self._get_audio
745
+ replace_func = self._replace_audio
746
+ elif (self.image_start_tag != None) and (mtype == 'image'):
747
+ match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag)
748
+ drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag)
749
+ extract_func = self._get_image
750
+ replace_func = self._replace_image
751
+ elif (self.video_start_tag != None) and (mtype == 'video'):
752
+ video_match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag)
753
+ video_drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag)
754
+ # 处理视频,将视频路径转换为多帧图像路径
755
+ mm_info_list = re.findall(video_match_regex, text)
756
+ for mm_info in mm_info_list:
757
+ frame_str = self._get_video_frame_str(re.sub(video_drop_regex, '', mm_info))
758
+ # 替换路径;如果视频不存在,路径替换为空字符串
759
+ text = re.sub(mm_info, self.video_start_tag + frame_str + self.video_end_tag, text)
760
+ # 采用多图像处理方式
761
+ match_regex = re.compile(self.video_start_tag+r'(.*?)'+self.video_end_tag)
762
+ drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag)
763
+ extract_func = self._get_video_frame
764
+ replace_func = self._replace_video_frame
765
+ else:
766
+ raise ValueError("mtype not supportted!")
767
+
768
+ mm_info_list = re.findall(match_regex, text)
769
+ mm_info_list = [re.sub(drop_regex, '', mm_info) for mm_info in mm_info_list]
770
+
771
+ mminfo_ret_dict = {}
772
+ ret = OceanProcessorOutput()
773
+ for mm_info in mm_info_list: # 如果没有匹配到对应的模态 直接返回raw_text=text 结果不会是None
774
+ mm_ret = extract_func(mm_info, return_mm_data = return_mm_data)
775
+ mminfo_ret_dict[mm_info] = mm_ret
776
+ if mm_ret.audios is None and mm_ret.images is None and mm_ret.videos is None: # 数据包含音频/图像/视频但抽取失败 整条数据无效(ret的raw_text为None
777
+ return ret
778
+ ret = ret.concatenate(mm_ret) # 可能有多条结果,初步collect
779
+ # ret.raw_text = re.sub(match_regex, lambda x: replace_func(x.group()), text)
780
+ ret.raw_text = re.sub(match_regex, lambda x: replace_func(x.group(), mminfo_ret_dict), text)
781
+ return ret
782
+
783
+ def process_one(self, text, index=0, raw_only=False, return_mm_data = True):
784
+ ret = OceanProcessorOutput(index=index)
785
+ for mtype in self.config.multimodal: # 循环获取音频 图像结果 并更新raw_text字段
786
+ mret = self.extract_replace_multimodal(text, mtype, return_mm_data = return_mm_data) # 增加获取视频结果
787
+ if mret.raw_text is None: # 数据包含音频但音频获取失败
788
+ return OceanProcessorOutput(index=index)
789
+ ret = ret.concatenate(mret)
790
+ text = mret.raw_text
791
+ ret.raw_text = text
792
+ if raw_only:
793
+ return ret # 兼容SFT等自定义tokenizer逻辑的代码
794
+
795
+ # 处理预训练中的trainable部分
796
+ input_ids, labels = [], []
797
+ trainable_sep = re.findall(r'<trainable_start>|<trainable_end>', ret.raw_text.replace('\n', '<LF>'))
798
+ if len(trainable_sep) <= 0:
799
+ input_ids = self.tokenizer(ret.raw_text, padding='do_not_pad', truncation=True, return_tensors="np")['input_ids'][0].tolist()
800
+ labels = [True for _ in input_ids]
801
+ else:
802
+ split_content = re.split(r'<trainable_start>|<trainable_end>', ret.raw_text)
803
+ for i, sc in enumerate(split_content):
804
+ if len(sc.strip()) == 0:
805
+ continue # 把多余的空格干掉
806
+ sc_ids = self.tokenizer(sc, padding='do_not_pad', truncation=True, return_tensors="np")['input_ids'][0].tolist()
807
+ input_ids.extend(sc_ids)
808
+ if i == 0 or trainable_sep[i - 1] == '<trainable_end>': # stop gradient
809
+ labels.extend([False] * len(sc_ids))
810
+ else:
811
+ labels.extend([True] * len(sc_ids))
812
+ # input_ids += [self.tokenizer.eos_token_id]
813
+ # labels += [True]
814
+ ret.labels = [input_ids[j] if (l and input_ids[j] not in self.config.multimodal_special_token_no_loss_list) else -100 for j, l in enumerate(labels)]
815
+ ret.input_ids = input_ids
816
+ ret.index = index
817
+ return ret
818
+
819
+ @torch.no_grad()
820
+ def __call__(self, example, parallel=8):
821
+ # 最终入口 支持预训练数据string,sft数据message, 以及 batch推理数据listofstring 3种形式
822
+ if isinstance(example, Dict):
823
+ pass
824
+ elif isinstance(example, str):
825
+ return self.process_one(example)
826
+ elif isinstance(example, List): # batch推理 异步多线程处理
827
+ with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor:
828
+ future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)]
829
+ batch_data = [key.result() for key in cf.as_completed(future_list)]
830
+ valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data])
831
+ assert(valid_num == len(batch_data)) # 推理数据严格要求数量对齐
832
+ batch_data = sorted(batch_data, key=lambda x: x.index) # 保证顺序不变
833
+
834
+ ret = OceanProcessorOutput()
835
+ for i in range(len(batch_data)):
836
+ ret = ret.concatenate(batch_data[i])
837
+ self.tokenizer.padding_side = "left"
838
+ padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt')
839
+ ret.input_ids = padding_result["input_ids"]
840
+ ret.attention_mask = padding_result["attention_mask"] # batch推理不pack 不需要seqlens
841
+ padding_result = self.tokenizer.pad({"input_ids": [r.labels for r in batch_data]}, return_tensors='pt')
842
+ ret.labels = padding_result["input_ids"]
843
+
844
+ if ret.audios is not None:
845
+ ret.audios = default_collate(ret.audios)
846
+ ret.encoder_length = default_collate(ret.encoder_length)
847
+ ret.bridge_length = default_collate(ret.bridge_length)
848
+
849
+ if ret.images is not None:
850
+ ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]
851
+ # else:ret.images = default_collate(ret.images)
852
+ # ret.patch_nums = default_collate(ret.patch_nums)
853
+
854
+ if ret.videos is not None:
855
+ ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos]
856
+
857
+ return ret
858
+
859
+ else:
860
+ raise ValueError("example format supported yet")
861
+
862
+ @torch.no_grad()
863
+ def pack_batch_pretrain(self, raw_batch, max_sequence_length=None, parallel=8):
864
+ if max_sequence_length is None:
865
+ max_sequence_length = self.tokenizer.model_max_length
866
+ # 将N条数据pack为M条 max_sequence_length长度的数据, 每条数据包含所属的多模态输入
867
+ assert isinstance(raw_batch, List)
868
+ start_ts = time.time()
869
+ if parallel > 1:
870
+ with cf.ThreadPoolExecutor(max_workers=parallel) as executor:
871
+ future_list = []
872
+ for idx, json_text in enumerate(raw_batch):
873
+ try: # 读取json
874
+ json_obj = ujson.loads(json_text.strip())
875
+ except:
876
+ try:
877
+ json_obj = ast.literal_eval(json_text.strip())
878
+ except:
879
+ print("parse json obj faild: {}....".format(json_text[:300]))
880
+ continue
881
+ try: # chensong
882
+ if isinstance(json_obj, list):
883
+ content = json_obj[1]
884
+ elif 'raw' in json_obj.keys():
885
+ content = (json_obj["title"] if "title" in json_obj.keys() else "") + json_obj["raw"]
886
+ else:
887
+ content = (json_obj["title"] if "title" in json_obj.keys() else "") + json_obj["content"]
888
+ except:
889
+ print("parse json raw/content error: {}....".format(json_text[:300]))
890
+ continue
891
+
892
+ future_list.append(executor.submit(self.process_one, content, idx))
893
+ # 获取结果 乱序
894
+ batch_data = [key.result() for key in cf.as_completed(future_list)]
895
+ else: # debug only
896
+
897
+ batch_data = []
898
+ for json_text in raw_batch:
899
+ data = ujson.loads(json_text.strip())
900
+ if 'raw' in data.keys():
901
+ batch_data.append(self.process_one(data['raw'], 0))
902
+ else:
903
+ batch_data.append(self.process_one(data['content'], 0))
904
+
905
+ if (time.time() - start_ts) / (len(batch_data) + 1e-3) > 1.0:
906
+ print('[WARNING] processing each data cost more than 1.0s')
907
+
908
+ # packing 文本部分的输入,不做任何截断
909
+ current_length, packed_output, output = 0, OceanProcessorOutput(position_ids=[], seqlens=[]), []
910
+ empty_data = OceanProcessorOutput(input_ids=[], labels=[])
911
+ for idx, bd in enumerate(batch_data + [empty_data]): # 加空数据方便appedn最后一个数据到output,防止遗漏
912
+ if bd.input_ids is None and idx < len(batch_data):
913
+ continue # 数据没取到 并且不是最后一个
914
+ if (len(bd.input_ids) <= 0 or len(bd.input_ids) + 1 > max_sequence_length) and idx < len(batch_data):
915
+ continue # 太长的直接不要 并且不是最后一个
916
+ if current_length + len(bd.input_ids) + 1 > max_sequence_length or idx == len(batch_data):
917
+ pad_nums = max_sequence_length - current_length # right padding
918
+ if packed_output.input_ids is None or packed_output.labels is None:
919
+ packed_output.input_ids = [self.tokenizer.pad_token_id] * pad_nums
920
+ packed_output.labels = [-100] * pad_nums
921
+ packed_output.position_ids += [0] * (pad_nums+1)
922
+ else:
923
+ packed_output.input_ids += [self.tokenizer.pad_token_id] * pad_nums
924
+ packed_output.labels += [-100] * pad_nums
925
+ packed_output.position_ids += [0] * pad_nums
926
+ packed_output.attention_mask = [1] * current_length + [0] * pad_nums
927
+ packed_output.seqlens += [0] * (max_sequence_length - len(packed_output.seqlens))
928
+ output.append(packed_output)
929
+ packed_output = OceanProcessorOutput(position_ids=[], seqlens=[]) # reset empty
930
+ packed_output = packed_output.concatenate(bd)
931
+ packed_output.input_ids.append(self.tokenizer.eos_token_id) # </s>需要单独加
932
+ packed_output.labels.append(self.tokenizer.eos_token_id)
933
+
934
+ packed_output.position_ids.extend(list(range(len(bd.input_ids) + 1)))
935
+ packed_output.seqlens.append(len(bd.input_ids) + 1)
936
+
937
+ current_length = len(packed_output.input_ids)
938
+ return output
939
+
940
+ @torch.no_grad()
941
+ def collect_batch_pretrain(self, batch_data):
942
+ ret = OceanProcessorOutput()
943
+ for i in range(len(batch_data)):
944
+ ret = ret.concatenate(batch_data[i])
945
+ ret.input_ids = default_collate([np.asarray(x.input_ids, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True)
946
+ ret.labels = default_collate([np.asarray(x.labels, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True)
947
+ ret.attention_mask = default_collate([np.asarray(x.attention_mask, dtype=np.float32) for x in batch_data]).cuda(non_blocking=True)
948
+ ret.position_ids = default_collate([np.asarray(x.position_ids, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True)
949
+ ret.seqlens = default_collate([np.asarray(x.seqlens, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True)
950
+
951
+ ret.raw_text = None
952
+ if ret.audios is not None:
953
+ ret.audios = default_collate(np.asarray(ret.audios, dtype=np.float32)).cuda(non_blocking=True)
954
+ ret.encoder_length = default_collate(np.asarray(ret.encoder_length, dtype=np.int32)).cuda(non_blocking=True)
955
+ ret.bridge_length = default_collate(np.asarray(ret.bridge_length, dtype=np.int32)).cuda(non_blocking=True)
956
+ if ret.images is not None:
957
+ ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)).cuda(non_blocking=True) for image in ret.images]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True)
958
+ ret.patch_nums = default_collate(np.asarray(ret.patch_nums, dtype=np.int32)).cuda(non_blocking=True)
959
+ if ret.videos is not None:
960
+ ret.videos = [torch.from_numpy(np.asarray(video, dtype=np.float32)).cuda(non_blocking=True) for video in ret.videos]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True)
961
+ ret.videos_patch_nums = default_collate(np.asarray(ret.videos_patch_nums, dtype=np.int32)).cuda(non_blocking=True)
962
+
963
+ return ret
964
+
965
+ @torch.no_grad()
966
+ def collect_batch_sft(self, batch_data):
967
+ # list of dict to dataclass
968
+ batch_data = [OceanProcessorOutput(**bd) for bd in batch_data]
969
+ ret = OceanProcessorOutput()
970
+ for i in range(len(batch_data)):
971
+ ret = ret.concatenate(batch_data[i])
972
+ ret.input_ids = default_collate([np.asarray(x.input_ids, dtype=np.int64) for x in batch_data])
973
+ ret.labels = default_collate([np.asarray(x.labels, dtype=np.int64) for x in batch_data])
974
+ ret.position_ids = default_collate([np.asarray(x.position_ids, dtype=np.int64) for x in batch_data])
975
+ ret.seqlens = default_collate([np.asarray(x.seqlens, dtype=np.int64) for x in batch_data])
976
+
977
+ ret.raw_text = None
978
+ if ret.audios is not None:
979
+ ret.audios = default_collate(np.asarray(ret.audios, dtype=np.float32))
980
+ ret.encoder_length = default_collate(np.asarray(ret.encoder_length, dtype=np.int32))
981
+ ret.bridge_length = default_collate(np.asarray(ret.bridge_length, dtype=np.int32))
982
+ if ret.images is not None:
983
+ # 转换 每个image 为torch tensor
984
+ ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True)
985
+ if ret.videos is not None:
986
+ ret.videos = [torch.from_numpy(np.asarray(video, dtype=np.float32)) for video in ret.videos]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True)
987
+
988
+ # ret.patch_nums = default_collate(np.asarray(ret.patch_nums, dtype=np.int32)).cuda(non_blocking=True)
989
+
990
+ ret = ret.__dict__
991
+ del ret['patch_nums']
992
+ del ret['images_size']
993
+ del ret['crop_size']
994
+ del ret['raw_text']
995
+ del ret['index']
996
+ del ret['attention_mask']
997
+ del ret['videos_patch_nums']
998
+ del ret['videos_size']
999
+ del ret['videos_crop_size']
1000
+ return ret
1001
+
1002
+
1003
+ #######################################################
1004
+ ## Unit Test Functions, usage
1005
+ ## python processor_ocean.py test
1006
+ #######################################################
1007
+
1008
+ def test_img_processor():
1009
+ from transformers import AutoConfig
1010
+ from transformers.models.clip import CLIPImageProcessor
1011
+ config = AutoConfig.from_pretrained("./", trust_remote_code=True)
1012
+ processor = OceanImageProcessor(config.visual_config)
1013
+ offical_processor = CLIPImageProcessor(size=config.visual_config.crop_size, crop_size=config.visual_config.crop_size,
1014
+ image_mean=config.visual_config.image_mean, image_std=config.visual_config.image_std,
1015
+ do_convert_rgb=True)
1016
+ img_files = ['sogou/7a2c8ffc1bc61146b32805c3390f42e2', 'wukong/77c1db1c0e4200d12b478c33ba3a412d', 'wukong/62e9a5c8eb8b0ea8858a34ba3f1a999f', 'wukong/fb9ab4d7c3fe9f54289948fd6a57fc30']
1017
+ cos_client = CosClient()
1018
+ for img_file in img_files:
1019
+ img_bytes = cos_client(img_file)
1020
+ img_rbg = Image.open(io.BytesIO(img_bytes))
1021
+ image, org_size = processor.image_transform(img_bytes)
1022
+ offical_image = offical_processor.preprocess([img_rbg],
1023
+ do_resize=True, do_center_crop=True, do_rescale=True, do_normalize=True,
1024
+ return_tensors='np').data['pixel_values'][0]
1025
+ print('-'*60)
1026
+ print(np.array(img_rbg).shape)
1027
+ print(image.shape)
1028
+ print(offical_image.shape)
1029
+ print(image - offical_image)
1030
+
1031
+ def test_audio_processor():
1032
+ from transformers.models.whisper import WhisperFeatureExtractor
1033
+ from transformers import AutoConfig
1034
+ config = AutoConfig.from_pretrained("./", trust_remote_code=True)
1035
+ offical_processor = WhisperFeatureExtractor(feature_size=128)
1036
+ processor = OceanAudioProcessor(config.audio_config)
1037
+ # wave_files = glob.glob('/home/nfs_bc_alignment/sunhaoze/audio-data/openaqa/openaqa-as/audio/*')
1038
+ wave_files = ['/home/nfs_bc_alignment/sunhaoze/sounds/audioset_full/7ZY0U5tfKyQ.flac', '/home/nfs_bc_alignment/sunhaoze/sounds/audioset_full/Osly4Shchs4.flac']
1039
+ for wave_file in wave_files:
1040
+ wave = processor.load_audio_waveform(wave_file, True, False)
1041
+ offical_features = offical_processor(wave[0].numpy(), do_normalize=False)
1042
+ feat = offical_features['input_features'][0]
1043
+ wave, frame_nums = processor.extract_fbank_features(wave)
1044
+ print("="*60)
1045
+ print(feat.shape)
1046
+ print(wave.shape, frame_nums)
1047
+ print('the difference between offical extractor and our implementation: {}'.format(wave_file))
1048
+ print(wave[:, :frame_nums] - feat[:, :frame_nums])
1049
+ print(wave)
1050
+ # print(wave[120:-1, :])
1051
+ # print(feat[120:-1, :wave.shape[1]])
1052
+ zeros_before = np.sum(wave == 0)
1053
+ aug = processor.data_augment(wave, frame_nums)
1054
+ zeros_after = np.sum(aug == 0)
1055
+ print(zeros_before, zeros_after)
1056
+
1057
+
1058
+ def test_audio_long(): # 测试超过30秒音频的截断策略
1059
+ from transformers import AutoConfig, AutoTokenizer
1060
+ config = AutoConfig.from_pretrained("./", trust_remote_code=True)
1061
+ config.audio_config.split_overlap = 1
1062
+ tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096)
1063
+ processor = OceanMMProcessor(tokenizer, config, True)
1064
+ examples = ["<audio_start_ocean>{\"path\": \"panda\/testdata\/podcast_demo_30s\/easy_chat_xianliaohuier_30s\/easy_chat_xianliaohuier-133.mp3\"}<audio_end_ocean>What is the level of noise from the speech?\n<trainable_start>The speech energy\n is medium.<trainable_end>",
1065
+ "what's the sound's energy? \n sound1 <audio_start_ocean>{\"path\": \"panda\/testdata\/podcast_demo_30s\/btrt_talk_heihua_30s\/btrt_talk_heihua-116.mp3\"}<audio_end_ocean> \n sound2 <audio_start_ocean>{\"path\": \"panda\/testdata\/podcast_demo_30s\/btrt_talk_heihua_30s\/btrt_talk_heihua-221.mp3\"}<audio_end_ocean>The speech energy is medium.",
1066
+ ]
1067
+ ret = processor(examples)
1068
+ print(ret)
1069
+ print(torch.sum(ret.input_ids == 151659))
1070
+ print(torch.sum(ret.input_ids == 151674))
1071
+
1072
+ def test_processor():
1073
+ from transformers import AutoConfig, AutoTokenizer
1074
+ config = AutoConfig.from_pretrained("./", trust_remote_code=True)
1075
+ tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096)
1076
+ processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds')
1077
+ examples = ["<audio_start_ocean>{\"path\": \"vggsound\/7DH5fqj8j6Q.flac\"}<audio_end_ocean>What is the level of noise from the speech?\n<trainable_start>The speech energy\n is medium.<trainable_end>",
1078
+ "hello, ocean 你好 百川智能。",
1079
+ "what's the sound's energy? \n <audio_start_ocean>{\"path\": \"iemocap\/Ses01F_script01_3_F022.wav\"}<audio_end_ocean>The speech energy is medium.",
1080
+ "sound1: <audio_start_ocean>{\"path\": \"audioset_full\/9B53NVDNT8U.flac\"}<audio_end_ocean>\n sound2: \n<audio_start_ocean>{\"path\": \"audioset_full\/a2dgzb9GDSQ.flac\"}<audio_end_ocean>How is the speech speed related to the estimated speaker age?\n<trainable_start>The slow speech speed suggests a more deliberate and thoughtful approach often seen in mature individuals.<trainable_end>",
1081
+ "<img_start_ocean>{\"path\": \"sogou\/7351ae4f3fbe58ff0e4cc165cfabb3ed\"}<img_end_ocean>新和记潮汕牛肉火锅的牛肉丸好不好吃 用户评价口味怎么样 常州美食牛肉丸实拍图片 大众点评",
1082
+ "这两个图片有什么关系?图片1<img_start_ocean>{\"path\": \"sogou\/ac91d57ab68335913ed41aa283e76356\"}<img_end_ocean>图片2\n<img_start_ocean>{\"path\": \"sogou\/6ad5e632b74265d9ef689e45936ab1aa\"}<img_end_ocean>",
1083
+ "根据图片和语音给出描述\n图片<img_start_ocean>{\"path\": \"sogou\/32274c1ab28d11f8c490cf7ae15b36f1\"}<img_end_ocean>语音<audio_start_ocean>{\"path\": \"voxceleb2\/id06726_s2lysJWkjus_00169.m4a\"}<audio_end_ocean><trainable_start>这是一只猫<trainable_end>",
1084
+ "这些图片和音频不存在<img_start_ocean>{\"path\": \"soogou\/32274c1ab28d11f8c490cf7ae15b36f1\"}<img_end_ocean>语音<audio_start_ocean>{\"path\": \"voxceleb_1\/id06726_s2lysJWkjus_00169.m4a\"}<audio_end_ocean><trainable_start>这是一只猫<trainable_end>"
1085
+ ]
1086
+ ret = processor(examples[4:-1])
1087
+ print(ret)
1088
+ print(torch.sum(ret.input_ids == 151659))
1089
+ print(torch.sum(ret.input_ids == 151662))
1090
+ try:
1091
+ print(ret.bridge_length)
1092
+ print(ret.patch_nums)
1093
+ except:
1094
+ pass
1095
+ print(torch.sum(ret.attention_mask, dim=1))
1096
+
1097
+
1098
+ def test_grounding():
1099
+ from transformers import AutoConfig, AutoTokenizer
1100
+ config = AutoConfig.from_pretrained("./", trust_remote_code=True)
1101
+ tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096)
1102
+ processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds')
1103
+ examples = ["<img_start_ocean>{\"path\": \"grit\/663423bf2f0884c034bf75279bce9694\"}<img_end_ocean>\nWhere is \"A woman\" ? Answer: <trainable_start>The bounding box is <box_start_ocean>(0.58,0.8),(0.71,1.0)<box_end_ocean><trainable_end>",
1104
+ "hello, ocean 你好 百川智能。",
1105
+ "<img_start_ocean>{\"path\": \"grit\/0e6e3952c584cbac7235940a22514656\"}<img_end_ocean> Generate the caption with grounding: <trainable_start>Photo pour Portrait of <ref_start_ocean>young Asian muslim woman wearing hijab<ref_end_ocean><box_start_ocean>(0.09,0.01),(0.77,1.0)<box_end_ocean> shows regret gesture, hand on her forehead, forget something important, against red background - image libre de droit<trainable_end>",
1106
+ "Recognize the object in the outlined section <img_start_ocean>{\"path\": \"grit\/045823cf6f819670f27aee20af7ae0e6\"}<img_end_ocean> of the picture.<box_start_ocean>(0.07,0.2),(0.91,0.96)<box_end_ocean>\n<trainable_start>Inflatable water trampolines<trainable_end>"
1107
+ ]
1108
+ ret = processor(examples)
1109
+ print(ret)
1110
+ for i, input_ids in enumerate(ret.input_ids):
1111
+ print("="*60)
1112
+ print(ret.labels[i])
1113
+
1114
+ def test_pack():
1115
+ from transformers import AutoConfig, AutoTokenizer
1116
+ config = AutoConfig.from_pretrained("./", trust_remote_code=True)
1117
+ tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=2048)
1118
+ processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds')
1119
+ examples = open('/cpfs/29f69eb5e2e60f26/user/sunhaoze/pretrain-v6/sogou/part-00000').readlines()[:5]
1120
+ examples += open('/home/nfs_bc_alignment/sunhaoze/text/openaqa-as-stage2-v1/part-00000').readlines()[:5]
1121
+ random.shuffle(examples)
1122
+ batch_output = processor.pack_batch_pretrain(examples)
1123
+ for i, b in enumerate(batch_output):
1124
+ print('='*60)
1125
+ try:
1126
+ print(b.input_ids, len(b.input_ids))
1127
+ print(b.labels, len(b.labels))
1128
+ print(b.attention_mask, len(b.attention_mask))
1129
+ print(b.position_ids, len(b.position_ids))
1130
+ print(b.seqlens, len(b.seqlens))
1131
+ print(b.audios)
1132
+ print(b.bridge_length)
1133
+ except:
1134
+ continue
1135
+
1136
+ batch_for_model = processor.collect_batch_pretrain(batch_output)
1137
+ print(batch_for_model.input_ids.shape)
1138
+ print(batch_for_model.labels.shape)
1139
+ print(batch_for_model.audios.shape)
1140
+ print(batch_for_model["bridge_length"])
1141
+ print(batch_for_model.images.shape)
1142
+ print(batch_for_model["patch_nums"])
1143
+ print(batch_for_model["position_ids"])
1144
+ print(batch_for_model["seqlens"])
1145
+
1146
+ def test_cos_audio():
1147
+ cos_client = CosClient()
1148
+ audio_bytes = cos_client('panda/data/common_voice/cv-corpus-18.0-2024-06-14/zh-CN/clips/common_voice_zh-CN_19428637.mp3', 'audio-data-1317568651')
1149
+ wave, sr = torchaudio.load(audio_bytes, normalize=False)
1150
+ print(wave.shape, sr)
1151
+ # torchaudio.save('tmp.flac', wave, sr)
1152
+
1153
+ if __name__ == '__main__':
1154
+ fire.Fire()
special_tokens_map.json CHANGED
@@ -24,27 +24,27 @@
24
  "<calc_start>",
25
  "<calc_end>",
26
  "<inner_think>",
27
- "<audio_start_baichuan>",
28
- "<audio_end_baichuan>",
29
- "<audio_pad_baichuan>",
30
- "<img_start_baichuan>",
31
- "<img_end_baichuan>",
32
- "<img_pad_baichuan>",
33
- "<img_newline_baichuan>",
34
- "<box_start_baichuan>",
35
- "<box_end_baichuan>",
36
- "<box_delim_baichuan>",
37
- "<ref_start_baichuan>",
38
- "<ref_end_baichuan>",
39
- "<img_delim_baichuan>",
40
- "<polygon_start_baichuan>",
41
- "<polygon_end_baichuan>",
42
- "<baichuan_pad_token>",
43
  "<reserved_113>",
44
- "<audio_delim_baichuan>",
45
- "<video_start_baichuan>",
46
- "<video_end_baichuan>",
47
- "<video_palce_baichuan>"
48
  ],
49
  "eos_token": {
50
  "content": "<|endoftext|>",
 
24
  "<calc_start>",
25
  "<calc_end>",
26
  "<inner_think>",
27
+ "<audio_start_ocean>",
28
+ "<audio_end_ocean>",
29
+ "<audio_pad_ocean>",
30
+ "<img_start_ocean>",
31
+ "<img_end_ocean>",
32
+ "<img_pad_ocean>",
33
+ "<img_newline_ocean>",
34
+ "<box_start_ocean>",
35
+ "<box_end_ocean>",
36
+ "<box_delim_ocean>",
37
+ "<ref_start_ocean>",
38
+ "<ref_end_ocean>",
39
+ "<img_delim_ocean>",
40
+ "<polygon_start_ocean>",
41
+ "<polygon_end_ocean>",
42
+ "<ocean_pad_token>",
43
  "<reserved_113>",
44
+ "<audio_delim_ocean>",
45
+ "<video_start_ocean>",
46
+ "<video_end_ocean>",
47
+ "<video_palce_ocean>"
48
  ],
49
  "eos_token": {
50
  "content": "<|endoftext|>",
tokenizer.json CHANGED
@@ -302,7 +302,7 @@
302
  },
303
  {
304
  "id": 151676,
305
- "content": "<audio_start_baichuan>",
306
  "single_word": false,
307
  "lstrip": false,
308
  "rstrip": false,
@@ -311,7 +311,7 @@
311
  },
312
  {
313
  "id": 151677,
314
- "content": "<audio_end_baichuan>",
315
  "single_word": false,
316
  "lstrip": false,
317
  "rstrip": false,
@@ -320,7 +320,7 @@
320
  },
321
  {
322
  "id": 151678,
323
- "content": "<audio_pad_baichuan>",
324
  "single_word": false,
325
  "lstrip": false,
326
  "rstrip": false,
@@ -329,7 +329,7 @@
329
  },
330
  {
331
  "id": 151679,
332
- "content": "<img_start_baichuan>",
333
  "single_word": false,
334
  "lstrip": false,
335
  "rstrip": false,
@@ -338,7 +338,7 @@
338
  },
339
  {
340
  "id": 151680,
341
- "content": "<img_end_baichuan>",
342
  "single_word": false,
343
  "lstrip": false,
344
  "rstrip": false,
@@ -347,7 +347,7 @@
347
  },
348
  {
349
  "id": 151681,
350
- "content": "<img_pad_baichuan>",
351
  "single_word": false,
352
  "lstrip": false,
353
  "rstrip": false,
@@ -356,7 +356,7 @@
356
  },
357
  {
358
  "id": 151682,
359
- "content": "<img_newline_baichuan>",
360
  "single_word": false,
361
  "lstrip": false,
362
  "rstrip": false,
@@ -365,7 +365,7 @@
365
  },
366
  {
367
  "id": 151683,
368
- "content": "<box_start_baichuan>",
369
  "single_word": false,
370
  "lstrip": false,
371
  "rstrip": false,
@@ -374,7 +374,7 @@
374
  },
375
  {
376
  "id": 151684,
377
- "content": "<box_end_baichuan>",
378
  "single_word": false,
379
  "lstrip": false,
380
  "rstrip": false,
@@ -383,7 +383,7 @@
383
  },
384
  {
385
  "id": 151685,
386
- "content": "<box_delim_baichuan>",
387
  "single_word": false,
388
  "lstrip": false,
389
  "rstrip": false,
@@ -392,7 +392,7 @@
392
  },
393
  {
394
  "id": 151686,
395
- "content": "<ref_start_baichuan>",
396
  "single_word": false,
397
  "lstrip": false,
398
  "rstrip": false,
@@ -401,7 +401,7 @@
401
  },
402
  {
403
  "id": 151687,
404
- "content": "<ref_end_baichuan>",
405
  "single_word": false,
406
  "lstrip": false,
407
  "rstrip": false,
@@ -410,7 +410,7 @@
410
  },
411
  {
412
  "id": 151688,
413
- "content": "<img_delim_baichuan>",
414
  "single_word": false,
415
  "lstrip": false,
416
  "rstrip": false,
@@ -419,7 +419,7 @@
419
  },
420
  {
421
  "id": 151689,
422
- "content": "<polygon_start_baichuan>",
423
  "single_word": false,
424
  "lstrip": false,
425
  "rstrip": false,
@@ -428,7 +428,7 @@
428
  },
429
  {
430
  "id": 151690,
431
- "content": "<polygon_end_baichuan>",
432
  "single_word": false,
433
  "lstrip": false,
434
  "rstrip": false,
@@ -437,7 +437,7 @@
437
  },
438
  {
439
  "id": 151691,
440
- "content": "<baichuan_pad_token>",
441
  "single_word": false,
442
  "lstrip": false,
443
  "rstrip": false,
@@ -455,7 +455,7 @@
455
  },
456
  {
457
  "id": 151693,
458
- "content": "<audio_delim_baichuan>",
459
  "single_word": false,
460
  "lstrip": false,
461
  "rstrip": false,
@@ -464,7 +464,7 @@
464
  },
465
  {
466
  "id": 151694,
467
- "content": "<video_palce_baichuan>",
468
  "single_word": false,
469
  "lstrip": false,
470
  "rstrip": false,
@@ -473,7 +473,7 @@
473
  },
474
  {
475
  "id": 151695,
476
- "content": "<video_start_baichuan>",
477
  "single_word": false,
478
  "lstrip": false,
479
  "rstrip": false,
@@ -482,7 +482,7 @@
482
  },
483
  {
484
  "id": 151696,
485
- "content": "<video_end_baichuan>",
486
  "single_word": false,
487
  "lstrip": false,
488
  "rstrip": false,
 
302
  },
303
  {
304
  "id": 151676,
305
+ "content": "<audio_start_ocean>",
306
  "single_word": false,
307
  "lstrip": false,
308
  "rstrip": false,
 
311
  },
312
  {
313
  "id": 151677,
314
+ "content": "<audio_end_ocean>",
315
  "single_word": false,
316
  "lstrip": false,
317
  "rstrip": false,
 
320
  },
321
  {
322
  "id": 151678,
323
+ "content": "<audio_pad_ocean>",
324
  "single_word": false,
325
  "lstrip": false,
326
  "rstrip": false,
 
329
  },
330
  {
331
  "id": 151679,
332
+ "content": "<img_start_ocean>",
333
  "single_word": false,
334
  "lstrip": false,
335
  "rstrip": false,
 
338
  },
339
  {
340
  "id": 151680,
341
+ "content": "<img_end_ocean>",
342
  "single_word": false,
343
  "lstrip": false,
344
  "rstrip": false,
 
347
  },
348
  {
349
  "id": 151681,
350
+ "content": "<img_pad_ocean>",
351
  "single_word": false,
352
  "lstrip": false,
353
  "rstrip": false,
 
356
  },
357
  {
358
  "id": 151682,
359
+ "content": "<img_newline_ocean>",
360
  "single_word": false,
361
  "lstrip": false,
362
  "rstrip": false,
 
365
  },
366
  {
367
  "id": 151683,
368
+ "content": "<box_start_ocean>",
369
  "single_word": false,
370
  "lstrip": false,
371
  "rstrip": false,
 
374
  },
375
  {
376
  "id": 151684,
377
+ "content": "<box_end_ocean>",
378
  "single_word": false,
379
  "lstrip": false,
380
  "rstrip": false,
 
383
  },
384
  {
385
  "id": 151685,
386
+ "content": "<box_delim_ocean>",
387
  "single_word": false,
388
  "lstrip": false,
389
  "rstrip": false,
 
392
  },
393
  {
394
  "id": 151686,
395
+ "content": "<ref_start_ocean>",
396
  "single_word": false,
397
  "lstrip": false,
398
  "rstrip": false,
 
401
  },
402
  {
403
  "id": 151687,
404
+ "content": "<ref_end_ocean>",
405
  "single_word": false,
406
  "lstrip": false,
407
  "rstrip": false,
 
410
  },
411
  {
412
  "id": 151688,
413
+ "content": "<img_delim_ocean>",
414
  "single_word": false,
415
  "lstrip": false,
416
  "rstrip": false,
 
419
  },
420
  {
421
  "id": 151689,
422
+ "content": "<polygon_start_ocean>",
423
  "single_word": false,
424
  "lstrip": false,
425
  "rstrip": false,
 
428
  },
429
  {
430
  "id": 151690,
431
+ "content": "<polygon_end_ocean>",
432
  "single_word": false,
433
  "lstrip": false,
434
  "rstrip": false,
 
437
  },
438
  {
439
  "id": 151691,
440
+ "content": "<ocean_pad_token>",
441
  "single_word": false,
442
  "lstrip": false,
443
  "rstrip": false,
 
455
  },
456
  {
457
  "id": 151693,
458
+ "content": "<audio_delim_ocean>",
459
  "single_word": false,
460
  "lstrip": false,
461
  "rstrip": false,
 
464
  },
465
  {
466
  "id": 151694,
467
+ "content": "<video_palce_ocean>",
468
  "single_word": false,
469
  "lstrip": false,
470
  "rstrip": false,
 
473
  },
474
  {
475
  "id": 151695,
476
+ "content": "<video_start_ocean>",
477
  "single_word": false,
478
  "lstrip": false,
479
  "rstrip": false,
 
482
  },
483
  {
484
  "id": 151696,
485
+ "content": "<video_end_ocean>",
486
  "single_word": false,
487
  "lstrip": false,
488
  "rstrip": false,
tokenizer_config.json CHANGED
@@ -267,7 +267,7 @@
267
  "special": true
268
  },
269
  "151676": {
270
- "content": "<audio_start_baichuan>",
271
  "lstrip": false,
272
  "normalized": false,
273
  "rstrip": false,
@@ -275,7 +275,7 @@
275
  "special": true
276
  },
277
  "151677": {
278
- "content": "<audio_end_baichuan>",
279
  "lstrip": false,
280
  "normalized": false,
281
  "rstrip": false,
@@ -283,7 +283,7 @@
283
  "special": true
284
  },
285
  "151678": {
286
- "content": "<audio_pad_baichuan>",
287
  "lstrip": false,
288
  "normalized": false,
289
  "rstrip": false,
@@ -291,7 +291,7 @@
291
  "special": true
292
  },
293
  "151679": {
294
- "content": "<img_start_baichuan>",
295
  "lstrip": false,
296
  "normalized": false,
297
  "rstrip": false,
@@ -299,7 +299,7 @@
299
  "special": true
300
  },
301
  "151680": {
302
- "content": "<img_end_baichuan>",
303
  "lstrip": false,
304
  "normalized": false,
305
  "rstrip": false,
@@ -307,7 +307,7 @@
307
  "special": true
308
  },
309
  "151681": {
310
- "content": "<img_pad_baichuan>",
311
  "lstrip": false,
312
  "normalized": false,
313
  "rstrip": false,
@@ -315,7 +315,7 @@
315
  "special": true
316
  },
317
  "151682": {
318
- "content": "<img_newline_baichuan>",
319
  "lstrip": false,
320
  "normalized": false,
321
  "rstrip": false,
@@ -323,7 +323,7 @@
323
  "special": true
324
  },
325
  "151683": {
326
- "content": "<box_start_baichuan>",
327
  "lstrip": false,
328
  "normalized": false,
329
  "rstrip": false,
@@ -331,7 +331,7 @@
331
  "special": true
332
  },
333
  "151684": {
334
- "content": "<box_end_baichuan>",
335
  "lstrip": false,
336
  "normalized": false,
337
  "rstrip": false,
@@ -339,7 +339,7 @@
339
  "special": true
340
  },
341
  "151685": {
342
- "content": "<box_delim_baichuan>",
343
  "lstrip": false,
344
  "normalized": false,
345
  "rstrip": false,
@@ -347,7 +347,7 @@
347
  "special": true
348
  },
349
  "151686": {
350
- "content": "<ref_start_baichuan>",
351
  "lstrip": false,
352
  "normalized": false,
353
  "rstrip": false,
@@ -355,7 +355,7 @@
355
  "special": true
356
  },
357
  "151687": {
358
- "content": "<ref_end_baichuan>",
359
  "lstrip": false,
360
  "normalized": false,
361
  "rstrip": false,
@@ -363,7 +363,7 @@
363
  "special": true
364
  },
365
  "151688": {
366
- "content": "<img_delim_baichuan>",
367
  "lstrip": false,
368
  "normalized": false,
369
  "rstrip": false,
@@ -371,7 +371,7 @@
371
  "special": true
372
  },
373
  "151689": {
374
- "content": "<polygon_start_baichuan>",
375
  "lstrip": false,
376
  "normalized": false,
377
  "rstrip": false,
@@ -379,7 +379,7 @@
379
  "special": true
380
  },
381
  "151690": {
382
- "content": "<polygon_end_baichuan>",
383
  "lstrip": false,
384
  "normalized": false,
385
  "rstrip": false,
@@ -387,7 +387,7 @@
387
  "special": true
388
  },
389
  "151691": {
390
- "content": "<baichuan_pad_token>",
391
  "lstrip": false,
392
  "normalized": false,
393
  "rstrip": false,
@@ -403,7 +403,7 @@
403
  "special": true
404
  },
405
  "151693": {
406
- "content": "<audio_delim_baichuan>",
407
  "lstrip": false,
408
  "normalized": false,
409
  "rstrip": false,
@@ -411,7 +411,7 @@
411
  "special": true
412
  },
413
  "151694": {
414
- "content": "<video_palce_baichuan>",
415
  "lstrip": false,
416
  "normalized": false,
417
  "rstrip": false,
@@ -419,7 +419,7 @@
419
  "special": true
420
  },
421
  "151695": {
422
- "content": "<video_start_baichuan>",
423
  "lstrip": false,
424
  "normalized": false,
425
  "rstrip": false,
@@ -427,7 +427,7 @@
427
  "special": true
428
  },
429
  "151696": {
430
- "content": "<video_end_baichuan>",
431
  "lstrip": false,
432
  "normalized": false,
433
  "rstrip": false,
@@ -460,27 +460,27 @@
460
  "<calc_start>",
461
  "<calc_end>",
462
  "<inner_think>",
463
- "<audio_start_baichuan>",
464
- "<audio_end_baichuan>",
465
- "<audio_pad_baichuan>",
466
- "<img_start_baichuan>",
467
- "<img_end_baichuan>",
468
- "<img_pad_baichuan>",
469
- "<img_newline_baichuan>",
470
- "<box_start_baichuan>",
471
- "<box_end_baichuan>",
472
- "<box_delim_baichuan>",
473
- "<ref_start_baichuan>",
474
- "<ref_end_baichuan>",
475
- "<img_delim_baichuan>",
476
- "<polygon_start_baichuan>",
477
- "<polygon_end_baichuan>",
478
- "<baichuan_pad_token>",
479
  "<reserved_113>",
480
- "<audio_delim_baichuan>",
481
- "<video_start_baichuan>",
482
- "<video_end_baichuan>",
483
- "<video_palce_baichuan>"
484
  ],
485
  "bos_token": null,
486
  "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
 
267
  "special": true
268
  },
269
  "151676": {
270
+ "content": "<audio_start_ocean>",
271
  "lstrip": false,
272
  "normalized": false,
273
  "rstrip": false,
 
275
  "special": true
276
  },
277
  "151677": {
278
+ "content": "<audio_end_ocean>",
279
  "lstrip": false,
280
  "normalized": false,
281
  "rstrip": false,
 
283
  "special": true
284
  },
285
  "151678": {
286
+ "content": "<audio_pad_ocean>",
287
  "lstrip": false,
288
  "normalized": false,
289
  "rstrip": false,
 
291
  "special": true
292
  },
293
  "151679": {
294
+ "content": "<img_start_ocean>",
295
  "lstrip": false,
296
  "normalized": false,
297
  "rstrip": false,
 
299
  "special": true
300
  },
301
  "151680": {
302
+ "content": "<img_end_ocean>",
303
  "lstrip": false,
304
  "normalized": false,
305
  "rstrip": false,
 
307
  "special": true
308
  },
309
  "151681": {
310
+ "content": "<img_pad_ocean>",
311
  "lstrip": false,
312
  "normalized": false,
313
  "rstrip": false,
 
315
  "special": true
316
  },
317
  "151682": {
318
+ "content": "<img_newline_ocean>",
319
  "lstrip": false,
320
  "normalized": false,
321
  "rstrip": false,
 
323
  "special": true
324
  },
325
  "151683": {
326
+ "content": "<box_start_ocean>",
327
  "lstrip": false,
328
  "normalized": false,
329
  "rstrip": false,
 
331
  "special": true
332
  },
333
  "151684": {
334
+ "content": "<box_end_ocean>",
335
  "lstrip": false,
336
  "normalized": false,
337
  "rstrip": false,
 
339
  "special": true
340
  },
341
  "151685": {
342
+ "content": "<box_delim_ocean>",
343
  "lstrip": false,
344
  "normalized": false,
345
  "rstrip": false,
 
347
  "special": true
348
  },
349
  "151686": {
350
+ "content": "<ref_start_ocean>",
351
  "lstrip": false,
352
  "normalized": false,
353
  "rstrip": false,
 
355
  "special": true
356
  },
357
  "151687": {
358
+ "content": "<ref_end_ocean>",
359
  "lstrip": false,
360
  "normalized": false,
361
  "rstrip": false,
 
363
  "special": true
364
  },
365
  "151688": {
366
+ "content": "<img_delim_ocean>",
367
  "lstrip": false,
368
  "normalized": false,
369
  "rstrip": false,
 
371
  "special": true
372
  },
373
  "151689": {
374
+ "content": "<polygon_start_ocean>",
375
  "lstrip": false,
376
  "normalized": false,
377
  "rstrip": false,
 
379
  "special": true
380
  },
381
  "151690": {
382
+ "content": "<polygon_end_ocean>",
383
  "lstrip": false,
384
  "normalized": false,
385
  "rstrip": false,
 
387
  "special": true
388
  },
389
  "151691": {
390
+ "content": "<ocean_pad_token>",
391
  "lstrip": false,
392
  "normalized": false,
393
  "rstrip": false,
 
403
  "special": true
404
  },
405
  "151693": {
406
+ "content": "<audio_delim_ocean>",
407
  "lstrip": false,
408
  "normalized": false,
409
  "rstrip": false,
 
411
  "special": true
412
  },
413
  "151694": {
414
+ "content": "<video_palce_ocean>",
415
  "lstrip": false,
416
  "normalized": false,
417
  "rstrip": false,
 
419
  "special": true
420
  },
421
  "151695": {
422
+ "content": "<video_start_ocean>",
423
  "lstrip": false,
424
  "normalized": false,
425
  "rstrip": false,
 
427
  "special": true
428
  },
429
  "151696": {
430
+ "content": "<video_end_ocean>",
431
  "lstrip": false,
432
  "normalized": false,
433
  "rstrip": false,
 
460
  "<calc_start>",
461
  "<calc_end>",
462
  "<inner_think>",
463
+ "<audio_start_ocean>",
464
+ "<audio_end_ocean>",
465
+ "<audio_pad_ocean>",
466
+ "<img_start_ocean>",
467
+ "<img_end_ocean>",
468
+ "<img_pad_ocean>",
469
+ "<img_newline_ocean>",
470
+ "<box_start_ocean>",
471
+ "<box_end_ocean>",
472
+ "<box_delim_ocean>",
473
+ "<ref_start_ocean>",
474
+ "<ref_end_ocean>",
475
+ "<img_delim_ocean>",
476
+ "<polygon_start_ocean>",
477
+ "<polygon_end_ocean>",
478
+ "<ocean_pad_token>",
479
  "<reserved_113>",
480
+ "<audio_delim_ocean>",
481
+ "<video_start_ocean>",
482
+ "<video_end_ocean>",
483
+ "<video_palce_ocean>"
484
  ],
485
  "bos_token": null,
486
  "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
visual_modeling_ocean.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Optional, Tuple, Union
3
+ import torch, math
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ import transformers
7
+ from flash_attn import flash_attn_varlen_func
8
+ from transformers.activations import ACT2FN
9
+ from PIL import Image
10
+ import io, fire
11
+ from torch.nn import functional as F
12
+
13
+
14
+ class OceanVisualAttention(nn.Module):
15
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
16
+
17
+ def __init__(self, config):
18
+ super().__init__()
19
+ self.config = config
20
+ self.embed_dim = config.hidden_size
21
+ self.num_heads = config.num_attention_heads
22
+ self.head_dim = self.embed_dim // self.num_heads
23
+ if self.head_dim * self.num_heads != self.embed_dim:
24
+ raise ValueError(
25
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
26
+ f" {self.num_heads})."
27
+ )
28
+ self.scale = self.head_dim**-0.5 # flash attention不需要使用
29
+ self.dropout = config.attention_dropout
30
+
31
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
32
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
33
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
34
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
35
+ # print("*******OceanVisualAttention monkey patched!!*******")
36
+
37
+ # initializer transformer from_pretrain下不生效
38
+ factor = self.config.initializer_factor
39
+ in_proj_std = (self.embed_dim**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) * factor
40
+ out_proj_std = (self.embed_dim**-0.5) * factor
41
+ nn.init.normal_(self.q_proj.weight, std=in_proj_std)
42
+ nn.init.normal_(self.k_proj.weight, std=in_proj_std)
43
+ nn.init.normal_(self.v_proj.weight, std=in_proj_std)
44
+ nn.init.normal_(self.out_proj.weight, std=out_proj_std)
45
+
46
+ def forward(
47
+ self,
48
+ hidden_states: torch.Tensor,
49
+ attention_mask: Optional[torch.Tensor] = None,
50
+ causal_attention_mask: Optional[torch.Tensor] = None,
51
+ output_attentions: Optional[bool] = False,
52
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
53
+ """Input shape: Batch x Time x Channel"""
54
+
55
+ bsz, tgt_len, embed_dim = hidden_states.size()
56
+ src_len = tgt_len
57
+
58
+ query_states = self.q_proj(hidden_states).view(bsz * tgt_len, self.num_heads, self.head_dim)
59
+ key_states = self.k_proj(hidden_states).view(bsz * tgt_len, self.num_heads, self.head_dim)
60
+ value_states = self.v_proj(hidden_states).view(bsz * tgt_len, self.num_heads, self.head_dim)
61
+
62
+ # 暂时不考虑变长patch nums 固定长度为256/1024
63
+ cu_len = torch.arange(0, (bsz + 1) * tgt_len, step=tgt_len, dtype=torch.int32, device=query_states.device)
64
+ # print(self.config.s2a, self.config.rope_scaling, cu_len, torch.sum(cu_len), q_len, kv_seq_len)
65
+ # 如果不是f16 bf16不用flash attn
66
+ if query_states.dtype in [torch.float16, torch.bfloat16]:
67
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, tgt_len, tgt_len, causal=False) # (bsz * qlen, nheads, headdim)
68
+ attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim)
69
+ else:
70
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
71
+ attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask, 0.0)
72
+ attn_output = attn_output.transpose(1, 2)
73
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
74
+ attn_output = self.out_proj(attn_output)
75
+
76
+ return attn_output, None
77
+
78
+ # monkey patch for flash attention
79
+ # transformers.models.siglip.modeling_siglip.SiglipAttention = OceanVisualAttention
80
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
81
+ from qwen_vl_utils import process_vision_info
82
+
83
+ class OceanVisualEncoder(transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel):
84
+ def __init__(self, config):
85
+ super().__init__(config)
86
+ self.gradient_checkpointing = True # 强制开启
87
+ self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
88
+ del self.merger
89
+
90
+ def forward(
91
+ self,
92
+ pixel_values: torch.Tensor,
93
+ grid_thw: torch.Tensor,
94
+ ):
95
+ hidden_states = pixel_values.to(self.get_dtype())
96
+ grid_thw = grid_thw.to(pixel_values.device)
97
+
98
+ hidden_states = self.patch_embed(hidden_states)
99
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
100
+
101
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
102
+ dim=0, dtype=torch.int32
103
+ )
104
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
105
+
106
+ for blk in self.blocks:
107
+ if self.gradient_checkpointing and self.training:
108
+ hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb)
109
+ else:
110
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
111
+
112
+ return hidden_states
113
+
114
+ @torch.no_grad()
115
+ def fake_input(self, device, merge_size=2):
116
+ merge_size = max(merge_size, self.config.spatial_merge_size)
117
+ fake_image = [torch.zeros([
118
+ 1,
119
+ self.config.temporal_patch_size,
120
+ 3,
121
+ merge_size // self.config.spatial_merge_size,
122
+ self.config.spatial_merge_size,
123
+ self.config.patch_size,
124
+ merge_size // self.config.spatial_merge_size,
125
+ self.config.spatial_merge_size,
126
+ self.config.patch_size,
127
+ ], dtype=torch.float32, device=device)]
128
+ return fake_image
129
+
130
+
131
+ class OceanVisualBridge(nn.Module):
132
+ def __init__(self, config):
133
+ super().__init__()
134
+ self.config = config
135
+ self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
136
+ self.hidden_size = config.embed_dim * (self.merge_size**2)
137
+ self.ln_q = nn.LayerNorm(config.embed_dim, eps=1e-6)
138
+ self.mlp = nn.Sequential(
139
+ nn.Linear(self.hidden_size, self.hidden_size),
140
+ nn.GELU(),
141
+ nn.Linear(self.hidden_size, config.hidden_size),
142
+ )
143
+
144
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
145
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
146
+ return x
147
+
148
+
149
+ def test_vision():
150
+ from transformers.models.clip.modeling_clip import CLIPPreTrainedModel
151
+ from transformers import AutoConfig
152
+ config = AutoConfig.from_pretrained("./", trust_remote_code=True)
153
+
154
+ ae = OceanVisualEncoder(config.visual_config).cuda().to(torch.bfloat16)
155
+ bg = OceanVisualBridge(config).cuda().to(torch.bfloat16)
156
+ print(ae)
157
+ pixel_input = torch.rand([4, 3, config.visual_config.image_size, config.visual_config.image_size], dtype=torch.float32).cuda()
158
+
159
+ visual_embedding = ae(pixel_input)[0][:, 1:] # 删除class token
160
+ visual_proj = bg(visual_embedding)
161
+ print(visual_proj.shape)
162
+ print(ae.fake_input(visual_proj.device))
163
+
164
+ if __name__ == '__main__':
165
+ fire.Fire()
166
+