File size: 18,709 Bytes
aac4a9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d860d3
aac4a9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d860d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aac4a9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d860d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aac4a9c
3d860d3
 
 
 
 
 
 
 
 
 
 
 
aac4a9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d860d3
 
 
aac4a9c
 
 
 
 
 
 
 
 
 
 
 
3d860d3
 
 
 
 
 
 
 
 
 
 
 
 
 
aac4a9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d860d3
aac4a9c
 
 
 
 
3d860d3
aac4a9c
 
 
 
 
 
 
 
 
3d860d3
aac4a9c
 
 
 
3d860d3
aac4a9c
 
 
 
3d860d3
aac4a9c
 
 
 
 
 
 
 
3d860d3
 
 
aac4a9c
 
 
 
3d860d3
aac4a9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
from typing import Callable, Optional, Union

import torch
from torch import nn

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import (
    GenericForQuestionAnswering,
    GenericForSequenceClassification,
    GenericForTokenClassification,
    GradientCheckpointingLayer,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import check_model_inputs
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config

from transformers.models.qwen2.modeling_qwen2 import (
    Qwen2MLP,
    Qwen2Attention,
    apply_rotary_pos_emb,
    eager_attention_forward,
    Qwen2RMSNorm,
    Qwen2RotaryEmbedding,
    Qwen2Model,
    Qwen2ForCausalLM,
)

from transformers.modeling_layers import (
    GenericForQuestionAnswering,
    GenericForSequenceClassification,
    GenericForTokenClassification,
    GradientCheckpointingLayer,
)

from .configuration_fp8_qwen2 import FP8Qwen2Config

from torchao.float8.float8_training_tensor import Float8TrainingTensor

from quasar.module import (
    FP8Quant,
    FP8RMSNorm,
    FP8DSLinearWithCoat,
    FP8DSLinearWithCoatWeightBlock,
    FP8FusedSiLUMul,
    FP8Identity,
)

from quasar.kernel.configs import FP8RMSNormConfig, QuantType, FP8MulConfig, FP8DSLinearWithCoatConfig, FP8QuantConfig
from quasar.kernel.quant.quantize_hp2pb import fp8_quantize_hp2pb
from quasar.kernel.quant.dequantize_pb2hp import fp8_dequantize_pb2hp

logger = logging.get_logger(__name__)


class FP8Qwen2MLP(Qwen2MLP):
    def __init__(self, config: FP8Qwen2Config):
        super().__init__(config)
        linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock
        self.gate_proj = linear_module(
            self.hidden_size,
            self.intermediate_size,
            bias=False,
            dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"gate_proj", scale_dtype=torch.float32),
        )
        self.up_proj = linear_module(
            self.hidden_size,
            self.intermediate_size,
            bias=False,
            dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"up_proj", scale_dtype=torch.float32),
        )
        self.down_proj = linear_module(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
            dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"down_proj", scale_dtype=torch.float32),
        )

        if config.hidden_act == "silu":
            mul_config = FP8MulConfig(
                quant_type=QuantType.MUL,
            )
            self.act_fn = FP8FusedSiLUMul(mul_config)
        else:
            raise ValueError(f"Unsupported activation function: {config.hidden_act}")

    def forward(self, x):
        gate_x = self.gate_proj(x)
        up_x = self.up_proj(x)

        mul_x = self.act_fn(gate_x, up_x)
        down_proj = self.down_proj(mul_x)

        return down_proj


class FP8Qwen2Attention(Qwen2Attention):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: FP8Qwen2Config, layer_idx: int):
        super().__init__(config, layer_idx)
        linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock
        self.q_proj = linear_module(
            config.hidden_size,
            config.num_attention_heads * self.head_dim,
            bias=True,
            dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"q_proj", scale_dtype=torch.float32),
        )
        self.k_proj = linear_module(
            config.hidden_size,
            config.num_key_value_heads * self.head_dim,
            bias=True,
            dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"k_proj", scale_dtype=torch.float32),
        )
        self.v_proj = linear_module(
            config.hidden_size,
            config.num_key_value_heads * self.head_dim,
            bias=True,
            dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"v_proj", scale_dtype=torch.float32),
        )

        # In both training and inference, we quantize the output of the attention layer.
        self.o_proj_quant = FP8Quant(
            quant_config=FP8QuantConfig(
                float8_dtype=config.fp8_config.float8_dtype,
                quant_type=QuantType.DIV,
                fwd_block_size=config.fp8_config.mm_block_size,
                layer_name=f"o_proj_quant",
                scale_dtype=torch.float32,
            )
        )
        self.o_proj = linear_module(
            config.num_attention_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            dsgemm_config=FP8DSLinearWithCoatConfig(
                fwd_input_quant_type=QuantType.DIV,
                layer_name=f"o_proj",
                scale_dtype=torch.float32,
            ),
        )

    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        if isinstance(hidden_states, Float8TrainingTensor):
            # Float8Tensor's last dim is quantize group size, not hidden size.
            input_shape = hidden_states.shape[:-2]
        else:
            input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # TODO: Add quantization

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=self.sliding_window,  # main diff with Qwen2
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()

        # Quantize the output of the attention layer.
        attn_output = self.o_proj_quant(attn_output)
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class FP8Qwen2DecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: FP8Qwen2Config, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = FP8Qwen2Attention(config=config, layer_idx=layer_idx)

        self.mlp = FP8Qwen2MLP(config)
        self.input_layernorm = FP8RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
            norm_config=FP8RMSNormConfig(
                mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True
            ),
        )
        self.post_attention_layernorm = FP8RMSNorm(
            config.hidden_size,
            eps=config.rms_norm_eps,
            norm_config=FP8RMSNormConfig(
                mm_block_size=config.fp8_config.mm_block_size, quant_type=QuantType.MUL, save_fp8_input=True
            ),
        )
        self.attention_type = config.layer_types[layer_idx]

    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        # Self Attention
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


@auto_docstring
class FP8Qwen2PreTrainedModel(PreTrainedModel):
    config_class = FP8Qwen2Config
    config: FP8Qwen2Config
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["FP8Qwen2DecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn = True
    _supports_sdpa = True
    _supports_flex_attn = True

    _can_compile_fullgraph = True
    _supports_attention_backend = True
    _can_record_outputs = {
        "hidden_states": FP8Qwen2DecoderLayer,
        "attentions": FP8Qwen2Attention,
    }


@auto_docstring
class FP8Qwen2Model(FP8Qwen2PreTrainedModel):
    def __init__(self, config: FP8Qwen2Config):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [FP8Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = Qwen2RotaryEmbedding(config=config)
        self.gradient_checkpointing = False
        self.has_sliding_layers = "sliding_attention" in self.config.layer_types

        # Initialize weights and apply final processing
        self.post_init()

    forward = Qwen2Model.forward


@auto_docstring
class FP8Qwen2ForCausalLM(FP8Qwen2PreTrainedModel, GenerationMixin):
    config_class = FP8Qwen2Config
    _tied_weights_keys = ["lm_head.weight"]
    _tp_plan = {"lm_head": "colwise_rep"}
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config):
        super().__init__(config)
        self.model = FP8Qwen2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[TransformersKwargs],
    ) -> CausalLMOutputWithPast:
        r"""
        Example:

        ```python
        >>> from transformers import AutoTokenizer, Qwen2ForCausalLM

        >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class FP8Qwen2ForSequenceClassification(GenericForSequenceClassification, FP8Qwen2PreTrainedModel):
    pass


class FP8Qwen2ForTokenClassification(GenericForTokenClassification, FP8Qwen2PreTrainedModel):
    pass


class FP8Qwen2ForQuestionAnswering(GenericForQuestionAnswering, FP8Qwen2PreTrainedModel):
    base_model_prefix = "transformer"  # For BC, where `transformer` was used instead of `model`


__all__ = [
    "FP8Qwen2PreTrainedModel",
    "FP8Qwen2Model",
    "FP8Qwen2ForCausalLM",
    "FP8Qwen2ForSequenceClassification",
    "FP8Qwen2ForTokenClassification",
    "FP8Qwen2ForQuestionAnswering",
]


FP8Qwen2Model.register_for_auto_class("AutoModel")
FP8Qwen2ForCausalLM.register_for_auto_class("AutoModelForCausalLM")


def make_state_dict_compatible_with_hf(
    state_dict: dict[str, torch.Tensor],
    linear_keys: list[str],
    undesired_linear_keys: list[str],
    config: FP8Qwen2Config = FP8Qwen2Config(),
    already_fp8: bool = False,
) -> dict[str, torch.Tensor]:
    """
    Make the state dict compatible with HuggingFace.
    """
    # Assert linear keys and undesired linear keys are non-overlapping
    assert set(linear_keys).isdisjoint(set(undesired_linear_keys))

    compatible_state_dict = {}

    for key in state_dict.keys():
        if any(k in key for k in linear_keys):
            weight = state_dict[key]

            if already_fp8:
                # The name (either weight or weight_scale_inv) is the same as the original key.
                compatible_state_dict[key] = weight
            else:
                # We need to use float32 for the scale, since we are using DeepGEMM.
                tmp_quant_cfg = FP8QuantConfig(
                    float8_dtype=config.fp8_config.float8_dtype,
                    quant_type=config.fp8_config.quant_type,
                    fwd_block_size=config.fp8_config.mm_block_size,
                    scale_dtype=torch.float32,
                )
                quant_weight, scale_weight = fp8_quantize_hp2pb(
                    weight, tmp_quant_cfg, block_size=config.fp8_config.mm_block_size
                )

                name_quant = key.replace("weight", "weight")
                name_scale = key.replace("weight", "weight_scale_inv")
                compatible_state_dict[name_quant] = quant_weight
                compatible_state_dict[name_scale] = scale_weight

        elif any(k in key for k in undesired_linear_keys):
            # Dequantize the weight
            if already_fp8:
                # We only do the dequantization once. When encountering the weight, we skip it.
                if "weight_scale_inv" in key:
                    name_quant = key.replace("weight_scale_inv", "weight")
                    quant_weight = state_dict[name_quant]
                    scale_weight = state_dict[key]
                    weight = fp8_dequantize_pb2hp(
                        quant_weight, scale_weight, config.fp8_config, block_size=config.fp8_config.mm_block_size
                    )
                    compatible_state_dict[name_quant] = weight
            else:
                # Do not quantize the weight.
                compatible_state_dict[key] = state_dict[key]

        else:
            compatible_state_dict[key] = state_dict[key]
    return compatible_state_dict


def set_named_weight_to_fp8(
    model: FP8Qwen2ForCausalLM,
    linear_keys: list[str] = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
):
    """
    Set the dtype of the weight of the linear layers to FP8.
    Also set layer name for debugging.
    """
    for name, module in model.named_modules():
        if any(k in name for k in linear_keys):
            module.weight.data = module.weight.data.to(torch.float8_e4m3fn)
            module.layer_name = name

    return model