File size: 5,489 Bytes
717a938 c4d2ee1 717a938 c4d2ee1 717a938 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# pylint: disable=too-many-arguments, too-many-positional-arguments
"""
Author : Fabien FURFARO
"""
from typing import Optional, Union
from transformers import PreTrainedModel, TrainerCallback
from .modeling_tptt import LiZAttention
class LiZACallback(TrainerCallback):
"""
TrainerCallback to schedule mag_weight or enable/disable linear attention during training.
Modes:
- "gradual": linear interpolation from initial_weight to final_weight.
- "cyclic": alternate between values in weight_list at each step.
- "switch": alternately enable/disable linear attention at each step.
"""
def __init__(
self,
model: PreTrainedModel,
mode: str = "gradual",
initial_weight: float = 0.0,
final_weight: float = 0.5,
transition_step: Union[int, tuple, list] = 100,
weight_list: Optional[list] = None,
switch_period: int = 1, # period for switching
):
self.model = model
self.mode = mode
# Ensure initial_weight is a float scalar, not tuple/list
if isinstance(initial_weight, (tuple, list)):
initial_weight = initial_weight[0]
if isinstance(final_weight, (tuple, list)):
final_weight = final_weight[0]
self.initial_weight = float(initial_weight)
self.final_weight = float(final_weight)
# Ensure transition_step is an int scalar, not tuple/list
self.transition_step = ensure_int(transition_step)
if self.mode == "constant":
# For constant mode, transition_step is not used
self.initial_weight = self.final_weight
# For cyclic mode: ensure all weights are float scalars
if weight_list is not None:
self.weight_list = [
float(w[0]) if isinstance(w, (tuple, list)) else float(w)
for w in weight_list
]
else:
self.weight_list = [self.initial_weight, self.final_weight]
# For switch_alternate mode
self.switch_period = int(switch_period)
def on_step_end(self, args, state, control, **kwargs):
current_step = state.global_step
transition_step = self.transition_step
# Ensure current_step and transition_step are plain ints
current_step = ensure_int(current_step)
transition_step = ensure_int(transition_step)
# Select mag_weight or enable/disable linear attention according to mode
if self.mode == "constant":
# Set mag_weight to final_weight for constant mode
weight = self.final_weight
for _, module in self.model.named_modules():
if isinstance(module, LiZAttention):
module.mag_weight = weight
elif self.mode == "gradual":
if current_step <= transition_step:
weight = self.initial_weight + (
self.final_weight - self.initial_weight
) * (current_step / transition_step)
else:
weight = self.final_weight
for _, module in self.model.named_modules():
if isinstance(module, LiZAttention):
module.mag_weight = weight
elif self.mode == "cyclic":
idx = current_step % len(self.weight_list)
weight = self.weight_list[idx]
for _, module in self.model.named_modules():
if isinstance(module, LiZAttention):
module.mag_weight = weight
elif self.mode == "switch":
# Alternately enable/disable linear attention every switch_period steps
disable = (current_step // self.switch_period) % 2 == 0
for _, module in self.model.named_modules():
if isinstance(module, LiZAttention):
module.disable_linear_attn = disable
else:
raise ValueError(f"Unknown mode: {self.mode}")
def on_log(self, args, state, control, logs=None, **kwargs):
mag_weight = None
disable_linear_attn = None
# Log the current mag_weight and disable_linear_attn
for _, module in self.model.named_modules():
if isinstance(module, LiZAttention):
mag_weight = getattr(module, "mag_weight", None)
disable_linear_attn = getattr(module, "disable_linear_attn", None)
break
if mag_weight is not None and logs is not None:
logs["mag_weight"] = float(mag_weight)
if disable_linear_attn is not None and logs is not None:
logs["disable_linear_attn"] = not bool(disable_linear_attn)
def ensure_int(value: Union[int, tuple, list]) -> int:
"""Ensure the value is a plain integer."""
if isinstance(value, (tuple, list)):
value = int(value[0])
if hasattr(value, "item"):
value = int(value.item())
return value
class SaveBestModelCallback(TrainerCallback):
"""TrainerCallback to save the best model based on evaluation loss."""
def __init__(self):
self.best_metric = float("inf")
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if metrics is not None and "eval_loss" in metrics:
if metrics["eval_loss"] < self.best_metric:
self.best_metric = metrics["eval_loss"]
control.should_save = True # Trigger save
else:
control.should_save = False # Skip save
|