guoxy25 commited on
Commit
e869cba
·
verified ·
1 Parent(s): 282a438

Delete visual_modeling_baichuan.py

Browse files
Files changed (1) hide show
  1. visual_modeling_baichuan.py +0 -166
visual_modeling_baichuan.py DELETED
@@ -1,166 +0,0 @@
1
-
2
- from typing import List, Optional, Tuple, Union
3
- import torch, math
4
- import torch.utils.checkpoint
5
- from torch import nn
6
- import transformers
7
- from flash_attn import flash_attn_varlen_func
8
- from transformers.activations import ACT2FN
9
- from PIL import Image
10
- import io, fire
11
- from torch.nn import functional as F
12
-
13
-
14
- class BaichuanVisualAttention(nn.Module):
15
- """Multi-headed attention from 'Attention Is All You Need' paper"""
16
-
17
- def __init__(self, config):
18
- super().__init__()
19
- self.config = config
20
- self.embed_dim = config.hidden_size
21
- self.num_heads = config.num_attention_heads
22
- self.head_dim = self.embed_dim // self.num_heads
23
- if self.head_dim * self.num_heads != self.embed_dim:
24
- raise ValueError(
25
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
26
- f" {self.num_heads})."
27
- )
28
- self.scale = self.head_dim**-0.5 # flash attention不需要使用
29
- self.dropout = config.attention_dropout
30
-
31
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
32
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
33
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
34
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
35
- # print("*******BaichuanVisualAttention monkey patched!!*******")
36
-
37
- # initializer transformer from_pretrain下不生效
38
- factor = self.config.initializer_factor
39
- in_proj_std = (self.embed_dim**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) * factor
40
- out_proj_std = (self.embed_dim**-0.5) * factor
41
- nn.init.normal_(self.q_proj.weight, std=in_proj_std)
42
- nn.init.normal_(self.k_proj.weight, std=in_proj_std)
43
- nn.init.normal_(self.v_proj.weight, std=in_proj_std)
44
- nn.init.normal_(self.out_proj.weight, std=out_proj_std)
45
-
46
- def forward(
47
- self,
48
- hidden_states: torch.Tensor,
49
- attention_mask: Optional[torch.Tensor] = None,
50
- causal_attention_mask: Optional[torch.Tensor] = None,
51
- output_attentions: Optional[bool] = False,
52
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
53
- """Input shape: Batch x Time x Channel"""
54
-
55
- bsz, tgt_len, embed_dim = hidden_states.size()
56
- src_len = tgt_len
57
-
58
- query_states = self.q_proj(hidden_states).view(bsz * tgt_len, self.num_heads, self.head_dim)
59
- key_states = self.k_proj(hidden_states).view(bsz * tgt_len, self.num_heads, self.head_dim)
60
- value_states = self.v_proj(hidden_states).view(bsz * tgt_len, self.num_heads, self.head_dim)
61
-
62
- # 暂时不考虑变长patch nums 固定长度为256/1024
63
- cu_len = torch.arange(0, (bsz + 1) * tgt_len, step=tgt_len, dtype=torch.int32, device=query_states.device)
64
- # print(self.config.s2a, self.config.rope_scaling, cu_len, torch.sum(cu_len), q_len, kv_seq_len)
65
- # 如果不是f16 bf16不用flash attn
66
- if query_states.dtype in [torch.float16, torch.bfloat16]:
67
- attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, tgt_len, tgt_len, causal=False) # (bsz * qlen, nheads, headdim)
68
- attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim)
69
- else:
70
- with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
71
- attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask, 0.0)
72
- attn_output = attn_output.transpose(1, 2)
73
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
74
- attn_output = self.out_proj(attn_output)
75
-
76
- return attn_output, None
77
-
78
- # monkey patch for flash attention
79
- # transformers.models.siglip.modeling_siglip.SiglipAttention = BaichuanVisualAttention
80
- from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
81
- from qwen_vl_utils import process_vision_info
82
-
83
- class BaichuanVisualEncoder(transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel):
84
- def __init__(self, config):
85
- super().__init__(config)
86
- self.gradient_checkpointing = True # 强制开启
87
- self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
88
- del self.merger
89
-
90
- def forward(
91
- self,
92
- pixel_values: torch.Tensor,
93
- grid_thw: torch.Tensor,
94
- ):
95
- hidden_states = pixel_values.to(self.get_dtype())
96
- grid_thw = grid_thw.to(pixel_values.device)
97
-
98
- hidden_states = self.patch_embed(hidden_states)
99
- rotary_pos_emb = self.rot_pos_emb(grid_thw)
100
-
101
- cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
102
- dim=0, dtype=torch.int32
103
- )
104
- cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
105
-
106
- for blk in self.blocks:
107
- if self.gradient_checkpointing and self.training:
108
- hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb)
109
- else:
110
- hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
111
-
112
- return hidden_states
113
-
114
- @torch.no_grad()
115
- def fake_input(self, device, merge_size=2):
116
- merge_size = max(merge_size, self.config.spatial_merge_size)
117
- fake_image = [torch.zeros([
118
- 1,
119
- self.config.temporal_patch_size,
120
- 3,
121
- merge_size // self.config.spatial_merge_size,
122
- self.config.spatial_merge_size,
123
- self.config.patch_size,
124
- merge_size // self.config.spatial_merge_size,
125
- self.config.spatial_merge_size,
126
- self.config.patch_size,
127
- ], dtype=torch.float32, device=device)]
128
- return fake_image
129
-
130
-
131
- class BaichuanVisualBridge(nn.Module):
132
- def __init__(self, config):
133
- super().__init__()
134
- self.config = config
135
- self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
136
- self.hidden_size = config.embed_dim * (self.merge_size**2)
137
- self.ln_q = nn.LayerNorm(config.embed_dim, eps=1e-6)
138
- self.mlp = nn.Sequential(
139
- nn.Linear(self.hidden_size, self.hidden_size),
140
- nn.GELU(),
141
- nn.Linear(self.hidden_size, config.hidden_size),
142
- )
143
-
144
- def forward(self, x: torch.Tensor) -> torch.Tensor:
145
- x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
146
- return x
147
-
148
-
149
- def test_vision():
150
- from transformers.models.clip.modeling_clip import CLIPPreTrainedModel
151
- from transformers import AutoConfig
152
- config = AutoConfig.from_pretrained("./", trust_remote_code=True)
153
-
154
- ae = BaichuanVisualEncoder(config.visual_config).cuda().to(torch.bfloat16)
155
- bg = BaichuanVisualBridge(config).cuda().to(torch.bfloat16)
156
- print(ae)
157
- pixel_input = torch.rand([4, 3, config.visual_config.image_size, config.visual_config.image_size], dtype=torch.float32).cuda()
158
-
159
- visual_embedding = ae(pixel_input)[0][:, 1:] # 删除class token
160
- visual_proj = bg(visual_embedding)
161
- print(visual_proj.shape)
162
- print(ae.fake_input(visual_proj.device))
163
-
164
- if __name__ == '__main__':
165
- fire.Fire()
166
-