File size: 5,489 Bytes
717a938
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4d2ee1
 
 
717a938
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4d2ee1
 
 
 
 
 
 
 
717a938
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# pylint: disable=too-many-arguments, too-many-positional-arguments

"""
Author : Fabien FURFARO
"""

from typing import Optional, Union

from transformers import PreTrainedModel, TrainerCallback

from .modeling_tptt import LiZAttention


class LiZACallback(TrainerCallback):
    """
    TrainerCallback to schedule mag_weight or enable/disable linear attention during training.

    Modes:
        - "gradual": linear interpolation from initial_weight to final_weight.
        - "cyclic": alternate between values in weight_list at each step.
        - "switch": alternately enable/disable linear attention at each step.
    """

    def __init__(
        self,
        model: PreTrainedModel,
        mode: str = "gradual",
        initial_weight: float = 0.0,
        final_weight: float = 0.5,
        transition_step: Union[int, tuple, list] = 100,
        weight_list: Optional[list] = None,
        switch_period: int = 1,  # period for switching
    ):
        self.model = model
        self.mode = mode

        # Ensure initial_weight is a float scalar, not tuple/list
        if isinstance(initial_weight, (tuple, list)):
            initial_weight = initial_weight[0]
        if isinstance(final_weight, (tuple, list)):
            final_weight = final_weight[0]
        self.initial_weight = float(initial_weight)
        self.final_weight = float(final_weight)

        # Ensure transition_step is an int scalar, not tuple/list
        self.transition_step = ensure_int(transition_step)
        if self.mode == "constant":
            # For constant mode, transition_step is not used
            self.initial_weight = self.final_weight
        # For cyclic mode: ensure all weights are float scalars
        if weight_list is not None:
            self.weight_list = [
                float(w[0]) if isinstance(w, (tuple, list)) else float(w)
                for w in weight_list
            ]
        else:
            self.weight_list = [self.initial_weight, self.final_weight]

        # For switch_alternate mode
        self.switch_period = int(switch_period)

    def on_step_end(self, args, state, control, **kwargs):
        current_step = state.global_step
        transition_step = self.transition_step

        # Ensure current_step and transition_step are plain ints
        current_step = ensure_int(current_step)
        transition_step = ensure_int(transition_step)

        # Select mag_weight or enable/disable linear attention according to mode
        if self.mode == "constant":
            # Set mag_weight to final_weight for constant mode
            weight = self.final_weight
            for _, module in self.model.named_modules():
                if isinstance(module, LiZAttention):
                    module.mag_weight = weight

        elif self.mode == "gradual":
            if current_step <= transition_step:
                weight = self.initial_weight + (
                    self.final_weight - self.initial_weight
                ) * (current_step / transition_step)
            else:
                weight = self.final_weight
            for _, module in self.model.named_modules():
                if isinstance(module, LiZAttention):
                    module.mag_weight = weight

        elif self.mode == "cyclic":
            idx = current_step % len(self.weight_list)
            weight = self.weight_list[idx]
            for _, module in self.model.named_modules():
                if isinstance(module, LiZAttention):
                    module.mag_weight = weight

        elif self.mode == "switch":
            # Alternately enable/disable linear attention every switch_period steps
            disable = (current_step // self.switch_period) % 2 == 0
            for _, module in self.model.named_modules():
                if isinstance(module, LiZAttention):
                    module.disable_linear_attn = disable

        else:
            raise ValueError(f"Unknown mode: {self.mode}")

    def on_log(self, args, state, control, logs=None, **kwargs):
        mag_weight = None
        disable_linear_attn = None
        # Log the current mag_weight and disable_linear_attn
        for _, module in self.model.named_modules():
            if isinstance(module, LiZAttention):
                mag_weight = getattr(module, "mag_weight", None)
                disable_linear_attn = getattr(module, "disable_linear_attn", None)
                break
        if mag_weight is not None and logs is not None:
            logs["mag_weight"] = float(mag_weight)
        if disable_linear_attn is not None and logs is not None:
            logs["disable_linear_attn"] = not bool(disable_linear_attn)


def ensure_int(value: Union[int, tuple, list]) -> int:
    """Ensure the value is a plain integer."""
    if isinstance(value, (tuple, list)):
        value = int(value[0])
    if hasattr(value, "item"):
        value = int(value.item())
    return value


class SaveBestModelCallback(TrainerCallback):
    """TrainerCallback to save the best model based on evaluation loss."""

    def __init__(self):
        self.best_metric = float("inf")

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is not None and "eval_loss" in metrics:
            if metrics["eval_loss"] < self.best_metric:
                self.best_metric = metrics["eval_loss"]
                control.should_save = True  # Trigger save
            else:
                control.should_save = False  # Skip save