Rally Lin commited on
Commit
f0ad19d
·
1 Parent(s): 9a62748

add modeling

Browse files
Files changed (1) hide show
  1. llava_qwen2.py +0 -2234
llava_qwen2.py DELETED
@@ -1,2234 +0,0 @@
1
-
2
- # Copyright 2023 Haotian Liu
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
-
17
- from typing import List, Optional, Tuple, Union
18
-
19
- import re
20
- import copy
21
- from timm.models import create_model
22
- from abc import ABC, abstractmethod
23
-
24
- import torch
25
- import torch.nn as nn
26
- from torch import Tensor
27
- import torch.nn.functional as F
28
- from torch.nn.init import normal_
29
-
30
- from transformers import CLIPImageProcessor
31
- from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM
32
-
33
- from transformers.modeling_outputs import CausalLMOutputWithPast
34
- from transformers.generation.utils import GenerateOutput
35
-
36
- from functools import partial
37
- from typing import List, Tuple, Optional, Union, Dict, Any
38
-
39
- from timm.models import register_model
40
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
41
- from timm.layers import DropPath, SqueezeExcite
42
-
43
- CONTROLLER_HEART_BEAT_EXPIRATION = 30
44
- WORKER_HEART_BEAT_INTERVAL = 15
45
- LOGDIR = "."
46
- # Model Constants
47
- IGNORE_INDEX = -100
48
- IMAGE_TOKEN_INDEX = -200
49
- DEFAULT_IMAGE_TOKEN = "<image>"
50
- DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
51
- DEFAULT_IM_START_TOKEN = "<im_start>"
52
- DEFAULT_IM_END_TOKEN = "<im_end>"
53
- IMAGE_PLACEHOLDER = "<image-placeholder>"
54
-
55
- class LlavaConfig(Qwen2Config):
56
- model_type = "llava_qwen2"
57
-
58
- def _cfg(url="", **kwargs):
59
- return {
60
- "url": url,
61
- "num_classes": 1000,
62
- "input_size": (3, 256, 256),
63
- "pool_size": None,
64
- "crop_pct": 0.95,
65
- "interpolation": "bicubic",
66
- "mean": IMAGENET_DEFAULT_MEAN,
67
- "std": IMAGENET_DEFAULT_STD,
68
- "classifier": "head",
69
- **kwargs,
70
- }
71
-
72
-
73
- default_cfgs = {
74
- "fastvit_t": _cfg(crop_pct=0.9),
75
- "fastvit_s": _cfg(crop_pct=0.9),
76
- "fastvit_m": _cfg(crop_pct=0.95),
77
- }
78
-
79
-
80
- class SEBlock(nn.Module):
81
- """Squeeze and Excite module.
82
-
83
- Pytorch implementation of `Squeeze-and-Excitation Networks` -
84
- https://arxiv.org/pdf/1709.01507.pdf
85
- """
86
-
87
- def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
88
- """Construct a Squeeze and Excite Module.
89
-
90
- Args:
91
- in_channels: Number of input channels.
92
- rd_ratio: Input channel reduction ratio.
93
- """
94
- super(SEBlock, self).__init__()
95
- self.reduce = nn.Conv2d(
96
- in_channels=in_channels,
97
- out_channels=int(in_channels * rd_ratio),
98
- kernel_size=1,
99
- stride=1,
100
- bias=True,
101
- )
102
- self.expand = nn.Conv2d(
103
- in_channels=int(in_channels * rd_ratio),
104
- out_channels=in_channels,
105
- kernel_size=1,
106
- stride=1,
107
- bias=True,
108
- )
109
-
110
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
111
- """Apply forward pass."""
112
- b, c, h, w = inputs.size()
113
- # x = F.avg_pool2d(inputs, kernel_size=[h, w])
114
- x = F.avg_pool2d(inputs, kernel_size=[16, 16])
115
- x = self.reduce(x)
116
- x = F.relu(x)
117
- x = self.expand(x)
118
- x = torch.sigmoid(x)
119
- x = x.view(-1, c, 1, 1)
120
- return inputs * x
121
-
122
-
123
- class MobileOneBlock(nn.Module):
124
- """MobileOne building block.
125
-
126
- This block has a multi-branched architecture at train-time
127
- and plain-CNN style architecture at inference time
128
- For more details, please refer to our paper:
129
- `An Improved One millisecond Mobile Backbone` -
130
- https://arxiv.org/pdf/2206.04040.pdf
131
- """
132
-
133
- def __init__(
134
- self,
135
- in_channels: int,
136
- out_channels: int,
137
- kernel_size: int,
138
- stride: int = 1,
139
- padding: int = 0,
140
- dilation: int = 1,
141
- groups: int = 1,
142
- inference_mode: bool = False,
143
- use_se: bool = False,
144
- use_act: bool = True,
145
- use_scale_branch: bool = True,
146
- num_conv_branches: int = 1,
147
- activation: nn.Module = nn.GELU(),
148
- ) -> None:
149
- """Construct a MobileOneBlock module.
150
-
151
- Args:
152
- in_channels: Number of channels in the input.
153
- out_channels: Number of channels produced by the block.
154
- kernel_size: Size of the convolution kernel.
155
- stride: Stride size.
156
- padding: Zero-padding size.
157
- dilation: Kernel dilation factor.
158
- groups: Group number.
159
- inference_mode: If True, instantiates model in inference mode.
160
- use_se: Whether to use SE-ReLU activations.
161
- use_act: Whether to use activation. Default: ``True``
162
- use_scale_branch: Whether to use scale branch. Default: ``True``
163
- num_conv_branches: Number of linear conv branches.
164
- """
165
- super(MobileOneBlock, self).__init__()
166
- self.inference_mode = inference_mode
167
- self.groups = groups
168
- self.stride = stride
169
- self.padding = padding
170
- self.dilation = dilation
171
- self.kernel_size = kernel_size
172
- self.in_channels = in_channels
173
- self.out_channels = out_channels
174
- self.num_conv_branches = num_conv_branches
175
-
176
- # Check if SE-ReLU is requested
177
- if use_se:
178
- self.se = SEBlock(out_channels)
179
- else:
180
- self.se = nn.Identity()
181
-
182
- if use_act:
183
- self.activation = activation
184
- else:
185
- self.activation = nn.Identity()
186
-
187
- if inference_mode:
188
- self.reparam_conv = nn.Conv2d(
189
- in_channels=in_channels,
190
- out_channels=out_channels,
191
- kernel_size=kernel_size,
192
- stride=stride,
193
- padding=padding,
194
- dilation=dilation,
195
- groups=groups,
196
- bias=True,
197
- )
198
- else:
199
- # Re-parameterizable skip connection
200
- # Fallback, sometimes batchnorm tensors
201
- # do not get instantiated correctly on some processes
202
- # when using deepspeed + accelerate
203
- norm_layer = nn.BatchNorm2d(num_features=in_channels)
204
- if norm_layer.weight.shape[0] == 0:
205
- norm_layer.weight = nn.Parameter(torch.zeros(in_channels))
206
- if norm_layer.bias.shape[0] == 0:
207
- norm_layer.bias = nn.Parameter(torch.zeros(in_channels))
208
-
209
- self.rbr_skip = (
210
- norm_layer
211
- if out_channels == in_channels and stride == 1
212
- else None
213
- )
214
-
215
- # Re-parameterizable conv branches
216
- if num_conv_branches > 0:
217
- rbr_conv = list()
218
- for _ in range(self.num_conv_branches):
219
- rbr_conv.append(
220
- self._conv_bn(kernel_size=kernel_size, padding=padding)
221
- )
222
- self.rbr_conv = nn.ModuleList(rbr_conv)
223
- else:
224
- self.rbr_conv = None
225
-
226
- # Re-parameterizable scale branch
227
- self.rbr_scale = None
228
- if not isinstance(kernel_size, int):
229
- kernel_size = kernel_size[0]
230
- if (kernel_size > 1) and use_scale_branch:
231
- self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
232
-
233
- def forward(self, x: torch.Tensor) -> torch.Tensor:
234
- """Apply forward pass."""
235
- # Inference mode forward pass.
236
- if self.inference_mode:
237
- return self.activation(self.se(self.reparam_conv(x)))
238
-
239
- # Multi-branched train-time forward pass.
240
- # Skip branch output
241
- identity_out = 0
242
- if self.rbr_skip is not None:
243
- identity_out = self.rbr_skip(x)
244
-
245
- # Scale branch output
246
- scale_out = 0
247
- if self.rbr_scale is not None:
248
- scale_out = self.rbr_scale(x)
249
-
250
- # Other branches
251
- out = scale_out + identity_out
252
- if self.rbr_conv is not None:
253
- for ix in range(self.num_conv_branches):
254
- out += self.rbr_conv[ix](x)
255
-
256
- return self.activation(self.se(out))
257
-
258
- def reparameterize(self):
259
- """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
260
- https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
261
- architecture used at training time to obtain a plain CNN-like structure
262
- for inference.
263
- """
264
- if self.inference_mode:
265
- return
266
- kernel, bias = self._get_kernel_bias()
267
- self.reparam_conv = nn.Conv2d(
268
- in_channels=self.in_channels,
269
- out_channels=self.out_channels,
270
- kernel_size=self.kernel_size,
271
- stride=self.stride,
272
- padding=self.padding,
273
- dilation=self.dilation,
274
- groups=self.groups,
275
- bias=True,
276
- )
277
- self.reparam_conv.weight.data = kernel
278
- self.reparam_conv.bias.data = bias
279
-
280
- # Delete un-used branches
281
- self.__delattr__("rbr_conv")
282
- self.__delattr__("rbr_scale")
283
- if hasattr(self, "rbr_skip"):
284
- self.__delattr__("rbr_skip")
285
-
286
- self.inference_mode = True
287
-
288
- def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
289
- """Method to obtain re-parameterized kernel and bias.
290
- Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
291
-
292
- Returns:
293
- Tuple of (kernel, bias) after fusing branches.
294
- """
295
- # get weights and bias of scale branch
296
- kernel_scale = 0
297
- bias_scale = 0
298
- if self.rbr_scale is not None:
299
- kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
300
- # Pad scale branch kernel to match conv branch kernel size.
301
- pad = self.kernel_size // 2
302
- kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
303
-
304
- # get weights and bias of skip branch
305
- kernel_identity = 0
306
- bias_identity = 0
307
- if self.rbr_skip is not None:
308
- kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
309
-
310
- # get weights and bias of conv branches
311
- kernel_conv = 0
312
- bias_conv = 0
313
- if self.rbr_conv is not None:
314
- for ix in range(self.num_conv_branches):
315
- _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
316
- kernel_conv += _kernel
317
- bias_conv += _bias
318
-
319
- kernel_final = kernel_conv + kernel_scale + kernel_identity
320
- bias_final = bias_conv + bias_scale + bias_identity
321
- return kernel_final, bias_final
322
-
323
- def _fuse_bn_tensor(
324
- self, branch: Union[nn.Sequential, nn.BatchNorm2d]
325
- ) -> Tuple[torch.Tensor, torch.Tensor]:
326
- """Method to fuse batchnorm layer with preceeding conv layer.
327
- Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
328
-
329
- Args:
330
- branch: Sequence of ops to be fused.
331
-
332
- Returns:
333
- Tuple of (kernel, bias) after fusing batchnorm.
334
- """
335
- if isinstance(branch, nn.Sequential):
336
- kernel = branch.conv.weight
337
- running_mean = branch.bn.running_mean
338
- running_var = branch.bn.running_var
339
- gamma = branch.bn.weight
340
- beta = branch.bn.bias
341
- eps = branch.bn.eps
342
- else:
343
- assert isinstance(branch, nn.BatchNorm2d)
344
- if not hasattr(self, "id_tensor"):
345
- input_dim = self.in_channels // self.groups
346
-
347
- kernel_size = self.kernel_size
348
- if isinstance(self.kernel_size, int):
349
- kernel_size = (self.kernel_size, self.kernel_size)
350
-
351
- kernel_value = torch.zeros(
352
- (self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
353
- dtype=branch.weight.dtype,
354
- device=branch.weight.device,
355
- )
356
- for i in range(self.in_channels):
357
- kernel_value[
358
- i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2
359
- ] = 1
360
- self.id_tensor = kernel_value
361
- kernel = self.id_tensor
362
- running_mean = branch.running_mean
363
- running_var = branch.running_var
364
- gamma = branch.weight
365
- beta = branch.bias
366
- eps = branch.eps
367
- std = (running_var + eps).sqrt()
368
- t = (gamma / std).reshape(-1, 1, 1, 1)
369
- return kernel * t, beta - running_mean * gamma / std
370
-
371
- def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
372
- """Helper method to construct conv-batchnorm layers.
373
-
374
- Args:
375
- kernel_size: Size of the convolution kernel.
376
- padding: Zero-padding size.
377
-
378
- Returns:
379
- Conv-BN module.
380
- """
381
- # Fallback, sometimes batchnorm tensors
382
- # do not get instantiated correctly on some processes
383
- # when using deepspeed + accelerate
384
- norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
385
- if norm_layer.weight.shape[0] == 0:
386
- norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
387
- if norm_layer.bias.shape[0] == 0:
388
- norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
389
-
390
- mod_list = nn.Sequential()
391
- mod_list.add_module(
392
- "conv",
393
- nn.Conv2d(
394
- in_channels=self.in_channels,
395
- out_channels=self.out_channels,
396
- kernel_size=kernel_size,
397
- stride=self.stride,
398
- padding=padding,
399
- groups=self.groups,
400
- bias=False,
401
- ),
402
- )
403
- mod_list.add_module("bn", norm_layer)
404
- return mod_list
405
-
406
-
407
- class ReparamLargeKernelConv(nn.Module):
408
- """Building Block of RepLKNet
409
-
410
- This class defines overparameterized large kernel conv block
411
- introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
412
-
413
- Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
414
- """
415
-
416
- def __init__(
417
- self,
418
- in_channels: int,
419
- out_channels: int,
420
- kernel_size: int,
421
- stride: int,
422
- groups: int,
423
- small_kernel: int,
424
- inference_mode: bool = False,
425
- use_se: bool = False,
426
- activation: nn.Module = nn.GELU(),
427
- ) -> None:
428
- """Construct a ReparamLargeKernelConv module.
429
-
430
- Args:
431
- in_channels: Number of input channels.
432
- out_channels: Number of output channels.
433
- kernel_size: Kernel size of the large kernel conv branch.
434
- stride: Stride size. Default: 1
435
- groups: Group number. Default: 1
436
- small_kernel: Kernel size of small kernel conv branch.
437
- inference_mode: If True, instantiates model in inference mode. Default: ``False``
438
- activation: Activation module. Default: ``nn.GELU``
439
- """
440
- super(ReparamLargeKernelConv, self).__init__()
441
-
442
- self.stride = stride
443
- self.groups = groups
444
- self.in_channels = in_channels
445
- self.out_channels = out_channels
446
- self.activation = activation
447
-
448
- self.kernel_size = kernel_size
449
- self.small_kernel = small_kernel
450
- self.padding = kernel_size // 2
451
-
452
- # Check if SE is requested
453
- if use_se:
454
- self.se = SqueezeExcite(out_channels, rd_ratio=0.25)
455
- else:
456
- self.se = nn.Identity()
457
-
458
- if inference_mode:
459
- self.lkb_reparam = nn.Conv2d(
460
- in_channels=in_channels,
461
- out_channels=out_channels,
462
- kernel_size=kernel_size,
463
- stride=stride,
464
- padding=self.padding,
465
- dilation=1,
466
- groups=groups,
467
- bias=True,
468
- )
469
- else:
470
- self.lkb_origin = self._conv_bn(
471
- kernel_size=kernel_size, padding=self.padding
472
- )
473
- if small_kernel is not None:
474
- assert (
475
- small_kernel <= kernel_size
476
- ), "The kernel size for re-param cannot be larger than the large kernel!"
477
- self.small_conv = self._conv_bn(
478
- kernel_size=small_kernel, padding=small_kernel // 2
479
- )
480
-
481
- def forward(self, x: torch.Tensor) -> torch.Tensor:
482
- """Apply forward pass."""
483
- if hasattr(self, "lkb_reparam"):
484
- out = self.lkb_reparam(x)
485
- else:
486
- out = self.lkb_origin(x)
487
- if hasattr(self, "small_conv"):
488
- out += self.small_conv(x)
489
-
490
- return self.activation(self.se(out))
491
-
492
- def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
493
- """Method to obtain re-parameterized kernel and bias.
494
- Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
495
-
496
- Returns:
497
- Tuple of (kernel, bias) after fusing branches.
498
- """
499
- eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
500
- if hasattr(self, "small_conv"):
501
- small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
502
- eq_b += small_b
503
- eq_k += nn.functional.pad(
504
- small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
505
- )
506
- return eq_k, eq_b
507
-
508
- def reparameterize(self) -> None:
509
- """
510
- Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
511
- https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
512
- architecture used at training time to obtain a plain CNN-like structure
513
- for inference.
514
- """
515
- eq_k, eq_b = self.get_kernel_bias()
516
- self.lkb_reparam = nn.Conv2d(
517
- in_channels=self.in_channels,
518
- out_channels=self.out_channels,
519
- kernel_size=self.kernel_size,
520
- stride=self.stride,
521
- padding=self.padding,
522
- dilation=self.lkb_origin.conv.dilation,
523
- groups=self.groups,
524
- bias=True,
525
- )
526
-
527
- self.lkb_reparam.weight.data = eq_k
528
- self.lkb_reparam.bias.data = eq_b
529
- self.__delattr__("lkb_origin")
530
- if hasattr(self, "small_conv"):
531
- self.__delattr__("small_conv")
532
-
533
- @staticmethod
534
- def _fuse_bn(
535
- conv: torch.Tensor, bn: nn.BatchNorm2d
536
- ) -> Tuple[torch.Tensor, torch.Tensor]:
537
- """Method to fuse batchnorm layer with conv layer.
538
-
539
- Args:
540
- conv: Convolutional kernel weights.
541
- bn: Batchnorm 2d layer.
542
-
543
- Returns:
544
- Tuple of (kernel, bias) after fusing batchnorm.
545
- """
546
- kernel = conv.weight
547
- running_mean = bn.running_mean
548
- running_var = bn.running_var
549
- gamma = bn.weight
550
- beta = bn.bias
551
- eps = bn.eps
552
- std = (running_var + eps).sqrt()
553
- t = (gamma / std).reshape(-1, 1, 1, 1)
554
- return kernel * t, beta - running_mean * gamma / std
555
-
556
- def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
557
- """Helper method to construct conv-batchnorm layers.
558
-
559
- Args:
560
- kernel_size: Size of the convolution kernel.
561
- padding: Zero-padding size.
562
-
563
- Returns:
564
- A nn.Sequential Conv-BN module.
565
- """
566
- # Fallback, sometimes batchnorm tensors
567
- # do not get instantiated correctly on some processes
568
- # when using deepspeed + accelerate
569
- norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
570
- if norm_layer.weight.shape[0] == 0:
571
- norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
572
- if norm_layer.bias.shape[0] == 0:
573
- norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
574
-
575
- mod_list = nn.Sequential()
576
- mod_list.add_module(
577
- "conv",
578
- nn.Conv2d(
579
- in_channels=self.in_channels,
580
- out_channels=self.out_channels,
581
- kernel_size=kernel_size,
582
- stride=self.stride,
583
- padding=padding,
584
- groups=self.groups,
585
- bias=False,
586
- ),
587
- )
588
- mod_list.add_module("bn", norm_layer)
589
- return mod_list
590
-
591
-
592
- def convolutional_stem(
593
- in_channels: int, out_channels: int, inference_mode: bool = False, use_scale_branch: bool = True,
594
- ) -> nn.Sequential:
595
- """Build convolutional stem with MobileOne blocks.
596
-
597
- Args:
598
- in_channels: Number of input channels.
599
- out_channels: Number of output channels.
600
- inference_mode: Flag to instantiate model in inference mode. Default: ``False``
601
-
602
- Returns:
603
- nn.Sequential object with stem elements.
604
- """
605
- return nn.Sequential(
606
- MobileOneBlock(
607
- in_channels=in_channels,
608
- out_channels=out_channels,
609
- kernel_size=3,
610
- stride=2,
611
- padding=1,
612
- groups=1,
613
- inference_mode=inference_mode,
614
- use_se=False,
615
- num_conv_branches=1,
616
- use_scale_branch=use_scale_branch
617
- ),
618
- MobileOneBlock(
619
- in_channels=out_channels,
620
- out_channels=out_channels,
621
- kernel_size=3,
622
- stride=2,
623
- padding=1,
624
- groups=out_channels,
625
- inference_mode=inference_mode,
626
- use_se=False,
627
- num_conv_branches=1,
628
- use_scale_branch=use_scale_branch
629
- ),
630
- MobileOneBlock(
631
- in_channels=out_channels,
632
- out_channels=out_channels,
633
- kernel_size=1,
634
- stride=1,
635
- padding=0,
636
- groups=1,
637
- inference_mode=inference_mode,
638
- use_se=False,
639
- num_conv_branches=1,
640
- use_scale_branch=use_scale_branch
641
- ),
642
- )
643
-
644
-
645
- class LayerNormChannel(nn.Module):
646
- """
647
- LayerNorm only for Channel Dimension.
648
- Input: tensor in shape [B, C, H, W]
649
- """
650
- def __init__(self, num_features, eps=1e-05) -> None:
651
- super().__init__()
652
- self.weight = nn.Parameter(torch.ones(num_features))
653
- self.bias = nn.Parameter(torch.zeros(num_features))
654
- self.eps = eps
655
-
656
- def forward(self, x) -> torch.Tensor:
657
- u = x.mean(1, keepdim=True)
658
- s = (x - u).pow(2).mean(1, keepdim=True)
659
- x = (x - u) / torch.sqrt(s + self.eps)
660
- x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
661
- + self.bias.unsqueeze(-1).unsqueeze(-1)
662
- return x
663
-
664
-
665
- class MHSA(nn.Module):
666
- """Multi-headed Self Attention module.
667
-
668
- Source modified from:
669
- https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
670
- """
671
-
672
- def __init__(
673
- self,
674
- dim: int,
675
- head_dim: int = 32,
676
- qkv_bias: bool = False,
677
- attn_drop: float = 0.0,
678
- proj_drop: float = 0.0,
679
- ) -> None:
680
- """Build MHSA module that can handle 3D or 4D input tensors.
681
-
682
- Args:
683
- dim: Number of embedding dimensions.
684
- head_dim: Number of hidden dimensions per head. Default: ``32``
685
- qkv_bias: Use bias or not. Default: ``False``
686
- attn_drop: Dropout rate for attention tensor.
687
- proj_drop: Dropout rate for projection tensor.
688
- """
689
- super().__init__()
690
- assert dim % head_dim == 0, "dim should be divisible by head_dim"
691
- self.head_dim = head_dim
692
- self.num_heads = dim // head_dim
693
- self.scale = head_dim**-0.5
694
-
695
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
696
- self.attn_drop = nn.Dropout(attn_drop)
697
- self.proj = nn.Linear(dim, dim)
698
- self.proj_drop = nn.Dropout(proj_drop)
699
-
700
- def forward(self, x: torch.Tensor) -> torch.Tensor:
701
- shape = x.shape
702
- B, C, H, W = shape
703
- N = H * W
704
- if len(shape) == 4:
705
- x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
706
- qkv = (
707
- self.qkv(x)
708
- .reshape(B, N, 3, self.num_heads, self.head_dim)
709
- .permute(2, 0, 3, 1, 4)
710
- )
711
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
712
-
713
- # trick here to make [email protected] more stable
714
- attn = (q * self.scale) @ k.transpose(-2, -1)
715
- attn = attn.softmax(dim=-1)
716
- attn = self.attn_drop(attn)
717
-
718
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
719
- x = self.proj(x)
720
- x = self.proj_drop(x)
721
- if len(shape) == 4:
722
- x = x.transpose(-2, -1).reshape(B, C, H, W)
723
-
724
- return x
725
-
726
-
727
- class PatchEmbed(nn.Module):
728
- """Convolutional patch embedding layer."""
729
-
730
- def __init__(
731
- self,
732
- patch_size: int,
733
- stride: int,
734
- in_channels: int,
735
- embed_dim: int,
736
- inference_mode: bool = False,
737
- use_se: bool = False,
738
- ) -> None:
739
- """Build patch embedding layer.
740
-
741
- Args:
742
- patch_size: Patch size for embedding computation.
743
- stride: Stride for convolutional embedding layer.
744
- in_channels: Number of channels of input tensor.
745
- embed_dim: Number of embedding dimensions.
746
- inference_mode: Flag to instantiate model in inference mode. Default: ``False``
747
- use_se: If ``True`` SE block will be used.
748
- """
749
- super().__init__()
750
- block = list()
751
- block.append(
752
- ReparamLargeKernelConv(
753
- in_channels=in_channels,
754
- out_channels=embed_dim,
755
- kernel_size=patch_size,
756
- stride=stride,
757
- groups=in_channels,
758
- small_kernel=3,
759
- inference_mode=inference_mode,
760
- use_se=use_se,
761
- )
762
- )
763
- block.append(
764
- MobileOneBlock(
765
- in_channels=embed_dim,
766
- out_channels=embed_dim,
767
- kernel_size=1,
768
- stride=1,
769
- padding=0,
770
- groups=1,
771
- inference_mode=inference_mode,
772
- use_se=False,
773
- num_conv_branches=1,
774
- )
775
- )
776
- self.proj = nn.Sequential(*block)
777
-
778
- def forward(self, x: torch.Tensor) -> torch.Tensor:
779
- x = self.proj(x)
780
- return x
781
-
782
-
783
- class RepMixer(nn.Module):
784
- """Reparameterizable token mixer.
785
-
786
- For more details, please refer to our paper:
787
- `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
788
- """
789
-
790
- def __init__(
791
- self,
792
- dim,
793
- kernel_size=3,
794
- use_layer_scale=True,
795
- layer_scale_init_value=1e-5,
796
- inference_mode: bool = False,
797
- ):
798
- """Build RepMixer Module.
799
-
800
- Args:
801
- dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
802
- kernel_size: Kernel size for spatial mixing. Default: 3
803
- use_layer_scale: If True, learnable layer scale is used. Default: ``True``
804
- layer_scale_init_value: Initial value for layer scale. Default: 1e-5
805
- inference_mode: If True, instantiates model in inference mode. Default: ``False``
806
- """
807
- super().__init__()
808
- self.dim = dim
809
- self.kernel_size = kernel_size
810
- self.inference_mode = inference_mode
811
-
812
- if inference_mode:
813
- self.reparam_conv = nn.Conv2d(
814
- in_channels=self.dim,
815
- out_channels=self.dim,
816
- kernel_size=self.kernel_size,
817
- stride=1,
818
- padding=self.kernel_size // 2,
819
- groups=self.dim,
820
- bias=True,
821
- )
822
- else:
823
- self.norm = MobileOneBlock(
824
- dim,
825
- dim,
826
- kernel_size,
827
- padding=kernel_size // 2,
828
- groups=dim,
829
- use_act=False,
830
- use_scale_branch=False,
831
- num_conv_branches=0,
832
- )
833
- self.mixer = MobileOneBlock(
834
- dim,
835
- dim,
836
- kernel_size,
837
- padding=kernel_size // 2,
838
- groups=dim,
839
- use_act=False,
840
- )
841
- self.use_layer_scale = use_layer_scale
842
- if use_layer_scale:
843
- self.layer_scale = nn.Parameter(
844
- layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
845
- )
846
-
847
- def forward(self, x: torch.Tensor) -> torch.Tensor:
848
- if hasattr(self, "reparam_conv"):
849
- x = self.reparam_conv(x)
850
- return x
851
- else:
852
- if self.use_layer_scale:
853
- x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
854
- else:
855
- x = x + self.mixer(x) - self.norm(x)
856
- return x
857
-
858
- def reparameterize(self) -> None:
859
- """Reparameterize mixer and norm into a single
860
- convolutional layer for efficient inference.
861
- """
862
- if self.inference_mode:
863
- return
864
-
865
- self.mixer.reparameterize()
866
- self.norm.reparameterize()
867
-
868
- if self.use_layer_scale:
869
- w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
870
- self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
871
- )
872
- b = torch.squeeze(self.layer_scale) * (
873
- self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
874
- )
875
- else:
876
- w = (
877
- self.mixer.id_tensor
878
- + self.mixer.reparam_conv.weight
879
- - self.norm.reparam_conv.weight
880
- )
881
- b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
882
-
883
- self.reparam_conv = nn.Conv2d(
884
- in_channels=self.dim,
885
- out_channels=self.dim,
886
- kernel_size=self.kernel_size,
887
- stride=1,
888
- padding=self.kernel_size // 2,
889
- groups=self.dim,
890
- bias=True,
891
- )
892
- self.reparam_conv.weight.data = w
893
- self.reparam_conv.bias.data = b
894
-
895
- self.__delattr__("mixer")
896
- self.__delattr__("norm")
897
- if self.use_layer_scale:
898
- self.__delattr__("layer_scale")
899
-
900
-
901
- class ConvFFN(nn.Module):
902
- """Convolutional FFN Module."""
903
-
904
- def __init__(
905
- self,
906
- in_channels: int,
907
- hidden_channels: Optional[int] = None,
908
- out_channels: Optional[int] = None,
909
- act_layer: nn.Module = nn.GELU,
910
- drop: float = 0.0,
911
- ) -> None:
912
- """Build convolutional FFN module.
913
-
914
- Args:
915
- in_channels: Number of input channels.
916
- hidden_channels: Number of channels after expansion. Default: None
917
- out_channels: Number of output channels. Default: None
918
- act_layer: Activation layer. Default: ``GELU``
919
- drop: Dropout rate. Default: ``0.0``.
920
- """
921
- super().__init__()
922
- out_channels = out_channels or in_channels
923
- hidden_channels = hidden_channels or in_channels
924
- self.conv = nn.Sequential()
925
- self.conv.add_module(
926
- "conv",
927
- nn.Conv2d(
928
- in_channels=in_channels,
929
- out_channels=out_channels,
930
- kernel_size=7,
931
- padding=3,
932
- groups=in_channels,
933
- bias=False,
934
- ),
935
- )
936
-
937
- # Fallback, sometimes batchnorm tensors
938
- # do not get instantiated correctly on some processes
939
- # when using deepspeed + accelerate
940
- norm_layer = nn.BatchNorm2d(num_features=out_channels)
941
- if norm_layer.weight.shape[0] == 0:
942
- norm_layer.weight = nn.Parameter(torch.zeros(out_channels))
943
- if norm_layer.bias.shape[0] == 0:
944
- norm_layer.bias = nn.Parameter(torch.zeros(out_channels))
945
-
946
- self.conv.add_module("bn", norm_layer)
947
- self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
948
- self.act = act_layer()
949
- self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
950
- self.drop = nn.Dropout(drop)
951
- self.apply(self._init_weights)
952
-
953
- def _init_weights(self, m: nn.Module) -> None:
954
- if isinstance(m, nn.Conv2d):
955
- normal_(m.weight, std=0.02)
956
- if m.bias is not None:
957
- nn.init.constant_(m.bias, 0)
958
-
959
- def forward(self, x: torch.Tensor) -> torch.Tensor:
960
- x = self.conv(x)
961
- x = self.fc1(x)
962
- x = self.act(x)
963
- x = self.drop(x)
964
- x = self.fc2(x)
965
- x = self.drop(x)
966
- return x
967
-
968
-
969
- class RepCPE(nn.Module):
970
- """Implementation of conditional positional encoding.
971
-
972
- For more details refer to paper:
973
- `Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
974
-
975
- In our implementation, we can reparameterize this module to eliminate a skip connection.
976
- """
977
-
978
- def __init__(
979
- self,
980
- in_channels: int,
981
- embed_dim: int = 768,
982
- spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
983
- inference_mode=False,
984
- ) -> None:
985
- """Build reparameterizable conditional positional encoding
986
-
987
- Args:
988
- in_channels: Number of input channels.
989
- embed_dim: Number of embedding dimensions. Default: 768
990
- spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
991
- inference_mode: Flag to instantiate block in inference mode. Default: ``False``
992
- """
993
- super(RepCPE, self).__init__()
994
- if isinstance(spatial_shape, int):
995
- spatial_shape = tuple([spatial_shape] * 2)
996
- assert isinstance(spatial_shape, Tuple), (
997
- f'"spatial_shape" must by a sequence or int, '
998
- f"get {type(spatial_shape)} instead."
999
- )
1000
- assert len(spatial_shape) == 2, (
1001
- f'Length of "spatial_shape" should be 2, '
1002
- f"got {len(spatial_shape)} instead."
1003
- )
1004
-
1005
- self.spatial_shape = spatial_shape
1006
- self.embed_dim = embed_dim
1007
- self.in_channels = in_channels
1008
- self.groups = embed_dim
1009
-
1010
- if inference_mode:
1011
- self.reparam_conv = nn.Conv2d(
1012
- in_channels=self.in_channels,
1013
- out_channels=self.embed_dim,
1014
- kernel_size=self.spatial_shape,
1015
- stride=1,
1016
- padding=int(self.spatial_shape[0] // 2),
1017
- groups=self.embed_dim,
1018
- bias=True,
1019
- )
1020
- else:
1021
- self.pe = nn.Conv2d(
1022
- in_channels,
1023
- embed_dim,
1024
- spatial_shape,
1025
- 1,
1026
- int(spatial_shape[0] // 2),
1027
- bias=True,
1028
- groups=embed_dim,
1029
- )
1030
-
1031
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1032
- if hasattr(self, "reparam_conv"):
1033
- x = self.reparam_conv(x)
1034
- return x
1035
- else:
1036
- x = self.pe(x) + x
1037
- return x
1038
-
1039
- def reparameterize(self) -> None:
1040
- # Build equivalent Id tensor
1041
- input_dim = self.in_channels // self.groups
1042
- kernel_value = torch.zeros(
1043
- (
1044
- self.in_channels,
1045
- input_dim,
1046
- self.spatial_shape[0],
1047
- self.spatial_shape[1],
1048
- ),
1049
- dtype=self.pe.weight.dtype,
1050
- device=self.pe.weight.device,
1051
- )
1052
- for i in range(self.in_channels):
1053
- kernel_value[
1054
- i,
1055
- i % input_dim,
1056
- self.spatial_shape[0] // 2,
1057
- self.spatial_shape[1] // 2,
1058
- ] = 1
1059
- id_tensor = kernel_value
1060
-
1061
- # Reparameterize Id tensor and conv
1062
- w_final = id_tensor + self.pe.weight
1063
- b_final = self.pe.bias
1064
-
1065
- # Introduce reparam conv
1066
- self.reparam_conv = nn.Conv2d(
1067
- in_channels=self.in_channels,
1068
- out_channels=self.embed_dim,
1069
- kernel_size=self.spatial_shape,
1070
- stride=1,
1071
- padding=int(self.spatial_shape[0] // 2),
1072
- groups=self.embed_dim,
1073
- bias=True,
1074
- )
1075
- self.reparam_conv.weight.data = w_final
1076
- self.reparam_conv.bias.data = b_final
1077
-
1078
- self.__delattr__("pe")
1079
-
1080
-
1081
- class RepMixerBlock(nn.Module):
1082
- """Implementation of Metaformer block with RepMixer as token mixer.
1083
-
1084
- For more details on Metaformer structure, please refer to:
1085
- `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
1086
- """
1087
-
1088
- def __init__(
1089
- self,
1090
- dim: int,
1091
- kernel_size: int = 3,
1092
- mlp_ratio: float = 4.0,
1093
- act_layer: nn.Module = nn.GELU,
1094
- drop: float = 0.0,
1095
- drop_path: float = 0.0,
1096
- use_layer_scale: bool = True,
1097
- layer_scale_init_value: float = 1e-5,
1098
- inference_mode: bool = False,
1099
- ):
1100
- """Build RepMixer Block.
1101
-
1102
- Args:
1103
- dim: Number of embedding dimensions.
1104
- kernel_size: Kernel size for repmixer. Default: 3
1105
- mlp_ratio: MLP expansion ratio. Default: 4.0
1106
- act_layer: Activation layer. Default: ``nn.GELU``
1107
- drop: Dropout rate. Default: 0.0
1108
- drop_path: Drop path rate. Default: 0.0
1109
- use_layer_scale: Flag to turn on layer scale. Default: ``True``
1110
- layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
1111
- inference_mode: Flag to instantiate block in inference mode. Default: ``False``
1112
- """
1113
-
1114
- super().__init__()
1115
-
1116
- self.token_mixer = RepMixer(
1117
- dim,
1118
- kernel_size=kernel_size,
1119
- use_layer_scale=use_layer_scale,
1120
- layer_scale_init_value=layer_scale_init_value,
1121
- inference_mode=inference_mode,
1122
- )
1123
-
1124
- assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
1125
- mlp_ratio
1126
- )
1127
- mlp_hidden_dim = int(dim * mlp_ratio)
1128
- self.convffn = ConvFFN(
1129
- in_channels=dim,
1130
- hidden_channels=mlp_hidden_dim,
1131
- act_layer=act_layer,
1132
- drop=drop,
1133
- )
1134
-
1135
- # Drop Path
1136
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1137
-
1138
- # Layer Scale
1139
- self.use_layer_scale = use_layer_scale
1140
- if use_layer_scale:
1141
- self.layer_scale = nn.Parameter(
1142
- layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1143
- )
1144
-
1145
- def forward(self, x):
1146
- if self.use_layer_scale:
1147
- x = self.token_mixer(x)
1148
- x = x + self.drop_path(self.layer_scale * self.convffn(x))
1149
- else:
1150
- x = self.token_mixer(x)
1151
- x = x + self.drop_path(self.convffn(x))
1152
- return x
1153
-
1154
-
1155
- class AttentionBlock(nn.Module):
1156
- """Implementation of metaformer block with MHSA as token mixer.
1157
-
1158
- For more details on Metaformer structure, please refer to:
1159
- `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
1160
- """
1161
-
1162
- def __init__(
1163
- self,
1164
- dim: int,
1165
- mlp_ratio: float = 4.0,
1166
- act_layer: nn.Module = nn.GELU,
1167
- norm_layer: nn.Module = nn.BatchNorm2d,
1168
- drop: float = 0.0,
1169
- drop_path: float = 0.0,
1170
- use_layer_scale: bool = True,
1171
- layer_scale_init_value: float = 1e-5,
1172
- ):
1173
- """Build Attention Block.
1174
-
1175
- Args:
1176
- dim: Number of embedding dimensions.
1177
- mlp_ratio: MLP expansion ratio. Default: 4.0
1178
- act_layer: Activation layer. Default: ``nn.GELU``
1179
- norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
1180
- drop: Dropout rate. Default: 0.0
1181
- drop_path: Drop path rate. Default: 0.0
1182
- use_layer_scale: Flag to turn on layer scale. Default: ``True``
1183
- layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
1184
- """
1185
-
1186
- super().__init__()
1187
-
1188
- # Fallback, sometimes batchnorm tensors
1189
- # do not get instantiated correctly on some processes
1190
- # when using deepspeed + accelerate
1191
- norm_layer_ = norm_layer(num_features=dim)
1192
- if norm_layer_.weight.shape[0] == 0:
1193
- norm_layer_.weight = nn.Parameter(torch.zeros(dim))
1194
- if norm_layer_.bias.shape[0] == 0:
1195
- norm_layer_.bias = nn.Parameter(torch.zeros(dim))
1196
-
1197
- self.norm = norm_layer_
1198
- self.token_mixer = MHSA(dim=dim)
1199
-
1200
- assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
1201
- mlp_ratio
1202
- )
1203
- mlp_hidden_dim = int(dim * mlp_ratio)
1204
- self.convffn = ConvFFN(
1205
- in_channels=dim,
1206
- hidden_channels=mlp_hidden_dim,
1207
- act_layer=act_layer,
1208
- drop=drop,
1209
- )
1210
-
1211
- # Drop path
1212
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1213
-
1214
- # Layer Scale
1215
- self.use_layer_scale = use_layer_scale
1216
- if use_layer_scale:
1217
- self.layer_scale_1 = nn.Parameter(
1218
- layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1219
- )
1220
- self.layer_scale_2 = nn.Parameter(
1221
- layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1222
- )
1223
-
1224
- def forward(self, x):
1225
- if self.use_layer_scale:
1226
- x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
1227
- x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
1228
- else:
1229
- x = x + self.drop_path(self.token_mixer(self.norm(x)))
1230
- x = x + self.drop_path(self.convffn(x))
1231
- return x
1232
-
1233
-
1234
- def basic_blocks(
1235
- dim: int,
1236
- block_index: int,
1237
- num_blocks: List[int],
1238
- token_mixer_type: str,
1239
- kernel_size: int = 3,
1240
- mlp_ratio: float = 4.0,
1241
- act_layer: nn.Module = nn.GELU,
1242
- norm_layer: nn.Module = nn.BatchNorm2d,
1243
- drop_rate: float = 0.0,
1244
- drop_path_rate: float = 0.0,
1245
- use_layer_scale: bool = True,
1246
- layer_scale_init_value: float = 1e-5,
1247
- inference_mode=False,
1248
- ) -> nn.Sequential:
1249
- """Build FastViT blocks within a stage.
1250
-
1251
- Args:
1252
- dim: Number of embedding dimensions.
1253
- block_index: block index.
1254
- num_blocks: List containing number of blocks per stage.
1255
- token_mixer_type: Token mixer type.
1256
- kernel_size: Kernel size for repmixer.
1257
- mlp_ratio: MLP expansion ratio.
1258
- act_layer: Activation layer.
1259
- norm_layer: Normalization layer.
1260
- drop_rate: Dropout rate.
1261
- drop_path_rate: Drop path rate.
1262
- use_layer_scale: Flag to turn on layer scale regularization.
1263
- layer_scale_init_value: Layer scale value at initialization.
1264
- inference_mode: Flag to instantiate block in inference mode.
1265
-
1266
- Returns:
1267
- nn.Sequential object of all the blocks within the stage.
1268
- """
1269
- blocks = []
1270
- for block_idx in range(num_blocks[block_index]):
1271
- block_dpr = (
1272
- drop_path_rate
1273
- * (block_idx + sum(num_blocks[:block_index]))
1274
- / (sum(num_blocks) - 1)
1275
- )
1276
- if token_mixer_type == "repmixer":
1277
- blocks.append(
1278
- RepMixerBlock(
1279
- dim,
1280
- kernel_size=kernel_size,
1281
- mlp_ratio=mlp_ratio,
1282
- act_layer=act_layer,
1283
- drop=drop_rate,
1284
- drop_path=block_dpr,
1285
- use_layer_scale=use_layer_scale,
1286
- layer_scale_init_value=layer_scale_init_value,
1287
- inference_mode=inference_mode,
1288
- )
1289
- )
1290
- elif token_mixer_type == "attention":
1291
- blocks.append(
1292
- AttentionBlock(
1293
- dim,
1294
- mlp_ratio=mlp_ratio,
1295
- act_layer=act_layer,
1296
- norm_layer=norm_layer,
1297
- drop=drop_rate,
1298
- drop_path=block_dpr,
1299
- use_layer_scale=use_layer_scale,
1300
- layer_scale_init_value=layer_scale_init_value,
1301
- )
1302
- )
1303
- else:
1304
- raise ValueError(
1305
- "Token mixer type: {} not supported".format(token_mixer_type)
1306
- )
1307
- blocks = nn.Sequential(*blocks)
1308
- return blocks
1309
-
1310
-
1311
- class GlobalPool2D(nn.Module):
1312
- """This class implements global pooling with linear projection."""
1313
-
1314
- def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None:
1315
- super().__init__()
1316
- scale = in_dim**-0.5
1317
- self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
1318
- self.in_dim = in_dim
1319
- self.out_dim = out_dim
1320
-
1321
- def pool(self, x) -> Tensor:
1322
- if x.dim() == 4:
1323
- dims = [-2, -1]
1324
- elif x.dim() == 5:
1325
- dims = [-3, -2, -1]
1326
- x = torch.mean(x, dim=dims, keepdim=False)
1327
- return x
1328
-
1329
- def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
1330
- # x is of shape [batch, in_dim]
1331
- assert (
1332
- x.dim() == 4
1333
- ), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format(
1334
- x.shape
1335
- )
1336
-
1337
- # [batch, in_dim, in_height, in_width] --> [batch, in_dim]
1338
- x = self.pool(x)
1339
- # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
1340
- x = x @ self.proj
1341
- return x
1342
-
1343
-
1344
- class FastViT(nn.Module):
1345
- """
1346
- This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
1347
- """
1348
-
1349
- def __init__(
1350
- self,
1351
- layers,
1352
- token_mixers: Tuple[str, ...],
1353
- embed_dims=None,
1354
- mlp_ratios=None,
1355
- downsamples=None,
1356
- se_downsamples=None,
1357
- repmixer_kernel_size=3,
1358
- norm_layer: nn.Module = nn.BatchNorm2d,
1359
- act_layer: nn.Module = nn.GELU,
1360
- num_classes=1000,
1361
- pos_embs=None,
1362
- down_patch_size=7,
1363
- down_stride=2,
1364
- drop_rate=0.0,
1365
- drop_path_rate=0.0,
1366
- use_layer_scale=True,
1367
- layer_scale_init_value=1e-5,
1368
- init_cfg=None,
1369
- pretrained=None,
1370
- cls_ratio=2.0,
1371
- inference_mode=False,
1372
- stem_scale_branch=True,
1373
- **kwargs,
1374
- ) -> None:
1375
-
1376
- super().__init__()
1377
-
1378
- self.num_classes = num_classes
1379
- if len(layers) == 4:
1380
- self.out_indices = [0, 2, 4, 7]
1381
- elif len(layers) == 5:
1382
- self.out_indices = [0, 2, 4, 7, 10]
1383
- else:
1384
- raise NotImplementedError("FPN is not implemented for more than 5 stages.")
1385
-
1386
- if pos_embs is None:
1387
- pos_embs = [None] * len(layers)
1388
-
1389
- if se_downsamples is None:
1390
- se_downsamples = [False] * len(layers)
1391
-
1392
- # Convolutional stem
1393
- self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode,
1394
- use_scale_branch=stem_scale_branch)
1395
-
1396
- # Build the main stages of the network architecture
1397
- network = []
1398
- for i in range(len(layers)):
1399
- # Add position embeddings if requested
1400
- if pos_embs[i] is not None:
1401
- network.append(
1402
- pos_embs[i](
1403
- embed_dims[i], embed_dims[i], inference_mode=inference_mode
1404
- )
1405
- )
1406
- stage = basic_blocks(
1407
- embed_dims[i],
1408
- i,
1409
- layers,
1410
- token_mixer_type=token_mixers[i],
1411
- kernel_size=repmixer_kernel_size,
1412
- mlp_ratio=mlp_ratios[i],
1413
- act_layer=act_layer,
1414
- norm_layer=norm_layer,
1415
- drop_rate=drop_rate,
1416
- drop_path_rate=drop_path_rate,
1417
- use_layer_scale=use_layer_scale,
1418
- layer_scale_init_value=layer_scale_init_value,
1419
- inference_mode=inference_mode,
1420
- )
1421
- network.append(stage)
1422
- if i >= len(layers) - 1:
1423
- break
1424
-
1425
- # Patch merging/downsampling between stages.
1426
- if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
1427
- network.append(
1428
- PatchEmbed(
1429
- patch_size=down_patch_size,
1430
- stride=down_stride,
1431
- in_channels=embed_dims[i],
1432
- embed_dim=embed_dims[i + 1],
1433
- inference_mode=inference_mode,
1434
- use_se=se_downsamples[i + 1],
1435
- )
1436
- )
1437
- self.network = nn.ModuleList(network)
1438
-
1439
- # Classifier head
1440
- self.conv_exp = MobileOneBlock(
1441
- in_channels=embed_dims[-1],
1442
- out_channels=int(embed_dims[-1] * cls_ratio),
1443
- kernel_size=3,
1444
- stride=1,
1445
- padding=1,
1446
- groups=embed_dims[-1],
1447
- inference_mode=inference_mode,
1448
- use_se=True,
1449
- num_conv_branches=1,
1450
- )
1451
- self.head = (
1452
- nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes)
1453
- if num_classes > 0
1454
- else nn.Identity()
1455
- )
1456
- self.apply(self.cls_init_weights)
1457
- self.init_cfg = copy.deepcopy(init_cfg)
1458
-
1459
- def cls_init_weights(self, m: nn.Module) -> None:
1460
- """Init. for classification"""
1461
- if isinstance(m, nn.Linear):
1462
- normal_(m.weight, std=0.02)
1463
- if isinstance(m, nn.Linear) and m.bias is not None:
1464
- nn.init.constant_(m.bias, 0)
1465
-
1466
- def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
1467
- x = self.patch_embed(x)
1468
- return x
1469
-
1470
- def forward_tokens(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1471
- for idx, block in enumerate(self.network):
1472
- x = block(x)
1473
- return x
1474
-
1475
- def forward(self, x: torch.Tensor, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]:
1476
- # input embedding
1477
- x = self.forward_embeddings(x)
1478
- # through backbone
1479
- x = self.forward_tokens(x)
1480
- # for image classification/embedding
1481
- x = self.conv_exp(x)
1482
- cls_out = self.head(x)
1483
-
1484
- out_dict = dict()
1485
- if kwargs.get("return_image_embeddings", False):
1486
- out_dict.update({"logits": cls_out})
1487
- out_dict.update({"image_embeddings": x})
1488
- return out_dict
1489
- else:
1490
- return cls_out
1491
-
1492
-
1493
- @register_model
1494
- def fastvithd(pretrained=False, **kwargs):
1495
- """Instantiate FastViTHD model variant."""
1496
- layers = [2, 12, 24, 4, 2]
1497
- embed_dims = [96, 192, 384, 768, 1536]
1498
- mlp_ratios = [4, 4, 4, 4, 4]
1499
- downsamples = [True, True, True, True, True]
1500
- pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7)), partial(RepCPE, spatial_shape=(7, 7))]
1501
- token_mixers = ("repmixer", "repmixer", "repmixer", "attention", "attention")
1502
- model = FastViT(
1503
- layers,
1504
- token_mixers=token_mixers,
1505
- embed_dims=embed_dims,
1506
- pos_embs=pos_embs,
1507
- mlp_ratios=mlp_ratios,
1508
- downsamples=downsamples,
1509
- norm_layer=LayerNormChannel,
1510
- stem_scale_branch=False,
1511
- inference_mode=True,
1512
- **kwargs,
1513
- )
1514
- model.default_cfg = default_cfgs["fastvit_m"]
1515
- if pretrained:
1516
- raise ValueError("Functionality not implemented.")
1517
- return model
1518
-
1519
- def load_model_config(
1520
- model_name: str,
1521
- ) -> Any:
1522
- model_cfg = {
1523
- "embed_dim": 768,
1524
- "image_cfg": {
1525
- "image_size": 1024,
1526
- "model_name": "fastvithd",
1527
- "embed_dim": 3072,
1528
- "patch_size": 64
1529
- },
1530
- "text_cfg": {
1531
- "context_length": 77,
1532
- "vocab_size": 49408,
1533
- "dim": 768,
1534
- "ffn_multiplier_per_layer": 4.0,
1535
- "n_heads_per_layer": 12,
1536
- "n_transformer_layers": 12,
1537
- "norm_layer": "layer_norm_fp32",
1538
- "causal_masking": False,
1539
- "model_name": "base"
1540
- }
1541
- }
1542
- return model_cfg
1543
-
1544
-
1545
- class MCi(nn.Module):
1546
- """
1547
- This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_
1548
- """
1549
-
1550
- def __init__(self, model_name: str, *args, **kwargs) -> None:
1551
- super().__init__()
1552
- self.projection_dim = None
1553
- if "projection_dim" in kwargs:
1554
- self.projection_dim = kwargs.get("projection_dim")
1555
-
1556
- # Create model
1557
- self.model = create_model(model_name, projection_dim=self.projection_dim)
1558
-
1559
- # Build out projection head.
1560
- if self.projection_dim is not None:
1561
- if hasattr(self.model, "head"):
1562
- self.model.head = MCi._update_image_classifier(
1563
- image_classifier=self.model.head, projection_dim=self.projection_dim
1564
- )
1565
-
1566
- def forward(self, x: Any, *args, **kwargs) -> Any:
1567
- """A forward function of the model."""
1568
- x = self.model(x, *args, **kwargs)
1569
- return x
1570
-
1571
- @staticmethod
1572
- def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
1573
- """Return the input feature dimension to the image classification head."""
1574
- in_features = None
1575
- if isinstance(image_classifier, nn.Sequential):
1576
- # Classifier that uses nn.Sequential usually has global pooling and
1577
- # multiple linear layers. Find the first linear layer and get its
1578
- # in_features
1579
- for layer in image_classifier:
1580
- if isinstance(layer, nn.Linear):
1581
- in_features = layer.in_features
1582
- break
1583
- elif isinstance(image_classifier, nn.Linear):
1584
- in_features = image_classifier.in_features
1585
-
1586
- if in_features is None:
1587
- raise NotImplementedError(
1588
- f"Cannot get input feature dimension of {image_classifier}."
1589
- )
1590
- return in_features
1591
-
1592
- @staticmethod
1593
- def _update_image_classifier(
1594
- image_classifier: nn.Module, projection_dim: int, *args, **kwargs
1595
- ) -> nn.Module:
1596
- in_features = MCi._get_in_feature_dimension(image_classifier)
1597
- new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
1598
- return new_img_classifier
1599
-
1600
-
1601
- class MobileCLIPVisionTower(nn.Module):
1602
- def __init__(self, vision_tower, args, delay_load=False):
1603
- super().__init__()
1604
-
1605
- self.is_loaded = False
1606
- self.vision_tower_name = vision_tower
1607
- self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
1608
- self.input_image_size = int(vision_tower.split("_")[-1])
1609
-
1610
- # Delay load is disabled for now
1611
- if not delay_load:
1612
- self.load_model()
1613
- elif getattr(args, 'unfreeze_mm_vision_tower', False):
1614
- self.load_model()
1615
- else:
1616
- model_cfg = load_model_config(self.vision_tower_name)
1617
- self.cfg_only = model_cfg
1618
-
1619
- def load_model(self, device_map=None):
1620
- if self.is_loaded:
1621
- print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
1622
- return
1623
-
1624
- # Load model config
1625
- model_cfg = load_model_config(self.vision_tower_name)
1626
-
1627
- # Override default image resolution
1628
- model_cfg["image_cfg"]["image_size"] = self.input_image_size
1629
-
1630
- self.cfg_only = model_cfg
1631
-
1632
- # Build HF CLIPImageProcessor with MobileCLIP parameters
1633
- self.image_processor = CLIPImageProcessor(crop_size={"height": model_cfg["image_cfg"]["image_size"],
1634
- "width": model_cfg["image_cfg"]["image_size"]},
1635
- image_mean=[0.0, 0.0, 0.0],
1636
- image_std=[1.0, 1.0, 1.0],
1637
- size={"shortest_edge": model_cfg["image_cfg"]["image_size"]})
1638
-
1639
- # Instantiate the image encoder
1640
- self.vision_tower = MCi(model_name=model_cfg["image_cfg"]["model_name"],
1641
- projection_dim=model_cfg["embed_dim"])
1642
-
1643
- if not self.tune_vision_tower:
1644
- self.vision_tower.requires_grad_(False)
1645
-
1646
- self.is_loaded = True
1647
-
1648
- def feature_select(self, image_forward_outs):
1649
- # Features from penultimate layer
1650
- image_features = image_forward_outs["image_embeddings"]
1651
-
1652
- # Reshape 4D tensor to 3D
1653
- B, C, H, W = image_features.shape
1654
- image_features = image_features.reshape(B, C, H*W)
1655
- image_features = image_features.transpose(1, 2)
1656
- return image_features
1657
-
1658
- def forward(self, images):
1659
- if self.tune_vision_tower:
1660
- return self.forward_images(images)
1661
- else:
1662
- with torch.no_grad():
1663
- return self.forward_images(images)
1664
-
1665
- def forward_images(self, images):
1666
- if type(images) is list:
1667
- image_features = []
1668
- for image in images:
1669
- image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), return_image_embeddings=True)
1670
- image_feature = self.feature_select(image_forward_out).to(image.dtype)
1671
- image_features.append(image_feature)
1672
- else:
1673
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), return_image_embeddings=True)
1674
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
1675
-
1676
- return image_features
1677
-
1678
- @property
1679
- def dummy_feature(self):
1680
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
1681
-
1682
- @property
1683
- def dtype(self):
1684
- return next(self.vision_tower.parameters()).dtype
1685
-
1686
- @property
1687
- def device(self):
1688
- return next(self.vision_tower.parameters()).device
1689
-
1690
- @property
1691
- def config(self):
1692
- return self.cfg_only
1693
-
1694
- @property
1695
- def hidden_size(self):
1696
- return self.config["image_cfg"]["embed_dim"]
1697
-
1698
- @property
1699
- def num_patches_per_side(self):
1700
- return self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]
1701
-
1702
- @property
1703
- def num_patches(self):
1704
- return (self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]) ** 2
1705
-
1706
- class IdentityMap(nn.Module):
1707
- def __init__(self):
1708
- super().__init__()
1709
-
1710
- def forward(self, x, *args, **kwargs):
1711
- return x
1712
-
1713
- @property
1714
- def config(self):
1715
- return {"mm_projector_type": 'identity'}
1716
-
1717
- def build_vision_projector(config, delay_load=False, **kwargs):
1718
- projector_type = getattr(config, 'mm_projector_type', 'linear')
1719
-
1720
- if projector_type == 'linear':
1721
- return nn.Linear(config.mm_hidden_size, config.hidden_size)
1722
-
1723
- mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
1724
- if mlp_gelu_match:
1725
- mlp_depth = int(mlp_gelu_match.group(1))
1726
- modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
1727
- for _ in range(1, mlp_depth):
1728
- modules.append(nn.GELU())
1729
- modules.append(nn.Linear(config.hidden_size, config.hidden_size))
1730
- return nn.Sequential(*modules)
1731
-
1732
- if projector_type == 'identity':
1733
- return IdentityMap()
1734
-
1735
- raise ValueError(f'Unknown projector type: {projector_type}')
1736
-
1737
- def build_vision_tower(vision_tower_cfg, **kwargs):
1738
- vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
1739
- return MobileCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
1740
-
1741
- class LlavaMetaModel:
1742
-
1743
- def __init__(self, config):
1744
- super(LlavaMetaModel, self).__init__(config)
1745
-
1746
- if hasattr(config, "mm_vision_tower"):
1747
- self.vision_tower = build_vision_tower(config, delay_load=True)
1748
- self.mm_projector = build_vision_projector(config)
1749
-
1750
- if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
1751
- self.image_newline = nn.Parameter(
1752
- torch.empty(config.hidden_size, dtype=self.dtype)
1753
- )
1754
-
1755
- def get_vision_tower(self):
1756
- vision_tower = getattr(self, 'vision_tower', None)
1757
- if type(vision_tower) is list:
1758
- vision_tower = vision_tower[0]
1759
- return vision_tower
1760
-
1761
- def initialize_vision_modules(self, model_args, fsdp=None):
1762
- vision_tower = model_args.vision_tower
1763
- mm_vision_select_layer = model_args.mm_vision_select_layer
1764
- mm_vision_select_feature = model_args.mm_vision_select_feature
1765
- pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
1766
- mm_patch_merge_type = model_args.mm_patch_merge_type
1767
-
1768
- self.config.mm_vision_tower = vision_tower
1769
-
1770
- if self.get_vision_tower() is None:
1771
- vision_tower = build_vision_tower(model_args)
1772
-
1773
- if fsdp is not None and len(fsdp) > 0:
1774
- self.vision_tower = [vision_tower]
1775
- else:
1776
- self.vision_tower = vision_tower
1777
- else:
1778
- if fsdp is not None and len(fsdp) > 0:
1779
- vision_tower = self.vision_tower[0]
1780
- else:
1781
- vision_tower = self.vision_tower
1782
- vision_tower.load_model()
1783
-
1784
- self.config.use_mm_proj = True
1785
- self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
1786
- self.config.mm_hidden_size = vision_tower.hidden_size
1787
- self.config.mm_vision_select_layer = mm_vision_select_layer
1788
- self.config.mm_vision_select_feature = mm_vision_select_feature
1789
- self.config.mm_patch_merge_type = mm_patch_merge_type
1790
-
1791
- if getattr(self, 'mm_projector', None) is None:
1792
- self.mm_projector = build_vision_projector(self.config)
1793
-
1794
- if 'unpad' in mm_patch_merge_type:
1795
- embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
1796
- self.image_newline = nn.Parameter(
1797
- torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
1798
- )
1799
- else:
1800
- # In case it is frozen by LoRA
1801
- for p in self.mm_projector.parameters():
1802
- p.requires_grad = True
1803
-
1804
- if pretrain_mm_mlp_adapter is not None:
1805
- mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
1806
-
1807
- def get_w(weights, keyword):
1808
- return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
1809
-
1810
- self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
1811
-
1812
- def select_best_resolution(original_size, possible_resolutions):
1813
- """
1814
- Selects the best resolution from a list of possible resolutions based on the original size.
1815
-
1816
- Args:
1817
- original_size (tuple): The original size of the image in the format (width, height).
1818
- possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
1819
-
1820
- Returns:
1821
- tuple: The best fit resolution in the format (width, height).
1822
- """
1823
- original_width, original_height = original_size
1824
- best_fit = None
1825
- max_effective_resolution = 0
1826
- min_wasted_resolution = float('inf')
1827
-
1828
- for width, height in possible_resolutions:
1829
- scale = min(width / original_width, height / original_height)
1830
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
1831
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
1832
- wasted_resolution = (width * height) - effective_resolution
1833
-
1834
- if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
1835
- max_effective_resolution = effective_resolution
1836
- min_wasted_resolution = wasted_resolution
1837
- best_fit = (width, height)
1838
-
1839
- return best_fit
1840
-
1841
- def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
1842
- """
1843
- Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
1844
-
1845
- Args:
1846
- image_size (tuple): The size of the input image in the format (width, height).
1847
- grid_pinpoints (str): A string representation of a list of possible resolutions.
1848
- patch_size (int): The size of each image patch.
1849
-
1850
- Returns:
1851
- tuple: The shape of the image patch grid in the format (width, height).
1852
- """
1853
- import ast
1854
- if type(grid_pinpoints) is list:
1855
- possible_resolutions = grid_pinpoints
1856
- else:
1857
- possible_resolutions = ast.literal_eval(grid_pinpoints)
1858
- width, height = select_best_resolution(image_size, possible_resolutions)
1859
- return width // patch_size, height // patch_size
1860
-
1861
- class LlavaMetaForCausalLM(ABC):
1862
-
1863
- @abstractmethod
1864
- def get_model(self):
1865
- pass
1866
-
1867
- def get_vision_tower(self):
1868
- return self.get_model().get_vision_tower()
1869
-
1870
- def encode_images(self, images):
1871
- image_features = self.get_model().get_vision_tower()(images)
1872
- image_features = self.get_model().mm_projector(image_features)
1873
- return image_features
1874
-
1875
- def prepare_inputs_labels_for_multimodal(
1876
- self, input_ids, position_ids, attention_mask, past_key_values, labels,
1877
- images, image_sizes=None
1878
- ):
1879
- vision_tower = self.get_vision_tower()
1880
- if vision_tower is None or images is None or input_ids.shape[1] == 1:
1881
- return input_ids, position_ids, attention_mask, past_key_values, None, labels
1882
-
1883
- if type(images) is list or images.ndim == 5:
1884
- if type(images) is list:
1885
- images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
1886
- concat_images = torch.cat([image for image in images], dim=0)
1887
- image_features = self.encode_images(concat_images)
1888
- split_sizes = [image.shape[0] for image in images]
1889
- image_features = torch.split(image_features, split_sizes, dim=0)
1890
- mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
1891
- image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
1892
- if mm_patch_merge_type == 'flat':
1893
- image_features = [x.flatten(0, 1) for x in image_features]
1894
- elif mm_patch_merge_type.startswith('spatial'):
1895
- new_image_features = []
1896
- for image_idx, image_feature in enumerate(image_features):
1897
- if image_feature.shape[0] > 1:
1898
- base_image_feature = image_feature[0]
1899
- image_feature = image_feature[1:]
1900
- height = width = self.get_vision_tower().num_patches_per_side
1901
- assert height * width == base_image_feature.shape[0]
1902
- if image_aspect_ratio == 'anyres':
1903
- if hasattr(self.get_vision_tower(), 's2_image_size'):
1904
- img_size = self.get_vision_tower().s2_image_size
1905
- elif isinstance(self.get_vision_tower().config, dict):
1906
- img_size = self.get_vision_tower().config["image_cfg"]["image_size"]
1907
- else:
1908
- img_size = self.get_vision_tower().config.image_size
1909
-
1910
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, img_size)
1911
- image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
1912
- else:
1913
- raise NotImplementedError
1914
- if 'unpad' in mm_patch_merge_type:
1915
- image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
1916
- image_feature = image_feature.flatten(1, 2).flatten(2, 3)
1917
- image_feature = unpad_image(image_feature, image_sizes[image_idx])
1918
- image_feature = torch.cat((
1919
- image_feature,
1920
- self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
1921
- ), dim=-1)
1922
- image_feature = image_feature.flatten(1, 2).transpose(0, 1)
1923
- else:
1924
- image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
1925
- image_feature = image_feature.flatten(0, 3)
1926
- image_feature = torch.cat((base_image_feature, image_feature), dim=0)
1927
- else:
1928
- image_feature = image_feature[0]
1929
- if 'unpad' in mm_patch_merge_type:
1930
- image_feature = torch.cat((
1931
- image_feature,
1932
- self.model.image_newline[None].to(image_feature.device)
1933
- ), dim=0)
1934
- new_image_features.append(image_feature)
1935
- image_features = new_image_features
1936
- else:
1937
- raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
1938
- else:
1939
- image_features = self.encode_images(images)
1940
-
1941
- # TODO: image start / end is not implemented here to support pretraining.
1942
- if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
1943
- raise NotImplementedError
1944
-
1945
- # Let's just add dummy tensors if they do not exist,
1946
- # it is a headache to deal with None all the time.
1947
- # But it is not ideal, and if you have a better idea,
1948
- # please open an issue / submit a PR, thanks.
1949
- _labels = labels
1950
- _position_ids = position_ids
1951
- _attention_mask = attention_mask
1952
- if attention_mask is None:
1953
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1954
- else:
1955
- attention_mask = attention_mask.bool()
1956
- if position_ids is None:
1957
- position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
1958
- if labels is None:
1959
- labels = torch.full_like(input_ids, IGNORE_INDEX)
1960
-
1961
- # remove the padding using attention_mask -- FIXME
1962
- _input_ids = input_ids
1963
- input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
1964
- labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
1965
-
1966
- new_input_embeds = []
1967
- new_labels = []
1968
- cur_image_idx = 0
1969
- for batch_idx, cur_input_ids in enumerate(input_ids):
1970
- num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
1971
- if num_images == 0:
1972
- cur_image_features = image_features[cur_image_idx]
1973
- cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
1974
- cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
1975
- new_input_embeds.append(cur_input_embeds)
1976
- new_labels.append(labels[batch_idx])
1977
- cur_image_idx += 1
1978
- continue
1979
-
1980
- image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
1981
- cur_input_ids_noim = []
1982
- cur_labels = labels[batch_idx]
1983
- cur_labels_noim = []
1984
- for i in range(len(image_token_indices) - 1):
1985
- cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
1986
- cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
1987
- split_sizes = [x.shape[0] for x in cur_labels_noim]
1988
- cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
1989
- cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
1990
- cur_new_input_embeds = []
1991
- cur_new_labels = []
1992
-
1993
- for i in range(num_images + 1):
1994
- cur_new_input_embeds.append(cur_input_embeds_no_im[i])
1995
- cur_new_labels.append(cur_labels_noim[i])
1996
- if i < num_images:
1997
- cur_image_features = image_features[cur_image_idx]
1998
- cur_image_idx += 1
1999
- cur_new_input_embeds.append(cur_image_features)
2000
- cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
2001
-
2002
- cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
2003
-
2004
- cur_new_input_embeds = torch.cat(cur_new_input_embeds)
2005
- cur_new_labels = torch.cat(cur_new_labels)
2006
-
2007
- new_input_embeds.append(cur_new_input_embeds)
2008
- new_labels.append(cur_new_labels)
2009
-
2010
- # Truncate sequences to max length as image embeddings can make the sequence longer
2011
- tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
2012
- if tokenizer_model_max_length is not None:
2013
- new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
2014
- new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
2015
-
2016
- # Combine them
2017
- max_len = max(x.shape[0] for x in new_input_embeds)
2018
- batch_size = len(new_input_embeds)
2019
-
2020
- new_input_embeds_padded = []
2021
- new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
2022
- attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
2023
- position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
2024
-
2025
- for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
2026
- cur_len = cur_new_embed.shape[0]
2027
- if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
2028
- new_input_embeds_padded.append(torch.cat((
2029
- torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
2030
- cur_new_embed
2031
- ), dim=0))
2032
- if cur_len > 0:
2033
- new_labels_padded[i, -cur_len:] = cur_new_labels
2034
- attention_mask[i, -cur_len:] = True
2035
- position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
2036
- else:
2037
- new_input_embeds_padded.append(torch.cat((
2038
- cur_new_embed,
2039
- torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
2040
- ), dim=0))
2041
- if cur_len > 0:
2042
- new_labels_padded[i, :cur_len] = cur_new_labels
2043
- attention_mask[i, :cur_len] = True
2044
- position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
2045
-
2046
- new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
2047
-
2048
- if _labels is None:
2049
- new_labels = None
2050
- else:
2051
- new_labels = new_labels_padded
2052
-
2053
- if _attention_mask is None:
2054
- attention_mask = None
2055
- else:
2056
- attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
2057
-
2058
- if _position_ids is None:
2059
- position_ids = None
2060
-
2061
- return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
2062
-
2063
- def initialize_vision_tokenizer(self, model_args, tokenizer):
2064
- if model_args.mm_use_im_patch_token:
2065
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
2066
- self.resize_token_embeddings(len(tokenizer))
2067
-
2068
- if model_args.mm_use_im_start_end:
2069
- num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
2070
- self.resize_token_embeddings(len(tokenizer))
2071
-
2072
- if num_new_tokens > 0:
2073
- input_embeddings = self.get_input_embeddings().weight.data
2074
- output_embeddings = self.get_output_embeddings().weight.data
2075
-
2076
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
2077
- dim=0, keepdim=True)
2078
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
2079
- dim=0, keepdim=True)
2080
-
2081
- input_embeddings[-num_new_tokens:] = input_embeddings_avg
2082
- output_embeddings[-num_new_tokens:] = output_embeddings_avg
2083
-
2084
- if model_args.tune_mm_mlp_adapter:
2085
- for p in self.get_input_embeddings().parameters():
2086
- p.requires_grad = True
2087
- for p in self.get_output_embeddings().parameters():
2088
- p.requires_grad = False
2089
-
2090
- if model_args.pretrain_mm_mlp_adapter:
2091
- mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
2092
- embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
2093
- assert num_new_tokens == 2
2094
- if input_embeddings.shape == embed_tokens_weight.shape:
2095
- input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
2096
- elif embed_tokens_weight.shape[0] == num_new_tokens:
2097
- input_embeddings[-num_new_tokens:] = embed_tokens_weight
2098
- else:
2099
- raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
2100
- elif model_args.mm_use_im_patch_token:
2101
- if model_args.tune_mm_mlp_adapter:
2102
- for p in self.get_input_embeddings().parameters():
2103
- p.requires_grad = False
2104
- for p in self.get_output_embeddings().parameters():
2105
- p.requires_grad = False
2106
-
2107
-
2108
- class LlavaQwen2Model(LlavaMetaModel, Qwen2Model):
2109
- config_class = LlavaConfig
2110
-
2111
- def __init__(self, config: Qwen2Config):
2112
- super(LlavaQwen2Model, self).__init__(config)
2113
-
2114
-
2115
- class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
2116
- config_class = LlavaConfig
2117
-
2118
- def __init__(self, config):
2119
- super(Qwen2ForCausalLM, self).__init__(config)
2120
- self.model = LlavaQwen2Model(config)
2121
- # self.pretraining_tp = config.pretraining_tp
2122
- self.vocab_size = config.vocab_size
2123
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
2124
-
2125
- # Initialize weights and apply final processing
2126
- self.post_init()
2127
-
2128
- def get_model(self):
2129
- return self.model
2130
-
2131
- def forward(
2132
- self,
2133
- input_ids: torch.LongTensor = None,
2134
- attention_mask: Optional[torch.Tensor] = None,
2135
- position_ids: Optional[torch.LongTensor] = None,
2136
- past_key_values: Optional[List[torch.FloatTensor]] = None,
2137
- inputs_embeds: Optional[torch.FloatTensor] = None,
2138
- labels: Optional[torch.LongTensor] = None,
2139
- use_cache: Optional[bool] = None,
2140
- output_attentions: Optional[bool] = None,
2141
- output_hidden_states: Optional[bool] = None,
2142
- images: Optional[torch.FloatTensor] = None,
2143
- image_sizes: Optional[List[List[int]]] = None,
2144
- return_dict: Optional[bool] = None,
2145
- cache_position=None,
2146
- ) -> Union[Tuple, CausalLMOutputWithPast]:
2147
-
2148
- if inputs_embeds is None:
2149
- (
2150
- input_ids,
2151
- position_ids,
2152
- attention_mask,
2153
- past_key_values,
2154
- inputs_embeds,
2155
- labels
2156
- ) = self.prepare_inputs_labels_for_multimodal(
2157
- input_ids,
2158
- position_ids,
2159
- attention_mask,
2160
- past_key_values,
2161
- labels,
2162
- images,
2163
- image_sizes
2164
- )
2165
-
2166
- return super().forward(
2167
- input_ids=input_ids,
2168
- attention_mask=attention_mask,
2169
- position_ids=position_ids,
2170
- past_key_values=past_key_values,
2171
- inputs_embeds=inputs_embeds,
2172
- labels=labels,
2173
- use_cache=use_cache,
2174
- output_attentions=output_attentions,
2175
- output_hidden_states=output_hidden_states,
2176
- return_dict=return_dict
2177
- )
2178
-
2179
- @torch.no_grad()
2180
- def generate(
2181
- self,
2182
- inputs: Optional[torch.Tensor] = None,
2183
- images: Optional[torch.Tensor] = None,
2184
- image_sizes: Optional[torch.Tensor] = None,
2185
- **kwargs,
2186
- ) -> Union[GenerateOutput, torch.LongTensor]:
2187
- position_ids = kwargs.pop("position_ids", None)
2188
- attention_mask = kwargs.pop("attention_mask", None)
2189
- if "inputs_embeds" in kwargs:
2190
- raise NotImplementedError("`inputs_embeds` is not supported")
2191
-
2192
- if images is not None:
2193
- (
2194
- inputs,
2195
- position_ids,
2196
- attention_mask,
2197
- _,
2198
- inputs_embeds,
2199
- _
2200
- ) = self.prepare_inputs_labels_for_multimodal(
2201
- inputs,
2202
- position_ids,
2203
- attention_mask,
2204
- None,
2205
- None,
2206
- images,
2207
- image_sizes=image_sizes
2208
- )
2209
- else:
2210
- inputs_embeds = self.get_model().embed_tokens(inputs)
2211
-
2212
- return super().generate(
2213
- position_ids=position_ids,
2214
- attention_mask=attention_mask,
2215
- inputs_embeds=inputs_embeds,
2216
- **kwargs
2217
- )
2218
-
2219
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
2220
- inputs_embeds=None, **kwargs):
2221
- images = kwargs.pop("images", None)
2222
- image_sizes = kwargs.pop("image_sizes", None)
2223
- inputs = super().prepare_inputs_for_generation(
2224
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
2225
- )
2226
- if images is not None:
2227
- inputs['images'] = images
2228
- if image_sizes is not None:
2229
- inputs['image_sizes'] = image_sizes
2230
- return inputs
2231
-
2232
-
2233
- AutoConfig.register("llava_qwen2", LlavaConfig)
2234
- AutoModelForCausalLM.register(LlavaConfig, LlavaQwen2ForCausalLM)