Upload 56 files
Browse files- added_tokens.json +20 -20
- audio_modeling_ocean.py +194 -0
- config.json +5 -5
- configuration_ocean.py +111 -0
- modeling_ocean.py +1001 -0
- processor_ocean.py +1154 -0
- special_tokens_map.json +20 -20
- tokenizer.json +20 -20
- tokenizer_config.json +40 -40
- visual_modeling_ocean.py +166 -0
added_tokens.json
CHANGED
|
@@ -7,32 +7,32 @@
|
|
| 7 |
"<B_USYS>": 151666,
|
| 8 |
"<C_A>": 151668,
|
| 9 |
"<C_Q>": 151667,
|
| 10 |
-
"<
|
| 11 |
-
"<
|
| 12 |
-
"<
|
| 13 |
-
"<
|
| 14 |
-
"<
|
| 15 |
-
"<
|
| 16 |
-
"<
|
| 17 |
-
"<
|
| 18 |
"<calc_end>": 151674,
|
| 19 |
"<calc_start>": 151673,
|
| 20 |
"<function_calling>": 151672,
|
| 21 |
-
"<
|
| 22 |
-
"<
|
| 23 |
-
"<
|
| 24 |
-
"<
|
| 25 |
-
"<
|
| 26 |
"<inner_think>": 151675,
|
| 27 |
-
"<
|
| 28 |
-
"<
|
| 29 |
-
"<
|
| 30 |
-
"<
|
| 31 |
"<reserved_113>": 151692,
|
| 32 |
"<tool_call>": 151657,
|
| 33 |
-
"<
|
| 34 |
-
"<
|
| 35 |
-
"<
|
| 36 |
"<|box_end|>": 151649,
|
| 37 |
"<|box_start|>": 151648,
|
| 38 |
"<|endoftext|>": 151643,
|
|
|
|
| 7 |
"<B_USYS>": 151666,
|
| 8 |
"<C_A>": 151668,
|
| 9 |
"<C_Q>": 151667,
|
| 10 |
+
"<audio_delim_ocean>": 151693,
|
| 11 |
+
"<audio_end_ocean>": 151677,
|
| 12 |
+
"<audio_pad_ocean>": 151678,
|
| 13 |
+
"<audio_start_ocean>": 151676,
|
| 14 |
+
"<ocean_pad_token>": 151691,
|
| 15 |
+
"<box_delim_ocean>": 151685,
|
| 16 |
+
"<box_end_ocean>": 151684,
|
| 17 |
+
"<box_start_ocean>": 151683,
|
| 18 |
"<calc_end>": 151674,
|
| 19 |
"<calc_start>": 151673,
|
| 20 |
"<function_calling>": 151672,
|
| 21 |
+
"<img_delim_ocean>": 151688,
|
| 22 |
+
"<img_end_ocean>": 151680,
|
| 23 |
+
"<img_newline_ocean>": 151682,
|
| 24 |
+
"<img_pad_ocean>": 151681,
|
| 25 |
+
"<img_start_ocean>": 151679,
|
| 26 |
"<inner_think>": 151675,
|
| 27 |
+
"<polygon_end_ocean>": 151690,
|
| 28 |
+
"<polygon_start_ocean>": 151689,
|
| 29 |
+
"<ref_end_ocean>": 151687,
|
| 30 |
+
"<ref_start_ocean>": 151686,
|
| 31 |
"<reserved_113>": 151692,
|
| 32 |
"<tool_call>": 151657,
|
| 33 |
+
"<video_end_ocean>": 151696,
|
| 34 |
+
"<video_palce_ocean>": 151694,
|
| 35 |
+
"<video_start_ocean>": 151695,
|
| 36 |
"<|box_end|>": 151649,
|
| 37 |
"<|box_start|>": 151648,
|
| 38 |
"<|endoftext|>": 151643,
|
audio_modeling_ocean.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, random, fire
|
| 2 |
+
from transformers.models.whisper import WhisperConfig
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from flash_attn import flash_attn_varlen_func
|
| 5 |
+
from torch import nn
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers.activations import ACT2FN
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
def sinusoids(length, channels, max_timescale=10000):
|
| 11 |
+
"""Returns sinusoids for positional embedding"""
|
| 12 |
+
assert channels % 2 == 0
|
| 13 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
| 14 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
| 15 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
| 16 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
| 17 |
+
|
| 18 |
+
class OceanWhisperAttention(nn.Module):
|
| 19 |
+
def __init__(self, embed_dim, num_heads):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.embed_dim = embed_dim
|
| 22 |
+
self.num_heads = num_heads
|
| 23 |
+
self.head_dim = embed_dim // num_heads
|
| 24 |
+
|
| 25 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 26 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 27 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 28 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 29 |
+
|
| 30 |
+
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
|
| 31 |
+
bsz, _ = hidden_states.size()
|
| 32 |
+
|
| 33 |
+
query_states = self.q_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
| 34 |
+
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
| 35 |
+
value_states = self.v_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
| 36 |
+
|
| 37 |
+
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
|
| 38 |
+
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
|
| 39 |
+
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen, max_seqlen, causal=False) # (bsz * qlen, nheads, headdim)
|
| 40 |
+
attn_output = attn_output.reshape(bsz, self.embed_dim)
|
| 41 |
+
attn_output = self.out_proj(attn_output)
|
| 42 |
+
return attn_output
|
| 43 |
+
|
| 44 |
+
class OceanWhisperEncoderLayer(nn.Module):
|
| 45 |
+
def __init__(self, config):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.embed_dim = config.d_model
|
| 48 |
+
self.self_attn = OceanWhisperAttention(self.embed_dim, config.encoder_attention_heads)
|
| 49 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 50 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
| 51 |
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
| 52 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
| 53 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 54 |
+
|
| 55 |
+
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
residual = hidden_states
|
| 57 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 58 |
+
hidden_states = self.self_attn(hidden_states, seq_len)
|
| 59 |
+
hidden_states = residual + hidden_states
|
| 60 |
+
residual = hidden_states
|
| 61 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 62 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 63 |
+
hidden_states = self.fc2(hidden_states)
|
| 64 |
+
hidden_states = residual + hidden_states
|
| 65 |
+
|
| 66 |
+
if (hidden_states.dtype == torch.float16 or hidden_states.dtype == torch.bfloat16) and (
|
| 67 |
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
| 68 |
+
):
|
| 69 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 70 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
| 71 |
+
return hidden_states
|
| 72 |
+
|
| 73 |
+
class OceanAudioEncoder(nn.Module):
|
| 74 |
+
def __init__(self, config):
|
| 75 |
+
super().__init__()
|
| 76 |
+
config._attn_implementation = 'flash_attention_2' #
|
| 77 |
+
self.config = config
|
| 78 |
+
self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
|
| 79 |
+
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
| 80 |
+
|
| 81 |
+
# 需要在LLM的初始化中注册注册
|
| 82 |
+
self.conv1 = nn.Conv1d(config.num_mel_bins, config.d_model, kernel_size=config.kernel_size, padding=1)
|
| 83 |
+
self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size, stride=config.stride_size, padding=1)
|
| 84 |
+
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, config.d_model)) # 1500 * d
|
| 85 |
+
|
| 86 |
+
self.layers = nn.ModuleList([OceanWhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
|
| 87 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
| 88 |
+
|
| 89 |
+
self.gradient_checkpointing = True
|
| 90 |
+
|
| 91 |
+
@torch.no_grad()
|
| 92 |
+
def fake_input(self, device):
|
| 93 |
+
input_features = torch.rand([2, self.config.num_mel_bins, 10], dtype=torch.float32, device=device)
|
| 94 |
+
encoder_length = torch.ones([2], dtype=torch.int32, device=device) * 3
|
| 95 |
+
bridge_length = torch.ones([2], dtype=torch.int32, device=device)
|
| 96 |
+
return input_features, encoder_length, bridge_length
|
| 97 |
+
|
| 98 |
+
def forward(
|
| 99 |
+
self,
|
| 100 |
+
input_features,
|
| 101 |
+
output_length, # MAKESURE 输入的必须是两次conv计算后的hidden state长度
|
| 102 |
+
):
|
| 103 |
+
input_features = input_features.to(self.conv1.weight.dtype)
|
| 104 |
+
inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (bs, channels, frames)
|
| 105 |
+
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (bs, channels, frames // 2)
|
| 106 |
+
inputs_embeds = inputs_embeds.permute(0, 2, 1) # (bs, frams, channels)
|
| 107 |
+
bsz, tgt_len, _ = inputs_embeds.size() # 当前batch最大长度
|
| 108 |
+
if tgt_len < self.positional_embedding.shape[0]:
|
| 109 |
+
current_positional_embedding = self.positional_embedding[:tgt_len]
|
| 110 |
+
else:
|
| 111 |
+
current_positional_embedding = self.positional_embedding
|
| 112 |
+
hidden_states = (inputs_embeds.to(torch.float32) + current_positional_embedding).to(inputs_embeds.dtype)
|
| 113 |
+
|
| 114 |
+
# packing hidden states
|
| 115 |
+
attention_mask = torch.arange(0, tgt_len).to(hidden_states.device)
|
| 116 |
+
attention_mask = torch.lt(attention_mask, output_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
|
| 117 |
+
unpacking_index = torch.cumsum(attention_mask.to(torch.int32).view(-1), dim=0) - 1 # 转成下标
|
| 118 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length), self.config.d_model)
|
| 119 |
+
|
| 120 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 121 |
+
if self.gradient_checkpointing and self.training:
|
| 122 |
+
def create_custom_forward(module):
|
| 123 |
+
def custom_forward(*inputs):
|
| 124 |
+
return module(*inputs)
|
| 125 |
+
return custom_forward
|
| 126 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 127 |
+
create_custom_forward(encoder_layer),
|
| 128 |
+
hidden_states,
|
| 129 |
+
output_length
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
| 133 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 134 |
+
# unpacking
|
| 135 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
|
| 136 |
+
hidden_states = torch.where(attention_mask, hidden_states, 0)
|
| 137 |
+
return hidden_states
|
| 138 |
+
|
| 139 |
+
class OceanAudioBridge(nn.Module):
|
| 140 |
+
def __init__(self, config):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.config = config.audio_config
|
| 143 |
+
if self.config.avg_pooler > 1:
|
| 144 |
+
self.avg_pooler = nn.AvgPool1d(self.config.avg_pooler, stride=2)
|
| 145 |
+
else:
|
| 146 |
+
self.avg_pooler = None
|
| 147 |
+
self.proj1 = nn.Linear(self.config.d_model, config.hidden_size)
|
| 148 |
+
self.proj2 = nn.Linear(config.hidden_size, config.hidden_size)
|
| 149 |
+
|
| 150 |
+
def forward(self, x, output_length):
|
| 151 |
+
if self.avg_pooler is not None:
|
| 152 |
+
x = x.permute(0, 2, 1)
|
| 153 |
+
x = self.avg_pooler(x)
|
| 154 |
+
x = x.permute(0, 2, 1)
|
| 155 |
+
|
| 156 |
+
batch_size, sl, _ = x.shape
|
| 157 |
+
output_length = output_length.to(x.device)
|
| 158 |
+
valid_mask = torch.arange(0, sl).to(x.device)
|
| 159 |
+
valid_mask = torch.lt(valid_mask, output_length.reshape(batch_size, 1)).reshape(batch_size, sl, 1)
|
| 160 |
+
x = torch.masked_select(x, valid_mask).reshape(-1, self.config.d_model) # (sum(valid_sequence_length), d)
|
| 161 |
+
x = ACT2FN[self.config.activation_function](self.proj1(x))
|
| 162 |
+
x = self.proj2(x)
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def test_audio():
|
| 167 |
+
from transformers import AutoConfig
|
| 168 |
+
from processor_ocean import OceanAudioProcessor
|
| 169 |
+
# from ..configuration_ocean import OceanConfig
|
| 170 |
+
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
|
| 171 |
+
|
| 172 |
+
config.audio_config.d_model = 24
|
| 173 |
+
config.audio_config.encoder_layers = 2
|
| 174 |
+
config.audio_config.encoder_attention_heads = 4
|
| 175 |
+
config.audio_config.encoder_ffn_dim = 48
|
| 176 |
+
|
| 177 |
+
ae = OceanAudioEncoder(config.audio_config).cuda().to(torch.bfloat16)
|
| 178 |
+
bg = OceanAudioBridge(config).cuda().to(torch.bfloat16)
|
| 179 |
+
l = random.randint(10, 30)
|
| 180 |
+
bs = 3
|
| 181 |
+
input_length = torch.tensor([random.randint(1, l) for _ in range(bs)])
|
| 182 |
+
encoder_length, bridge_length = OceanAudioProcessor.inference_output_length(config.audio_config, input_length)
|
| 183 |
+
print("l={}, input_valid_length={},\nencoder_valid_length={}, bridge_valid_length={}".format(l, input_length, encoder_length, bridge_length))
|
| 184 |
+
wave_features = torch.rand((bs, config.audio_config.num_mel_bins, l))
|
| 185 |
+
a = ae(wave_features.to('cuda'), encoder_length.to('cuda'))
|
| 186 |
+
b = bg(a, bridge_length.to('cuda'))
|
| 187 |
+
print('encoder output={}, bridge output={}'.format(a.shape, b.shape))
|
| 188 |
+
print(a)
|
| 189 |
+
print(b)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
if __name__ == '__main__':
|
| 193 |
+
fire.Fire()
|
| 194 |
+
|
config.json
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
{
|
| 2 |
"_name_or_path": "/cpfs/29f69eb5e2e60f26/code/mllm/zhangtao02/workspace/3_bc_mllm/models/mm_pretrain/cs_ocr3b/cs_ocr3b_ift_1204_ckpt_epoch_1_12030256",
|
| 3 |
"architectures": [
|
| 4 |
-
"
|
| 5 |
],
|
| 6 |
"attention_qkv_bias": true,
|
| 7 |
"attention_qkv_pack": false,
|
|
@@ -117,10 +117,10 @@
|
|
| 117 |
"vocab_size": 51865
|
| 118 |
},
|
| 119 |
"auto_map": {
|
| 120 |
-
"AutoConfig": "
|
| 121 |
-
"AutoModelForCausalLM": "
|
| 122 |
},
|
| 123 |
-
"
|
| 124 |
"bos_token_id": 1,
|
| 125 |
"eos_token_id": 2,
|
| 126 |
"head_dim": 128,
|
|
@@ -130,7 +130,7 @@
|
|
| 130 |
"intermediate_size": 11008,
|
| 131 |
"max_position_embeddings": 8192,
|
| 132 |
"max_window_layers": 36,
|
| 133 |
-
"model_type": "
|
| 134 |
"moe": false,
|
| 135 |
"multimodal": [
|
| 136 |
"image"
|
|
|
|
| 1 |
{
|
| 2 |
"_name_or_path": "/cpfs/29f69eb5e2e60f26/code/mllm/zhangtao02/workspace/3_bc_mllm/models/mm_pretrain/cs_ocr3b/cs_ocr3b_ift_1204_ckpt_epoch_1_12030256",
|
| 3 |
"architectures": [
|
| 4 |
+
"OceanForCausalLM"
|
| 5 |
],
|
| 6 |
"attention_qkv_bias": true,
|
| 7 |
"attention_qkv_pack": false,
|
|
|
|
| 117 |
"vocab_size": 51865
|
| 118 |
},
|
| 119 |
"auto_map": {
|
| 120 |
+
"AutoConfig": "configuration_ocean.OceanConfig",
|
| 121 |
+
"AutoModelForCausalLM": "modeling_ocean.OceanForCausalLM"
|
| 122 |
},
|
| 123 |
+
"ocean_tokenizer_type": "auto",
|
| 124 |
"bos_token_id": 1,
|
| 125 |
"eos_token_id": 2,
|
| 126 |
"head_dim": 128,
|
|
|
|
| 130 |
"intermediate_size": 11008,
|
| 131 |
"max_position_embeddings": 8192,
|
| 132 |
"max_window_layers": 36,
|
| 133 |
+
"model_type": "ocean",
|
| 134 |
"moe": false,
|
| 135 |
"multimodal": [
|
| 136 |
"image"
|
configuration_ocean.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Ocean Inc. All Rights Reserved.
|
| 2 |
+
|
| 3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 6 |
+
# and OPT implementations in this library. It has been modified from its
|
| 7 |
+
# original forms to accommodate minor architectural differences compared
|
| 8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 23 |
+
from transformers.utils import logging
|
| 24 |
+
from transformers import WhisperConfig
|
| 25 |
+
from transformers import CLIPVisionConfig
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class OceanConfig(PretrainedConfig):
|
| 31 |
+
model_type = "ocean"
|
| 32 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
vocab_size=125696,
|
| 37 |
+
hidden_size=4096,
|
| 38 |
+
intermediate_size=11008,
|
| 39 |
+
num_hidden_layers=32,
|
| 40 |
+
num_attention_heads=32,
|
| 41 |
+
num_key_value_heads=None,
|
| 42 |
+
sparse_attention_heads=None,
|
| 43 |
+
sparse_attention_layers=[],
|
| 44 |
+
head_dim=None,
|
| 45 |
+
attention_qkv_pack=True,
|
| 46 |
+
attention_qkv_bias=False,
|
| 47 |
+
use_norm_head=True,
|
| 48 |
+
hidden_act="silu",
|
| 49 |
+
max_position_embeddings=4096,
|
| 50 |
+
position_embedding_type="rope",
|
| 51 |
+
initializer_range=0.02,
|
| 52 |
+
rms_norm_eps=1e-6,
|
| 53 |
+
use_cache=True,
|
| 54 |
+
pad_token_id=0,
|
| 55 |
+
bos_token_id=1,
|
| 56 |
+
eos_token_id=2,
|
| 57 |
+
tie_word_embeddings=False,
|
| 58 |
+
audio_config=None,
|
| 59 |
+
visual_config=None,
|
| 60 |
+
video_config=None,
|
| 61 |
+
**kwargs,
|
| 62 |
+
):
|
| 63 |
+
self.vocab_size = vocab_size
|
| 64 |
+
self.max_position_embeddings = max_position_embeddings
|
| 65 |
+
self.hidden_size = hidden_size
|
| 66 |
+
self.intermediate_size = intermediate_size
|
| 67 |
+
self.num_hidden_layers = num_hidden_layers
|
| 68 |
+
self.num_attention_heads = num_attention_heads
|
| 69 |
+
self.num_key_value_heads = num_key_value_heads or self.num_attention_heads
|
| 70 |
+
self.sparse_attention_heads = sparse_attention_heads
|
| 71 |
+
self.sparse_attention_layers = sparse_attention_layers
|
| 72 |
+
self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
|
| 73 |
+
self.attention_qkv_pack = attention_qkv_pack
|
| 74 |
+
self.attention_qkv_bias = attention_qkv_bias
|
| 75 |
+
self.use_norm_head = use_norm_head
|
| 76 |
+
self.hidden_act = hidden_act
|
| 77 |
+
self.position_embedding_type = position_embedding_type
|
| 78 |
+
self.initializer_range = initializer_range
|
| 79 |
+
self.rms_norm_eps = rms_norm_eps
|
| 80 |
+
self.use_cache = use_cache
|
| 81 |
+
assert self.position_embedding_type.lower() in ("rope", "alibi")
|
| 82 |
+
super().__init__(
|
| 83 |
+
pad_token_id=pad_token_id,
|
| 84 |
+
bos_token_id=bos_token_id,
|
| 85 |
+
eos_token_id=eos_token_id,
|
| 86 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 87 |
+
**kwargs,
|
| 88 |
+
)
|
| 89 |
+
if audio_config is not None:
|
| 90 |
+
self.audio_config = WhisperConfig(**audio_config)
|
| 91 |
+
if visual_config is not None:
|
| 92 |
+
self.visual_config = CLIPVisionConfig(**visual_config)
|
| 93 |
+
if video_config is not None:
|
| 94 |
+
self.video_config = CLIPVisionConfig(**video_config)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def to_diff_dict(self):
|
| 98 |
+
data = super().to_diff_dict()
|
| 99 |
+
data["model_type"] = self.model_type
|
| 100 |
+
return data
|
| 101 |
+
|
| 102 |
+
def get_rotary_base(self):
|
| 103 |
+
if hasattr(self, "rotary_emb_base"):
|
| 104 |
+
return self.rotary_emb_base
|
| 105 |
+
else:
|
| 106 |
+
return self.rope_theta
|
| 107 |
+
|
| 108 |
+
if __name__ == '__main__':
|
| 109 |
+
from transformers import AutoConfig
|
| 110 |
+
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
|
| 111 |
+
print(config)
|
modeling_ocean.py
ADDED
|
@@ -0,0 +1,1001 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Ocean Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 6 |
+
# and OPT implementations in this library. It has been modified from its
|
| 7 |
+
# original forms to accommodate minor architectural differences compared
|
| 8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
""" PyTorch Ocean model."""
|
| 22 |
+
import os
|
| 23 |
+
import json
|
| 24 |
+
import math
|
| 25 |
+
from typing import List, Optional, Tuple, Union
|
| 26 |
+
from threading import Thread
|
| 27 |
+
from easydict import EasyDict
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.utils.checkpoint
|
| 32 |
+
from torch import nn
|
| 33 |
+
from torch.nn import CrossEntropyLoss
|
| 34 |
+
from torch.nn import functional as F
|
| 35 |
+
|
| 36 |
+
from transformers import PreTrainedModel
|
| 37 |
+
from transformers.activations import ACT2FN
|
| 38 |
+
from dataclasses import dataclass
|
| 39 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 40 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
| 41 |
+
from transformers.generation.utils import GenerationConfig
|
| 42 |
+
from transformers.utils import logging
|
| 43 |
+
|
| 44 |
+
from .configuration_ocean import OceanConfig
|
| 45 |
+
from .audio_modeling_ocean import OceanAudioEncoder, OceanAudioBridge
|
| 46 |
+
from .visual_modeling_ocean import OceanVisualEncoder, OceanVisualBridge
|
| 47 |
+
from .processor_ocean import OceanMMProcessor
|
| 48 |
+
from .moe import moe_matmul
|
| 49 |
+
|
| 50 |
+
# support model path contain point(.)
|
| 51 |
+
try:
|
| 52 |
+
# step1: copy relative imports to transformers_modules
|
| 53 |
+
from .generation_utils import build_chat_input, TextIterStreamer
|
| 54 |
+
from .sequence_parallel_utils import (
|
| 55 |
+
create_attention_layer,
|
| 56 |
+
get_sequence_parallel_size,
|
| 57 |
+
get_sequence_parallel_chunk,
|
| 58 |
+
)
|
| 59 |
+
except ModuleNotFoundError:
|
| 60 |
+
# step2: direct import from transformers_modules
|
| 61 |
+
try: # bypass check_imports failure
|
| 62 |
+
import sys
|
| 63 |
+
sys.path.append(os.path.dirname(__file__))
|
| 64 |
+
from generation_utils import build_chat_input, TextIterStreamer
|
| 65 |
+
from sequence_parallel_utils import (
|
| 66 |
+
create_attention_layer,
|
| 67 |
+
get_sequence_parallel_size,
|
| 68 |
+
get_sequence_parallel_chunk,
|
| 69 |
+
)
|
| 70 |
+
except Exception:
|
| 71 |
+
raise
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
logger = logging.get_logger(__name__)
|
| 75 |
+
|
| 76 |
+
def get_slopes(n):
|
| 77 |
+
def get_slopes_power_of_2(n):
|
| 78 |
+
start = (2 ** (-2 ** -(math.log2(n) - 3)))
|
| 79 |
+
ratio = start
|
| 80 |
+
return [start * ratio ** i for i in range(n)]
|
| 81 |
+
|
| 82 |
+
if math.log2(n).is_integer():
|
| 83 |
+
return get_slopes_power_of_2(
|
| 84 |
+
n) # In the paper, we only train models that have 2^a heads for some a. This function has
|
| 85 |
+
else: # some good properties that only occur when the input is a power of 2. To maintain that even
|
| 86 |
+
closest_power_of_2 = 2 ** math.floor(
|
| 87 |
+
math.log2(n)) # when the number of heads is not a power of 2, we use this workaround.
|
| 88 |
+
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
|
| 89 |
+
:n - closest_power_of_2]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class RMSNorm(nn.Module):
|
| 93 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 94 |
+
"""
|
| 95 |
+
RMSNorm is equivalent to T5LayerNorm
|
| 96 |
+
"""
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 99 |
+
self.variance_epsilon = eps
|
| 100 |
+
|
| 101 |
+
def forward(self, hidden_states):
|
| 102 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 103 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 104 |
+
|
| 105 |
+
# convert into half-precision if necessary
|
| 106 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 107 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
| 108 |
+
|
| 109 |
+
return self.weight * hidden_states
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class RotaryEmbedding(torch.nn.Module):
|
| 113 |
+
def __init__(self, dim, max_position_embeddings=2048, base=5e6, device=None):
|
| 114 |
+
super().__init__()
|
| 115 |
+
# 修复RePE初始化精度问题 https://zhuanlan.zhihu.com/p/678963442
|
| 116 |
+
# DeepSpeed 会 Hack torch.arange 强制在 GPU 上运行,这里使用原生的 torch.arange
|
| 117 |
+
try:
|
| 118 |
+
import deepspeed
|
| 119 |
+
self.arange = deepspeed.runtime.zero.partition_parameters._orig_torch_arange
|
| 120 |
+
except:
|
| 121 |
+
self.arange = torch.arange
|
| 122 |
+
|
| 123 |
+
self.inv_freq = 1.0 / (base ** (self.arange(0, dim, 2).float().to(device) / dim))
|
| 124 |
+
self.max_seq_len_cached = max_position_embeddings
|
| 125 |
+
t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
| 126 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 127 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 128 |
+
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
|
| 129 |
+
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
|
| 130 |
+
|
| 131 |
+
def forward(self, x, seq_len=None):
|
| 132 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 133 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
| 134 |
+
if seq_len > self.max_seq_len_cached:
|
| 135 |
+
self.max_seq_len_cached = seq_len
|
| 136 |
+
t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
| 137 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 138 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 139 |
+
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
|
| 140 |
+
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
|
| 141 |
+
return (
|
| 142 |
+
self.cos_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
|
| 143 |
+
self.sin_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def rotate_half(x):
|
| 148 |
+
"""Rotates half the hidden dims of the input."""
|
| 149 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 150 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 151 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
|
| 155 |
+
cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 156 |
+
sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 157 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 158 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 159 |
+
q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
|
| 160 |
+
k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
|
| 161 |
+
return q_embed.to(q.dtype), k_embed.to(k.dtype)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class MLP(nn.Module):
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
hidden_size: int,
|
| 168 |
+
intermediate_size: int,
|
| 169 |
+
hidden_act: str,
|
| 170 |
+
):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 173 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 174 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 175 |
+
self.act_fn = ACT2FN[hidden_act]
|
| 176 |
+
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 182 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 183 |
+
"""
|
| 184 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 185 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 186 |
+
"""
|
| 187 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 188 |
+
if n_rep == 1:
|
| 189 |
+
return hidden_states
|
| 190 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 191 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class Attention(nn.Module):
|
| 195 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 196 |
+
def __init__(self, config: OceanConfig, is_sparse=False):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.config = config
|
| 199 |
+
self.position_embedding_type = config.position_embedding_type.lower()
|
| 200 |
+
self.num_kv_heads = config.num_key_value_heads
|
| 201 |
+
self.head_dim = config.head_dim
|
| 202 |
+
self.hidden_size = config.num_attention_heads * self.head_dim
|
| 203 |
+
self.hidden_kv_size = self.num_kv_heads * self.head_dim
|
| 204 |
+
|
| 205 |
+
if is_sparse:
|
| 206 |
+
self.num_heads = config.sparse_attention_heads
|
| 207 |
+
assert self.num_kv_heads == config.num_attention_heads
|
| 208 |
+
self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=config.attention_qkv_bias)
|
| 209 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 210 |
+
else:
|
| 211 |
+
self.num_heads = config.num_attention_heads
|
| 212 |
+
if self.config.attention_qkv_pack:
|
| 213 |
+
self.W_pack = nn.Linear(config.hidden_size, self.hidden_size + self.hidden_kv_size * 2, bias=config.attention_qkv_bias)
|
| 214 |
+
if config.moe:
|
| 215 |
+
self.moe_W_pack = nn.Linear(config.hidden_size, self.hidden_size + self.hidden_kv_size * 2, bias=False)
|
| 216 |
+
else:
|
| 217 |
+
self.q_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=config.attention_qkv_bias)
|
| 218 |
+
self.k_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
|
| 219 |
+
self.v_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
|
| 220 |
+
|
| 221 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
| 222 |
+
if config.moe:
|
| 223 |
+
self.moe_o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
| 224 |
+
|
| 225 |
+
if self.position_embedding_type == 'rope':
|
| 226 |
+
self.rotary_emb = RotaryEmbedding(
|
| 227 |
+
dim=self.head_dim,
|
| 228 |
+
max_position_embeddings=config.max_position_embeddings,
|
| 229 |
+
base=config.get_rotary_base()
|
| 230 |
+
)
|
| 231 |
+
elif self.position_embedding_type == 'alibi':
|
| 232 |
+
self.alibi_slopes = get_slopes(self.num_heads)
|
| 233 |
+
self.attention = create_attention_layer(self.hidden_size, self.num_heads, self.head_dim)
|
| 234 |
+
|
| 235 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 236 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 237 |
+
|
| 238 |
+
def _repeat_kv(self, hidden_states: torch.Tensor, num_heads: int) -> torch.Tensor:
|
| 239 |
+
assert hidden_states.size(1) <= num_heads and num_heads % hidden_states.size(1) == 0
|
| 240 |
+
return repeat_kv(hidden_states, num_heads // hidden_states.size(1))
|
| 241 |
+
|
| 242 |
+
def forward(
|
| 243 |
+
self,
|
| 244 |
+
hidden_states: torch.Tensor,
|
| 245 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 246 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 247 |
+
seqlens: Optional[torch.IntTensor] = None,
|
| 248 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 249 |
+
output_attentions: bool = False,
|
| 250 |
+
use_cache: bool = False,
|
| 251 |
+
group_index=None,
|
| 252 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 253 |
+
bsz, q_len = hidden_states.shape[:2]
|
| 254 |
+
|
| 255 |
+
if self.config.attention_qkv_pack:
|
| 256 |
+
if self.config.moe and group_index is not None:
|
| 257 |
+
proj = moe_matmul(hidden_states, [self.W_pack.weight, self.moe_W_pack.weight], group_index, lambda x, y: torch.einsum('bd,ld->bl', x, y))
|
| 258 |
+
if self.config.attention_qkv_bias:
|
| 259 |
+
proj += self.W_pack.bias
|
| 260 |
+
else:
|
| 261 |
+
proj = self.W_pack(hidden_states)
|
| 262 |
+
query_states, key_states, value_states = proj.split([self.hidden_size, self.hidden_kv_size, self.hidden_kv_size], dim=-1)
|
| 263 |
+
else:
|
| 264 |
+
query_states = self.q_proj(hidden_states)
|
| 265 |
+
key_states = self.k_proj(hidden_states)
|
| 266 |
+
value_states = self.v_proj(hidden_states)
|
| 267 |
+
|
| 268 |
+
# (B, S, hidden_size) -> (B, num_heads, S, head_size)
|
| 269 |
+
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
| 270 |
+
# (B, S, hidden_size) -> (B, num_kv_heads, S, head_size)
|
| 271 |
+
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
| 272 |
+
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
| 273 |
+
|
| 274 |
+
kv_seq_len = key_states.shape[-2]
|
| 275 |
+
if past_key_value is not None:
|
| 276 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
| 277 |
+
if self.position_embedding_type == 'rope':
|
| 278 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len * get_sequence_parallel_size())
|
| 279 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 280 |
+
query_states, key_states, cos, sin,
|
| 281 |
+
get_sequence_parallel_chunk(position_ids)
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
if past_key_value is not None:
|
| 285 |
+
# reuse k, v, self_attention
|
| 286 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 287 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 288 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
| 289 |
+
|
| 290 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 291 |
+
key_states = self._repeat_kv(key_states, query_states.size(1))
|
| 292 |
+
value_states = self._repeat_kv(value_states, query_states.size(1))
|
| 293 |
+
|
| 294 |
+
if seqlens is not None:
|
| 295 |
+
seqlens = seqlens.to(dtype=torch.int32)
|
| 296 |
+
max_seqlen = (seqlens[1:] - seqlens[:-1]).max().item()
|
| 297 |
+
if self.position_embedding_type == 'alibi':
|
| 298 |
+
alibi_slopes = torch.tensor(self.alibi_slopes, dtype=torch.float32).to(query_states.device)
|
| 299 |
+
else:
|
| 300 |
+
alibi_slopes = None
|
| 301 |
+
attn_output = self.attention(
|
| 302 |
+
query_states, key_states, value_states, seqlens, seqlens,
|
| 303 |
+
max_seqlen, max_seqlen, causal=True, alibi_slopes=alibi_slopes, use_flash=True)
|
| 304 |
+
else:
|
| 305 |
+
attn_output = self.attention(
|
| 306 |
+
query_states, key_states, value_states, attn_mask=attention_mask, use_flash=False)
|
| 307 |
+
|
| 308 |
+
attn_output = attn_output.reshape(bsz, q_len, -1)
|
| 309 |
+
if not self.config.moe or group_index is None:
|
| 310 |
+
attn_output = self.o_proj(attn_output)
|
| 311 |
+
else:
|
| 312 |
+
attn_output = moe_matmul(attn_output, [self.o_proj.weight, self.moe_o_proj.weight], group_index, lambda x, y: torch.einsum('bd,ld->bl', x, y))
|
| 313 |
+
|
| 314 |
+
return attn_output, None, past_key_value
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class DecoderLayer(nn.Module):
|
| 318 |
+
def __init__(self, config: OceanConfig, is_sparse=False):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.hidden_size = config.hidden_size
|
| 321 |
+
self.self_attn = Attention(config=config, is_sparse=is_sparse)
|
| 322 |
+
self.mlp = MLP(
|
| 323 |
+
hidden_size=self.hidden_size,
|
| 324 |
+
intermediate_size=config.intermediate_size,
|
| 325 |
+
hidden_act=config.hidden_act,
|
| 326 |
+
)
|
| 327 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 328 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 329 |
+
|
| 330 |
+
def forward(
|
| 331 |
+
self,
|
| 332 |
+
hidden_states: torch.Tensor,
|
| 333 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 334 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 335 |
+
seqlens: Optional[torch.IntTensor] = None,
|
| 336 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 337 |
+
output_attentions: Optional[bool] = False,
|
| 338 |
+
use_cache: Optional[bool] = False,
|
| 339 |
+
group_index=None,
|
| 340 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 341 |
+
|
| 342 |
+
residual = hidden_states
|
| 343 |
+
|
| 344 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 345 |
+
|
| 346 |
+
# Self Attention
|
| 347 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 348 |
+
hidden_states=hidden_states,
|
| 349 |
+
attention_mask=attention_mask,
|
| 350 |
+
position_ids=position_ids,
|
| 351 |
+
seqlens=seqlens,
|
| 352 |
+
past_key_value=past_key_value,
|
| 353 |
+
output_attentions=output_attentions,
|
| 354 |
+
use_cache=use_cache,
|
| 355 |
+
group_index=group_index,
|
| 356 |
+
)
|
| 357 |
+
hidden_states = residual + hidden_states
|
| 358 |
+
|
| 359 |
+
# Fully Connected
|
| 360 |
+
residual = hidden_states
|
| 361 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 362 |
+
hidden_states = self.mlp(hidden_states)
|
| 363 |
+
hidden_states = residual + hidden_states
|
| 364 |
+
|
| 365 |
+
outputs = (hidden_states,)
|
| 366 |
+
|
| 367 |
+
if output_attentions:
|
| 368 |
+
outputs += (self_attn_weights,)
|
| 369 |
+
|
| 370 |
+
if use_cache:
|
| 371 |
+
outputs += (present_key_value,)
|
| 372 |
+
|
| 373 |
+
return outputs
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class OceanPreTrainedModel(PreTrainedModel):
|
| 377 |
+
config_class = OceanConfig
|
| 378 |
+
base_model_prefix = "model"
|
| 379 |
+
supports_gradient_checkpointing = True
|
| 380 |
+
_no_split_modules = ["DecoderLayer"]
|
| 381 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
| 382 |
+
|
| 383 |
+
def _init_weights(self, module):
|
| 384 |
+
std = self.config.initializer_range
|
| 385 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d):
|
| 386 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 387 |
+
if module.bias is not None:
|
| 388 |
+
module.bias.data.zero_()
|
| 389 |
+
elif isinstance(module, nn.Embedding):
|
| 390 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 391 |
+
if module.padding_idx is not None:
|
| 392 |
+
module.weight.data[module.padding_idx].zero_()
|
| 393 |
+
elif isinstance(module, nn.LayerNorm):
|
| 394 |
+
module.weight.data.fill_(1.0)
|
| 395 |
+
module.bias.data.zero_()
|
| 396 |
+
|
| 397 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 398 |
+
if isinstance(module, OceanModel):
|
| 399 |
+
module.gradient_checkpointing = value
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class OceanModel(OceanPreTrainedModel):
|
| 403 |
+
def __init__(self, config: OceanConfig):
|
| 404 |
+
super().__init__(config)
|
| 405 |
+
self.padding_idx = config.pad_token_id
|
| 406 |
+
self.vocab_size = config.vocab_size
|
| 407 |
+
self.merge_size = 1
|
| 408 |
+
if config.audio_config.enable:
|
| 409 |
+
self.audio_model = OceanAudioEncoder(config.audio_config)
|
| 410 |
+
self.audio_bridge_model = OceanAudioBridge(config)
|
| 411 |
+
if config.visual_config.enable:
|
| 412 |
+
self.visual_model = OceanVisualEncoder(config.visual_config)
|
| 413 |
+
self.visual_bridge_model = OceanVisualBridge(config.visual_config)
|
| 414 |
+
self.merge_size = max(config.visual_config.merge_size, self.merge_size)
|
| 415 |
+
if config.video_config.enable: # in case 没有visual_config而只有video_config
|
| 416 |
+
if not config.visual_config.enable:
|
| 417 |
+
self.visual_model = OceanVisualEncoder(config.video_config)
|
| 418 |
+
self.video_bridge_model = OceanVisualBridge(config.video_config)
|
| 419 |
+
self.merge_size = max(config.video_config.merge_size, self.merge_size)
|
| 420 |
+
|
| 421 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 422 |
+
self.layers = nn.ModuleList([
|
| 423 |
+
DecoderLayer(config, is_sparse=layer_idx in config.sparse_attention_layers)
|
| 424 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 425 |
+
])
|
| 426 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 427 |
+
|
| 428 |
+
self.gradient_checkpointing = True
|
| 429 |
+
# Initialize weights and apply final processing
|
| 430 |
+
self.post_init()
|
| 431 |
+
|
| 432 |
+
def get_input_embeddings(self):
|
| 433 |
+
return self.embed_tokens
|
| 434 |
+
|
| 435 |
+
def set_input_embeddings(self, value):
|
| 436 |
+
self.embed_tokens = value
|
| 437 |
+
|
| 438 |
+
def get_multimodal_mask(self, input_ids, pad_token_id, special_token_list):
|
| 439 |
+
'''
|
| 440 |
+
获取任意模态的特殊mask,包含以下
|
| 441 |
+
1. pad mask 表示文本中图像/语音/视频模态提前留出的token位置
|
| 442 |
+
2. special token mask 特殊token 例如对理解模型<start> <end> 不需要next token prediction
|
| 443 |
+
3. embedding mask / lm_head mask 标记出特殊token在embedding中的mask
|
| 444 |
+
'''
|
| 445 |
+
|
| 446 |
+
pad_mask = torch.eq(input_ids, pad_token_id)
|
| 447 |
+
sp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
| 448 |
+
lm_head_mask = torch.zeros([self.config.vocab_size, 1], dtype=torch.bool)
|
| 449 |
+
for sp_id in special_token_list:
|
| 450 |
+
sp_mask = torch.logical_or(sp_mask, torch.eq(input_ids, sp_id))
|
| 451 |
+
lm_head_mask[sp_id, 0] = True
|
| 452 |
+
return pad_mask, sp_mask, lm_head_mask
|
| 453 |
+
|
| 454 |
+
def get_audio_embed(
|
| 455 |
+
self,
|
| 456 |
+
input_ids,
|
| 457 |
+
text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
|
| 458 |
+
features, # list of tensors
|
| 459 |
+
encoder_length,
|
| 460 |
+
bridge_length,
|
| 461 |
+
group_index=None, # 某种模态的编号 for MoE
|
| 462 |
+
):
|
| 463 |
+
pad_mask, sp_mask, _ = self.get_multimodal_mask(input_ids, self.config.audio_config.audio_pad_token_id, self.config.multimodal_special_token_list)
|
| 464 |
+
if features is None or len(features) <= 0 : # 空list or None 保证梯度回传
|
| 465 |
+
features, encoder_length, bridge_length = self.audio_model.fake_input(input_ids.device)
|
| 466 |
+
fake_input = True
|
| 467 |
+
else:
|
| 468 |
+
fake_input = False
|
| 469 |
+
audio_embed = self.audio_model(features, encoder_length)
|
| 470 |
+
audio_embed = self.audio_bridge_model(audio_embed, bridge_length) # (?, d)
|
| 471 |
+
if not self.training: # 推理支持auto map 把多模态模块输出和input_ids 统一到一个device
|
| 472 |
+
audio_embed = audio_embed.to(input_ids.device)
|
| 473 |
+
if not fake_input: # 检查多模态token 和 pad mask数量一致 (不正确的截断会导致该问题)
|
| 474 |
+
assert pad_mask.sum() == audio_embed.shape[0]
|
| 475 |
+
else:
|
| 476 |
+
assert pad_mask.sum() <= 0 # 0 vs 1
|
| 477 |
+
|
| 478 |
+
# 合并 当前模态embeddings 和text embeddings
|
| 479 |
+
input_ids = torch.where(pad_mask, torch.cumsum(pad_mask.view(-1).to(input_ids), dim=0).view(input_ids.shape)-1, input_ids)
|
| 480 |
+
if self.config.train_multimodal_special_tokens_only and self.training:
|
| 481 |
+
# 仅special token传梯度到embedding weight, 保证LLM部分不变
|
| 482 |
+
# 注意: 多种模态之间special token list应该共享,否则会有部分被stop gradient
|
| 483 |
+
sp_mask = sp_mask.unsqueeze(-1).to(text_embedding)
|
| 484 |
+
text_embedding = (1 - sp_mask) * text_embedding.detach() + sp_mask * text_embedding
|
| 485 |
+
text_embedding = (1 - pad_mask.to(text_embedding)).unsqueeze(-1) * text_embedding # pad token位置填0 (不传梯度)
|
| 486 |
+
multimodal_embedding = torch.embedding(audio_embed, input_ids * pad_mask) # 非 pad token 位置填idx=0位置结果
|
| 487 |
+
multimodal_embedding = pad_mask.to(multimodal_embedding).unsqueeze(-1) * multimodal_embedding # 非pad token 位置填0
|
| 488 |
+
|
| 489 |
+
final_embedding = multimodal_embedding.to(text_embedding) + text_embedding
|
| 490 |
+
|
| 491 |
+
if group_index is None:
|
| 492 |
+
group_index = pad_mask.to(torch.int32)
|
| 493 |
+
else:
|
| 494 |
+
current_index = torch.max(group_index) + 1
|
| 495 |
+
group_index += pad_mask.to(torch.int32) * current_index # 假设模态无重叠
|
| 496 |
+
|
| 497 |
+
return final_embedding, group_index # group_index 不传None 防止MoE部分参数无梯度
|
| 498 |
+
|
| 499 |
+
def get_visual_embed(
|
| 500 |
+
self,
|
| 501 |
+
input_ids,
|
| 502 |
+
text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
|
| 503 |
+
images,
|
| 504 |
+
group_index, # 某种模态的编号 for MoE
|
| 505 |
+
images_grid
|
| 506 |
+
):
|
| 507 |
+
# TODO 与get_audio_embed合并重复功能 减少冗余代码
|
| 508 |
+
pad_mask, sp_mask, _ = self.get_multimodal_mask(input_ids, self.config.visual_config.image_pad_token_id, self.config.multimodal_special_token_list)
|
| 509 |
+
if images is None or len(images) <= 0 : # 空list or None 保证梯度回传
|
| 510 |
+
images = self.visual_model.fake_input(input_ids.device, self.merge_size)
|
| 511 |
+
images_grid = [(1, self.merge_size, self.merge_size)]
|
| 512 |
+
fake_input = True
|
| 513 |
+
else:
|
| 514 |
+
fake_input = False
|
| 515 |
+
|
| 516 |
+
images = torch.cat(images, dim=0)
|
| 517 |
+
images_grid = torch.tensor(np.array(images_grid))
|
| 518 |
+
visual_embed = self.visual_model(images, grid_thw=images_grid)
|
| 519 |
+
visual_embed = self.visual_bridge_model(visual_embed)
|
| 520 |
+
|
| 521 |
+
if not self.training: # 推理支持auto map 把多模态模块输出和input_ids 统一到一个device
|
| 522 |
+
visual_embed = visual_embed.to(input_ids.device)
|
| 523 |
+
if not fake_input: # 检查多模态token 和 pad mask数量一致 (不正确的截断会导致该问题)
|
| 524 |
+
assert pad_mask.sum() == visual_embed.shape[0], '{} != {}'.format(pad_mask.sum(), visual_embed.shape[0])
|
| 525 |
+
else:
|
| 526 |
+
assert pad_mask.sum() <= 0, '{} != {}'.format(pad_mask.sum(), visual_embed.shape[0])
|
| 527 |
+
|
| 528 |
+
# 合并 当前模态embeddings 和text embeddings
|
| 529 |
+
input_ids = torch.where(pad_mask, torch.cumsum(pad_mask.view(-1).to(input_ids), dim=0).view(input_ids.shape)-1, input_ids)
|
| 530 |
+
if self.config.train_multimodal_special_tokens_only and self.training:
|
| 531 |
+
# 仅special token传梯度到embedding weight, 保证LLM部分不变
|
| 532 |
+
# 注意: 多种模态之间special token list应���共享,否则会有部分被stop gradient
|
| 533 |
+
sp_mask = sp_mask.unsqueeze(-1).to(text_embedding)
|
| 534 |
+
text_embedding = (1 - sp_mask) * text_embedding.detach() + sp_mask * text_embedding
|
| 535 |
+
text_embedding = (1 - pad_mask.to(text_embedding)).unsqueeze(-1) * text_embedding # pad token位置填0 (不传梯度)
|
| 536 |
+
multimodal_embedding = torch.embedding(visual_embed, input_ids * pad_mask) # 非 pad token 位置填idx=0位置结果
|
| 537 |
+
multimodal_embedding = pad_mask.to(multimodal_embedding).unsqueeze(-1) * multimodal_embedding # 非pad token 位置填0
|
| 538 |
+
|
| 539 |
+
final_embedding = multimodal_embedding.to(text_embedding) + text_embedding
|
| 540 |
+
|
| 541 |
+
if group_index is None:
|
| 542 |
+
group_index = pad_mask.to(torch.int32)
|
| 543 |
+
else:
|
| 544 |
+
current_index = torch.max(group_index) + 1
|
| 545 |
+
group_index += pad_mask.to(torch.int32) * current_index # 假设模态无重叠
|
| 546 |
+
|
| 547 |
+
return final_embedding, group_index # group_index 不传None 防止MoE部分参数无梯度
|
| 548 |
+
|
| 549 |
+
def get_video_embed(
|
| 550 |
+
self,
|
| 551 |
+
input_ids,
|
| 552 |
+
text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
|
| 553 |
+
images,
|
| 554 |
+
group_index, # 某种模态的编号 for MoE
|
| 555 |
+
images_grid
|
| 556 |
+
|
| 557 |
+
):
|
| 558 |
+
# TODO 与get_audio_embed合并重复功能 减少冗余代码
|
| 559 |
+
pad_mask, sp_mask, _ = self.get_multimodal_mask(input_ids, self.config.video_config.video_place_token_id, self.config.multimodal_special_token_list)
|
| 560 |
+
if images is None or len(images) <= 0 : # 空list or None 保证梯度回传
|
| 561 |
+
images = self.visual_model.fake_input(input_ids.device, self.merge_size)
|
| 562 |
+
images_grid = [(1, self.merge_size, self.merge_size)]
|
| 563 |
+
fake_input = True
|
| 564 |
+
else:
|
| 565 |
+
fake_input = False
|
| 566 |
+
|
| 567 |
+
images = torch.cat(images, dim=0)
|
| 568 |
+
images_grid = torch.tensor(np.array(images_grid))
|
| 569 |
+
visual_embed = self.visual_model(images, grid_thw=images_grid)
|
| 570 |
+
visual_embed = self.video_bridge_model(visual_embed)
|
| 571 |
+
|
| 572 |
+
if not self.training: # 推理支持auto map 把多模态模块输出和input_ids 统一到一个device
|
| 573 |
+
visual_embed = visual_embed.to(input_ids.device)
|
| 574 |
+
if not fake_input: # 检查多模态token 和 pad mask数量一致 (不正确的截断会导致该问题)
|
| 575 |
+
assert pad_mask.sum() == visual_embed.shape[0], '{} != {}'.format(pad_mask.sum(), visual_embed.shape[0])
|
| 576 |
+
assert pad_mask.sum() == visual_embed.shape[0], '{} != {}'.format(pad_mask.sum(), visual_embed.shape[0])
|
| 577 |
+
else:
|
| 578 |
+
assert pad_mask.sum() <= 0, '{} != {}'.format(pad_mask.sum(), visual_embed.shape[0])
|
| 579 |
+
|
| 580 |
+
# 合并 当前模态embeddings 和text embeddings
|
| 581 |
+
input_ids = torch.where(pad_mask, torch.cumsum(pad_mask.view(-1).to(input_ids), dim=0).view(input_ids.shape)-1, input_ids)
|
| 582 |
+
if self.config.train_multimodal_special_tokens_only and self.training:
|
| 583 |
+
# 仅special token传梯度到embedding weight, 保证LLM部分不变
|
| 584 |
+
# 注意: 多种模态之间special token list应该共享,否则会有部分被stop gradient
|
| 585 |
+
sp_mask = sp_mask.unsqueeze(-1).to(text_embedding)
|
| 586 |
+
text_embedding = (1 - sp_mask) * text_embedding.detach() + sp_mask * text_embedding
|
| 587 |
+
text_embedding = (1 - pad_mask.to(text_embedding)).unsqueeze(-1) * text_embedding # pad token位置填0 (不传梯度)
|
| 588 |
+
multimodal_embedding = torch.embedding(visual_embed, input_ids * pad_mask) # 非 pad token 位置填idx=0位置结果
|
| 589 |
+
multimodal_embedding = pad_mask.to(multimodal_embedding).unsqueeze(-1) * multimodal_embedding # 非pad token 位置填0
|
| 590 |
+
|
| 591 |
+
final_embedding = multimodal_embedding.to(text_embedding) + text_embedding
|
| 592 |
+
|
| 593 |
+
if group_index is None:
|
| 594 |
+
group_index = pad_mask.to(torch.int32)
|
| 595 |
+
else:
|
| 596 |
+
current_index = torch.max(group_index) + 1
|
| 597 |
+
group_index += pad_mask.to(torch.int32) * current_index # 假设模态无重叠
|
| 598 |
+
|
| 599 |
+
return final_embedding, group_index # group_index 不传None 防止MoE部分参数无梯度
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def forward(
|
| 604 |
+
self,
|
| 605 |
+
input_ids: torch.LongTensor = None,
|
| 606 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 607 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 608 |
+
seqlens: Optional[torch.IntTensor] = None,
|
| 609 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 610 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 611 |
+
audios: Optional[List|torch.Tensor] = None,
|
| 612 |
+
encoder_length: Optional[torch.Tensor] = None,
|
| 613 |
+
bridge_length: Optional[torch.Tensor] = None,
|
| 614 |
+
images: Optional[List|torch.Tensor] = None,
|
| 615 |
+
images_grid: Optional[List|torch.Tensor] = None,
|
| 616 |
+
videos: Optional[List|torch.Tensor] = None,
|
| 617 |
+
videos_grid: Optional[List|torch.Tensor] = None,
|
| 618 |
+
use_cache: Optional[bool] = None,
|
| 619 |
+
output_attentions: Optional[bool] = None,
|
| 620 |
+
output_hidden_states: Optional[bool] = None,
|
| 621 |
+
return_dict: Optional[bool] = None,
|
| 622 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 623 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 624 |
+
output_hidden_states = (
|
| 625 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 626 |
+
)
|
| 627 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 628 |
+
|
| 629 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 630 |
+
|
| 631 |
+
# retrieve input_ids and inputs_embeds
|
| 632 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 633 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
| 634 |
+
elif input_ids is not None:
|
| 635 |
+
batch_size, seq_length = input_ids.shape
|
| 636 |
+
elif inputs_embeds is not None:
|
| 637 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 638 |
+
else:
|
| 639 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 640 |
+
|
| 641 |
+
seq_length_with_past = seq_length
|
| 642 |
+
past_key_values_length = 0
|
| 643 |
+
|
| 644 |
+
if past_key_values is not None:
|
| 645 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 646 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 647 |
+
|
| 648 |
+
if position_ids is None:
|
| 649 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 650 |
+
position_ids = torch.arange(
|
| 651 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
| 652 |
+
)
|
| 653 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 654 |
+
else:
|
| 655 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
| 656 |
+
|
| 657 |
+
group_index = None
|
| 658 |
+
if inputs_embeds is None:
|
| 659 |
+
sp_input_ids = get_sequence_parallel_chunk(input_ids)
|
| 660 |
+
inputs_embeds = self.embed_tokens(sp_input_ids)
|
| 661 |
+
if self.config.audio_config.enable:
|
| 662 |
+
inputs_embeds, group_index = self.get_audio_embed(sp_input_ids, inputs_embeds, audios, encoder_length, bridge_length)
|
| 663 |
+
if self.config.visual_config.enable:
|
| 664 |
+
inputs_embeds, group_index = self.get_visual_embed(sp_input_ids, inputs_embeds, images, group_index, images_grid) # 注意更新group index
|
| 665 |
+
if self.config.video_config.enable:
|
| 666 |
+
inputs_embeds, group_index = self.get_video_embed(sp_input_ids, inputs_embeds, videos, group_index, videos_grid) # 注意更新group index
|
| 667 |
+
|
| 668 |
+
if seqlens is not None and seqlens.ndim == 2:
|
| 669 |
+
# batch multi-pack 样本拉平
|
| 670 |
+
cu_seqlens = []
|
| 671 |
+
offset, seqlen = 0, seqlens.size(1)
|
| 672 |
+
for lens in seqlens:
|
| 673 |
+
cu_seqlens.append(offset)
|
| 674 |
+
cu_seqlens.extend((lens[(lens > 0) & (lens < seqlen)] + offset).tolist())
|
| 675 |
+
offset += seqlen
|
| 676 |
+
cu_seqlens.append(offset)
|
| 677 |
+
seqlens = torch.tensor(cu_seqlens, dtype=seqlens.dtype, device=seqlens.device)
|
| 678 |
+
elif seqlens is None and self.training:
|
| 679 |
+
# 兼容预训练场景, 此时 seqlens=None, 默认 maxlength
|
| 680 |
+
seqlens = torch.arange(
|
| 681 |
+
end=input_ids.size(0) + 1,
|
| 682 |
+
dtype=torch.int32,
|
| 683 |
+
device=input_ids.device
|
| 684 |
+
) * input_ids.size(1)
|
| 685 |
+
if seqlens is not None:
|
| 686 |
+
attention_mask = None # unset attention_mask to save memory
|
| 687 |
+
|
| 688 |
+
if seqlens is None and attention_mask is None:
|
| 689 |
+
attention_mask = torch.ones(
|
| 690 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 691 |
+
)
|
| 692 |
+
if attention_mask is not None:
|
| 693 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 694 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
# embed positions
|
| 698 |
+
hidden_states = inputs_embeds
|
| 699 |
+
|
| 700 |
+
if self.gradient_checkpointing and self.training:
|
| 701 |
+
if use_cache:
|
| 702 |
+
logger.warning_once(
|
| 703 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 704 |
+
)
|
| 705 |
+
use_cache = False
|
| 706 |
+
|
| 707 |
+
# decoder layers
|
| 708 |
+
all_hidden_states = () if output_hidden_states else None
|
| 709 |
+
all_self_attns = () if output_attentions else None
|
| 710 |
+
next_decoder_cache = () if use_cache else None
|
| 711 |
+
|
| 712 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 713 |
+
if output_hidden_states:
|
| 714 |
+
all_hidden_states += (hidden_states,)
|
| 715 |
+
|
| 716 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 717 |
+
|
| 718 |
+
if self.gradient_checkpointing and self.training:
|
| 719 |
+
|
| 720 |
+
def create_custom_forward(module):
|
| 721 |
+
def custom_forward(*inputs):
|
| 722 |
+
# None for past_key_value
|
| 723 |
+
return module(*inputs, output_attentions, False, group_index)
|
| 724 |
+
|
| 725 |
+
return custom_forward
|
| 726 |
+
|
| 727 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 728 |
+
create_custom_forward(decoder_layer),
|
| 729 |
+
hidden_states,
|
| 730 |
+
attention_mask,
|
| 731 |
+
position_ids,
|
| 732 |
+
seqlens,
|
| 733 |
+
None,
|
| 734 |
+
)
|
| 735 |
+
else:
|
| 736 |
+
layer_outputs = decoder_layer(
|
| 737 |
+
hidden_states,
|
| 738 |
+
attention_mask=attention_mask,
|
| 739 |
+
position_ids=position_ids,
|
| 740 |
+
seqlens=seqlens,
|
| 741 |
+
past_key_value=past_key_value,
|
| 742 |
+
output_attentions=output_attentions,
|
| 743 |
+
use_cache=use_cache,
|
| 744 |
+
group_index=group_index,
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
hidden_states = layer_outputs[0]
|
| 748 |
+
|
| 749 |
+
if use_cache:
|
| 750 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
| 751 |
+
|
| 752 |
+
if output_attentions:
|
| 753 |
+
all_self_attns += (layer_outputs[1],)
|
| 754 |
+
|
| 755 |
+
hidden_states = self.norm(hidden_states)
|
| 756 |
+
|
| 757 |
+
# add hidden states from the last decoder layer
|
| 758 |
+
if output_hidden_states:
|
| 759 |
+
all_hidden_states += (hidden_states,)
|
| 760 |
+
|
| 761 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 762 |
+
if not return_dict:
|
| 763 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 764 |
+
return BaseModelOutputWithPast(
|
| 765 |
+
last_hidden_state=hidden_states,
|
| 766 |
+
past_key_values=next_cache,
|
| 767 |
+
hidden_states=all_hidden_states,
|
| 768 |
+
attentions=all_self_attns,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
class NormHead(nn.Module):
|
| 773 |
+
def __init__(self, hidden_size, vocab_size, bias=False):
|
| 774 |
+
super().__init__()
|
| 775 |
+
self.hidden_size = hidden_size
|
| 776 |
+
self.vocab_size = vocab_size
|
| 777 |
+
self.weight = nn.Parameter(torch.empty((self.vocab_size, self.hidden_size)))
|
| 778 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 779 |
+
|
| 780 |
+
def forward(self, hidden_states, mask=None):
|
| 781 |
+
norm_weight = nn.functional.normalize(self.weight)
|
| 782 |
+
if mask is not None:
|
| 783 |
+
mask = mask.to(norm_weight)
|
| 784 |
+
norm_weight = norm_weight * mask + (1 - mask) * norm_weight.detach()
|
| 785 |
+
return nn.functional.linear(hidden_states, norm_weight)
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
def extra_repr(self) -> str:
|
| 789 |
+
return f'in_features={self.hidden_size}, out_features={self.vocab_size}'
|
| 790 |
+
|
| 791 |
+
@dataclass
|
| 792 |
+
class OceanMMCausalLMOutputWithPast(ModelOutput):
|
| 793 |
+
loss: Optional[torch.FloatTensor] = None
|
| 794 |
+
logits: torch.FloatTensor = None
|
| 795 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 796 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 797 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 798 |
+
text_nt_loss: Optional[torch.FloatTensor] = None
|
| 799 |
+
flatten_loss: Optional[torch.FloatTensor] = None
|
| 800 |
+
|
| 801 |
+
class OceanForCausalLM(OceanPreTrainedModel):
|
| 802 |
+
def __init__(self, config):
|
| 803 |
+
super().__init__(config)
|
| 804 |
+
self.config = config
|
| 805 |
+
self.model = OceanModel(config)
|
| 806 |
+
if config.use_norm_head:
|
| 807 |
+
self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
|
| 808 |
+
else:
|
| 809 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 810 |
+
# Initialize weights and apply final processing
|
| 811 |
+
self.post_init()
|
| 812 |
+
|
| 813 |
+
def bind_processor(self, tokenizer, **kwargs):
|
| 814 |
+
self.processor = OceanMMProcessor(
|
| 815 |
+
tokenizer=tokenizer,
|
| 816 |
+
config=self.config,
|
| 817 |
+
**kwargs,
|
| 818 |
+
)
|
| 819 |
+
return self.processor
|
| 820 |
+
|
| 821 |
+
def get_input_embeddings(self):
|
| 822 |
+
return self.model.embed_tokens
|
| 823 |
+
|
| 824 |
+
def set_input_embeddings(self, value):
|
| 825 |
+
self.model.embed_tokens = value
|
| 826 |
+
|
| 827 |
+
def get_output_embeddings(self):
|
| 828 |
+
return self.lm_head
|
| 829 |
+
|
| 830 |
+
def set_output_embeddings(self, new_embeddings):
|
| 831 |
+
self.lm_head = new_embeddings
|
| 832 |
+
|
| 833 |
+
def set_decoder(self, decoder):
|
| 834 |
+
self.model = decoder
|
| 835 |
+
|
| 836 |
+
def get_decoder(self):
|
| 837 |
+
return self.model
|
| 838 |
+
|
| 839 |
+
def forward(
|
| 840 |
+
self,
|
| 841 |
+
input_ids: torch.LongTensor = None,
|
| 842 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 843 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 844 |
+
seqlens: Optional[torch.IntTensor] = None,
|
| 845 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 846 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 847 |
+
labels: Optional[torch.LongTensor] = None,
|
| 848 |
+
audios: Optional[List|torch.Tensor] = None,
|
| 849 |
+
encoder_length: Optional[torch.Tensor] = None,
|
| 850 |
+
bridge_length: Optional[torch.Tensor] = None,
|
| 851 |
+
images: Optional[torch.Tensor] = None,
|
| 852 |
+
images_grid: Optional[torch.Tensor] = None,
|
| 853 |
+
videos: Optional[torch.Tensor] = None,
|
| 854 |
+
videos_grid: Optional[torch.Tensor] = None,
|
| 855 |
+
use_cache: Optional[bool] = None,
|
| 856 |
+
output_attentions: Optional[bool] = None,
|
| 857 |
+
output_hidden_states: Optional[bool] = None,
|
| 858 |
+
return_dict: Optional[bool] = None,
|
| 859 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 860 |
+
|
| 861 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 862 |
+
output_hidden_states = (
|
| 863 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 864 |
+
)
|
| 865 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 866 |
+
|
| 867 |
+
_, sp_mask, _ = self.model.get_multimodal_mask(input_ids, self.config.audio_config.audio_pad_token_id, self.config.multimodal_special_token_list)
|
| 868 |
+
# TODO 放开部分可学习的special token lmhead参数
|
| 869 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 870 |
+
outputs = self.model(
|
| 871 |
+
input_ids=input_ids,
|
| 872 |
+
attention_mask=attention_mask,
|
| 873 |
+
position_ids=position_ids,
|
| 874 |
+
seqlens=seqlens,
|
| 875 |
+
past_key_values=past_key_values,
|
| 876 |
+
inputs_embeds=inputs_embeds,
|
| 877 |
+
audios=audios,
|
| 878 |
+
encoder_length=encoder_length,
|
| 879 |
+
bridge_length=bridge_length,
|
| 880 |
+
images=images,
|
| 881 |
+
images_grid=images_grid,
|
| 882 |
+
videos=videos,
|
| 883 |
+
videos_grid=videos_grid,
|
| 884 |
+
use_cache=use_cache,
|
| 885 |
+
output_attentions=output_attentions,
|
| 886 |
+
output_hidden_states=output_hidden_states,
|
| 887 |
+
return_dict=return_dict,
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
hidden_states = outputs[0]
|
| 891 |
+
|
| 892 |
+
# 部分可学习的special token放开lm head梯度
|
| 893 |
+
special_with_loss_list = list(set(self.config.multimodal_special_token_list) - set(self.config.multimodal_special_token_no_loss_list))
|
| 894 |
+
_, sp_with_loss_mask, lm_head_mask = self.model.get_multimodal_mask(input_ids, self.config.audio_config.audio_pad_token_id, special_with_loss_list)
|
| 895 |
+
if self.config.train_multimodal_special_tokens_only and self.training and len(special_with_loss_list) > 0:
|
| 896 |
+
if self.config.use_norm_head:
|
| 897 |
+
logits = self.lm_head(hidden_states, mask=lm_head_mask)
|
| 898 |
+
else:
|
| 899 |
+
lm_head_mask = lm_head_mask.to(self.lm_head.weight)
|
| 900 |
+
norm_weight = self.lm_head.weight * lm_head_mask + (1 - lm_head_mask) * self.lm_head.weight.detach()
|
| 901 |
+
logits = torch.einsum('bsd,ld->bsl', hidden_states, norm_weight)
|
| 902 |
+
else:
|
| 903 |
+
logits = self.lm_head(hidden_states)
|
| 904 |
+
|
| 905 |
+
loss = torch.tensor(0, device=hidden_states.device, dtype=hidden_states.dtype)
|
| 906 |
+
text_nt_loss = torch.tensor(0, device=hidden_states.device, dtype=hidden_states.dtype)
|
| 907 |
+
flatten_loss = None
|
| 908 |
+
if labels is not None:
|
| 909 |
+
# Shift so that tokens < n predict n
|
| 910 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 911 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 912 |
+
|
| 913 |
+
valid_mask = torch.gt(shift_labels, -1) # label < 0 视为pad位置
|
| 914 |
+
sp_mask = sp_mask[..., 1:].contiguous()
|
| 915 |
+
text_mask = torch.logical_and(valid_mask, torch.logical_not(sp_mask))
|
| 916 |
+
|
| 917 |
+
# Flatten the tokens
|
| 918 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 919 |
+
shift_labels = shift_labels.view(-1)
|
| 920 |
+
# Enable model parallelism
|
| 921 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 922 |
+
flatten_loss = F.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction='none')
|
| 923 |
+
|
| 924 |
+
loss = torch.mean(torch.masked_select(flatten_loss, valid_mask.view(-1)))
|
| 925 |
+
text_nt_loss = torch.mean(torch.masked_select(flatten_loss, text_mask.view(-1))).detach()
|
| 926 |
+
|
| 927 |
+
if not return_dict:
|
| 928 |
+
output = (logits,) + outputs[1:]
|
| 929 |
+
return (loss,) + output if loss is not None else output
|
| 930 |
+
|
| 931 |
+
return OceanMMCausalLMOutputWithPast(
|
| 932 |
+
loss=loss,
|
| 933 |
+
logits=logits,
|
| 934 |
+
past_key_values=outputs.past_key_values,
|
| 935 |
+
hidden_states=outputs.hidden_states,
|
| 936 |
+
attentions=outputs.attentions,
|
| 937 |
+
text_nt_loss=text_nt_loss,
|
| 938 |
+
flatten_loss=flatten_loss
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
def prepare_inputs_for_generation(
|
| 943 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 944 |
+
):
|
| 945 |
+
if past_key_values:
|
| 946 |
+
input_ids = input_ids[:, -1:]
|
| 947 |
+
|
| 948 |
+
position_ids = kwargs.get("position_ids", None)
|
| 949 |
+
if attention_mask is not None and position_ids is None:
|
| 950 |
+
# create position_ids on the fly for batch generation
|
| 951 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 952 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 953 |
+
if past_key_values:
|
| 954 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 955 |
+
|
| 956 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 957 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 958 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 959 |
+
elif past_key_values is not None:
|
| 960 |
+
model_inputs = {"input_ids": input_ids}
|
| 961 |
+
else:
|
| 962 |
+
model_inputs = {"input_ids": input_ids,
|
| 963 |
+
"audios": kwargs.get("audios", None), "encoder_length": kwargs.get("encoder_length", None), "bridge_length": kwargs.get("bridge_length", None),
|
| 964 |
+
"images": kwargs.get("images", None),
|
| 965 |
+
"videos": kwargs.get("videos", None)
|
| 966 |
+
}
|
| 967 |
+
|
| 968 |
+
model_inputs.update(
|
| 969 |
+
{
|
| 970 |
+
"position_ids": position_ids,
|
| 971 |
+
"past_key_values": past_key_values,
|
| 972 |
+
"use_cache": kwargs.get("use_cache"),
|
| 973 |
+
"attention_mask": attention_mask,
|
| 974 |
+
"images_grid": kwargs.get("images_grid"),
|
| 975 |
+
"videos_grid": kwargs.get("videos_grid"),
|
| 976 |
+
}
|
| 977 |
+
)
|
| 978 |
+
return model_inputs
|
| 979 |
+
|
| 980 |
+
@staticmethod
|
| 981 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 982 |
+
reordered_past = ()
|
| 983 |
+
for layer_past in past_key_values:
|
| 984 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
| 985 |
+
return reordered_past
|
| 986 |
+
|
| 987 |
+
def chat(self, tokenizer, messages: List[dict], stream=False,
|
| 988 |
+
generation_config: Optional[GenerationConfig]=None):
|
| 989 |
+
generation_config = generation_config or self.generation_config
|
| 990 |
+
input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
|
| 991 |
+
if stream:
|
| 992 |
+
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 993 |
+
Thread(target=self.generate, kwargs=dict(
|
| 994 |
+
inputs=input_ids, streamer=streamer,
|
| 995 |
+
generation_config=generation_config,
|
| 996 |
+
)).start()
|
| 997 |
+
return streamer
|
| 998 |
+
else:
|
| 999 |
+
outputs = self.generate(input_ids, generation_config=generation_config)
|
| 1000 |
+
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
| 1001 |
+
return response
|
processor_ocean.py
ADDED
|
@@ -0,0 +1,1154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import re, ujson, os, sys, fire, glob, random, time, json
|
| 3 |
+
import numpy as np
|
| 4 |
+
import io
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import default_collate
|
| 7 |
+
import torchaudio
|
| 8 |
+
from typing import *
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
import transformers
|
| 11 |
+
from transformers.modeling_outputs import ModelOutput
|
| 12 |
+
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
|
| 13 |
+
from functools import lru_cache
|
| 14 |
+
from io import BytesIO
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from qcloud_cos import CosConfig
|
| 17 |
+
from qcloud_cos import CosS3Client
|
| 18 |
+
import tos
|
| 19 |
+
import concurrent.futures as cf
|
| 20 |
+
from transformers.image_transforms import resize, center_crop, get_resize_output_image_size
|
| 21 |
+
from transformers.image_utils import PILImageResampling
|
| 22 |
+
from PIL import Image, ImageOps
|
| 23 |
+
from PIL import ImageFile
|
| 24 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 25 |
+
import base64
|
| 26 |
+
from decord import VideoReader, cpu
|
| 27 |
+
import cv2
|
| 28 |
+
import av
|
| 29 |
+
import imagesize
|
| 30 |
+
import math
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def smart_resize(
|
| 34 |
+
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
|
| 35 |
+
):
|
| 36 |
+
"""Rescales the image so that the following conditions are met:
|
| 37 |
+
|
| 38 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
| 39 |
+
|
| 40 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
| 41 |
+
|
| 42 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
| 43 |
+
|
| 44 |
+
"""
|
| 45 |
+
# if height < factor or width < factor:
|
| 46 |
+
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
| 47 |
+
if max(height, width) / min(height, width) > 200:
|
| 48 |
+
raise ValueError(
|
| 49 |
+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
| 50 |
+
)
|
| 51 |
+
h_bar = round(height / factor) * factor if height > factor else factor
|
| 52 |
+
w_bar = round(width / factor) * factor if width > factor else factor
|
| 53 |
+
if h_bar * w_bar > max_pixels:
|
| 54 |
+
beta = math.sqrt((height * width) / max_pixels)
|
| 55 |
+
h_bar = math.floor(height / beta / factor) * factor
|
| 56 |
+
w_bar = math.floor(width / beta / factor) * factor
|
| 57 |
+
elif h_bar * w_bar < min_pixels:
|
| 58 |
+
beta = math.sqrt(min_pixels / (height * width))
|
| 59 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
| 60 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
| 61 |
+
return h_bar, w_bar
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def select_best_resolution(image_size, candidate_resolutions):
|
| 65 |
+
'''找到最佳的resolution 对于原图进行放缩
|
| 66 |
+
image_size 通常为ori_size e.g. (8*336, 16*336)
|
| 67 |
+
candidate_resolutions 为备选分辨率 e.g. (1*336, 4*336)
|
| 68 |
+
'''
|
| 69 |
+
try:
|
| 70 |
+
original_width, original_height = image_size
|
| 71 |
+
except:
|
| 72 |
+
pass
|
| 73 |
+
best_fit = None
|
| 74 |
+
max_effective_resolution = 0
|
| 75 |
+
min_wasted_resolution = float("inf")
|
| 76 |
+
# 从candidate_resolutions 中遍历宽和高
|
| 77 |
+
for width, height in candidate_resolutions:
|
| 78 |
+
# width / original_width 和 height / original_height 中最小的那个作为scale
|
| 79 |
+
scale = min(width / original_width, height / original_height) # e.g. scale =min (1/8, 1/4) = 1/8
|
| 80 |
+
# 放缩 original_width 和 original_height
|
| 81 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) # e.g. 1*336, 2*336
|
| 82 |
+
# effective_resolution 为 放缩之后的分辨率 s^2 * w * h
|
| 83 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) # e.g. min(1*336 * 2*336, 8*336 * 16*336)
|
| 84 |
+
# wasted_resolution 为 放缩前后分辨率的差值
|
| 85 |
+
wasted_resolution = (width * height) - effective_resolution
|
| 86 |
+
# 若 (1) 放缩之后的分辨率 比当前的max_effective_resolution更大;
|
| 87 |
+
# (2)
|
| 88 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
| 89 |
+
max_effective_resolution = effective_resolution # 更新max_effective_resolution
|
| 90 |
+
min_wasted_resolution = wasted_resolution # min_wasted_resolution
|
| 91 |
+
best_fit = (width, height)
|
| 92 |
+
|
| 93 |
+
return best_fit
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def read_video(image_path, max_frame_number, decode_way):
|
| 97 |
+
if decode_way=='1fps':
|
| 98 |
+
try:
|
| 99 |
+
vr = VideoReader(image_path, ctx=cpu(0))
|
| 100 |
+
total_frame_num = len(vr)
|
| 101 |
+
fps = round(vr.get_avg_fps())
|
| 102 |
+
frame_idx = [i for i in range(0, len(vr), fps)]
|
| 103 |
+
frames = vr.get_batch(frame_idx).asnumpy()
|
| 104 |
+
frames = [i for i in frames]
|
| 105 |
+
cnt = len(frames)
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(image_path)
|
| 108 |
+
print('error is', e)
|
| 109 |
+
return None
|
| 110 |
+
elif decode_way=='key':
|
| 111 |
+
try:
|
| 112 |
+
with av.open(image_path) as container:
|
| 113 |
+
stream = container.streams.video[0]
|
| 114 |
+
stream.codec_context.skip_frame = 'NONKEY'
|
| 115 |
+
frames = []
|
| 116 |
+
fps = int(stream.average_rate)
|
| 117 |
+
cnt = 0
|
| 118 |
+
for frame in container.decode(stream): # 关键帧存成image patch
|
| 119 |
+
image = frame.to_image()
|
| 120 |
+
frames.append(image)
|
| 121 |
+
cnt += 1
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print('error is', e)
|
| 124 |
+
return None
|
| 125 |
+
if frames is None or len(frames)==0:
|
| 126 |
+
return None
|
| 127 |
+
if len(frames)>max_frame_number and max_frame_number>0:
|
| 128 |
+
# 生成均匀间隔的索引
|
| 129 |
+
indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int)
|
| 130 |
+
# 根据索引获取对应元素
|
| 131 |
+
sampled_elements = [frames[idx] for idx in indices]
|
| 132 |
+
frames = sampled_elements
|
| 133 |
+
return frames
|
| 134 |
+
|
| 135 |
+
class OceanImageProcessor:
|
| 136 |
+
def __init__(self, config, **kwargs):
|
| 137 |
+
self.config = config # visual_config
|
| 138 |
+
self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56
|
| 139 |
+
self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280
|
| 140 |
+
self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14
|
| 141 |
+
self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2
|
| 142 |
+
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
|
| 143 |
+
self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2
|
| 144 |
+
|
| 145 |
+
def image_transform(self, strseq, return_mm_data = True):
|
| 146 |
+
image = None
|
| 147 |
+
if isinstance(strseq, str):
|
| 148 |
+
if return_mm_data:
|
| 149 |
+
image = Image.open(strseq).convert("RGB")
|
| 150 |
+
else:
|
| 151 |
+
image = Image.open(BytesIO(strseq)).convert("RGB")
|
| 152 |
+
|
| 153 |
+
image = np.array(image.convert("RGB")) # 这一步首先将图像转换为 RGB 格式,确保图像有三个通道(R、G、B)。然后使用 np.array() 将其转换为 NumPy 数组,方便后续处理。
|
| 154 |
+
image_org_size = image.shape[:2] # 这里保存了图像的原始大小(高度和宽度),image.shape 返回图像的形状 (高度, 宽度, 通道数),而 image.shape[:2] 提取了前两个值,即原始的高度和宽度。这个信息可以用于后续的对比或其他处理。
|
| 155 |
+
|
| 156 |
+
# resize, crop, scale, normalize
|
| 157 |
+
# 接受目标尺寸作为输入参数,通常是目标尺寸的短边或长边长度。例如,如果指定目标短边为 336 像素,函数会自动计算出对应的长边大小,以保持图像的宽高比。
|
| 158 |
+
# 输出一个新的尺寸,这个尺寸通常是 (宽度, 高度) 格式,用于后续的图像调整操作,如缩放或裁剪。
|
| 159 |
+
resized_height, resized_width = smart_resize(
|
| 160 |
+
image_org_size[0], image_org_size[1],
|
| 161 |
+
factor=self.patch_size * self.spatial_merge_size,
|
| 162 |
+
min_pixels=self.min_pixels,
|
| 163 |
+
max_pixels=self.max_pixels,
|
| 164 |
+
)
|
| 165 |
+
output_size = (resized_height, resized_width)
|
| 166 |
+
|
| 167 |
+
# output_size = get_resize_output_image_size(image, self.config.crop_size, False) # 短边resize到336
|
| 168 |
+
# 使用 resize 函数将图像调整到 output_size 大小。PILImageResampling.BICUBIC 指定使用双三次插值法来进行图像缩放,这种方法通常能够提供较好的图像质量。
|
| 169 |
+
# image: 输入的图像数据,可以是 NumPy 数组或 PIL 图像对象;output_size: 目标大小,通常是一个二元组 (宽度, 高度)。这个尺寸可以是图像的绝对大小,也可以是相对于原始图像的比例;
|
| 170 |
+
# resample: 可选的重采样方法,通常用于确定如何插值像素。例如,PILImageResampling.BICUBIC 表示使用双三次插值法,这是一种平滑的插值方法,常用于图像缩放。
|
| 171 |
+
image = resize(image, output_size, PILImageResampling.BICUBIC)
|
| 172 |
+
# 从图像中心裁剪出一个指定大小的区域,这里是一个正方形区域 self.config.crop_size x self.config.crop_size。center_crop 函数的参数 return_numpy=True 表示返回一个 NumPy 数组形式的裁剪图像。
|
| 173 |
+
# image = center_crop(image, (self.config.crop_size, self.config.crop_size), return_numpy=True)
|
| 174 |
+
img = image.transpose(2, 0, 1)
|
| 175 |
+
# 对图像进行归一化和标准化处理
|
| 176 |
+
image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis]
|
| 177 |
+
# 处理成patch
|
| 178 |
+
patches = image[np.newaxis, :]
|
| 179 |
+
if patches.shape[0] == 1:
|
| 180 |
+
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
|
| 181 |
+
channel = patches.shape[1]
|
| 182 |
+
grid_t = patches.shape[0] // self.temporal_patch_size
|
| 183 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
| 184 |
+
patches = patches.reshape(
|
| 185 |
+
grid_t,
|
| 186 |
+
self.temporal_patch_size,
|
| 187 |
+
channel,
|
| 188 |
+
grid_h // self.spatial_merge_size,
|
| 189 |
+
self.spatial_merge_size,
|
| 190 |
+
self.patch_size,
|
| 191 |
+
grid_w // self.spatial_merge_size,
|
| 192 |
+
self.spatial_merge_size,
|
| 193 |
+
self.patch_size,
|
| 194 |
+
)
|
| 195 |
+
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
| 196 |
+
flatten_patches = patches.reshape(
|
| 197 |
+
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
return flatten_patches, image_org_size, (grid_t, grid_h, grid_w)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class OceanAudioProcessor:
|
| 204 |
+
# 包含基本的音频特征抽取模块 + 输入数据解析模块 + cos请求/缓存模块
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
config, # audio processor config
|
| 208 |
+
**kwargs
|
| 209 |
+
):
|
| 210 |
+
# make sure you have install 'conda install -c conda-forge 'ffmpeg<7'' for torchaudio
|
| 211 |
+
assert(len(torchaudio.list_audio_backends()) > 0)
|
| 212 |
+
self.config = config
|
| 213 |
+
self.mel_filters = mel_filter_bank(
|
| 214 |
+
num_frequency_bins=1 + self.config.n_fft // 2,
|
| 215 |
+
num_mel_filters=self.config.num_mel_bins,
|
| 216 |
+
min_frequency=0.0,
|
| 217 |
+
max_frequency=self.config.sampling_rate / 2.0,
|
| 218 |
+
sampling_rate=self.config.sampling_rate,
|
| 219 |
+
norm="slaney",
|
| 220 |
+
mel_scale="slaney",
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
@staticmethod
|
| 224 |
+
def zero_mean_unit_var_norm(x):
|
| 225 |
+
return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
|
| 226 |
+
|
| 227 |
+
def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False):
|
| 228 |
+
metadata = torchaudio.info(uri) # sample_rate, num_frames, num_channels, bits_per_sample, encoding=PCM_S
|
| 229 |
+
assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
|
| 230 |
+
waveform_tensor, _ = torchaudio.load(uri, normalize=True)
|
| 231 |
+
if self.config.sampling_rate != metadata.sample_rate:
|
| 232 |
+
waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate)
|
| 233 |
+
|
| 234 |
+
# downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
|
| 235 |
+
if metadata.num_channels > 1:
|
| 236 |
+
waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
|
| 237 |
+
|
| 238 |
+
# normalized to zero mean (Qwen Audio没有处理 但Whisper官方实现)
|
| 239 |
+
if do_normalize:
|
| 240 |
+
waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
|
| 241 |
+
|
| 242 |
+
if return_tensors: # (channels, samples)
|
| 243 |
+
return waveform_tensor
|
| 244 |
+
else:
|
| 245 |
+
return waveform_tensor.numpy()
|
| 246 |
+
|
| 247 |
+
def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
|
| 248 |
+
channels, wave_samples = waveform.shape
|
| 249 |
+
max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate
|
| 250 |
+
if wave_samples <= max_audio_samples or self.config.split_overlap < 0:
|
| 251 |
+
return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
|
| 252 |
+
|
| 253 |
+
split_waveform, start = [], 0
|
| 254 |
+
while start < wave_samples: # 20240724修改 统一按秒数对齐overlap 保证不同sampling rate/n_fft/hop length配置下采到的数据是一致的
|
| 255 |
+
if start > int(self.config.sampling_rate * self.config.split_overlap):
|
| 256 |
+
start -= int(self.config.sampling_rate * self.config.split_overlap) # 0表示没有overlap,>0 overlap对应秒数
|
| 257 |
+
end = min(start + max_audio_samples, wave_samples)
|
| 258 |
+
split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
|
| 259 |
+
start = end
|
| 260 |
+
return split_waveform
|
| 261 |
+
|
| 262 |
+
@classmethod
|
| 263 |
+
def inference_output_length(cls, config, input_length):
|
| 264 |
+
# for whisper + bridge
|
| 265 |
+
kernel_size = config.kernel_size
|
| 266 |
+
stride_size = config.stride_size
|
| 267 |
+
avg_pooler = config.avg_pooler
|
| 268 |
+
encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
|
| 269 |
+
encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
|
| 270 |
+
if avg_pooler > 1:
|
| 271 |
+
bridge_length = encoder_length // avg_pooler
|
| 272 |
+
return encoder_length, bridge_length
|
| 273 |
+
|
| 274 |
+
def extract_fbank_features(self, waveform):
|
| 275 |
+
# ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
|
| 276 |
+
channels, wave_samples = waveform.shape
|
| 277 |
+
assert(wave_samples >= self.config.n_fft)
|
| 278 |
+
valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1)
|
| 279 |
+
if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate:
|
| 280 |
+
waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0)
|
| 281 |
+
else:
|
| 282 |
+
waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate]
|
| 283 |
+
|
| 284 |
+
window = torch.hann_window(self.config.n_fft)
|
| 285 |
+
stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
|
| 286 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
| 287 |
+
|
| 288 |
+
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
| 289 |
+
mel_spec = mel_filters.T @ magnitudes
|
| 290 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 291 |
+
|
| 292 |
+
if waveform.dim() == 2:
|
| 293 |
+
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
|
| 294 |
+
log_spec = torch.maximum(log_spec, max_val - 8.0)
|
| 295 |
+
else:
|
| 296 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 297 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 298 |
+
|
| 299 |
+
log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
|
| 300 |
+
log_spec[:, valid_frame_nums:] = 0.0 # pad0 在collect时取batch内最大长度
|
| 301 |
+
|
| 302 |
+
return log_spec, valid_frame_nums
|
| 303 |
+
|
| 304 |
+
def data_augment(self, feature: np.array, input_length, training=True):
|
| 305 |
+
# reference https://arxiv.org/pdf/1904.08779
|
| 306 |
+
# run only on cpu
|
| 307 |
+
def mask_start_indices(input_length, mask_length, min_masks, mask_prob):
|
| 308 |
+
# 计算总共需要mask的span数 之后随机筛选span开始下标
|
| 309 |
+
num_masked_span = int(mask_prob * input_length / mask_length + random.random())
|
| 310 |
+
num_masked_span = max(num_masked_span, min_masks)
|
| 311 |
+
start_indices = list(range(input_length - mask_length))
|
| 312 |
+
random.shuffle(start_indices)
|
| 313 |
+
start_indices = start_indices[:num_masked_span]
|
| 314 |
+
return start_indices
|
| 315 |
+
|
| 316 |
+
if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0):
|
| 317 |
+
return feature
|
| 318 |
+
if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1:
|
| 319 |
+
return feature
|
| 320 |
+
if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1:
|
| 321 |
+
return feature
|
| 322 |
+
|
| 323 |
+
if self.config.mask_time_prob > 0:
|
| 324 |
+
start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob)
|
| 325 |
+
for start_idx in start_indices:
|
| 326 |
+
feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0
|
| 327 |
+
if self.config.mask_feature_prob > 0:
|
| 328 |
+
start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob)
|
| 329 |
+
for start_idx in start_indices:
|
| 330 |
+
feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0
|
| 331 |
+
|
| 332 |
+
return feature
|
| 333 |
+
|
| 334 |
+
class CosClient():
|
| 335 |
+
def __init__(self, bucket_name='crawl-pic-1317568651',
|
| 336 |
+
max_retries=2):
|
| 337 |
+
self.config = CosConfig(
|
| 338 |
+
Endpoint="cos.ap-guangzhou.myqcloud.com",
|
| 339 |
+
# Region='ap-guangzhou',
|
| 340 |
+
SecretId='AKIDnRpxoOghgVs0tkU3Mfv20jAMI0SRDj02',
|
| 341 |
+
SecretKey='td9tRlqiPvEJ8i27wXwBIDiy5ye6JGyS',
|
| 342 |
+
Token=None, Scheme='https', Timeout=300)
|
| 343 |
+
self.client = CosS3Client(self.config)
|
| 344 |
+
self.max_retries = max_retries
|
| 345 |
+
self.bucket_name = bucket_name
|
| 346 |
+
|
| 347 |
+
def __call__(self, relative_path, bucket_name=None):
|
| 348 |
+
if bucket_name is None or len(bucket_name) <= 0:
|
| 349 |
+
bucket_name = self.bucket_name
|
| 350 |
+
multimodal_bytes = None
|
| 351 |
+
for _ in range(self.max_retries):
|
| 352 |
+
try:
|
| 353 |
+
response = self.client.get_object(Bucket=bucket_name, Key=relative_path)
|
| 354 |
+
fp = response['Body'].get_raw_stream()
|
| 355 |
+
multimodal_bytes = fp.read()
|
| 356 |
+
break
|
| 357 |
+
except Exception as e:
|
| 358 |
+
time.sleep(0.01)
|
| 359 |
+
continue
|
| 360 |
+
return multimodal_bytes
|
| 361 |
+
|
| 362 |
+
class TosClient(object):
|
| 363 |
+
def __init__(self):
|
| 364 |
+
ak = "AKLTYTM3MWY5MTFhNDgyNDk4YjhmYTE0ZTE3YTk5ZmU1MjU"
|
| 365 |
+
sk = "TVRRM1pUZGtaVEJqWTJJd05HSTNPR0ppWVdKa1lqYzVORFUwTlRobU1UVQ=="
|
| 366 |
+
endpoint = "tos-cn-beijing.ivolces.com" # "tos-cn-beijing.ivolces.com"
|
| 367 |
+
region = "cn-beijing"
|
| 368 |
+
self.bucket_name = "audio-dataset"
|
| 369 |
+
self.client = tos.TosClientV2(ak, sk, endpoint, region)
|
| 370 |
+
|
| 371 |
+
def __call__(self, path, bucket_name=None):
|
| 372 |
+
if bucket_name is None:
|
| 373 |
+
bucket_name = self.bucket_name
|
| 374 |
+
for _ in range(2):
|
| 375 |
+
try:
|
| 376 |
+
object_stream = self.client.get_object(bucket_name, path)
|
| 377 |
+
return object_stream.read()
|
| 378 |
+
except Exception as e:
|
| 379 |
+
time.sleep(0.01)
|
| 380 |
+
continue
|
| 381 |
+
return None
|
| 382 |
+
|
| 383 |
+
@dataclass
|
| 384 |
+
class OceanProcessorOutput(ModelOutput):
|
| 385 |
+
input_ids: Optional["List|torch.Tensor"] = None
|
| 386 |
+
labels: Optional["List|torch.Tensor"] = None
|
| 387 |
+
attention_mask: Optional["List|torch.Tensor"] = None
|
| 388 |
+
position_ids: Optional["List|torch.Tensor"] = None
|
| 389 |
+
seqlens: Optional["List|torch.Tensor"] = None # 需要配合Ocean Modeling使用
|
| 390 |
+
# audio fields
|
| 391 |
+
audios: Optional["List|torch.Tensor"] = None
|
| 392 |
+
encoder_length: Optional["List|torch.Tensor"] = None
|
| 393 |
+
bridge_length: Optional["List|torch.Tensor"] = None
|
| 394 |
+
# image fields
|
| 395 |
+
images: Optional["List|torch.Tensor"] = None
|
| 396 |
+
patch_nums: Optional["List|torch.Tensor"] = None
|
| 397 |
+
images_size: Optional["List|torch.Tensor"] = None
|
| 398 |
+
crop_size: Optional["List|torch.Tensor"] = None
|
| 399 |
+
images_grid: Optional["List|torch.Tensor"] = None
|
| 400 |
+
# video fields
|
| 401 |
+
videos: Optional["List|torch.Tensor"] = None
|
| 402 |
+
videos_patch_nums: Optional["List|torch.Tensor"] = None
|
| 403 |
+
videos_size: Optional["List|torch.Tensor"] = None
|
| 404 |
+
videos_crop_size: Optional["List|torch.Tensor"] = None
|
| 405 |
+
videos_grid: Optional["List|torch.Tensor"] = None
|
| 406 |
+
# processor fields
|
| 407 |
+
raw_text: Optional[str] = None
|
| 408 |
+
index: Optional[int] = None
|
| 409 |
+
|
| 410 |
+
def concatenate(self, other): # 仅限list使用
|
| 411 |
+
def concat_one(a, b):
|
| 412 |
+
if a is None and b is None:
|
| 413 |
+
return None
|
| 414 |
+
elif a is None and b is not None:
|
| 415 |
+
return b
|
| 416 |
+
elif a is not None and b is None:
|
| 417 |
+
return a
|
| 418 |
+
else:
|
| 419 |
+
return a + b
|
| 420 |
+
return OceanProcessorOutput(
|
| 421 |
+
input_ids=concat_one(self.input_ids, other.input_ids),
|
| 422 |
+
labels=concat_one(self.labels, other.labels),
|
| 423 |
+
audios=concat_one(self.audios, other.audios),
|
| 424 |
+
encoder_length=concat_one(self.encoder_length, other.encoder_length),
|
| 425 |
+
bridge_length=concat_one(self.bridge_length, other.bridge_length),
|
| 426 |
+
images=concat_one(self.images, other.images),
|
| 427 |
+
images_grid=concat_one(self.images_grid, other.images_grid),
|
| 428 |
+
patch_nums=concat_one(self.patch_nums, other.patch_nums),
|
| 429 |
+
|
| 430 |
+
videos=concat_one(self.videos, other.videos),
|
| 431 |
+
videos_grid=concat_one(self.videos_grid, other.videos_grid),
|
| 432 |
+
videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums),
|
| 433 |
+
|
| 434 |
+
position_ids=concat_one(self.position_ids, other.position_ids),
|
| 435 |
+
seqlens=concat_one(self.seqlens, other.seqlens),
|
| 436 |
+
images_size=concat_one(self.images_size, other.images_size)
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
class OceanMMProcessor(object):
|
| 440 |
+
def __init__(self,
|
| 441 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 442 |
+
config,
|
| 443 |
+
training,
|
| 444 |
+
relative_path=None,
|
| 445 |
+
**kwargs,
|
| 446 |
+
):
|
| 447 |
+
self.tokenizer = tokenizer
|
| 448 |
+
self.config = config
|
| 449 |
+
self.audio_processor = None
|
| 450 |
+
if hasattr(config, "audio_config"):
|
| 451 |
+
self.audio_processor = OceanAudioProcessor(config.audio_config)
|
| 452 |
+
self.visual_processor = None
|
| 453 |
+
if hasattr(config, "visual_config"):
|
| 454 |
+
self.visual_processor = OceanImageProcessor(config.visual_config)
|
| 455 |
+
self.video_processor = None
|
| 456 |
+
if hasattr(config, "video_config"):
|
| 457 |
+
self.video_processor = OceanImageProcessor(config.video_config)
|
| 458 |
+
self.training = training
|
| 459 |
+
self.relative_path = relative_path
|
| 460 |
+
self.cos_client = CosClient()
|
| 461 |
+
self.tos_client = TosClient()
|
| 462 |
+
# audio tag
|
| 463 |
+
self.audio_start_tag = None
|
| 464 |
+
self.audio_end_tag = None
|
| 465 |
+
self.audio_pad_tag = None
|
| 466 |
+
self.audio_delim_tag = None
|
| 467 |
+
if hasattr(self.config, "audio_config"):
|
| 468 |
+
self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id)
|
| 469 |
+
self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id)
|
| 470 |
+
self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id)
|
| 471 |
+
self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id)
|
| 472 |
+
# image tag
|
| 473 |
+
self.image_start_tag = None
|
| 474 |
+
self.image_end_tag = None
|
| 475 |
+
self.image_pad_tag = None
|
| 476 |
+
self.video_start_tag = None
|
| 477 |
+
self.video_end_tag = None
|
| 478 |
+
if hasattr(self.config, "visual_config"):
|
| 479 |
+
# special token for start_tag
|
| 480 |
+
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id)
|
| 481 |
+
# special token for end_tag
|
| 482 |
+
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id)
|
| 483 |
+
# special token for pad_tag
|
| 484 |
+
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id)
|
| 485 |
+
self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id)
|
| 486 |
+
self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id)
|
| 487 |
+
if hasattr(self.config, "video_config"):
|
| 488 |
+
self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id)
|
| 489 |
+
self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id)
|
| 490 |
+
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id)
|
| 491 |
+
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id)
|
| 492 |
+
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id)
|
| 493 |
+
self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# @lru_cache(maxsize=1024)
|
| 497 |
+
def _get_audio(self, audio_info, return_mm_data = True):
|
| 498 |
+
try:
|
| 499 |
+
audio_info = ujson.loads(audio_info)
|
| 500 |
+
audio_uri = None
|
| 501 |
+
if 'path' in audio_info.keys():
|
| 502 |
+
|
| 503 |
+
if self.relative_path is not None: # 优先匹配本地路径
|
| 504 |
+
audio_uri = os.path.join(self.relative_path, audio_info['path'])
|
| 505 |
+
if not os.path.exists(audio_uri):
|
| 506 |
+
audio_uri = None
|
| 507 |
+
if audio_uri is None: # 本地没有尝试取cos/tos
|
| 508 |
+
if audio_info.get('server', 'cos') == 'tos':
|
| 509 |
+
audio_uri = self.tos_client(audio_info['path'], 'audio-dataset')
|
| 510 |
+
else:
|
| 511 |
+
audio_uri = self.cos_client(audio_info['path'], 'audio-data-1317568651')
|
| 512 |
+
|
| 513 |
+
elif 'local' in audio_info.keys():
|
| 514 |
+
audio_uri = audio_info['local']
|
| 515 |
+
if not os.path.exists(audio_uri):
|
| 516 |
+
audio_uri = None
|
| 517 |
+
return OceanProcessorOutput()
|
| 518 |
+
else:
|
| 519 |
+
raise ValueError("can not find path or local in audio_info")
|
| 520 |
+
|
| 521 |
+
waveforms = self.audio_processor.load_audio_waveform(audio_uri, True)
|
| 522 |
+
waveforms = self.audio_processor.split_with_overlap(waveforms) # 分割逻辑
|
| 523 |
+
ret = OceanProcessorOutput() # 默认初始化 audios字段为None
|
| 524 |
+
for waveform in waveforms:
|
| 525 |
+
audio, input_length = self.audio_processor.extract_fbank_features(waveform)
|
| 526 |
+
audio = self.audio_processor.data_augment(audio, input_length, self.training)
|
| 527 |
+
encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length)
|
| 528 |
+
if bridge_length <= 0: # 过滤极端短数据 1. 如果len(waveforms)==1 ret=None; 2. len(waveforms)>1 则说明最后一段太短被抛弃
|
| 529 |
+
continue
|
| 530 |
+
current_ret = OceanProcessorOutput(
|
| 531 |
+
audios=[audio],
|
| 532 |
+
encoder_length=[encoder_length],
|
| 533 |
+
bridge_length=[bridge_length])
|
| 534 |
+
if ret.audios is None:
|
| 535 |
+
ret = current_ret
|
| 536 |
+
else:
|
| 537 |
+
ret = ret.concatenate(current_ret) # 拼接多个切片
|
| 538 |
+
if not return_mm_data:
|
| 539 |
+
ret.audios = [None]
|
| 540 |
+
return ret
|
| 541 |
+
except Exception as e:
|
| 542 |
+
print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info)))
|
| 543 |
+
return OceanProcessorOutput()
|
| 544 |
+
|
| 545 |
+
# @lru_cache(maxsize=1024)
|
| 546 |
+
def _get_image(self, image_info, return_mm_data = True):
|
| 547 |
+
try:
|
| 548 |
+
try: # chensong
|
| 549 |
+
image_info = ujson.loads(image_info)
|
| 550 |
+
except:
|
| 551 |
+
#image_info = image_info.replace("'", '"')
|
| 552 |
+
image_info = re.sub(r"(?<!\\)'", '"', image_info)
|
| 553 |
+
image_info = ujson.loads(image_info)
|
| 554 |
+
if 'base64' in image_info.keys():
|
| 555 |
+
image_data = base64.b64decode(image_info['base64'])
|
| 556 |
+
image_feat, org_size, image_list = self.visual_processor.image_transform(image_data)
|
| 557 |
+
elif 'local' in image_info.keys():
|
| 558 |
+
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local'],return_mm_data = return_mm_data)
|
| 559 |
+
elif 'path' in image_info.keys():
|
| 560 |
+
if "tos_bucket" in image_info.keys(): # tos上的每个item,一定要写明tos的桶以及tos_bucket这个key
|
| 561 |
+
tos_bucket = image_info['tos_bucket']
|
| 562 |
+
image_bytes = self.tos_client(image_info['path'], tos_bucket) # 从cos_client 获得 image
|
| 563 |
+
else:
|
| 564 |
+
cos_bucket = None
|
| 565 |
+
if "cos_bucket" in image_info.keys():
|
| 566 |
+
cos_bucket = image_info['cos_bucket']
|
| 567 |
+
if "bucket_name" in image_info.keys():
|
| 568 |
+
cos_bucket = image_info['bucket_name']
|
| 569 |
+
image_bytes = self.cos_client(image_info['path'], cos_bucket) # 从cos_client 获得 image
|
| 570 |
+
# 获得image_feat(image patches), org_size(image最初的size), image_list
|
| 571 |
+
image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes)
|
| 572 |
+
else:
|
| 573 |
+
raise ValueError("can not find any path in image_info")
|
| 574 |
+
|
| 575 |
+
merge_length = self.visual_processor.merge_size**2
|
| 576 |
+
patch_nums = np.array(image_list).prod() // merge_length
|
| 577 |
+
|
| 578 |
+
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
|
| 579 |
+
return OceanProcessorOutput(
|
| 580 |
+
images=[image_feat],
|
| 581 |
+
patch_nums=[patch_nums],
|
| 582 |
+
crop_size=[image_list],
|
| 583 |
+
images_size= [org_size],
|
| 584 |
+
images_grid=[image_list]
|
| 585 |
+
)
|
| 586 |
+
else:
|
| 587 |
+
print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info)))
|
| 588 |
+
return OceanProcessorOutput()
|
| 589 |
+
|
| 590 |
+
except Exception as e:
|
| 591 |
+
print("**** get image error: {}, info: {} *****".format(str(e), str(image_info)))
|
| 592 |
+
return OceanProcessorOutput()
|
| 593 |
+
|
| 594 |
+
# @lru_cache(maxsize=1024)
|
| 595 |
+
def _get_video_frame(self, video_frame_info, return_mm_data = True):
|
| 596 |
+
try:
|
| 597 |
+
pattern = r'\{.*?\}'
|
| 598 |
+
matches = re.findall(pattern, video_frame_info)
|
| 599 |
+
ret = OceanProcessorOutput()
|
| 600 |
+
# 逐个解析
|
| 601 |
+
for match in matches:
|
| 602 |
+
video_frame_info = ujson.loads(match)
|
| 603 |
+
if 'local' in video_frame_info.keys():
|
| 604 |
+
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'],return_mm_data = return_mm_data)
|
| 605 |
+
else:
|
| 606 |
+
raise ValueError("can not find any path in image_info")
|
| 607 |
+
|
| 608 |
+
merge_length = self.video_processor.merge_size**2
|
| 609 |
+
patch_nums = np.array(image_list).prod() // merge_length
|
| 610 |
+
|
| 611 |
+
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
|
| 612 |
+
ret = ret.concatenate(
|
| 613 |
+
OceanProcessorOutput(
|
| 614 |
+
videos=[image_feat],
|
| 615 |
+
videos_patch_nums=[patch_nums],
|
| 616 |
+
videos_crop_size=[image_list],
|
| 617 |
+
videos_size= [org_size],
|
| 618 |
+
videos_grid=[image_list]
|
| 619 |
+
)
|
| 620 |
+
)
|
| 621 |
+
else:
|
| 622 |
+
print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info)))
|
| 623 |
+
return ret
|
| 624 |
+
|
| 625 |
+
except Exception as e:
|
| 626 |
+
print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info)))
|
| 627 |
+
return OceanProcessorOutput()
|
| 628 |
+
|
| 629 |
+
# 读取视频
|
| 630 |
+
def _get_video_obj_byte(self, source, path, video_obj_json):
|
| 631 |
+
video_obj_byte = None
|
| 632 |
+
if source == "cos":
|
| 633 |
+
start_time = time.time()
|
| 634 |
+
video_obj_byte = self.cos_client(path, bucket_name=video_obj_json.get("cos_bucket", None))
|
| 635 |
+
if (time.time() - start_time) > 1.0:
|
| 636 |
+
self.reflash_cos_client()
|
| 637 |
+
if source == "local":
|
| 638 |
+
if os.path.exists(path):
|
| 639 |
+
video_obj_byte = open(path, "rb").read()
|
| 640 |
+
else:
|
| 641 |
+
video_obj_byte = None
|
| 642 |
+
if source == "base64":
|
| 643 |
+
video_obj_byte = base64.b64decode(path)
|
| 644 |
+
if source == "url":
|
| 645 |
+
video_obj_byte = requests.get(url=path).content
|
| 646 |
+
return video_obj_byte
|
| 647 |
+
|
| 648 |
+
# 将视频切分为帧,保存至子目录中
|
| 649 |
+
def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"):
|
| 650 |
+
video_path = video_info['local']
|
| 651 |
+
# 帧保存本地路径
|
| 652 |
+
frame_path = video_path.split('.')[0] + '_frames'
|
| 653 |
+
if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0:
|
| 654 |
+
# 保存帧
|
| 655 |
+
os.makedirs(frame_path, exist_ok=True)
|
| 656 |
+
mm_obj_byte = self._get_video_obj_byte('local', video_path, video_info)
|
| 657 |
+
if mm_obj_byte is None: # 未读取到视频文件
|
| 658 |
+
return ""
|
| 659 |
+
frames = read_video(io.BytesIO(mm_obj_byte), max_frame_number=max_frame_number, decode_way=decode_way) #读取全部帧
|
| 660 |
+
for frame_idx, frame in enumerate(frames):
|
| 661 |
+
output_filename = os.path.join(frame_path, f"{frame_idx}.jpg")
|
| 662 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 663 |
+
cv2.imwrite(output_filename, frame)
|
| 664 |
+
|
| 665 |
+
# 选取帧
|
| 666 |
+
frame_number = len([filename for filename in os.listdir(frame_path) if filename.endswith('.jpg')])
|
| 667 |
+
if frame_number>max_frame_number:
|
| 668 |
+
indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int)
|
| 669 |
+
else:
|
| 670 |
+
indices = np.linspace(0, frame_number - 1, frame_number, dtype=int)
|
| 671 |
+
# 拼接模式
|
| 672 |
+
replace_str = ""
|
| 673 |
+
for idx in indices:
|
| 674 |
+
frame_str = f"{self.image_start_tag}{os.path.join(frame_path, f'{idx}.jpg')}{self.image_end_tag}"
|
| 675 |
+
replace_str += frame_str
|
| 676 |
+
return replace_str
|
| 677 |
+
|
| 678 |
+
def _get_video_frame_str(self, video_info, return_mm_data = True ):
|
| 679 |
+
try:
|
| 680 |
+
video_info = ujson.loads(video_info)
|
| 681 |
+
if 'local' in video_info.keys():
|
| 682 |
+
# 获取包含多帧图像路径的字符串,最大帧数量max_frame_number
|
| 683 |
+
frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way)
|
| 684 |
+
if frames_str != "":
|
| 685 |
+
parts = frames_str.split(self.image_end_tag)
|
| 686 |
+
result = []
|
| 687 |
+
for part in parts:
|
| 688 |
+
if self.image_start_tag in part:
|
| 689 |
+
before_path, path = part.split(self.image_start_tag)
|
| 690 |
+
new_path = f'{self.image_start_tag}{{"local": "{path}"}}{self.image_end_tag}'
|
| 691 |
+
result.append(before_path + new_path)
|
| 692 |
+
else:
|
| 693 |
+
result.append(part)
|
| 694 |
+
return ''.join(result)
|
| 695 |
+
else:
|
| 696 |
+
raise ValueError('can not find localpath in video_info')
|
| 697 |
+
except Exception as e:
|
| 698 |
+
print("**** get video error: {}, info: {} *****".format(str(e), str(video_info)))
|
| 699 |
+
return ""
|
| 700 |
+
|
| 701 |
+
# def _replace_audio(self, audio_text, return_mm_data = True):
|
| 702 |
+
# audio_info = re.sub(re.compile(self.audio_start_tag + "|" + self.audio_end_tag), '', audio_text)
|
| 703 |
+
# ret = self._get_audio(audio_info, return_mm_data) # 重复取结果 cached result
|
| 704 |
+
def _replace_audio(self, audio_text, mminfo_ret_dict):
|
| 705 |
+
audio_info = re.sub(re.compile(self.audio_start_tag + "|" + self.audio_end_tag), '', audio_text)
|
| 706 |
+
# ret = self._get_audio(audio_info) # 重复取结果 cached result
|
| 707 |
+
ret = mminfo_ret_dict.get(audio_info, OceanProcessorOutput()) # 直接从字典取
|
| 708 |
+
if ret.bridge_length is not None: # TODO 如果pad token很多 tokenizer效率会很低
|
| 709 |
+
replaced_text = [self.audio_pad_tag * l for l in ret.bridge_length]
|
| 710 |
+
replaced_text = self.audio_delim_tag.join(replaced_text)
|
| 711 |
+
return self.audio_start_tag + replaced_text + self.audio_end_tag
|
| 712 |
+
return ''
|
| 713 |
+
|
| 714 |
+
# def _replace_image(self, image_text, return_mm_data = True):
|
| 715 |
+
# image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
|
| 716 |
+
# ret = self._get_image(image_info, return_mm_data) # 重复取结果 cached result
|
| 717 |
+
def _replace_image(self, image_text, mminfo_ret_dict):
|
| 718 |
+
image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
|
| 719 |
+
# ret = self._get_image(image_info) # 重复取结果 cached result
|
| 720 |
+
ret = mminfo_ret_dict.get(image_info, OceanProcessorOutput()) # 直接从字典取
|
| 721 |
+
if ret.patch_nums is None:
|
| 722 |
+
return ''
|
| 723 |
+
return self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag
|
| 724 |
+
return ''
|
| 725 |
+
|
| 726 |
+
# def _replace_video_frame(self, video_frame_text, return_mm_data = True):
|
| 727 |
+
# video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text)
|
| 728 |
+
# ret = self._get_video_frame(video_frame_info, return_mm_data) # 重复取结果 cached result
|
| 729 |
+
def _replace_video_frame(self, video_frame_text, mminfo_ret_dict):
|
| 730 |
+
video_frame_info = re.sub(re.compile(self.video_start_tag + '|' + self.video_end_tag), '', video_frame_text)
|
| 731 |
+
video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_info)
|
| 732 |
+
# ret = self._get_video_frame(video_frame_info) # 重复取结果 cached result
|
| 733 |
+
ret = mminfo_ret_dict.get(video_frame_info, OceanProcessorOutput())
|
| 734 |
+
if ret.videos_patch_nums is None:
|
| 735 |
+
return ''
|
| 736 |
+
video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))]
|
| 737 |
+
return ''.join(video_frame_str)
|
| 738 |
+
|
| 739 |
+
def extract_replace_multimodal(self, text, mtype='audio', return_mm_data = True):
|
| 740 |
+
# 抽取text中的json格式音频/图像信息,读取并转化为特征,同时估计encoder token数,填入对应数量的pad token
|
| 741 |
+
if (self.audio_start_tag != None) and (mtype == 'audio'):
|
| 742 |
+
match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag)
|
| 743 |
+
drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag)
|
| 744 |
+
extract_func = self._get_audio
|
| 745 |
+
replace_func = self._replace_audio
|
| 746 |
+
elif (self.image_start_tag != None) and (mtype == 'image'):
|
| 747 |
+
match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag)
|
| 748 |
+
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag)
|
| 749 |
+
extract_func = self._get_image
|
| 750 |
+
replace_func = self._replace_image
|
| 751 |
+
elif (self.video_start_tag != None) and (mtype == 'video'):
|
| 752 |
+
video_match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag)
|
| 753 |
+
video_drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag)
|
| 754 |
+
# 处理视频,将视频路径转换为多帧图像路径
|
| 755 |
+
mm_info_list = re.findall(video_match_regex, text)
|
| 756 |
+
for mm_info in mm_info_list:
|
| 757 |
+
frame_str = self._get_video_frame_str(re.sub(video_drop_regex, '', mm_info))
|
| 758 |
+
# 替换路径;如果视频不存在,路径替换为空字符串
|
| 759 |
+
text = re.sub(mm_info, self.video_start_tag + frame_str + self.video_end_tag, text)
|
| 760 |
+
# 采用多图像处理方式
|
| 761 |
+
match_regex = re.compile(self.video_start_tag+r'(.*?)'+self.video_end_tag)
|
| 762 |
+
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag)
|
| 763 |
+
extract_func = self._get_video_frame
|
| 764 |
+
replace_func = self._replace_video_frame
|
| 765 |
+
else:
|
| 766 |
+
raise ValueError("mtype not supportted!")
|
| 767 |
+
|
| 768 |
+
mm_info_list = re.findall(match_regex, text)
|
| 769 |
+
mm_info_list = [re.sub(drop_regex, '', mm_info) for mm_info in mm_info_list]
|
| 770 |
+
|
| 771 |
+
mminfo_ret_dict = {}
|
| 772 |
+
ret = OceanProcessorOutput()
|
| 773 |
+
for mm_info in mm_info_list: # 如果没有匹配到对应的模态 直接返回raw_text=text 结果不会是None
|
| 774 |
+
mm_ret = extract_func(mm_info, return_mm_data = return_mm_data)
|
| 775 |
+
mminfo_ret_dict[mm_info] = mm_ret
|
| 776 |
+
if mm_ret.audios is None and mm_ret.images is None and mm_ret.videos is None: # 数据包含音频/图像/视频但抽取失败 整条数据无效(ret的raw_text为None
|
| 777 |
+
return ret
|
| 778 |
+
ret = ret.concatenate(mm_ret) # 可能有多条结果,初步collect
|
| 779 |
+
# ret.raw_text = re.sub(match_regex, lambda x: replace_func(x.group()), text)
|
| 780 |
+
ret.raw_text = re.sub(match_regex, lambda x: replace_func(x.group(), mminfo_ret_dict), text)
|
| 781 |
+
return ret
|
| 782 |
+
|
| 783 |
+
def process_one(self, text, index=0, raw_only=False, return_mm_data = True):
|
| 784 |
+
ret = OceanProcessorOutput(index=index)
|
| 785 |
+
for mtype in self.config.multimodal: # 循环获取音频 图像结果 并更新raw_text字段
|
| 786 |
+
mret = self.extract_replace_multimodal(text, mtype, return_mm_data = return_mm_data) # 增加获取视频结果
|
| 787 |
+
if mret.raw_text is None: # 数据包含音频但音频获取失败
|
| 788 |
+
return OceanProcessorOutput(index=index)
|
| 789 |
+
ret = ret.concatenate(mret)
|
| 790 |
+
text = mret.raw_text
|
| 791 |
+
ret.raw_text = text
|
| 792 |
+
if raw_only:
|
| 793 |
+
return ret # 兼容SFT等自定义tokenizer逻辑的代码
|
| 794 |
+
|
| 795 |
+
# 处理预训练中的trainable部分
|
| 796 |
+
input_ids, labels = [], []
|
| 797 |
+
trainable_sep = re.findall(r'<trainable_start>|<trainable_end>', ret.raw_text.replace('\n', '<LF>'))
|
| 798 |
+
if len(trainable_sep) <= 0:
|
| 799 |
+
input_ids = self.tokenizer(ret.raw_text, padding='do_not_pad', truncation=True, return_tensors="np")['input_ids'][0].tolist()
|
| 800 |
+
labels = [True for _ in input_ids]
|
| 801 |
+
else:
|
| 802 |
+
split_content = re.split(r'<trainable_start>|<trainable_end>', ret.raw_text)
|
| 803 |
+
for i, sc in enumerate(split_content):
|
| 804 |
+
if len(sc.strip()) == 0:
|
| 805 |
+
continue # 把多余的空格干掉
|
| 806 |
+
sc_ids = self.tokenizer(sc, padding='do_not_pad', truncation=True, return_tensors="np")['input_ids'][0].tolist()
|
| 807 |
+
input_ids.extend(sc_ids)
|
| 808 |
+
if i == 0 or trainable_sep[i - 1] == '<trainable_end>': # stop gradient
|
| 809 |
+
labels.extend([False] * len(sc_ids))
|
| 810 |
+
else:
|
| 811 |
+
labels.extend([True] * len(sc_ids))
|
| 812 |
+
# input_ids += [self.tokenizer.eos_token_id]
|
| 813 |
+
# labels += [True]
|
| 814 |
+
ret.labels = [input_ids[j] if (l and input_ids[j] not in self.config.multimodal_special_token_no_loss_list) else -100 for j, l in enumerate(labels)]
|
| 815 |
+
ret.input_ids = input_ids
|
| 816 |
+
ret.index = index
|
| 817 |
+
return ret
|
| 818 |
+
|
| 819 |
+
@torch.no_grad()
|
| 820 |
+
def __call__(self, example, parallel=8):
|
| 821 |
+
# 最终入口 支持预训练数据string,sft数据message, 以及 batch推理数据listofstring 3种形式
|
| 822 |
+
if isinstance(example, Dict):
|
| 823 |
+
pass
|
| 824 |
+
elif isinstance(example, str):
|
| 825 |
+
return self.process_one(example)
|
| 826 |
+
elif isinstance(example, List): # batch推理 异步多线程处理
|
| 827 |
+
with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor:
|
| 828 |
+
future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)]
|
| 829 |
+
batch_data = [key.result() for key in cf.as_completed(future_list)]
|
| 830 |
+
valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data])
|
| 831 |
+
assert(valid_num == len(batch_data)) # 推理数据严格要求数量对齐
|
| 832 |
+
batch_data = sorted(batch_data, key=lambda x: x.index) # 保证顺序不变
|
| 833 |
+
|
| 834 |
+
ret = OceanProcessorOutput()
|
| 835 |
+
for i in range(len(batch_data)):
|
| 836 |
+
ret = ret.concatenate(batch_data[i])
|
| 837 |
+
self.tokenizer.padding_side = "left"
|
| 838 |
+
padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt')
|
| 839 |
+
ret.input_ids = padding_result["input_ids"]
|
| 840 |
+
ret.attention_mask = padding_result["attention_mask"] # batch推理不pack 不需要seqlens
|
| 841 |
+
padding_result = self.tokenizer.pad({"input_ids": [r.labels for r in batch_data]}, return_tensors='pt')
|
| 842 |
+
ret.labels = padding_result["input_ids"]
|
| 843 |
+
|
| 844 |
+
if ret.audios is not None:
|
| 845 |
+
ret.audios = default_collate(ret.audios)
|
| 846 |
+
ret.encoder_length = default_collate(ret.encoder_length)
|
| 847 |
+
ret.bridge_length = default_collate(ret.bridge_length)
|
| 848 |
+
|
| 849 |
+
if ret.images is not None:
|
| 850 |
+
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]
|
| 851 |
+
# else:ret.images = default_collate(ret.images)
|
| 852 |
+
# ret.patch_nums = default_collate(ret.patch_nums)
|
| 853 |
+
|
| 854 |
+
if ret.videos is not None:
|
| 855 |
+
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos]
|
| 856 |
+
|
| 857 |
+
return ret
|
| 858 |
+
|
| 859 |
+
else:
|
| 860 |
+
raise ValueError("example format supported yet")
|
| 861 |
+
|
| 862 |
+
@torch.no_grad()
|
| 863 |
+
def pack_batch_pretrain(self, raw_batch, max_sequence_length=None, parallel=8):
|
| 864 |
+
if max_sequence_length is None:
|
| 865 |
+
max_sequence_length = self.tokenizer.model_max_length
|
| 866 |
+
# 将N条数据pack为M条 max_sequence_length长度的数据, 每条数据包含所属的多模态输入
|
| 867 |
+
assert isinstance(raw_batch, List)
|
| 868 |
+
start_ts = time.time()
|
| 869 |
+
if parallel > 1:
|
| 870 |
+
with cf.ThreadPoolExecutor(max_workers=parallel) as executor:
|
| 871 |
+
future_list = []
|
| 872 |
+
for idx, json_text in enumerate(raw_batch):
|
| 873 |
+
try: # 读取json
|
| 874 |
+
json_obj = ujson.loads(json_text.strip())
|
| 875 |
+
except:
|
| 876 |
+
try:
|
| 877 |
+
json_obj = ast.literal_eval(json_text.strip())
|
| 878 |
+
except:
|
| 879 |
+
print("parse json obj faild: {}....".format(json_text[:300]))
|
| 880 |
+
continue
|
| 881 |
+
try: # chensong
|
| 882 |
+
if isinstance(json_obj, list):
|
| 883 |
+
content = json_obj[1]
|
| 884 |
+
elif 'raw' in json_obj.keys():
|
| 885 |
+
content = (json_obj["title"] if "title" in json_obj.keys() else "") + json_obj["raw"]
|
| 886 |
+
else:
|
| 887 |
+
content = (json_obj["title"] if "title" in json_obj.keys() else "") + json_obj["content"]
|
| 888 |
+
except:
|
| 889 |
+
print("parse json raw/content error: {}....".format(json_text[:300]))
|
| 890 |
+
continue
|
| 891 |
+
|
| 892 |
+
future_list.append(executor.submit(self.process_one, content, idx))
|
| 893 |
+
# 获取结果 乱序
|
| 894 |
+
batch_data = [key.result() for key in cf.as_completed(future_list)]
|
| 895 |
+
else: # debug only
|
| 896 |
+
|
| 897 |
+
batch_data = []
|
| 898 |
+
for json_text in raw_batch:
|
| 899 |
+
data = ujson.loads(json_text.strip())
|
| 900 |
+
if 'raw' in data.keys():
|
| 901 |
+
batch_data.append(self.process_one(data['raw'], 0))
|
| 902 |
+
else:
|
| 903 |
+
batch_data.append(self.process_one(data['content'], 0))
|
| 904 |
+
|
| 905 |
+
if (time.time() - start_ts) / (len(batch_data) + 1e-3) > 1.0:
|
| 906 |
+
print('[WARNING] processing each data cost more than 1.0s')
|
| 907 |
+
|
| 908 |
+
# packing 文本部分的输入,不做任何截断
|
| 909 |
+
current_length, packed_output, output = 0, OceanProcessorOutput(position_ids=[], seqlens=[]), []
|
| 910 |
+
empty_data = OceanProcessorOutput(input_ids=[], labels=[])
|
| 911 |
+
for idx, bd in enumerate(batch_data + [empty_data]): # 加空数据方便appedn最后一个数据到output,防止遗漏
|
| 912 |
+
if bd.input_ids is None and idx < len(batch_data):
|
| 913 |
+
continue # 数据没取到 并且不是最后一个
|
| 914 |
+
if (len(bd.input_ids) <= 0 or len(bd.input_ids) + 1 > max_sequence_length) and idx < len(batch_data):
|
| 915 |
+
continue # 太长的直接不要 并且不是最后一个
|
| 916 |
+
if current_length + len(bd.input_ids) + 1 > max_sequence_length or idx == len(batch_data):
|
| 917 |
+
pad_nums = max_sequence_length - current_length # right padding
|
| 918 |
+
if packed_output.input_ids is None or packed_output.labels is None:
|
| 919 |
+
packed_output.input_ids = [self.tokenizer.pad_token_id] * pad_nums
|
| 920 |
+
packed_output.labels = [-100] * pad_nums
|
| 921 |
+
packed_output.position_ids += [0] * (pad_nums+1)
|
| 922 |
+
else:
|
| 923 |
+
packed_output.input_ids += [self.tokenizer.pad_token_id] * pad_nums
|
| 924 |
+
packed_output.labels += [-100] * pad_nums
|
| 925 |
+
packed_output.position_ids += [0] * pad_nums
|
| 926 |
+
packed_output.attention_mask = [1] * current_length + [0] * pad_nums
|
| 927 |
+
packed_output.seqlens += [0] * (max_sequence_length - len(packed_output.seqlens))
|
| 928 |
+
output.append(packed_output)
|
| 929 |
+
packed_output = OceanProcessorOutput(position_ids=[], seqlens=[]) # reset empty
|
| 930 |
+
packed_output = packed_output.concatenate(bd)
|
| 931 |
+
packed_output.input_ids.append(self.tokenizer.eos_token_id) # </s>需要单独加
|
| 932 |
+
packed_output.labels.append(self.tokenizer.eos_token_id)
|
| 933 |
+
|
| 934 |
+
packed_output.position_ids.extend(list(range(len(bd.input_ids) + 1)))
|
| 935 |
+
packed_output.seqlens.append(len(bd.input_ids) + 1)
|
| 936 |
+
|
| 937 |
+
current_length = len(packed_output.input_ids)
|
| 938 |
+
return output
|
| 939 |
+
|
| 940 |
+
@torch.no_grad()
|
| 941 |
+
def collect_batch_pretrain(self, batch_data):
|
| 942 |
+
ret = OceanProcessorOutput()
|
| 943 |
+
for i in range(len(batch_data)):
|
| 944 |
+
ret = ret.concatenate(batch_data[i])
|
| 945 |
+
ret.input_ids = default_collate([np.asarray(x.input_ids, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True)
|
| 946 |
+
ret.labels = default_collate([np.asarray(x.labels, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True)
|
| 947 |
+
ret.attention_mask = default_collate([np.asarray(x.attention_mask, dtype=np.float32) for x in batch_data]).cuda(non_blocking=True)
|
| 948 |
+
ret.position_ids = default_collate([np.asarray(x.position_ids, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True)
|
| 949 |
+
ret.seqlens = default_collate([np.asarray(x.seqlens, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True)
|
| 950 |
+
|
| 951 |
+
ret.raw_text = None
|
| 952 |
+
if ret.audios is not None:
|
| 953 |
+
ret.audios = default_collate(np.asarray(ret.audios, dtype=np.float32)).cuda(non_blocking=True)
|
| 954 |
+
ret.encoder_length = default_collate(np.asarray(ret.encoder_length, dtype=np.int32)).cuda(non_blocking=True)
|
| 955 |
+
ret.bridge_length = default_collate(np.asarray(ret.bridge_length, dtype=np.int32)).cuda(non_blocking=True)
|
| 956 |
+
if ret.images is not None:
|
| 957 |
+
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)).cuda(non_blocking=True) for image in ret.images]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True)
|
| 958 |
+
ret.patch_nums = default_collate(np.asarray(ret.patch_nums, dtype=np.int32)).cuda(non_blocking=True)
|
| 959 |
+
if ret.videos is not None:
|
| 960 |
+
ret.videos = [torch.from_numpy(np.asarray(video, dtype=np.float32)).cuda(non_blocking=True) for video in ret.videos]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True)
|
| 961 |
+
ret.videos_patch_nums = default_collate(np.asarray(ret.videos_patch_nums, dtype=np.int32)).cuda(non_blocking=True)
|
| 962 |
+
|
| 963 |
+
return ret
|
| 964 |
+
|
| 965 |
+
@torch.no_grad()
|
| 966 |
+
def collect_batch_sft(self, batch_data):
|
| 967 |
+
# list of dict to dataclass
|
| 968 |
+
batch_data = [OceanProcessorOutput(**bd) for bd in batch_data]
|
| 969 |
+
ret = OceanProcessorOutput()
|
| 970 |
+
for i in range(len(batch_data)):
|
| 971 |
+
ret = ret.concatenate(batch_data[i])
|
| 972 |
+
ret.input_ids = default_collate([np.asarray(x.input_ids, dtype=np.int64) for x in batch_data])
|
| 973 |
+
ret.labels = default_collate([np.asarray(x.labels, dtype=np.int64) for x in batch_data])
|
| 974 |
+
ret.position_ids = default_collate([np.asarray(x.position_ids, dtype=np.int64) for x in batch_data])
|
| 975 |
+
ret.seqlens = default_collate([np.asarray(x.seqlens, dtype=np.int64) for x in batch_data])
|
| 976 |
+
|
| 977 |
+
ret.raw_text = None
|
| 978 |
+
if ret.audios is not None:
|
| 979 |
+
ret.audios = default_collate(np.asarray(ret.audios, dtype=np.float32))
|
| 980 |
+
ret.encoder_length = default_collate(np.asarray(ret.encoder_length, dtype=np.int32))
|
| 981 |
+
ret.bridge_length = default_collate(np.asarray(ret.bridge_length, dtype=np.int32))
|
| 982 |
+
if ret.images is not None:
|
| 983 |
+
# 转换 每个image 为torch tensor
|
| 984 |
+
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True)
|
| 985 |
+
if ret.videos is not None:
|
| 986 |
+
ret.videos = [torch.from_numpy(np.asarray(video, dtype=np.float32)) for video in ret.videos]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True)
|
| 987 |
+
|
| 988 |
+
# ret.patch_nums = default_collate(np.asarray(ret.patch_nums, dtype=np.int32)).cuda(non_blocking=True)
|
| 989 |
+
|
| 990 |
+
ret = ret.__dict__
|
| 991 |
+
del ret['patch_nums']
|
| 992 |
+
del ret['images_size']
|
| 993 |
+
del ret['crop_size']
|
| 994 |
+
del ret['raw_text']
|
| 995 |
+
del ret['index']
|
| 996 |
+
del ret['attention_mask']
|
| 997 |
+
del ret['videos_patch_nums']
|
| 998 |
+
del ret['videos_size']
|
| 999 |
+
del ret['videos_crop_size']
|
| 1000 |
+
return ret
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
#######################################################
|
| 1004 |
+
## Unit Test Functions, usage
|
| 1005 |
+
## python processor_ocean.py test
|
| 1006 |
+
#######################################################
|
| 1007 |
+
|
| 1008 |
+
def test_img_processor():
|
| 1009 |
+
from transformers import AutoConfig
|
| 1010 |
+
from transformers.models.clip import CLIPImageProcessor
|
| 1011 |
+
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
|
| 1012 |
+
processor = OceanImageProcessor(config.visual_config)
|
| 1013 |
+
offical_processor = CLIPImageProcessor(size=config.visual_config.crop_size, crop_size=config.visual_config.crop_size,
|
| 1014 |
+
image_mean=config.visual_config.image_mean, image_std=config.visual_config.image_std,
|
| 1015 |
+
do_convert_rgb=True)
|
| 1016 |
+
img_files = ['sogou/7a2c8ffc1bc61146b32805c3390f42e2', 'wukong/77c1db1c0e4200d12b478c33ba3a412d', 'wukong/62e9a5c8eb8b0ea8858a34ba3f1a999f', 'wukong/fb9ab4d7c3fe9f54289948fd6a57fc30']
|
| 1017 |
+
cos_client = CosClient()
|
| 1018 |
+
for img_file in img_files:
|
| 1019 |
+
img_bytes = cos_client(img_file)
|
| 1020 |
+
img_rbg = Image.open(io.BytesIO(img_bytes))
|
| 1021 |
+
image, org_size = processor.image_transform(img_bytes)
|
| 1022 |
+
offical_image = offical_processor.preprocess([img_rbg],
|
| 1023 |
+
do_resize=True, do_center_crop=True, do_rescale=True, do_normalize=True,
|
| 1024 |
+
return_tensors='np').data['pixel_values'][0]
|
| 1025 |
+
print('-'*60)
|
| 1026 |
+
print(np.array(img_rbg).shape)
|
| 1027 |
+
print(image.shape)
|
| 1028 |
+
print(offical_image.shape)
|
| 1029 |
+
print(image - offical_image)
|
| 1030 |
+
|
| 1031 |
+
def test_audio_processor():
|
| 1032 |
+
from transformers.models.whisper import WhisperFeatureExtractor
|
| 1033 |
+
from transformers import AutoConfig
|
| 1034 |
+
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
|
| 1035 |
+
offical_processor = WhisperFeatureExtractor(feature_size=128)
|
| 1036 |
+
processor = OceanAudioProcessor(config.audio_config)
|
| 1037 |
+
# wave_files = glob.glob('/home/nfs_bc_alignment/sunhaoze/audio-data/openaqa/openaqa-as/audio/*')
|
| 1038 |
+
wave_files = ['/home/nfs_bc_alignment/sunhaoze/sounds/audioset_full/7ZY0U5tfKyQ.flac', '/home/nfs_bc_alignment/sunhaoze/sounds/audioset_full/Osly4Shchs4.flac']
|
| 1039 |
+
for wave_file in wave_files:
|
| 1040 |
+
wave = processor.load_audio_waveform(wave_file, True, False)
|
| 1041 |
+
offical_features = offical_processor(wave[0].numpy(), do_normalize=False)
|
| 1042 |
+
feat = offical_features['input_features'][0]
|
| 1043 |
+
wave, frame_nums = processor.extract_fbank_features(wave)
|
| 1044 |
+
print("="*60)
|
| 1045 |
+
print(feat.shape)
|
| 1046 |
+
print(wave.shape, frame_nums)
|
| 1047 |
+
print('the difference between offical extractor and our implementation: {}'.format(wave_file))
|
| 1048 |
+
print(wave[:, :frame_nums] - feat[:, :frame_nums])
|
| 1049 |
+
print(wave)
|
| 1050 |
+
# print(wave[120:-1, :])
|
| 1051 |
+
# print(feat[120:-1, :wave.shape[1]])
|
| 1052 |
+
zeros_before = np.sum(wave == 0)
|
| 1053 |
+
aug = processor.data_augment(wave, frame_nums)
|
| 1054 |
+
zeros_after = np.sum(aug == 0)
|
| 1055 |
+
print(zeros_before, zeros_after)
|
| 1056 |
+
|
| 1057 |
+
|
| 1058 |
+
def test_audio_long(): # 测试超过30秒音频的截断策略
|
| 1059 |
+
from transformers import AutoConfig, AutoTokenizer
|
| 1060 |
+
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
|
| 1061 |
+
config.audio_config.split_overlap = 1
|
| 1062 |
+
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096)
|
| 1063 |
+
processor = OceanMMProcessor(tokenizer, config, True)
|
| 1064 |
+
examples = ["<audio_start_ocean>{\"path\": \"panda\/testdata\/podcast_demo_30s\/easy_chat_xianliaohuier_30s\/easy_chat_xianliaohuier-133.mp3\"}<audio_end_ocean>What is the level of noise from the speech?\n<trainable_start>The speech energy\n is medium.<trainable_end>",
|
| 1065 |
+
"what's the sound's energy? \n sound1 <audio_start_ocean>{\"path\": \"panda\/testdata\/podcast_demo_30s\/btrt_talk_heihua_30s\/btrt_talk_heihua-116.mp3\"}<audio_end_ocean> \n sound2 <audio_start_ocean>{\"path\": \"panda\/testdata\/podcast_demo_30s\/btrt_talk_heihua_30s\/btrt_talk_heihua-221.mp3\"}<audio_end_ocean>The speech energy is medium.",
|
| 1066 |
+
]
|
| 1067 |
+
ret = processor(examples)
|
| 1068 |
+
print(ret)
|
| 1069 |
+
print(torch.sum(ret.input_ids == 151659))
|
| 1070 |
+
print(torch.sum(ret.input_ids == 151674))
|
| 1071 |
+
|
| 1072 |
+
def test_processor():
|
| 1073 |
+
from transformers import AutoConfig, AutoTokenizer
|
| 1074 |
+
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
|
| 1075 |
+
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096)
|
| 1076 |
+
processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds')
|
| 1077 |
+
examples = ["<audio_start_ocean>{\"path\": \"vggsound\/7DH5fqj8j6Q.flac\"}<audio_end_ocean>What is the level of noise from the speech?\n<trainable_start>The speech energy\n is medium.<trainable_end>",
|
| 1078 |
+
"hello, ocean 你好 百川智能。",
|
| 1079 |
+
"what's the sound's energy? \n <audio_start_ocean>{\"path\": \"iemocap\/Ses01F_script01_3_F022.wav\"}<audio_end_ocean>The speech energy is medium.",
|
| 1080 |
+
"sound1: <audio_start_ocean>{\"path\": \"audioset_full\/9B53NVDNT8U.flac\"}<audio_end_ocean>\n sound2: \n<audio_start_ocean>{\"path\": \"audioset_full\/a2dgzb9GDSQ.flac\"}<audio_end_ocean>How is the speech speed related to the estimated speaker age?\n<trainable_start>The slow speech speed suggests a more deliberate and thoughtful approach often seen in mature individuals.<trainable_end>",
|
| 1081 |
+
"<img_start_ocean>{\"path\": \"sogou\/7351ae4f3fbe58ff0e4cc165cfabb3ed\"}<img_end_ocean>新和记潮汕牛肉火锅的牛肉丸好不好吃 用户评价口味怎么样 常州美食牛肉丸实拍图片 大众点评",
|
| 1082 |
+
"这两个图片有什么关系?图片1<img_start_ocean>{\"path\": \"sogou\/ac91d57ab68335913ed41aa283e76356\"}<img_end_ocean>图片2\n<img_start_ocean>{\"path\": \"sogou\/6ad5e632b74265d9ef689e45936ab1aa\"}<img_end_ocean>",
|
| 1083 |
+
"根据图片和语音给出描述\n图片<img_start_ocean>{\"path\": \"sogou\/32274c1ab28d11f8c490cf7ae15b36f1\"}<img_end_ocean>语音<audio_start_ocean>{\"path\": \"voxceleb2\/id06726_s2lysJWkjus_00169.m4a\"}<audio_end_ocean><trainable_start>这是一只猫<trainable_end>",
|
| 1084 |
+
"这些图片和音频不存在<img_start_ocean>{\"path\": \"soogou\/32274c1ab28d11f8c490cf7ae15b36f1\"}<img_end_ocean>语音<audio_start_ocean>{\"path\": \"voxceleb_1\/id06726_s2lysJWkjus_00169.m4a\"}<audio_end_ocean><trainable_start>这是一只猫<trainable_end>"
|
| 1085 |
+
]
|
| 1086 |
+
ret = processor(examples[4:-1])
|
| 1087 |
+
print(ret)
|
| 1088 |
+
print(torch.sum(ret.input_ids == 151659))
|
| 1089 |
+
print(torch.sum(ret.input_ids == 151662))
|
| 1090 |
+
try:
|
| 1091 |
+
print(ret.bridge_length)
|
| 1092 |
+
print(ret.patch_nums)
|
| 1093 |
+
except:
|
| 1094 |
+
pass
|
| 1095 |
+
print(torch.sum(ret.attention_mask, dim=1))
|
| 1096 |
+
|
| 1097 |
+
|
| 1098 |
+
def test_grounding():
|
| 1099 |
+
from transformers import AutoConfig, AutoTokenizer
|
| 1100 |
+
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
|
| 1101 |
+
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096)
|
| 1102 |
+
processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds')
|
| 1103 |
+
examples = ["<img_start_ocean>{\"path\": \"grit\/663423bf2f0884c034bf75279bce9694\"}<img_end_ocean>\nWhere is \"A woman\" ? Answer: <trainable_start>The bounding box is <box_start_ocean>(0.58,0.8),(0.71,1.0)<box_end_ocean><trainable_end>",
|
| 1104 |
+
"hello, ocean 你好 百川智能。",
|
| 1105 |
+
"<img_start_ocean>{\"path\": \"grit\/0e6e3952c584cbac7235940a22514656\"}<img_end_ocean> Generate the caption with grounding: <trainable_start>Photo pour Portrait of <ref_start_ocean>young Asian muslim woman wearing hijab<ref_end_ocean><box_start_ocean>(0.09,0.01),(0.77,1.0)<box_end_ocean> shows regret gesture, hand on her forehead, forget something important, against red background - image libre de droit<trainable_end>",
|
| 1106 |
+
"Recognize the object in the outlined section <img_start_ocean>{\"path\": \"grit\/045823cf6f819670f27aee20af7ae0e6\"}<img_end_ocean> of the picture.<box_start_ocean>(0.07,0.2),(0.91,0.96)<box_end_ocean>\n<trainable_start>Inflatable water trampolines<trainable_end>"
|
| 1107 |
+
]
|
| 1108 |
+
ret = processor(examples)
|
| 1109 |
+
print(ret)
|
| 1110 |
+
for i, input_ids in enumerate(ret.input_ids):
|
| 1111 |
+
print("="*60)
|
| 1112 |
+
print(ret.labels[i])
|
| 1113 |
+
|
| 1114 |
+
def test_pack():
|
| 1115 |
+
from transformers import AutoConfig, AutoTokenizer
|
| 1116 |
+
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
|
| 1117 |
+
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=2048)
|
| 1118 |
+
processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds')
|
| 1119 |
+
examples = open('/cpfs/29f69eb5e2e60f26/user/sunhaoze/pretrain-v6/sogou/part-00000').readlines()[:5]
|
| 1120 |
+
examples += open('/home/nfs_bc_alignment/sunhaoze/text/openaqa-as-stage2-v1/part-00000').readlines()[:5]
|
| 1121 |
+
random.shuffle(examples)
|
| 1122 |
+
batch_output = processor.pack_batch_pretrain(examples)
|
| 1123 |
+
for i, b in enumerate(batch_output):
|
| 1124 |
+
print('='*60)
|
| 1125 |
+
try:
|
| 1126 |
+
print(b.input_ids, len(b.input_ids))
|
| 1127 |
+
print(b.labels, len(b.labels))
|
| 1128 |
+
print(b.attention_mask, len(b.attention_mask))
|
| 1129 |
+
print(b.position_ids, len(b.position_ids))
|
| 1130 |
+
print(b.seqlens, len(b.seqlens))
|
| 1131 |
+
print(b.audios)
|
| 1132 |
+
print(b.bridge_length)
|
| 1133 |
+
except:
|
| 1134 |
+
continue
|
| 1135 |
+
|
| 1136 |
+
batch_for_model = processor.collect_batch_pretrain(batch_output)
|
| 1137 |
+
print(batch_for_model.input_ids.shape)
|
| 1138 |
+
print(batch_for_model.labels.shape)
|
| 1139 |
+
print(batch_for_model.audios.shape)
|
| 1140 |
+
print(batch_for_model["bridge_length"])
|
| 1141 |
+
print(batch_for_model.images.shape)
|
| 1142 |
+
print(batch_for_model["patch_nums"])
|
| 1143 |
+
print(batch_for_model["position_ids"])
|
| 1144 |
+
print(batch_for_model["seqlens"])
|
| 1145 |
+
|
| 1146 |
+
def test_cos_audio():
|
| 1147 |
+
cos_client = CosClient()
|
| 1148 |
+
audio_bytes = cos_client('panda/data/common_voice/cv-corpus-18.0-2024-06-14/zh-CN/clips/common_voice_zh-CN_19428637.mp3', 'audio-data-1317568651')
|
| 1149 |
+
wave, sr = torchaudio.load(audio_bytes, normalize=False)
|
| 1150 |
+
print(wave.shape, sr)
|
| 1151 |
+
# torchaudio.save('tmp.flac', wave, sr)
|
| 1152 |
+
|
| 1153 |
+
if __name__ == '__main__':
|
| 1154 |
+
fire.Fire()
|
special_tokens_map.json
CHANGED
|
@@ -24,27 +24,27 @@
|
|
| 24 |
"<calc_start>",
|
| 25 |
"<calc_end>",
|
| 26 |
"<inner_think>",
|
| 27 |
-
"<
|
| 28 |
-
"<
|
| 29 |
-
"<
|
| 30 |
-
"<
|
| 31 |
-
"<
|
| 32 |
-
"<
|
| 33 |
-
"<
|
| 34 |
-
"<
|
| 35 |
-
"<
|
| 36 |
-
"<
|
| 37 |
-
"<
|
| 38 |
-
"<
|
| 39 |
-
"<
|
| 40 |
-
"<
|
| 41 |
-
"<
|
| 42 |
-
"<
|
| 43 |
"<reserved_113>",
|
| 44 |
-
"<
|
| 45 |
-
"<
|
| 46 |
-
"<
|
| 47 |
-
"<
|
| 48 |
],
|
| 49 |
"eos_token": {
|
| 50 |
"content": "<|endoftext|>",
|
|
|
|
| 24 |
"<calc_start>",
|
| 25 |
"<calc_end>",
|
| 26 |
"<inner_think>",
|
| 27 |
+
"<audio_start_ocean>",
|
| 28 |
+
"<audio_end_ocean>",
|
| 29 |
+
"<audio_pad_ocean>",
|
| 30 |
+
"<img_start_ocean>",
|
| 31 |
+
"<img_end_ocean>",
|
| 32 |
+
"<img_pad_ocean>",
|
| 33 |
+
"<img_newline_ocean>",
|
| 34 |
+
"<box_start_ocean>",
|
| 35 |
+
"<box_end_ocean>",
|
| 36 |
+
"<box_delim_ocean>",
|
| 37 |
+
"<ref_start_ocean>",
|
| 38 |
+
"<ref_end_ocean>",
|
| 39 |
+
"<img_delim_ocean>",
|
| 40 |
+
"<polygon_start_ocean>",
|
| 41 |
+
"<polygon_end_ocean>",
|
| 42 |
+
"<ocean_pad_token>",
|
| 43 |
"<reserved_113>",
|
| 44 |
+
"<audio_delim_ocean>",
|
| 45 |
+
"<video_start_ocean>",
|
| 46 |
+
"<video_end_ocean>",
|
| 47 |
+
"<video_palce_ocean>"
|
| 48 |
],
|
| 49 |
"eos_token": {
|
| 50 |
"content": "<|endoftext|>",
|
tokenizer.json
CHANGED
|
@@ -302,7 +302,7 @@
|
|
| 302 |
},
|
| 303 |
{
|
| 304 |
"id": 151676,
|
| 305 |
-
"content": "<
|
| 306 |
"single_word": false,
|
| 307 |
"lstrip": false,
|
| 308 |
"rstrip": false,
|
|
@@ -311,7 +311,7 @@
|
|
| 311 |
},
|
| 312 |
{
|
| 313 |
"id": 151677,
|
| 314 |
-
"content": "<
|
| 315 |
"single_word": false,
|
| 316 |
"lstrip": false,
|
| 317 |
"rstrip": false,
|
|
@@ -320,7 +320,7 @@
|
|
| 320 |
},
|
| 321 |
{
|
| 322 |
"id": 151678,
|
| 323 |
-
"content": "<
|
| 324 |
"single_word": false,
|
| 325 |
"lstrip": false,
|
| 326 |
"rstrip": false,
|
|
@@ -329,7 +329,7 @@
|
|
| 329 |
},
|
| 330 |
{
|
| 331 |
"id": 151679,
|
| 332 |
-
"content": "<
|
| 333 |
"single_word": false,
|
| 334 |
"lstrip": false,
|
| 335 |
"rstrip": false,
|
|
@@ -338,7 +338,7 @@
|
|
| 338 |
},
|
| 339 |
{
|
| 340 |
"id": 151680,
|
| 341 |
-
"content": "<
|
| 342 |
"single_word": false,
|
| 343 |
"lstrip": false,
|
| 344 |
"rstrip": false,
|
|
@@ -347,7 +347,7 @@
|
|
| 347 |
},
|
| 348 |
{
|
| 349 |
"id": 151681,
|
| 350 |
-
"content": "<
|
| 351 |
"single_word": false,
|
| 352 |
"lstrip": false,
|
| 353 |
"rstrip": false,
|
|
@@ -356,7 +356,7 @@
|
|
| 356 |
},
|
| 357 |
{
|
| 358 |
"id": 151682,
|
| 359 |
-
"content": "<
|
| 360 |
"single_word": false,
|
| 361 |
"lstrip": false,
|
| 362 |
"rstrip": false,
|
|
@@ -365,7 +365,7 @@
|
|
| 365 |
},
|
| 366 |
{
|
| 367 |
"id": 151683,
|
| 368 |
-
"content": "<
|
| 369 |
"single_word": false,
|
| 370 |
"lstrip": false,
|
| 371 |
"rstrip": false,
|
|
@@ -374,7 +374,7 @@
|
|
| 374 |
},
|
| 375 |
{
|
| 376 |
"id": 151684,
|
| 377 |
-
"content": "<
|
| 378 |
"single_word": false,
|
| 379 |
"lstrip": false,
|
| 380 |
"rstrip": false,
|
|
@@ -383,7 +383,7 @@
|
|
| 383 |
},
|
| 384 |
{
|
| 385 |
"id": 151685,
|
| 386 |
-
"content": "<
|
| 387 |
"single_word": false,
|
| 388 |
"lstrip": false,
|
| 389 |
"rstrip": false,
|
|
@@ -392,7 +392,7 @@
|
|
| 392 |
},
|
| 393 |
{
|
| 394 |
"id": 151686,
|
| 395 |
-
"content": "<
|
| 396 |
"single_word": false,
|
| 397 |
"lstrip": false,
|
| 398 |
"rstrip": false,
|
|
@@ -401,7 +401,7 @@
|
|
| 401 |
},
|
| 402 |
{
|
| 403 |
"id": 151687,
|
| 404 |
-
"content": "<
|
| 405 |
"single_word": false,
|
| 406 |
"lstrip": false,
|
| 407 |
"rstrip": false,
|
|
@@ -410,7 +410,7 @@
|
|
| 410 |
},
|
| 411 |
{
|
| 412 |
"id": 151688,
|
| 413 |
-
"content": "<
|
| 414 |
"single_word": false,
|
| 415 |
"lstrip": false,
|
| 416 |
"rstrip": false,
|
|
@@ -419,7 +419,7 @@
|
|
| 419 |
},
|
| 420 |
{
|
| 421 |
"id": 151689,
|
| 422 |
-
"content": "<
|
| 423 |
"single_word": false,
|
| 424 |
"lstrip": false,
|
| 425 |
"rstrip": false,
|
|
@@ -428,7 +428,7 @@
|
|
| 428 |
},
|
| 429 |
{
|
| 430 |
"id": 151690,
|
| 431 |
-
"content": "<
|
| 432 |
"single_word": false,
|
| 433 |
"lstrip": false,
|
| 434 |
"rstrip": false,
|
|
@@ -437,7 +437,7 @@
|
|
| 437 |
},
|
| 438 |
{
|
| 439 |
"id": 151691,
|
| 440 |
-
"content": "<
|
| 441 |
"single_word": false,
|
| 442 |
"lstrip": false,
|
| 443 |
"rstrip": false,
|
|
@@ -455,7 +455,7 @@
|
|
| 455 |
},
|
| 456 |
{
|
| 457 |
"id": 151693,
|
| 458 |
-
"content": "<
|
| 459 |
"single_word": false,
|
| 460 |
"lstrip": false,
|
| 461 |
"rstrip": false,
|
|
@@ -464,7 +464,7 @@
|
|
| 464 |
},
|
| 465 |
{
|
| 466 |
"id": 151694,
|
| 467 |
-
"content": "<
|
| 468 |
"single_word": false,
|
| 469 |
"lstrip": false,
|
| 470 |
"rstrip": false,
|
|
@@ -473,7 +473,7 @@
|
|
| 473 |
},
|
| 474 |
{
|
| 475 |
"id": 151695,
|
| 476 |
-
"content": "<
|
| 477 |
"single_word": false,
|
| 478 |
"lstrip": false,
|
| 479 |
"rstrip": false,
|
|
@@ -482,7 +482,7 @@
|
|
| 482 |
},
|
| 483 |
{
|
| 484 |
"id": 151696,
|
| 485 |
-
"content": "<
|
| 486 |
"single_word": false,
|
| 487 |
"lstrip": false,
|
| 488 |
"rstrip": false,
|
|
|
|
| 302 |
},
|
| 303 |
{
|
| 304 |
"id": 151676,
|
| 305 |
+
"content": "<audio_start_ocean>",
|
| 306 |
"single_word": false,
|
| 307 |
"lstrip": false,
|
| 308 |
"rstrip": false,
|
|
|
|
| 311 |
},
|
| 312 |
{
|
| 313 |
"id": 151677,
|
| 314 |
+
"content": "<audio_end_ocean>",
|
| 315 |
"single_word": false,
|
| 316 |
"lstrip": false,
|
| 317 |
"rstrip": false,
|
|
|
|
| 320 |
},
|
| 321 |
{
|
| 322 |
"id": 151678,
|
| 323 |
+
"content": "<audio_pad_ocean>",
|
| 324 |
"single_word": false,
|
| 325 |
"lstrip": false,
|
| 326 |
"rstrip": false,
|
|
|
|
| 329 |
},
|
| 330 |
{
|
| 331 |
"id": 151679,
|
| 332 |
+
"content": "<img_start_ocean>",
|
| 333 |
"single_word": false,
|
| 334 |
"lstrip": false,
|
| 335 |
"rstrip": false,
|
|
|
|
| 338 |
},
|
| 339 |
{
|
| 340 |
"id": 151680,
|
| 341 |
+
"content": "<img_end_ocean>",
|
| 342 |
"single_word": false,
|
| 343 |
"lstrip": false,
|
| 344 |
"rstrip": false,
|
|
|
|
| 347 |
},
|
| 348 |
{
|
| 349 |
"id": 151681,
|
| 350 |
+
"content": "<img_pad_ocean>",
|
| 351 |
"single_word": false,
|
| 352 |
"lstrip": false,
|
| 353 |
"rstrip": false,
|
|
|
|
| 356 |
},
|
| 357 |
{
|
| 358 |
"id": 151682,
|
| 359 |
+
"content": "<img_newline_ocean>",
|
| 360 |
"single_word": false,
|
| 361 |
"lstrip": false,
|
| 362 |
"rstrip": false,
|
|
|
|
| 365 |
},
|
| 366 |
{
|
| 367 |
"id": 151683,
|
| 368 |
+
"content": "<box_start_ocean>",
|
| 369 |
"single_word": false,
|
| 370 |
"lstrip": false,
|
| 371 |
"rstrip": false,
|
|
|
|
| 374 |
},
|
| 375 |
{
|
| 376 |
"id": 151684,
|
| 377 |
+
"content": "<box_end_ocean>",
|
| 378 |
"single_word": false,
|
| 379 |
"lstrip": false,
|
| 380 |
"rstrip": false,
|
|
|
|
| 383 |
},
|
| 384 |
{
|
| 385 |
"id": 151685,
|
| 386 |
+
"content": "<box_delim_ocean>",
|
| 387 |
"single_word": false,
|
| 388 |
"lstrip": false,
|
| 389 |
"rstrip": false,
|
|
|
|
| 392 |
},
|
| 393 |
{
|
| 394 |
"id": 151686,
|
| 395 |
+
"content": "<ref_start_ocean>",
|
| 396 |
"single_word": false,
|
| 397 |
"lstrip": false,
|
| 398 |
"rstrip": false,
|
|
|
|
| 401 |
},
|
| 402 |
{
|
| 403 |
"id": 151687,
|
| 404 |
+
"content": "<ref_end_ocean>",
|
| 405 |
"single_word": false,
|
| 406 |
"lstrip": false,
|
| 407 |
"rstrip": false,
|
|
|
|
| 410 |
},
|
| 411 |
{
|
| 412 |
"id": 151688,
|
| 413 |
+
"content": "<img_delim_ocean>",
|
| 414 |
"single_word": false,
|
| 415 |
"lstrip": false,
|
| 416 |
"rstrip": false,
|
|
|
|
| 419 |
},
|
| 420 |
{
|
| 421 |
"id": 151689,
|
| 422 |
+
"content": "<polygon_start_ocean>",
|
| 423 |
"single_word": false,
|
| 424 |
"lstrip": false,
|
| 425 |
"rstrip": false,
|
|
|
|
| 428 |
},
|
| 429 |
{
|
| 430 |
"id": 151690,
|
| 431 |
+
"content": "<polygon_end_ocean>",
|
| 432 |
"single_word": false,
|
| 433 |
"lstrip": false,
|
| 434 |
"rstrip": false,
|
|
|
|
| 437 |
},
|
| 438 |
{
|
| 439 |
"id": 151691,
|
| 440 |
+
"content": "<ocean_pad_token>",
|
| 441 |
"single_word": false,
|
| 442 |
"lstrip": false,
|
| 443 |
"rstrip": false,
|
|
|
|
| 455 |
},
|
| 456 |
{
|
| 457 |
"id": 151693,
|
| 458 |
+
"content": "<audio_delim_ocean>",
|
| 459 |
"single_word": false,
|
| 460 |
"lstrip": false,
|
| 461 |
"rstrip": false,
|
|
|
|
| 464 |
},
|
| 465 |
{
|
| 466 |
"id": 151694,
|
| 467 |
+
"content": "<video_palce_ocean>",
|
| 468 |
"single_word": false,
|
| 469 |
"lstrip": false,
|
| 470 |
"rstrip": false,
|
|
|
|
| 473 |
},
|
| 474 |
{
|
| 475 |
"id": 151695,
|
| 476 |
+
"content": "<video_start_ocean>",
|
| 477 |
"single_word": false,
|
| 478 |
"lstrip": false,
|
| 479 |
"rstrip": false,
|
|
|
|
| 482 |
},
|
| 483 |
{
|
| 484 |
"id": 151696,
|
| 485 |
+
"content": "<video_end_ocean>",
|
| 486 |
"single_word": false,
|
| 487 |
"lstrip": false,
|
| 488 |
"rstrip": false,
|
tokenizer_config.json
CHANGED
|
@@ -267,7 +267,7 @@
|
|
| 267 |
"special": true
|
| 268 |
},
|
| 269 |
"151676": {
|
| 270 |
-
"content": "<
|
| 271 |
"lstrip": false,
|
| 272 |
"normalized": false,
|
| 273 |
"rstrip": false,
|
|
@@ -275,7 +275,7 @@
|
|
| 275 |
"special": true
|
| 276 |
},
|
| 277 |
"151677": {
|
| 278 |
-
"content": "<
|
| 279 |
"lstrip": false,
|
| 280 |
"normalized": false,
|
| 281 |
"rstrip": false,
|
|
@@ -283,7 +283,7 @@
|
|
| 283 |
"special": true
|
| 284 |
},
|
| 285 |
"151678": {
|
| 286 |
-
"content": "<
|
| 287 |
"lstrip": false,
|
| 288 |
"normalized": false,
|
| 289 |
"rstrip": false,
|
|
@@ -291,7 +291,7 @@
|
|
| 291 |
"special": true
|
| 292 |
},
|
| 293 |
"151679": {
|
| 294 |
-
"content": "<
|
| 295 |
"lstrip": false,
|
| 296 |
"normalized": false,
|
| 297 |
"rstrip": false,
|
|
@@ -299,7 +299,7 @@
|
|
| 299 |
"special": true
|
| 300 |
},
|
| 301 |
"151680": {
|
| 302 |
-
"content": "<
|
| 303 |
"lstrip": false,
|
| 304 |
"normalized": false,
|
| 305 |
"rstrip": false,
|
|
@@ -307,7 +307,7 @@
|
|
| 307 |
"special": true
|
| 308 |
},
|
| 309 |
"151681": {
|
| 310 |
-
"content": "<
|
| 311 |
"lstrip": false,
|
| 312 |
"normalized": false,
|
| 313 |
"rstrip": false,
|
|
@@ -315,7 +315,7 @@
|
|
| 315 |
"special": true
|
| 316 |
},
|
| 317 |
"151682": {
|
| 318 |
-
"content": "<
|
| 319 |
"lstrip": false,
|
| 320 |
"normalized": false,
|
| 321 |
"rstrip": false,
|
|
@@ -323,7 +323,7 @@
|
|
| 323 |
"special": true
|
| 324 |
},
|
| 325 |
"151683": {
|
| 326 |
-
"content": "<
|
| 327 |
"lstrip": false,
|
| 328 |
"normalized": false,
|
| 329 |
"rstrip": false,
|
|
@@ -331,7 +331,7 @@
|
|
| 331 |
"special": true
|
| 332 |
},
|
| 333 |
"151684": {
|
| 334 |
-
"content": "<
|
| 335 |
"lstrip": false,
|
| 336 |
"normalized": false,
|
| 337 |
"rstrip": false,
|
|
@@ -339,7 +339,7 @@
|
|
| 339 |
"special": true
|
| 340 |
},
|
| 341 |
"151685": {
|
| 342 |
-
"content": "<
|
| 343 |
"lstrip": false,
|
| 344 |
"normalized": false,
|
| 345 |
"rstrip": false,
|
|
@@ -347,7 +347,7 @@
|
|
| 347 |
"special": true
|
| 348 |
},
|
| 349 |
"151686": {
|
| 350 |
-
"content": "<
|
| 351 |
"lstrip": false,
|
| 352 |
"normalized": false,
|
| 353 |
"rstrip": false,
|
|
@@ -355,7 +355,7 @@
|
|
| 355 |
"special": true
|
| 356 |
},
|
| 357 |
"151687": {
|
| 358 |
-
"content": "<
|
| 359 |
"lstrip": false,
|
| 360 |
"normalized": false,
|
| 361 |
"rstrip": false,
|
|
@@ -363,7 +363,7 @@
|
|
| 363 |
"special": true
|
| 364 |
},
|
| 365 |
"151688": {
|
| 366 |
-
"content": "<
|
| 367 |
"lstrip": false,
|
| 368 |
"normalized": false,
|
| 369 |
"rstrip": false,
|
|
@@ -371,7 +371,7 @@
|
|
| 371 |
"special": true
|
| 372 |
},
|
| 373 |
"151689": {
|
| 374 |
-
"content": "<
|
| 375 |
"lstrip": false,
|
| 376 |
"normalized": false,
|
| 377 |
"rstrip": false,
|
|
@@ -379,7 +379,7 @@
|
|
| 379 |
"special": true
|
| 380 |
},
|
| 381 |
"151690": {
|
| 382 |
-
"content": "<
|
| 383 |
"lstrip": false,
|
| 384 |
"normalized": false,
|
| 385 |
"rstrip": false,
|
|
@@ -387,7 +387,7 @@
|
|
| 387 |
"special": true
|
| 388 |
},
|
| 389 |
"151691": {
|
| 390 |
-
"content": "<
|
| 391 |
"lstrip": false,
|
| 392 |
"normalized": false,
|
| 393 |
"rstrip": false,
|
|
@@ -403,7 +403,7 @@
|
|
| 403 |
"special": true
|
| 404 |
},
|
| 405 |
"151693": {
|
| 406 |
-
"content": "<
|
| 407 |
"lstrip": false,
|
| 408 |
"normalized": false,
|
| 409 |
"rstrip": false,
|
|
@@ -411,7 +411,7 @@
|
|
| 411 |
"special": true
|
| 412 |
},
|
| 413 |
"151694": {
|
| 414 |
-
"content": "<
|
| 415 |
"lstrip": false,
|
| 416 |
"normalized": false,
|
| 417 |
"rstrip": false,
|
|
@@ -419,7 +419,7 @@
|
|
| 419 |
"special": true
|
| 420 |
},
|
| 421 |
"151695": {
|
| 422 |
-
"content": "<
|
| 423 |
"lstrip": false,
|
| 424 |
"normalized": false,
|
| 425 |
"rstrip": false,
|
|
@@ -427,7 +427,7 @@
|
|
| 427 |
"special": true
|
| 428 |
},
|
| 429 |
"151696": {
|
| 430 |
-
"content": "<
|
| 431 |
"lstrip": false,
|
| 432 |
"normalized": false,
|
| 433 |
"rstrip": false,
|
|
@@ -460,27 +460,27 @@
|
|
| 460 |
"<calc_start>",
|
| 461 |
"<calc_end>",
|
| 462 |
"<inner_think>",
|
| 463 |
-
"<
|
| 464 |
-
"<
|
| 465 |
-
"<
|
| 466 |
-
"<
|
| 467 |
-
"<
|
| 468 |
-
"<
|
| 469 |
-
"<
|
| 470 |
-
"<
|
| 471 |
-
"<
|
| 472 |
-
"<
|
| 473 |
-
"<
|
| 474 |
-
"<
|
| 475 |
-
"<
|
| 476 |
-
"<
|
| 477 |
-
"<
|
| 478 |
-
"<
|
| 479 |
"<reserved_113>",
|
| 480 |
-
"<
|
| 481 |
-
"<
|
| 482 |
-
"<
|
| 483 |
-
"<
|
| 484 |
],
|
| 485 |
"bos_token": null,
|
| 486 |
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
|
|
|
| 267 |
"special": true
|
| 268 |
},
|
| 269 |
"151676": {
|
| 270 |
+
"content": "<audio_start_ocean>",
|
| 271 |
"lstrip": false,
|
| 272 |
"normalized": false,
|
| 273 |
"rstrip": false,
|
|
|
|
| 275 |
"special": true
|
| 276 |
},
|
| 277 |
"151677": {
|
| 278 |
+
"content": "<audio_end_ocean>",
|
| 279 |
"lstrip": false,
|
| 280 |
"normalized": false,
|
| 281 |
"rstrip": false,
|
|
|
|
| 283 |
"special": true
|
| 284 |
},
|
| 285 |
"151678": {
|
| 286 |
+
"content": "<audio_pad_ocean>",
|
| 287 |
"lstrip": false,
|
| 288 |
"normalized": false,
|
| 289 |
"rstrip": false,
|
|
|
|
| 291 |
"special": true
|
| 292 |
},
|
| 293 |
"151679": {
|
| 294 |
+
"content": "<img_start_ocean>",
|
| 295 |
"lstrip": false,
|
| 296 |
"normalized": false,
|
| 297 |
"rstrip": false,
|
|
|
|
| 299 |
"special": true
|
| 300 |
},
|
| 301 |
"151680": {
|
| 302 |
+
"content": "<img_end_ocean>",
|
| 303 |
"lstrip": false,
|
| 304 |
"normalized": false,
|
| 305 |
"rstrip": false,
|
|
|
|
| 307 |
"special": true
|
| 308 |
},
|
| 309 |
"151681": {
|
| 310 |
+
"content": "<img_pad_ocean>",
|
| 311 |
"lstrip": false,
|
| 312 |
"normalized": false,
|
| 313 |
"rstrip": false,
|
|
|
|
| 315 |
"special": true
|
| 316 |
},
|
| 317 |
"151682": {
|
| 318 |
+
"content": "<img_newline_ocean>",
|
| 319 |
"lstrip": false,
|
| 320 |
"normalized": false,
|
| 321 |
"rstrip": false,
|
|
|
|
| 323 |
"special": true
|
| 324 |
},
|
| 325 |
"151683": {
|
| 326 |
+
"content": "<box_start_ocean>",
|
| 327 |
"lstrip": false,
|
| 328 |
"normalized": false,
|
| 329 |
"rstrip": false,
|
|
|
|
| 331 |
"special": true
|
| 332 |
},
|
| 333 |
"151684": {
|
| 334 |
+
"content": "<box_end_ocean>",
|
| 335 |
"lstrip": false,
|
| 336 |
"normalized": false,
|
| 337 |
"rstrip": false,
|
|
|
|
| 339 |
"special": true
|
| 340 |
},
|
| 341 |
"151685": {
|
| 342 |
+
"content": "<box_delim_ocean>",
|
| 343 |
"lstrip": false,
|
| 344 |
"normalized": false,
|
| 345 |
"rstrip": false,
|
|
|
|
| 347 |
"special": true
|
| 348 |
},
|
| 349 |
"151686": {
|
| 350 |
+
"content": "<ref_start_ocean>",
|
| 351 |
"lstrip": false,
|
| 352 |
"normalized": false,
|
| 353 |
"rstrip": false,
|
|
|
|
| 355 |
"special": true
|
| 356 |
},
|
| 357 |
"151687": {
|
| 358 |
+
"content": "<ref_end_ocean>",
|
| 359 |
"lstrip": false,
|
| 360 |
"normalized": false,
|
| 361 |
"rstrip": false,
|
|
|
|
| 363 |
"special": true
|
| 364 |
},
|
| 365 |
"151688": {
|
| 366 |
+
"content": "<img_delim_ocean>",
|
| 367 |
"lstrip": false,
|
| 368 |
"normalized": false,
|
| 369 |
"rstrip": false,
|
|
|
|
| 371 |
"special": true
|
| 372 |
},
|
| 373 |
"151689": {
|
| 374 |
+
"content": "<polygon_start_ocean>",
|
| 375 |
"lstrip": false,
|
| 376 |
"normalized": false,
|
| 377 |
"rstrip": false,
|
|
|
|
| 379 |
"special": true
|
| 380 |
},
|
| 381 |
"151690": {
|
| 382 |
+
"content": "<polygon_end_ocean>",
|
| 383 |
"lstrip": false,
|
| 384 |
"normalized": false,
|
| 385 |
"rstrip": false,
|
|
|
|
| 387 |
"special": true
|
| 388 |
},
|
| 389 |
"151691": {
|
| 390 |
+
"content": "<ocean_pad_token>",
|
| 391 |
"lstrip": false,
|
| 392 |
"normalized": false,
|
| 393 |
"rstrip": false,
|
|
|
|
| 403 |
"special": true
|
| 404 |
},
|
| 405 |
"151693": {
|
| 406 |
+
"content": "<audio_delim_ocean>",
|
| 407 |
"lstrip": false,
|
| 408 |
"normalized": false,
|
| 409 |
"rstrip": false,
|
|
|
|
| 411 |
"special": true
|
| 412 |
},
|
| 413 |
"151694": {
|
| 414 |
+
"content": "<video_palce_ocean>",
|
| 415 |
"lstrip": false,
|
| 416 |
"normalized": false,
|
| 417 |
"rstrip": false,
|
|
|
|
| 419 |
"special": true
|
| 420 |
},
|
| 421 |
"151695": {
|
| 422 |
+
"content": "<video_start_ocean>",
|
| 423 |
"lstrip": false,
|
| 424 |
"normalized": false,
|
| 425 |
"rstrip": false,
|
|
|
|
| 427 |
"special": true
|
| 428 |
},
|
| 429 |
"151696": {
|
| 430 |
+
"content": "<video_end_ocean>",
|
| 431 |
"lstrip": false,
|
| 432 |
"normalized": false,
|
| 433 |
"rstrip": false,
|
|
|
|
| 460 |
"<calc_start>",
|
| 461 |
"<calc_end>",
|
| 462 |
"<inner_think>",
|
| 463 |
+
"<audio_start_ocean>",
|
| 464 |
+
"<audio_end_ocean>",
|
| 465 |
+
"<audio_pad_ocean>",
|
| 466 |
+
"<img_start_ocean>",
|
| 467 |
+
"<img_end_ocean>",
|
| 468 |
+
"<img_pad_ocean>",
|
| 469 |
+
"<img_newline_ocean>",
|
| 470 |
+
"<box_start_ocean>",
|
| 471 |
+
"<box_end_ocean>",
|
| 472 |
+
"<box_delim_ocean>",
|
| 473 |
+
"<ref_start_ocean>",
|
| 474 |
+
"<ref_end_ocean>",
|
| 475 |
+
"<img_delim_ocean>",
|
| 476 |
+
"<polygon_start_ocean>",
|
| 477 |
+
"<polygon_end_ocean>",
|
| 478 |
+
"<ocean_pad_token>",
|
| 479 |
"<reserved_113>",
|
| 480 |
+
"<audio_delim_ocean>",
|
| 481 |
+
"<video_start_ocean>",
|
| 482 |
+
"<video_end_ocean>",
|
| 483 |
+
"<video_palce_ocean>"
|
| 484 |
],
|
| 485 |
"bos_token": null,
|
| 486 |
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
visual_modeling_ocean.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
import torch, math
|
| 4 |
+
import torch.utils.checkpoint
|
| 5 |
+
from torch import nn
|
| 6 |
+
import transformers
|
| 7 |
+
from flash_attn import flash_attn_varlen_func
|
| 8 |
+
from transformers.activations import ACT2FN
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import io, fire
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class OceanVisualAttention(nn.Module):
|
| 15 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, config):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.config = config
|
| 20 |
+
self.embed_dim = config.hidden_size
|
| 21 |
+
self.num_heads = config.num_attention_heads
|
| 22 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 23 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 24 |
+
raise ValueError(
|
| 25 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 26 |
+
f" {self.num_heads})."
|
| 27 |
+
)
|
| 28 |
+
self.scale = self.head_dim**-0.5 # flash attention不需要使用
|
| 29 |
+
self.dropout = config.attention_dropout
|
| 30 |
+
|
| 31 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 32 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 33 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 34 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 35 |
+
# print("*******OceanVisualAttention monkey patched!!*******")
|
| 36 |
+
|
| 37 |
+
# initializer transformer from_pretrain下不生效
|
| 38 |
+
factor = self.config.initializer_factor
|
| 39 |
+
in_proj_std = (self.embed_dim**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) * factor
|
| 40 |
+
out_proj_std = (self.embed_dim**-0.5) * factor
|
| 41 |
+
nn.init.normal_(self.q_proj.weight, std=in_proj_std)
|
| 42 |
+
nn.init.normal_(self.k_proj.weight, std=in_proj_std)
|
| 43 |
+
nn.init.normal_(self.v_proj.weight, std=in_proj_std)
|
| 44 |
+
nn.init.normal_(self.out_proj.weight, std=out_proj_std)
|
| 45 |
+
|
| 46 |
+
def forward(
|
| 47 |
+
self,
|
| 48 |
+
hidden_states: torch.Tensor,
|
| 49 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 50 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
| 51 |
+
output_attentions: Optional[bool] = False,
|
| 52 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 53 |
+
"""Input shape: Batch x Time x Channel"""
|
| 54 |
+
|
| 55 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
| 56 |
+
src_len = tgt_len
|
| 57 |
+
|
| 58 |
+
query_states = self.q_proj(hidden_states).view(bsz * tgt_len, self.num_heads, self.head_dim)
|
| 59 |
+
key_states = self.k_proj(hidden_states).view(bsz * tgt_len, self.num_heads, self.head_dim)
|
| 60 |
+
value_states = self.v_proj(hidden_states).view(bsz * tgt_len, self.num_heads, self.head_dim)
|
| 61 |
+
|
| 62 |
+
# 暂时不考虑变长patch nums 固定长度为256/1024
|
| 63 |
+
cu_len = torch.arange(0, (bsz + 1) * tgt_len, step=tgt_len, dtype=torch.int32, device=query_states.device)
|
| 64 |
+
# print(self.config.s2a, self.config.rope_scaling, cu_len, torch.sum(cu_len), q_len, kv_seq_len)
|
| 65 |
+
# 如果不是f16 bf16不用flash attn
|
| 66 |
+
if query_states.dtype in [torch.float16, torch.bfloat16]:
|
| 67 |
+
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, tgt_len, tgt_len, causal=False) # (bsz * qlen, nheads, headdim)
|
| 68 |
+
attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
| 69 |
+
else:
|
| 70 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
| 71 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask, 0.0)
|
| 72 |
+
attn_output = attn_output.transpose(1, 2)
|
| 73 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 74 |
+
attn_output = self.out_proj(attn_output)
|
| 75 |
+
|
| 76 |
+
return attn_output, None
|
| 77 |
+
|
| 78 |
+
# monkey patch for flash attention
|
| 79 |
+
# transformers.models.siglip.modeling_siglip.SiglipAttention = OceanVisualAttention
|
| 80 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
|
| 81 |
+
from qwen_vl_utils import process_vision_info
|
| 82 |
+
|
| 83 |
+
class OceanVisualEncoder(transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel):
|
| 84 |
+
def __init__(self, config):
|
| 85 |
+
super().__init__(config)
|
| 86 |
+
self.gradient_checkpointing = True # 强制开启
|
| 87 |
+
self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
|
| 88 |
+
del self.merger
|
| 89 |
+
|
| 90 |
+
def forward(
|
| 91 |
+
self,
|
| 92 |
+
pixel_values: torch.Tensor,
|
| 93 |
+
grid_thw: torch.Tensor,
|
| 94 |
+
):
|
| 95 |
+
hidden_states = pixel_values.to(self.get_dtype())
|
| 96 |
+
grid_thw = grid_thw.to(pixel_values.device)
|
| 97 |
+
|
| 98 |
+
hidden_states = self.patch_embed(hidden_states)
|
| 99 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
| 100 |
+
|
| 101 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
| 102 |
+
dim=0, dtype=torch.int32
|
| 103 |
+
)
|
| 104 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
| 105 |
+
|
| 106 |
+
for blk in self.blocks:
|
| 107 |
+
if self.gradient_checkpointing and self.training:
|
| 108 |
+
hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb)
|
| 109 |
+
else:
|
| 110 |
+
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
| 111 |
+
|
| 112 |
+
return hidden_states
|
| 113 |
+
|
| 114 |
+
@torch.no_grad()
|
| 115 |
+
def fake_input(self, device, merge_size=2):
|
| 116 |
+
merge_size = max(merge_size, self.config.spatial_merge_size)
|
| 117 |
+
fake_image = [torch.zeros([
|
| 118 |
+
1,
|
| 119 |
+
self.config.temporal_patch_size,
|
| 120 |
+
3,
|
| 121 |
+
merge_size // self.config.spatial_merge_size,
|
| 122 |
+
self.config.spatial_merge_size,
|
| 123 |
+
self.config.patch_size,
|
| 124 |
+
merge_size // self.config.spatial_merge_size,
|
| 125 |
+
self.config.spatial_merge_size,
|
| 126 |
+
self.config.patch_size,
|
| 127 |
+
], dtype=torch.float32, device=device)]
|
| 128 |
+
return fake_image
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class OceanVisualBridge(nn.Module):
|
| 132 |
+
def __init__(self, config):
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.config = config
|
| 135 |
+
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
|
| 136 |
+
self.hidden_size = config.embed_dim * (self.merge_size**2)
|
| 137 |
+
self.ln_q = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
| 138 |
+
self.mlp = nn.Sequential(
|
| 139 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
| 140 |
+
nn.GELU(),
|
| 141 |
+
nn.Linear(self.hidden_size, config.hidden_size),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 145 |
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def test_vision():
|
| 150 |
+
from transformers.models.clip.modeling_clip import CLIPPreTrainedModel
|
| 151 |
+
from transformers import AutoConfig
|
| 152 |
+
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
|
| 153 |
+
|
| 154 |
+
ae = OceanVisualEncoder(config.visual_config).cuda().to(torch.bfloat16)
|
| 155 |
+
bg = OceanVisualBridge(config).cuda().to(torch.bfloat16)
|
| 156 |
+
print(ae)
|
| 157 |
+
pixel_input = torch.rand([4, 3, config.visual_config.image_size, config.visual_config.image_size], dtype=torch.float32).cuda()
|
| 158 |
+
|
| 159 |
+
visual_embedding = ae(pixel_input)[0][:, 1:] # 删除class token
|
| 160 |
+
visual_proj = bg(visual_embedding)
|
| 161 |
+
print(visual_proj.shape)
|
| 162 |
+
print(ae.fake_input(visual_proj.device))
|
| 163 |
+
|
| 164 |
+
if __name__ == '__main__':
|
| 165 |
+
fire.Fire()
|
| 166 |
+
|