|
|
|
|
|
import math |
|
|
import os |
|
|
import sys |
|
|
|
|
|
from transformers import Trainer |
|
|
|
|
|
from swift.trainers.optimizers.galore import create_optimizer_and_scheduler |
|
|
from swift.utils import get_dist_setting |
|
|
|
|
|
|
|
|
def calculate_max_steps(args: 'TrainArguments', dataset) -> int: |
|
|
if args.max_steps and args.max_steps > 0: |
|
|
max_steps = args.max_steps |
|
|
else: |
|
|
len_dataset = len(dataset) |
|
|
_, _, world_size, _ = get_dist_setting() |
|
|
total_train_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * world_size |
|
|
num_update_steps_per_epoch = len_dataset // total_train_batch_size |
|
|
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) |
|
|
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) |
|
|
return max_steps |
|
|
|
|
|
|
|
|
def create_galore_optimizer(args, model, dataset): |
|
|
training_steps = calculate_max_steps(args, dataset) |
|
|
optimizer, lr_scheduler = create_optimizer_and_scheduler( |
|
|
model, args, args.galore_config, training_steps, lr=args.learning_rate, weight_decay=args.weight_decay) |
|
|
|
|
|
args.galore_config = None |
|
|
return optimizer, lr_scheduler |
|
|
|
|
|
|
|
|
def create_lorap_optimizer(args, model, dataset): |
|
|
optimizer_grouped_parameters = None |
|
|
if hasattr(model, 'create_optimizer_param_groups'): |
|
|
|
|
|
optimizer_grouped_parameters = model.create_optimizer_param_groups( |
|
|
lr=args.learning_rate, weight_decay=args.weight_decay) |
|
|
|
|
|
if optimizer_grouped_parameters is None: |
|
|
|
|
|
decay_parameters = Trainer.get_decay_parameter_names(None, model) |
|
|
optimizer_grouped_parameters = [ |
|
|
{ |
|
|
'params': [p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad)], |
|
|
'weight_decay': args.weight_decay, |
|
|
}, |
|
|
{ |
|
|
'params': [p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad)], |
|
|
'weight_decay': 0.0, |
|
|
}, |
|
|
] |
|
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args) |
|
|
return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs), None |
|
|
|
|
|
|
|
|
def create_muon_optimizer(args, model, dataset): |
|
|
from swift.llm import git_clone_github, get_model_arch |
|
|
if not args.local_repo_path: |
|
|
args.local_repo_path = git_clone_github('https://github.com/MoonshotAI/Moonlight.git') |
|
|
sys.path.append(os.path.join(args.local_repo_path, 'examples')) |
|
|
from toy_train import Muon |
|
|
|
|
|
|
|
|
optim_args = {} |
|
|
if args.optim_args: |
|
|
for mapping in args.optim_args.replace(' ', '').split(','): |
|
|
key, value = mapping.split('=') |
|
|
optim_args[key] = value |
|
|
|
|
|
model_arch = get_model_arch(model.model_meta.model_arch) |
|
|
embed_key = model_arch.embedding or 'embed_tokens' |
|
|
lm_head_key = model_arch.lm_head or 'lm_head' |
|
|
muon_params = [ |
|
|
p for n, p in model.named_parameters() |
|
|
if p.requires_grad and p.ndim >= 2 and embed_key not in n and lm_head_key not in n |
|
|
] |
|
|
adamw_params = [ |
|
|
p for n, p in model.named_parameters() |
|
|
if p.requires_grad and not (p.ndim >= 2 and embed_key not in n and lm_head_key not in n) |
|
|
] |
|
|
|
|
|
return Muon( |
|
|
lr=args.learning_rate, |
|
|
wd=args.weight_decay, |
|
|
muon_params=muon_params, |
|
|
adamw_params=adamw_params, |
|
|
adamw_betas=(args.adam_beta1, args.adam_beta2), |
|
|
adamw_eps=args.adam_epsilon, |
|
|
**optim_args, |
|
|
), None |
|
|
|
|
|
|
|
|
|
|
|
optimizers_map = { |
|
|
'galore': create_galore_optimizer, |
|
|
'lorap': create_lorap_optimizer, |
|
|
'muon': create_muon_optimizer, |
|
|
} |
|
|
|