diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/model/gpt/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6658a952ab2e7255c50ca4c5451060cbecb288a2
--- /dev/null
+++ b/swift/megatron/model/gpt/config.py
@@ -0,0 +1,13 @@
+from typing import Any, Dict
+
+from ..config import convert_hf_config
+
+
+def convert_gpt_hf_config(config) -> Dict[str, Any]:
+ res = convert_hf_config(config)
+ model_type = res.get('model_type')
+ if model_type in {'qwen3', 'qwen3_moe'}:
+ res['qk_layernorm'] = True
+ if model_type in {'qwen2_moe', 'qwen3_moe'}:
+ res.pop('ffn_hidden_size', None)
+ return res
diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py
new file mode 100644
index 0000000000000000000000000000000000000000..46525df3c757c6e83aaf0a87a783a7acfde68135
--- /dev/null
+++ b/swift/megatron/model/gpt/hf2mcore.py
@@ -0,0 +1,74 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import torch
+from megatron.training import get_args
+
+
+def set_attn_state(args, mg_attn, hf_attn):
+ num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads)
+
+ # Copy weights
+ mg_attn.linear_qkv.weight.data.copy_(
+ torch.cat([
+ hf_attn.q_proj.weight.reshape((num_query_groups, -1, args.hidden_size)),
+ hf_attn.k_proj.weight.reshape((num_query_groups, -1, args.hidden_size)),
+ hf_attn.v_proj.weight.reshape((num_query_groups, -1, args.hidden_size)),
+ ],
+ dim=1).reshape((-1, args.hidden_size)))
+ mg_attn.linear_proj.weight.data.copy_(hf_attn.o_proj.weight)
+
+ # Copy bias
+ if args.add_qkv_bias:
+ mg_attn.linear_qkv.bias.data.copy_(
+ torch.cat([
+ hf_attn.q_proj.bias.reshape((num_query_groups, -1)),
+ hf_attn.k_proj.bias.reshape((num_query_groups, -1)),
+ hf_attn.v_proj.bias.reshape((num_query_groups, -1)),
+ ],
+ dim=1).reshape(-1))
+ if args.qk_layernorm:
+ mg_attn.q_layernorm.weight.data.copy_(hf_attn.q_norm.weight)
+ mg_attn.k_layernorm.weight.data.copy_(hf_attn.k_norm.weight)
+
+
+def _set_mlp_state(mg_mlp, hf_mlp):
+ mg_mlp.linear_fc1.weight.data.copy_(torch.cat([hf_mlp.gate_proj.weight, hf_mlp.up_proj.weight], dim=0))
+ mg_mlp.linear_fc2.weight.data.copy_(hf_mlp.down_proj.weight)
+
+
+def set_mlp_state(args, mg_mlp, hf_mlp):
+ if args.num_experts:
+ mg_mlp.router.weight.data.copy_(hf_mlp.gate.weight)
+ if mg_mlp.shared_experts is not None:
+ mg_mlp.shared_experts.gate_weight.data.copy_(hf_mlp.shared_expert_gate.weight)
+ for expert_idx in range(args.num_experts):
+ _set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx])
+
+ if mg_mlp.shared_experts is not None:
+ _set_mlp_state(mg_mlp.shared_experts, hf_mlp.shared_expert)
+ else:
+ _set_mlp_state(mg_mlp, hf_mlp)
+
+
+def set_layer_state(args, mg_model, hf_model, layer_idx):
+ mg_layer = mg_model.decoder.layers[layer_idx]
+ hf_layer = hf_model.model.layers[layer_idx]
+
+ set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn)
+ set_mlp_state(args, mg_layer.mlp, hf_layer.mlp)
+
+ post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight
+ if args.num_experts:
+ mg_layer.pre_mlp_layernorm.weight.data.copy_(post_attention_layernorm_weight)
+ else:
+ mg_layer.mlp.linear_fc1.layer_norm_weight.data.copy_(post_attention_layernorm_weight)
+ mg_layer.self_attention.linear_qkv.layer_norm_weight.data.copy_(hf_layer.input_layernorm.weight)
+
+
+def convert_hf2mcore(hf_model, mg_model):
+ args = get_args()
+ mg_model.embedding.word_embeddings.weight.data.copy_(hf_model.model.embed_tokens.weight)
+ if args.untie_embeddings_and_output_weights:
+ mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight)
+ mg_model.decoder.final_layernorm.weight.data.copy_(hf_model.model.norm.weight)
+ for layer_idx in range(args.num_layers):
+ set_layer_state(args, mg_model, hf_model, layer_idx)
diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f29abaf0e63482ef7a538f1171a74be3f5ea162
--- /dev/null
+++ b/swift/megatron/model/gpt/mcore2hf.py
@@ -0,0 +1,70 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from megatron.training import get_args
+
+
+def set_attn_state(args, mg_attn, hf_attn):
+ num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads)
+ # Copy weights
+ mg_attn_weight = mg_attn.linear_qkv.weight.reshape((num_query_groups, -1, args.hidden_size))
+ q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[
+ 0] // num_query_groups
+ hf_attn.q_proj.weight.data.copy_(mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size))
+ hf_attn.k_proj.weight.data.copy_(mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size))
+ hf_attn.v_proj.weight.data.copy_(mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size))
+ hf_attn.o_proj.weight.data.copy_(mg_attn.linear_proj.weight)
+
+ # Copy bias
+ if args.add_qkv_bias:
+ mg_attn_bias = mg_attn.linear_qkv.bias.reshape((num_query_groups, -1))
+ hf_attn.q_proj.bias.data.copy_(mg_attn_bias[:, :q_dim].reshape(-1))
+ hf_attn.k_proj.bias.data.copy_(mg_attn_bias[:, q_dim:-kv_dim].reshape(-1))
+ hf_attn.v_proj.bias.data.copy_(mg_attn_bias[:, -kv_dim:].reshape(-1))
+
+ if args.qk_layernorm:
+ hf_attn.q_norm.weight.data.copy_(mg_attn.q_layernorm.weight)
+ hf_attn.k_norm.weight.data.copy_(mg_attn.k_layernorm.weight)
+
+
+def _set_mlp_state(mg_mlp, hf_mlp):
+ ffn_hidden_size = hf_mlp.gate_proj.weight.shape[0]
+ hf_mlp.gate_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[:ffn_hidden_size])
+ hf_mlp.up_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[ffn_hidden_size:])
+ hf_mlp.down_proj.weight.data.copy_(mg_mlp.linear_fc2.weight)
+
+
+def set_mlp_state(args, mg_mlp, hf_mlp):
+ if args.num_experts:
+ hf_mlp.gate.weight.data.copy_(mg_mlp.router.weight)
+ if mg_mlp.shared_experts is not None:
+ hf_mlp.shared_expert_gate.weight.data.copy_(mg_mlp.shared_experts.gate_weight)
+ for expert_idx in range(args.num_experts):
+ _set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx])
+
+ if mg_mlp.shared_experts is not None:
+ _set_mlp_state(mg_mlp.shared_experts, hf_mlp.shared_expert)
+ else:
+ _set_mlp_state(mg_mlp, hf_mlp)
+
+
+def set_layer_state(args, mg_model, hf_model, layer_idx):
+ mg_layer = mg_model.decoder.layers[layer_idx]
+ hf_layer = hf_model.model.layers[layer_idx]
+ set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn)
+ set_mlp_state(args, mg_layer.mlp, hf_layer.mlp)
+
+ post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight
+ if args.num_experts:
+ post_attention_layernorm_weight.data.copy_(mg_layer.pre_mlp_layernorm.weight)
+ else:
+ post_attention_layernorm_weight.data.copy_(mg_layer.mlp.linear_fc1.layer_norm_weight)
+ hf_layer.input_layernorm.weight.data.copy_(mg_layer.self_attention.linear_qkv.layer_norm_weight)
+
+
+def convert_mcore2hf(hf_model, mg_model):
+ args = get_args()
+ hf_model.model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight)
+ if args.untie_embeddings_and_output_weights:
+ hf_model.lm_head.weight.data.copy_(mg_model.output_layer.weight)
+ hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight)
+ for layer_idx in range(args.num_layers):
+ set_layer_state(args, mg_model, hf_model, layer_idx)
diff --git a/swift/megatron/model/gpt/model.py b/swift/megatron/model/gpt/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bc6bf4fbc32dead7f5cea13cb3eae754c832b3e
--- /dev/null
+++ b/swift/megatron/model/gpt/model.py
@@ -0,0 +1,37 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from megatron.core.models.gpt import GPTModel
+from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
+from megatron.training import get_args
+from megatron.training.arguments import core_transformer_config_from_args
+
+from ..rope import update_rope_inv_freq
+
+
+def model_provider(pre_process=True, post_process=True):
+ args = get_args()
+ config = core_transformer_config_from_args(args)
+ config.variable_seq_lengths = True
+ transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm,
+ args.qk_layernorm, args.multi_latent_attention)
+ if args.num_experts and args.moe_shared_expert_intermediate_size:
+ # qwen2_moe/qwen3_moe
+ transformer_layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True}
+ model = GPTModel(
+ config=config,
+ transformer_layer_spec=transformer_layer_spec,
+ vocab_size=args.padded_vocab_size,
+ max_sequence_length=args.max_position_embeddings,
+ pre_process=pre_process,
+ post_process=post_process,
+ fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
+ parallel_output=True,
+ share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
+ position_embedding_type=args.position_embedding_type,
+ rotary_percent=args.rotary_percent,
+ rotary_base=args.rotary_base,
+ rope_scaling=args.use_rope_scaling,
+ rope_scaling_factor=args.rope_scaling_factor,
+ seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor)
+ if args.rope_scaling:
+ update_rope_inv_freq(model.rotary_pos_emb.inv_freq, args.rope_scaling)
+ return model
diff --git a/swift/megatron/train/__init__.py b/swift/megatron/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f6a98be92e5e625a4295b74dee1e80cf0200608
--- /dev/null
+++ b/swift/megatron/train/__init__.py
@@ -0,0 +1,2 @@
+from .pt import megatron_pt_main
+from .sft import megatron_sft_main
diff --git a/swift/megatron/train/patcher.py b/swift/megatron/train/patcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a9862421746a4f8e20f92473269c3f596ce81e
--- /dev/null
+++ b/swift/megatron/train/patcher.py
@@ -0,0 +1,64 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+from contextlib import contextmanager
+from functools import wraps
+
+import torch
+from megatron.training import get_args, global_vars, initialize, training
+
+from swift.utils import JsonlWriter, is_master
+
+
+@contextmanager
+def patch_training_log():
+ jsonl_writer = None
+ origin_training_log = training.training_log
+
+ @wraps(origin_training_log)
+ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale,
+ report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad, *_args, **kwargs):
+ nonlocal jsonl_writer
+ args = get_args()
+ if is_master() and iteration % args.log_interval == 0:
+ logging_path = os.path.join(args.save, 'logging.jsonl')
+ logs = {}
+ for k, v in loss_dict.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ logs[k] = round(v, 8)
+ for k in {'grad_norm', 'params_norm', 'learning_rate'}:
+ v = locals()[k]
+ if v is not None:
+ logs[k] = round(v, 8)
+ logs['consumed_samples'] = args.consumed_train_samples
+ logs['global_step/max_steps'] = f'{iteration}/{args.train_iters}'
+ if jsonl_writer is None:
+ jsonl_writer = JsonlWriter(logging_path, enable_async=True)
+ jsonl_writer.append(logs)
+ return origin_training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration,
+ loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm,
+ num_zeros_in_grad, *_args, **kwargs)
+
+ training.training_log = training_log
+ try:
+ yield
+ finally:
+ training.training_log = origin_training_log
+
+
+@contextmanager
+def patch_megatron_data_collator(data_collator):
+ origin_build_pretraining_data_loader = training.build_pretraining_data_loader
+
+ def build_pretraining_data_loader(*_args, **kwargs):
+ args = get_args()
+ res = origin_build_pretraining_data_loader(*_args, **kwargs)
+ if res is not None and args.dataloader_type != 'external':
+ res.collate_fn = data_collator
+ return res
+
+ training.build_pretraining_data_loader = build_pretraining_data_loader
+ try:
+ yield
+ finally:
+ training.build_pretraining_data_loader = origin_build_pretraining_data_loader
diff --git a/swift/megatron/train/pt.py b/swift/megatron/train/pt.py
new file mode 100644
index 0000000000000000000000000000000000000000..16f4bcd5905615776b0ec04d915f2548213f4e77
--- /dev/null
+++ b/swift/megatron/train/pt.py
@@ -0,0 +1,19 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import List, Union
+
+from ..argument import MegatronTrainArguments
+from .sft import MegatronSft
+
+
+class MegatronPt(MegatronSft):
+ args_class = MegatronTrainArguments
+ args: args_class
+
+ def _prepare_template(self) -> None:
+ self.args.use_chat_template = False
+ super()._prepare_template()
+ self.template.loss_scale = 'all'
+
+
+def megatron_pt_main(args: Union[List[str], MegatronTrainArguments, None] = None):
+ return MegatronPt(args).main()
diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fa3e24f18e381f8f3e8d6b778e9138fbe048dfd
--- /dev/null
+++ b/swift/megatron/train/sft.py
@@ -0,0 +1,65 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+from typing import List, Union
+
+from megatron.core.enums import ModelType
+from megatron.training import pretrain
+
+from swift.llm.train import SwiftSft
+from swift.utils import get_logger, is_master, plot_images
+from ..argument import MegatronTrainArguments
+from ..utils import patch_megatron_tokenizer
+from .patcher import patch_megatron_data_collator, patch_training_log
+from .utils import build_streaming_dataloader, forward_step, get_swift_datasets_provider
+
+logger = get_logger()
+
+
+class MegatronSft(SwiftSft):
+ args_class = MegatronTrainArguments
+ args: args_class
+
+ def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None) -> None:
+ self.train_msg = {}
+ super(SwiftSft, self).__init__(args)
+ args = self.args
+ _, self.processor = args.get_model_processor(load_model=False)
+ patch_megatron_tokenizer(self.processor)
+ args.init_model_args(self.processor.model_info.config)
+ self._prepare_template()
+ self.template.use_megatron = True
+ args.save_args(args.save)
+
+ def run(self):
+ args = self.args
+
+ train_dataset, val_dataset = self._get_dataset()
+ train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset)
+ data_collator = self.template.data_collator
+ if args.streaming:
+ train_dataset = build_streaming_dataloader(args, train_dataset, data_collator)
+ if val_dataset is not None:
+ val_dataset = build_streaming_dataloader(args, val_dataset, data_collator)
+ datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset)
+ datasets_provider.is_distributed = True
+
+ logging_path = os.path.join(args.save, 'logging.jsonl')
+ logger.info(f'The logging file will be saved in: {logging_path}')
+ try:
+ with patch_training_log(), patch_megatron_data_collator(data_collator):
+ pretrain(
+ datasets_provider,
+ args.megatron_model_meta.model_provider,
+ ModelType.encoder_or_decoder,
+ forward_step,
+ args_defaults=args.extra_args)
+ finally:
+ # Visualization
+ if is_master():
+ images_dir = os.path.join(args.save, 'images')
+ logger.info(f'images_dir: {images_dir}')
+ plot_images(images_dir, args.tensorboard_dir)
+
+
+def megatron_sft_main(args: Union[List[str], MegatronTrainArguments, None] = None):
+ return MegatronSft(args).main()
diff --git a/swift/megatron/train/utils.py b/swift/megatron/train/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..69caa161d16d091fa530c7715f06b6ca95f40d6f
--- /dev/null
+++ b/swift/megatron/train/utils.py
@@ -0,0 +1,229 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from functools import partial
+from typing import Any, Dict, Optional
+
+import torch
+from megatron.core import mpu
+from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.utils import StragglerDetector
+from megatron.training import get_args, get_timers
+from megatron.training.training import cyclic_iter
+
+from swift.llm import DataLoaderDispatcher
+
+stimer = StragglerDetector()
+
+
+def get_swift_datasets_provider(train_dataset, val_dataset):
+
+ def swift_datasets_provider(train_val_test_num_samples):
+ return train_dataset, val_dataset, None
+
+ return swift_datasets_provider
+
+
+class MegatronDataLoaderDispatcher(DataLoaderDispatcher):
+
+ @property
+ def group(self):
+ return mpu.get_data_parallel_group()
+
+
+def build_streaming_dataloader(args, dataset, collate_fn):
+ base_dataloader = torch.utils.data.DataLoader(
+ dataset,
+ num_workers=args.num_workers,
+ pin_memory=True,
+ collate_fn=collate_fn,
+ batch_size=args.micro_batch_size,
+ prefetch_factor=args.dataloader_prefetch_factor,
+ persistent_workers=args.dataloader_persistent_workers,
+ )
+ return iter(cyclic_iter(MegatronDataLoaderDispatcher(base_dataloader)))
+
+
+def get_batch_on_this_tp_rank(data_iterator):
+ # copy from megatron-lm
+
+ args = get_args()
+
+ def _broadcast(item):
+ if item is not None:
+ torch.distributed.broadcast(
+ item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
+
+ if mpu.get_tensor_model_parallel_rank() == 0:
+
+ try:
+ data = next(data_iterator)
+ except StopIteration:
+ seq_length = -1
+ else:
+ tokens = data['input_ids']
+ seq_length = tokens.shape[1]
+ batch = {
+ 'tokens': tokens.cuda(non_blocking=True),
+ 'labels': data['labels'].cuda(non_blocking=True),
+ 'attention_mask':
+ None if 'attention_mask' not in data else data['attention_mask'].cuda(non_blocking=True),
+ 'position_ids': data['position_ids'].cuda(non_blocking=True)
+ }
+ seq_length = torch.tensor(seq_length).cuda(non_blocking=True)
+ _broadcast(seq_length)
+ if seq_length.item() == -1:
+ return {}
+ if args.pipeline_model_parallel_size == 1:
+ _broadcast(batch['tokens'])
+ _broadcast(batch['labels'])
+ _broadcast(batch['attention_mask'])
+ _broadcast(batch['position_ids'])
+
+ elif mpu.is_pipeline_first_stage():
+ _broadcast(batch['tokens'])
+ _broadcast(batch['attention_mask'])
+ _broadcast(batch['position_ids'])
+
+ elif mpu.is_pipeline_last_stage():
+ _broadcast(batch['labels'])
+ _broadcast(batch['attention_mask'])
+ _broadcast(batch['position_ids'])
+
+ else:
+ seq_length = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device())
+ _broadcast(seq_length)
+ if seq_length.item() == -1:
+ return {}
+ micro_batch_size = 1 # use qkv_format 'thd'
+ tokens = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device())
+ labels = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device())
+ if args.create_attention_mask_in_dataloader:
+ attention_mask = torch.empty((micro_batch_size, 1, seq_length, seq_length),
+ dtype=torch.bool,
+ device=torch.cuda.current_device())
+ else:
+ attention_mask = None
+ position_ids = torch.empty((micro_batch_size, seq_length),
+ dtype=torch.int64,
+ device=torch.cuda.current_device())
+
+ if args.pipeline_model_parallel_size == 1:
+ _broadcast(tokens)
+ _broadcast(labels)
+ _broadcast(attention_mask)
+ _broadcast(position_ids)
+
+ elif mpu.is_pipeline_first_stage():
+ labels = None
+
+ _broadcast(tokens)
+ _broadcast(attention_mask)
+ _broadcast(position_ids)
+
+ elif mpu.is_pipeline_last_stage():
+ tokens = None
+
+ _broadcast(labels)
+ _broadcast(attention_mask)
+ _broadcast(position_ids) # compat packing & cp
+
+ batch = {'tokens': tokens, 'labels': labels, 'attention_mask': attention_mask, 'position_ids': position_ids}
+
+ return batch
+
+
+def get_packed_seq_params(position_ids: torch.Tensor) -> Optional[PackedSeqParams]:
+ position_ids_f = position_ids.flatten()
+ indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32)
+
+ cu_seqlens = torch.cat([
+ indices_q[position_ids_f == 0],
+ torch.tensor(position_ids_f.shape, device=position_ids_f.device, dtype=torch.int32),
+ ])
+
+ max_length = position_ids_f.max() + 1
+ return PackedSeqParams(
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_kv=cu_seqlens,
+ max_seqlen_q=max_length,
+ max_seqlen_kv=max_length,
+ qkv_format='thd')
+
+
+def _split_tokens(tokens, cu_seqlens):
+ assert tokens.shape[0] == 1, f'tokens.shape: {tokens.shape}'
+ new_tokens = []
+ cp_size = mpu.get_context_parallel_world_size()
+ cp_rank = mpu.get_context_parallel_rank()
+ for i in range(cu_seqlens.shape[0] - 1):
+ val = tokens[:, cu_seqlens[i]:cu_seqlens[i + 1]]
+ val = val.view(
+ tokens.shape[0],
+ 2 * cp_size,
+ val.shape[1] // (2 * cp_size),
+ )
+ index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu',
+ pin_memory=True).cuda(non_blocking=True)
+ val = val.index_select(1, index)
+ new_tokens.append(val.view(tokens.shape[0], -1))
+ return torch.cat(new_tokens, dim=1)
+
+
+def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
+ """Slice batch input along sequence dimension into multiple chunks,
+ which are parallelized across GPUs in a context parallel group.
+ """
+
+ # With causal masking, each token only attends to its prior tokens. Simply split
+ # sequence into CP chunks can result in severe load imbalance. That's to say, chunks
+ # at the end of sequence have bigger workload than others. To address this issue,
+ # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
+ # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
+ # that we can get balanced workload among GPUs in a context parallel group.
+ cp_size = mpu.get_context_parallel_world_size()
+ if cp_size > 1:
+ packed_seq_params = batch['packed_seq_params']
+ for key, val in batch.items():
+ if key == 'packed_seq_params':
+ continue
+ if val is not None:
+ batch[key] = _split_tokens(val, packed_seq_params.cu_seqlens_q)
+
+ return batch
+
+
+def get_batch(data_iterator):
+ """Generate a batch."""
+
+ # TODO: this is pretty hacky, find a better way
+ if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
+ return None, None, None, None, None
+
+ # get batches based on the TP rank you are on
+ batch = get_batch_on_this_tp_rank(data_iterator)
+ if not batch:
+ return batch
+ batch['packed_seq_params'] = get_packed_seq_params(batch['position_ids'])
+ # slice batch along sequence dimension for context parallelism
+ batch = get_batch_on_this_cp_rank(batch)
+ return batch.values()
+
+
+def forward_step(data_iterator, model):
+ from pretrain_gpt import loss_func
+
+ timers = get_timers()
+
+ # Get the batch.
+ timers('batch-generator', log_level=2).start()
+ global stimer
+ with stimer(bdata=True):
+ data = get_batch(data_iterator)
+ if not data:
+ raise StopIteration
+ tokens, labels, attention_mask, position_ids, packed_seq_params = data
+ timers('batch-generator').stop()
+
+ with stimer:
+ output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params)
+ loss_mask = None if labels is None else (labels != -100).float()
+ return output_tensor, partial(loss_func, loss_mask)
diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d2b722a2cf06a94691e9546b94247bca0998367
--- /dev/null
+++ b/swift/megatron/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+from .convert import convert_hf2mcore, convert_mcore2hf
+from .patcher import patch_megatron_tokenizer
diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..42d37b945e1372af1662c8ce80e8eeea98523815
--- /dev/null
+++ b/swift/megatron/utils/convert.py
@@ -0,0 +1,122 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import math
+
+import torch
+from megatron.training.checkpointing import load_checkpoint
+from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint
+from megatron.training.initialize import initialize_megatron
+from megatron.training.utils import get_ltor_masks_and_position_ids
+
+from swift.llm import ExportArguments, get_model_tokenizer, get_template, save_checkpoint
+from swift.utils import get_logger, get_n_params_grads
+from ..argument import MegatronArguments
+from ..model import get_megatron_model_meta
+from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard
+
+logger = get_logger()
+
+
+def test_convert_precision(hf_model, mg_model, processor):
+ torch_dtype = hf_model.dtype
+ template = get_template(hf_model.model_meta.template, processor)
+ input_ids = template.encode({'messages': [{'role': 'user', 'content': 'who are you?'}]})['input_ids']
+ input_ids = torch.tensor(input_ids)[None].to('cuda')
+ hf_model.to('cuda')
+ hf_model.to(torch.float32)
+ with torch.inference_mode():
+ hf_logits = hf_model(input_ids).logits
+ hf_model.to(torch_dtype)
+ hf_model.to('cpu')
+
+ attention_mask, _, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True)
+ mg_model.to('cuda')
+ mg_model.to(torch.float32)
+ with torch.inference_mode():
+ mg_logits = mg_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
+ mg_model.to(torch_dtype)
+ mg_model.to('cpu')
+
+ mean_diff = (mg_logits - hf_logits).abs().mean().item()
+ max_diff = (mg_logits - hf_logits).abs().max().item()
+ print(f'mean_diff: {mean_diff}, max_diff: {max_diff}')
+ hf_tokens = hf_logits.argmax(-1)
+ mg_tokens = mg_logits.argmax(-1)
+ print(f'hf_tokens: {hf_tokens[0].tolist()}\nmg_tokens: {mg_tokens[0].tolist()}')
+ assert mean_diff < 0.1
+ assert (hf_tokens == mg_tokens).all()
+
+
+convert_kwargs = {
+ 'use_cpu_initialization': True,
+ 'no_save_optim': True,
+ 'no_save_rng': True,
+ 'no_load_optim': True,
+ 'no_load_rng': True,
+ 'no_masked_softmax_fusion': True,
+ 'no_bias_dropout_fusion': True,
+ 'no_bias_swiglu_fusion': True,
+ 'no_rope_fusion': True
+}
+
+
+def convert_hf2mcore(args: ExportArguments) -> None:
+ kwargs = args.get_model_kwargs()
+ hf_model, processor = get_model_tokenizer(**kwargs)
+ if args.thread_count is None:
+ checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9
+ args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB
+ patch_torch_dist_shard(args.thread_count)
+
+ megatron_model_meta = get_megatron_model_meta(args.model_type)
+ assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
+ kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
+ megatron_args = MegatronArguments(**kwargs, **convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype)
+ patch_megatron_tokenizer(processor)
+ extra_args = megatron_args.parse_to_megatron()
+ initialize_megatron(args_defaults=extra_args)
+
+ mg_model = megatron_model_meta.model_provider()
+ logger.info('Megatron model created successfully.')
+ megatron_model_meta.convert_hf2mcore(hf_model, mg_model)
+ if args.test_convert_precision:
+ test_convert_precision(hf_model, mg_model, processor)
+ logger.info('Successfully transferred HF model weights to MG model.')
+ mg_save_checkpoint(1, [mg_model], None, None, 0)
+ args.save_args()
+ logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.')
+
+
+def convert_mcore2hf(args: ExportArguments) -> None:
+ kwargs = args.get_model_kwargs()
+ hf_model, processor = get_model_tokenizer(**kwargs)
+ if args.thread_count is None:
+ checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9
+ args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB
+ patch_torch_dist_shard(args.thread_count)
+
+ megatron_model_meta = get_megatron_model_meta(args.model_type)
+ assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
+ kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
+ megatron_args = MegatronArguments(**kwargs, **convert_kwargs, load=args.mcore_model, torch_dtype=args.torch_dtype)
+ patch_megatron_tokenizer(processor)
+ extra_args = megatron_args.parse_to_megatron()
+ initialize_megatron(args_defaults=extra_args)
+
+ mg_model = megatron_model_meta.model_provider()
+ load_checkpoint([mg_model], None, None, strict=True)
+ logger.info('Megatron model created successfully.')
+ megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
+ if args.test_convert_precision:
+ test_convert_precision(hf_model, mg_model, processor)
+ logger.info('Successfully transferred MG model weights to HF model.')
+ save_checkpoint(
+ hf_model,
+ processor,
+ args.output_dir,
+ safe_serialization=args.safe_serialization,
+ model_dirs=[args.mcore_model, args.model_dir],
+ max_shard_size=args.max_shard_size,
+ additional_saved_files=hf_model.model_meta.additional_saved_files)
+ args.save_args()
+ logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.')
diff --git a/swift/megatron/utils/patcher.py b/swift/megatron/utils/patcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a4aed76fcb7e0dd6aff7b31641d34b619f29a8a
--- /dev/null
+++ b/swift/megatron/utils/patcher.py
@@ -0,0 +1,26 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
+from megatron.training import get_args, global_vars, initialize, training
+
+from swift.utils import get_logger
+
+logger = get_logger()
+
+
+def patch_megatron_tokenizer(tokenizer):
+
+ def build_tokenizer(args):
+ args.extra_vocab_size = args.padded_vocab_size - tokenizer.vocab_size
+ return tokenizer
+
+ global_vars.build_tokenizer = build_tokenizer
+
+
+def patch_torch_dist_shard(thread_count):
+ __init__ = TorchDistSaveShardedStrategy.__init__
+
+ def __new_init__(*args, **kwargs):
+ kwargs['thread_count'] = thread_count
+ return __init__(*args, **kwargs)
+
+ TorchDistSaveShardedStrategy.__init__ = __new_init__
diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..109a4294314c7869d1b7e2cd7f1003c0c23aa50a
--- /dev/null
+++ b/swift/plugin/__init__.py
@@ -0,0 +1,42 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from swift.utils.import_utils import _LazyModule
+
+if TYPE_CHECKING:
+ from .callback import extra_callbacks
+ from .loss import LOSS_MAPPING, get_loss_func
+ from .loss_scale import loss_scale_map
+ from .metric import InferStats, MeanMetric, Metric, compute_acc, get_metric, compute_rouge_bleu
+ from .optimizer import optimizers_map
+ from .agent_template import agent_templates
+ from .tuner import Tuner, extra_tuners, PeftTuner
+ from .prm import prms, PRM
+ from .orm import orms, ORM
+ from .multi_turn import multi_turns
+ from .rm_plugin import rm_plugins
+
+else:
+ _import_structure = {
+ 'callback': ['extra_callbacks'],
+ 'loss': ['LOSS_MAPPING', 'get_loss_func'],
+ 'loss_scale': ['loss_scale_map'],
+ 'metric': ['InferStats', 'MeanMetric', 'Metric', 'compute_acc', 'get_metric', 'compute_rouge_bleu'],
+ 'optimizer': ['optimizers_map'],
+ 'agent_template': ['agent_templates'],
+ 'tuner': ['Tuner', 'extra_tuners', 'PeftTuner'],
+ 'prm': ['prms', 'PRM'],
+ 'orm': ['orms', 'ORM'],
+ 'multi_turn': ['multi_turns'],
+ 'rm_plugin': ['rm_plugins']
+ }
+
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/swift/plugin/agent_template/__init__.py b/swift/plugin/agent_template/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..35f40f9308aa70ad0a608cb3158fd0207578c5e9
--- /dev/null
+++ b/swift/plugin/agent_template/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from .base import BaseAgentTemplate
+from .extra import ReactGRPOAgentTemplate
+from .glm4 import GLM4_0414AgentTemplate, GLM4AgentTemplate
+from .hermes import HermesAgentTemplate
+from .llama import Llama3AgentTemplate, Llama4AgentTemplate
+from .qwen import QwenEnAgentTemplate, QwenEnParallelAgentTemplate, QwenZhAgentTemplate, QwenZhParallelAgentTemplate
+from .react import ReactEnAgentTemplate, ReactZnAgentTemplate
+from .toolbench import ToolBenchAgentTemplate
+
+agent_templates = {
+ # ref: https://qwen.readthedocs.io/zh-cn/latest/framework/function_call.html#function-calling-templates
+ 'react_en': ReactEnAgentTemplate,
+ 'react_zh': ReactZnAgentTemplate,
+ # ref: https://github.com/QwenLM/Qwen-Agent/blob/main/qwen_agent/llm/fncall_prompts/qwen_fncall_prompt.py
+ 'qwen_en': QwenEnAgentTemplate,
+ 'qwen_zh': QwenZhAgentTemplate,
+ 'qwen_en_parallel': QwenEnParallelAgentTemplate,
+ 'qwen_zh_parallel': QwenZhParallelAgentTemplate,
+ 'hermes': HermesAgentTemplate,
+ 'toolbench': ToolBenchAgentTemplate, # ref: https://modelscope.cn/datasets/swift/ToolBench
+ 'glm4': GLM4AgentTemplate,
+ 'glm4_0414': GLM4_0414AgentTemplate, # ref: https://modelscope.cn/models/ZhipuAI/GLM-4-9B-0414
+ 'llama3': Llama3AgentTemplate,
+ 'llama4': Llama4AgentTemplate,
+ # extra
+ 'react_grpo': ReactGRPOAgentTemplate
+}
diff --git a/swift/plugin/agent_template/base.py b/swift/plugin/agent_template/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a24fc9d49b804e0fa2eefe8f8f8803cf70a7ddaa
--- /dev/null
+++ b/swift/plugin/agent_template/base.py
@@ -0,0 +1,158 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import ast
+from abc import ABC, abstractmethod
+from dataclasses import asdict, dataclass
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
+
+import json
+
+if TYPE_CHECKING:
+ from swift.llm.infer import Function
+ from swift.llm.template import Prompt
+
+
+@dataclass
+class AgentKeyword:
+ action: str = 'Action:'
+ action_input: str = 'Action Input:'
+ observation: str = 'Observation:'
+
+
+@dataclass
+class ToolDesc:
+ name_for_model: str
+ name_for_human: str
+ description_for_model: str
+ parameters: str
+ args_format: str
+
+
+class ReactCompatMixin:
+ keyword = AgentKeyword()
+
+ @staticmethod
+ def _split_action_action_input(response: str, keyword: AgentKeyword) -> List['Function']:
+ from swift.llm.template import split_str_parts_by
+ from swift.llm.infer import Function
+ agent_parts = split_str_parts_by(response, list(asdict(keyword).values()))
+ functions = []
+ action_content = None
+
+ for part in agent_parts:
+ key, content = part['key'].lower(), part['content']
+ if action_content is None and key == keyword.action.lower():
+ action_content = content
+ elif action_content is not None and key == keyword.action_input.lower():
+ functions.append(Function(name=action_content, arguments=content))
+ action_content = None
+
+ return functions
+
+ def get_toolcall(self, response: str) -> List['Function']:
+ functions = self._split_action_action_input(response, self.keyword)
+ if len(functions) == 0 and self.keyword != ReactCompatMixin.keyword:
+ # compat react
+ functions = self._split_action_action_input(response, ReactCompatMixin.keyword)
+ return functions
+
+ def _format_tool_responses(
+ self,
+ assistant_content: str,
+ tool_messages,
+ ) -> Tuple[str, 'Prompt']:
+ assert len(tool_messages) > 0
+ with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
+ if with_action:
+ if not assistant_content.endswith(self.keyword.observation):
+ if not assistant_content.endswith('\n'):
+ assistant_content += '\n'
+ assistant_content += self.keyword.observation
+ res = []
+ for i, tool_message in enumerate(tool_messages):
+ if i > 0:
+ res.append(self.keyword.observation)
+ tool_content = tool_message['content']
+ res.append(tool_content)
+ if not tool_content.endswith('\n'):
+ res.append('\n')
+ else:
+ res = []
+ for tool_message in tool_messages:
+ res.append(tool_message['content'])
+ return assistant_content, res
+
+ @staticmethod
+ def _parse_tool_call(content) -> Dict[str, Any]:
+ obj = BaseAgentTemplate._parse_json(content)
+ name = obj['name']
+ arguments = obj.get('arguments') or obj.get('parameters')
+ arguments = BaseAgentTemplate._parse_json(arguments)
+ assert arguments is not None, f'content: {content}'
+ return {'name': name, 'arguments': arguments}
+
+ def _format_tool_calls(self, tool_call_messages) -> str:
+ # -> assistant_content
+ tool_calls = []
+ for message in tool_call_messages:
+ tool_call = self._parse_tool_call(message['content'])
+ tool_calls.append(f'{self.keyword.action} {tool_call["name"]}\n'
+ f'{self.keyword.action_input} {tool_call["arguments"]}\n')
+ tool_calls.append(self.keyword.observation)
+ return ''.join(tool_calls)
+
+
+class BaseAgentTemplate(ReactCompatMixin, ABC):
+
+ @staticmethod
+ def _get_tool_name(tool):
+ return tool.get('name_for_model') or tool.get('name')
+
+ @staticmethod
+ def unwrap_tool(tool):
+ assert isinstance(tool, dict), f'tool: {tool}'
+ if 'type' in tool and 'function' in tool:
+ tool = tool['function']
+ return tool
+
+ @staticmethod
+ def wrap_tool(tool):
+ assert isinstance(tool, dict), f'tool: {tool}'
+ if 'type' not in tool and 'function' not in tool:
+ tool = {'type': 'function', 'function': tool}
+ return tool
+
+ @staticmethod
+ def _parse_tool(tool, lang: Literal['zh', 'en']) -> ToolDesc:
+ tool = BaseAgentTemplate.unwrap_tool(tool)
+ name_for_model = BaseAgentTemplate._get_tool_name(tool)
+ name_for_human = tool.get('name_for_human') or name_for_model
+
+ description = tool.get('description') or tool.get('description_for_model')
+ parameters = tool.get('parameters') or {}
+ parameters = parameters if isinstance(parameters, str) else json.dumps(parameters, ensure_ascii=False)
+ args_format = '此工具的输入应为JSON对象。' if lang == 'zh' else 'Format the arguments as a JSON object.'
+ tool_desc = ToolDesc(
+ name_for_model=name_for_model,
+ name_for_human=name_for_human,
+ description_for_model=description,
+ parameters=parameters,
+ args_format=args_format)
+ assert name_for_model is not None and description is not None, f'tool_desc: {tool_desc}'
+ return tool_desc
+
+ @staticmethod
+ def _parse_json(json_str: str) -> Optional[Any]:
+ if not isinstance(json_str, str):
+ return json_str
+ try:
+ res = json.loads(json_str)
+ except json.JSONDecodeError:
+ try:
+ res = ast.literal_eval(json_str)
+ except Exception:
+ return
+ return res
+
+ @abstractmethod
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
+ pass
diff --git a/swift/plugin/agent_template/extra.py b/swift/plugin/agent_template/extra.py
new file mode 100644
index 0000000000000000000000000000000000000000..019f05a786c1a178a715c3a1522690351617c5fc
--- /dev/null
+++ b/swift/plugin/agent_template/extra.py
@@ -0,0 +1,36 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import List, Union
+
+from .base import BaseAgentTemplate
+
+
+class ReactGRPOAgentTemplate(BaseAgentTemplate):
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
+ tool_names = []
+ tool_descs = []
+ for tool in tools:
+ tool_desc = self._parse_tool(tool, 'en')
+ tool_names.append(tool_desc.name_for_model)
+ tool_descs.append(
+ f'{tool_desc.name_for_model}: Call this tool to interact with the {tool_desc.name_for_human} API. '
+ f'What is the {tool_desc.name_for_human} API useful for? {tool_desc.description_for_model} '
+ f'Parameters: {tool_desc.parameters} {tool_desc.args_format}')
+
+ return """A conversation for tool calling between User and Assistant. The user asks a question which may be solved by calling tools, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process should be enclosed within tags and answer should follow the ReACT format(Action:xxx\nAction Input:xxx), i.e., reasoning process here Action: action here\nAction Input: parameters here
+
+Answer the following questions as best as you can. You have access to the following tools:
+
+""" + '\n\n'.join(tool_descs) + f"""
+
+Use the following format:
+
+you should always think about what to do
+Action: the action to take, should be one of [{','.join(tool_names)}]
+Action Input: the input to the action
+Observation: the result of the action, given by the actual calling
+... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
+Final Answer: the final answer to the original input question
+
+Begin!
+""" # noqa
diff --git a/swift/plugin/agent_template/glm4.py b/swift/plugin/agent_template/glm4.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dfea2ab651d316085e042b53765ab71e562bf9f
--- /dev/null
+++ b/swift/plugin/agent_template/glm4.py
@@ -0,0 +1,79 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import re
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
+import json
+
+from .base import BaseAgentTemplate
+
+if TYPE_CHECKING:
+ from swift.llm.infer import Function
+ from swift.llm.template import Prompt
+
+
+class GLM4AgentTemplate(BaseAgentTemplate):
+ is_glm4_0414 = False
+
+ @staticmethod
+ def _find_function_call(single_content: str) -> Optional['Function']:
+ from swift.llm.infer import Function
+ single_content = single_content.replace('<|observation|>', '')
+ pattern = re.compile(r'([^\n`]*?)\n({.*?})(?=\w*\n|$)', re.DOTALL)
+ matches = pattern.findall(single_content)
+ if not matches:
+ return
+
+ name, arguments = matches[0]
+ return Function(name=name, arguments=arguments)
+
+ def get_toolcall(self, response: str) -> List['Function']:
+ toolcall_list = response.split('<|assistant|>')
+ functions = []
+ for toolcall in toolcall_list:
+ function = self._find_function_call(toolcall)
+ if function:
+ functions.append(function)
+ if len(functions) == 0:
+ # compat react_en
+ return super().get_toolcall(response)
+ return functions
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
+ tool_descs = []
+ for tool in tools:
+ tool = self.unwrap_tool(tool)
+ name = self._get_tool_name(tool)
+ tool_descs.append(f'## {name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n'
+ '在调用上述函数时,请使用 Json 格式表示调用的参数。')
+ glm4_system = '你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n' # noqa
+ return ('' if self.is_glm4_0414 else glm4_system) + """# 可用工具
+
+""" + '\n'.join(tool_descs)
+
+ def _format_tool_responses(
+ self,
+ assistant_content: str,
+ tool_messages,
+ ) -> Tuple[str, 'Prompt']:
+ with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
+ if with_action:
+ return super()._format_tool_responses(assistant_content, tool_messages)
+ res = ['\n']
+ for i, tool_message in enumerate(tool_messages):
+ tool_content = tool_message['content']
+ if i > 0:
+ res.append('<|observation|>\n')
+ res.append(tool_content)
+ res.append('<|assistant|>\n')
+ return assistant_content, res
+
+ def _format_tool_calls(self, tool_call_messages) -> str:
+ tool_calls = []
+ for message in tool_call_messages:
+ tool_call = self._parse_tool_call(message['content'])
+ tool_calls.append(f'{tool_call["name"]}\n{tool_call["arguments"]}')
+ return '<|assistant|>'.join(tool_calls) + '<|observation|>'
+
+
+class GLM4_0414AgentTemplate(GLM4AgentTemplate):
+ is_glm4_0414 = True
diff --git a/swift/plugin/agent_template/hermes.py b/swift/plugin/agent_template/hermes.py
new file mode 100644
index 0000000000000000000000000000000000000000..28ab23fa3d803a1f62b209cffcd168a361512483
--- /dev/null
+++ b/swift/plugin/agent_template/hermes.py
@@ -0,0 +1,78 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import re
+from typing import TYPE_CHECKING, List, Tuple, Union
+
+import json
+
+from .base import BaseAgentTemplate
+
+if TYPE_CHECKING:
+ from swift.llm.infer import Function
+ from swift.llm.template import Prompt
+
+
+class HermesAgentTemplate(BaseAgentTemplate):
+
+ def get_toolcall(self, response: str) -> List['Function']:
+ from swift.llm.infer import Function
+ res_list = re.findall(r'(.+?)', response, re.DOTALL)
+ functions = []
+ for res in res_list:
+ res = self._parse_json(res)
+ if isinstance(res, dict) and 'name' in res and 'arguments' in res:
+ functions.append(Function(name=res['name'], arguments=res['arguments']))
+ if len(functions) == 0:
+ # compat react_en
+ return super().get_toolcall(response)
+ return functions
+
+ def _format_tool_responses(
+ self,
+ assistant_content: str,
+ tool_messages,
+ ) -> Tuple[str, 'Prompt']:
+ with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
+ if with_action:
+ return super()._format_tool_responses(assistant_content, tool_messages)
+ if hasattr(self, 'template_meta'):
+ prompt = self.template_meta.prompt
+ chat_sep = self.template_meta.chat_sep
+ else:
+ prompt = ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n']
+ chat_sep = ['<|im_end|>\n']
+ res = chat_sep.copy()
+ res_tool = []
+ for tool_message in tool_messages:
+ tool_content = tool_message['content']
+ res_tool.append(f'\n{tool_content}\n')
+ total_tool = '\n'.join(res_tool)
+ for context in prompt:
+ if isinstance(context, str):
+ context = context.replace('{{QUERY}}', total_tool)
+ res.append(context)
+ return assistant_content, res
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
+ tool_descs = [json.dumps(self.wrap_tool(tool), ensure_ascii=False) for tool in tools]
+ return f"""{system}
+
+# Tools
+
+You may call one or more functions to assist with the user query.
+
+You are provided with function signatures within XML tags:
+
+""" + '\n'.join(tool_descs) + """
+
+
+For each function call, return a json object with function name and arguments within XML tags:
+
+{"name": , "arguments": }
+"""
+
+ def _format_tool_calls(self, tool_call_messages):
+ tool_calls = []
+ for message in tool_call_messages:
+ tool_call = self._parse_tool_call(message['content'])
+ tool_calls.append(f'\n{json.dumps(tool_call, ensure_ascii=False)}\n')
+ return '\n'.join(tool_calls)
diff --git a/swift/plugin/agent_template/llama.py b/swift/plugin/agent_template/llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..a247d8420a13d11ad68fbd97bc669d20741edd87
--- /dev/null
+++ b/swift/plugin/agent_template/llama.py
@@ -0,0 +1,78 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import re
+from typing import TYPE_CHECKING, List, Tuple, Union
+
+import json
+
+from .base import BaseAgentTemplate
+
+if TYPE_CHECKING:
+ from swift.llm.infer import Function
+ from swift.llm.template import Prompt
+
+
+class Llama3AgentTemplate(BaseAgentTemplate):
+ eom_token = '<|eom_id|>'
+ start_token = '<|start_header_id|>'
+ end_token = '<|end_header_id|>'
+ eot_token = '<|eot_id|>'
+
+ def get_toolcall(self, response: str) -> List['Function']:
+ from swift.llm.infer import Function
+ if response.endswith(self.eom_token):
+ response = response[:-len(self.eom_token)]
+ functions = []
+ res_list = re.findall(r'{[^{]*?"name":.*?"parameters":\s*?{.*?}\s*?}', response, re.DOTALL)
+ for res in res_list:
+ res = self._parse_json(res)
+ if isinstance(res, dict) and 'name' in res and 'parameters' in res:
+ functions.append(Function(name=res['name'], arguments=res['parameters']))
+ if len(functions) == 0:
+ # compat react_en
+ return super().get_toolcall(response)
+ return functions
+
+ def _format_tool_responses(
+ self,
+ assistant_content: str,
+ tool_messages,
+ ) -> Tuple[str, 'Prompt']:
+ with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
+ if with_action:
+ return super()._format_tool_responses(assistant_content, tool_messages)
+ res = [self.eot_token]
+ for tool_message in tool_messages:
+ tool_content = tool_message['content']
+ res.append(f'{self.start_token}tool{self.end_token}\n\n{tool_content}{self.eot_token}')
+ res.append(f'{self.start_token}assistant{self.end_token}\n\n')
+ return assistant_content, res
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
+ assert user_message is not None
+ user_content = user_message['content']
+ tool_descs = [json.dumps(tool, ensure_ascii=False, indent=4) for tool in tools]
+ new_user_content = """Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
+
+Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
+
+""" + '\n\n'.join(tool_descs) + f"""
+
+{user_content}""" # noqa
+ user_message['content'] = new_user_content
+ return system
+
+ def _format_tool_calls(self, tool_call_messages) -> str:
+ tool_calls = []
+ for message in tool_call_messages:
+ tool_call = self._parse_tool_call(message['content'])
+ tool_call['parameters'] = tool_call.pop('arguments')
+ tool_calls.append(json.dumps(tool_call, ensure_ascii=False))
+ return '\n'.join(tool_calls)
+
+
+class Llama4AgentTemplate(Llama3AgentTemplate):
+ eom_token = '<|eom|>'
+ start_token = '<|header_start|>'
+ end_token = '<|header_end|>'
+ eot_token = '<|eot|>'
+ toolcall_pattern = r'(.+?)<\|eom\|>'
diff --git a/swift/plugin/agent_template/qwen.py b/swift/plugin/agent_template/qwen.py
new file mode 100644
index 0000000000000000000000000000000000000000..6443a12d44e9e705ca5ea6a0fbe248bc093a2c21
--- /dev/null
+++ b/swift/plugin/agent_template/qwen.py
@@ -0,0 +1,130 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import List, Union
+
+from .base import AgentKeyword, BaseAgentTemplate
+
+keyword = AgentKeyword(
+ action='✿FUNCTION✿:',
+ action_input='✿ARGS✿:',
+ observation='✿RESULT✿:',
+)
+
+
+class QwenEnAgentTemplate(BaseAgentTemplate):
+ keyword = keyword
+
+ def _get_tool_names_descs(self, tools):
+ tool_names = []
+ tool_descs = []
+ for tool in tools:
+ tool_desc = self._parse_tool(tool, 'en')
+ tool_names.append(tool_desc.name_for_model)
+ tool_descs.append(f'### {tool_desc.name_for_human}\n\n'
+ f'{tool_desc.name_for_model}: {tool_desc.description_for_model} '
+ f'Parameters: {tool_desc.parameters} {tool_desc.args_format}')
+ return tool_names, tool_descs
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
+ tool_names, tool_descs = self._get_tool_names_descs(tools)
+ return f"""{system}
+
+# Tools
+
+## You have access to the following tools:
+
+""" + '\n\n'.join(tool_descs) + f"""
+
+## When you need to call a tool, please insert the following command in your reply, which can be called zero or multiple times according to your needs:
+
+✿FUNCTION✿: The tool to use, should be one of [{','.join(tool_names)}]
+✿ARGS✿: The input of the tool
+✿RESULT✿: Tool results
+✿RETURN✿: Reply based on tool results. Images need to be rendered as """ # noqa
+
+
+class QwenZhAgentTemplate(BaseAgentTemplate):
+ keyword = keyword
+
+ def _get_tool_names_descs(self, tools):
+ tool_names = []
+ tool_descs = []
+ for tool in tools:
+ tool_desc = self._parse_tool(tool, 'zh')
+ tool_names.append(tool_desc.name_for_model)
+ tool_descs.append(f'### {tool_desc.name_for_human}\n\n'
+ f'{tool_desc.name_for_model}: {tool_desc.description_for_model} '
+ f'输入参数:{tool_desc.parameters} {tool_desc.args_format}')
+ return tool_names, tool_descs
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
+ tool_names, tool_descs = self._get_tool_names_descs(tools)
+ return f"""{system}
+
+# 工具
+
+## 你拥有如下工具:
+
+""" + '\n\n'.join(tool_descs) + f"""
+
+## 你可以在回复中插入零次、一次或多次以下命令以调用工具:
+
+✿FUNCTION✿: 工具名称,必须是[{','.join(tool_names)}]之一。
+✿ARGS✿: 工具输入
+✿RESULT✿: 工具结果
+✿RETURN✿: 根据工具结果进行回复,需将图片用渲染出来""" # noqa
+
+
+class QwenEnParallelAgentTemplate(QwenEnAgentTemplate):
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
+ tool_names, tool_descs = self._get_tool_names_descs(tools)
+ return f"""{system}
+
+# Tools
+
+## You have access to the following tools:
+
+""" + '\n\n'.join(tool_descs) + f"""
+
+## Insert the following command in your reply when you need to call N tools in parallel:
+
+✿FUNCTION✿: The name of tool 1, should be one of [{','.join(tool_names)}]
+✿ARGS✿: The input of tool 1
+✿FUNCTION✿: The name of tool 2
+✿ARGS✿: The input of tool 2
+...
+✿FUNCTION✿: The name of tool N
+✿ARGS✿: The input of tool N
+✿RESULT✿: The result of tool 1
+✿RESULT✿: The result of tool 2
+...
+✿RESULT✿: he result of tool N
+✿RETURN✿: Reply based on tool results. Images need to be rendered as """ # noqa
+
+
+class QwenZhParallelAgentTemplate(QwenZhAgentTemplate):
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
+ tool_names, tool_descs = self._get_tool_names_descs(tools)
+ return f"""{system}
+
+# 工具
+
+## 你拥有如下工具:
+
+""" + '\n\n'.join(tool_descs) + f"""
+
+## 你可以在回复中插入以下命令以并行调用N个工具:
+
+✿FUNCTION✿: 工具1的名称,必须是[{','.join(tool_names)}]之一
+✿ARGS✿: 工具1的输入
+✿FUNCTION✿: 工具2的名称
+✿ARGS✿: 工具2的输入
+...
+✿FUNCTION✿: 工具N的名称
+✿ARGS✿: 工具N的输入
+✿RESULT✿: 工具1的结果
+✿RESULT✿: 工具2的结果
+...
+✿RESULT✿: 工具N的结果
+✿RETURN✿: 根据工具结果进行回复,需将图片用渲染出来""" # noqa
diff --git a/swift/plugin/agent_template/react.py b/swift/plugin/agent_template/react.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bfa5b820c611f9651890e13705e37a3be3e0933
--- /dev/null
+++ b/swift/plugin/agent_template/react.py
@@ -0,0 +1,66 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import List, Union
+
+from .base import BaseAgentTemplate
+
+
+class ReactEnAgentTemplate(BaseAgentTemplate):
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
+ tool_names = []
+ tool_descs = []
+ for tool in tools:
+ tool_desc = self._parse_tool(tool, 'en')
+ tool_names.append(tool_desc.name_for_model)
+ tool_descs.append(
+ f'{tool_desc.name_for_model}: Call this tool to interact with the {tool_desc.name_for_human} API. '
+ f'What is the {tool_desc.name_for_human} API useful for? {tool_desc.description_for_model} '
+ f'Parameters: {tool_desc.parameters} {tool_desc.args_format}')
+
+ return """Answer the following questions as best you can. You have access to the following tools:
+
+""" + '\n\n'.join(tool_descs) + f"""
+
+Use the following format:
+
+Question: the input question you must answer
+Thought: you should always think about what to do
+Action: the action to take, should be one of [{','.join(tool_names)}]
+Action Input: the input to the action
+Observation: the result of the action
+... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
+Thought: I now know the final answer
+Final Answer: the final answer to the original input question
+
+Begin!
+"""
+
+
+class ReactZnAgentTemplate(BaseAgentTemplate):
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
+ tool_names = []
+ tool_descs = []
+ for tool in tools:
+ tool_desc = self._parse_tool(tool, 'zh')
+ tool_names.append(tool_desc.name_for_model)
+ tool_descs.append(f'{tool_desc.name_for_model}: 调用此工具与 {tool_desc.name_for_human} API 进行交互。'
+ f'{tool_desc.name_for_human} 有什么用?{tool_desc.description_for_model} '
+ f'输入参数:{tool_desc.parameters} {tool_desc.args_format}')
+ return """尽可能地回答以下问题。你可以使用以下工具:
+
+""" + '\n\n'.join(tool_descs) + f"""
+
+请按照以下格式进行:
+
+Question: 需要你回答的输入问题
+Thought: 你应该总是思考该做什么
+Action: 需要使用的工具,应该是[{','.join(tool_names)}]中的一个
+Action Input: 传入工具的内容
+Observation: 行动的结果
+... (这个Thought/Action/Action Input/Observation可以重复N次)
+Thought: 我现在知道最后的答案
+Final Answer: 对原始输入问题的最终答案
+
+现在开始!
+"""
diff --git a/swift/plugin/agent_template/toolbench.py b/swift/plugin/agent_template/toolbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..54404e9f8e9faa75e5b9ecac1110d371e49318ba
--- /dev/null
+++ b/swift/plugin/agent_template/toolbench.py
@@ -0,0 +1,39 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import List, Union
+
+import json
+
+from .base import BaseAgentTemplate
+
+
+class ToolBenchAgentTemplate(BaseAgentTemplate):
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
+ for i, tool in enumerate(tools):
+ tools[i] = self.unwrap_tool(tool)
+ tools = json.dumps(tools, ensure_ascii=False)
+ return f"""You can use many tools(functions) to do the following task.
+First I will give you the task description, and your task start.
+At each step, you need to give your thought to analyze the status now and what to do next, \
+with a function call to actually execute your step. Your output should follow this format:
+Thought:
+Action:
+Action Input:
+
+After the call, you will get the call result, and you are now in a new state.
+Then you will analyze your status now, then decide what to do next...
+After many (Thought-call) pairs, you finally perform the task, then you can give your final answer.
+Remember:
+1.the state change is irreversible, you can't go back to one of the former state, if you want to restart the task, \
+say \"I give up and restart\".
+2.All the thought is short, at most in 5 sentence.
+3.You can do more then one try, so if your plan is to continuously try some conditions, \
+you can do one of the conditions per try.
+Let's Begin!
+Task description: You should use functions to help handle the real time user queries. Remember:
+1.ALWAYS call \"Finish\" function at the end of the task. And the final answer should contain enough information \
+to show to the user,If you can't handle the task, \
+or you find that function calls always fail(the function is not valid now), \
+use function Finish->give_up_and_restart.
+2.Do not use origin tool names, use only subfunctions' names.
+Specifically, you have access to the following APIs: {tools}"""
diff --git a/swift/plugin/callback.py b/swift/plugin/callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..01db43c9b014ae33d02e43bd6d3ee30eadbbeda5
--- /dev/null
+++ b/swift/plugin/callback.py
@@ -0,0 +1,32 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import numpy as np
+from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
+
+from swift.utils import get_logger
+
+logger = get_logger()
+
+
+class EarlyStopCallback(TrainerCallback):
+ """An early stop implementation"""
+
+ def __init__(self, total_interval=3):
+ self.best_metric = None
+ self.interval = 0
+ self.total_interval = total_interval
+
+ def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ operator = np.greater if args.greater_is_better else np.less
+ if self.best_metric is None or operator(state.best_metric, self.best_metric):
+ self.best_metric = state.best_metric
+ else:
+ self.interval += 1
+
+ if self.interval >= self.total_interval:
+ logger.info(f'Training stop because of eval metric is stable at step {state.global_step}')
+ control.should_training_stop = True
+
+
+extra_callbacks = []
+# This example shows a simple example of EarlyStop Callback, uncomment this to use
+# extra_callbacks = [EarlyStopCallback()]
diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ad82a5deef5b373e4a55eddeb1d136b65f13b06
--- /dev/null
+++ b/swift/plugin/loss.py
@@ -0,0 +1,388 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+from enum import Enum
+from typing import Callable, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from accelerate.utils import gather_object
+from torch import nn
+from torch.nn import CrossEntropyLoss, MSELoss
+from transformers.utils import strtobool
+
+
+class LossType:
+ loss_scale = 'loss_scale'
+ cosine_similarity = 'cosine_similarity'
+ contrastive = 'contrastive'
+ online_contrastive = 'online_contrastive'
+ infonce = 'infonce'
+
+
+LOSS_MAPPING = {}
+
+
+def register_loss_func(loss_type: str, loss_func: Optional[Callable] = None):
+ loss_info = {}
+
+ if loss_func is not None:
+ loss_info['loss_func'] = loss_func
+ LOSS_MAPPING[loss_type] = loss_info
+ return
+
+ def _register_loss_func(loss_func: Callable) -> Callable:
+ loss_info['loss_func'] = loss_func
+ LOSS_MAPPING[loss_type] = loss_info
+ return loss_func
+
+ return _register_loss_func
+
+
+def ce_loss_func(outputs, labels):
+ logits = outputs.logits
+ device = logits.device
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :]
+ shift_labels = labels[..., 1:].to(device)
+ # Save memory
+ masks = shift_labels != -100
+ shift_logits = shift_logits[masks]
+ shift_labels = shift_labels[masks]
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(reduction='none')
+ loss = loss_fct(shift_logits, shift_labels)
+ return loss, masks
+
+
+# Use @register_loss_func to decorate your own loss, use --loss_type xxx to train
+@register_loss_func(LossType.loss_scale)
+def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
+ """Loss func
+
+ Args:
+ outputs: The model outputs
+ labels: The labels
+ loss_scale: The loss scale
+ num_items_in_batch: Number of tokens in the labels of gradient accumulation round that are not -100.
+
+ Returns:
+
+ """
+ loss, masks = ce_loss_func(outputs, labels)
+ if loss_scale is not None:
+ shift_scale = loss_scale[..., 1:].to(masks.device)
+ shift_scale = shift_scale[masks]
+ loss = (shift_scale * loss)
+ if num_items_in_batch is None:
+ loss = loss.mean()
+ else:
+ # compat transformers>=4.46
+ loss = loss.sum() / num_items_in_batch
+ return loss
+
+
+def _parse_pair_sentence(outputs):
+ if isinstance(outputs, dict):
+ last_hidden_state = outputs['last_hidden_state']
+ else:
+ last_hidden_state = outputs
+ batch_size = last_hidden_state.shape[0]
+ shape_len = len(last_hidden_state.shape)
+ first_sentence = list(range(0, batch_size, 2))
+ second_sentence = list(range(1, batch_size, 2))
+ if shape_len == 3:
+ sentence1 = last_hidden_state[first_sentence][:, 0].squeeze(dim=1)
+ sentence2 = last_hidden_state[second_sentence][:, 0].squeeze(dim=1)
+ else:
+ sentence1 = last_hidden_state[first_sentence]
+ sentence2 = last_hidden_state[second_sentence]
+ return sentence1, sentence2
+
+
+# Code borrowed from sentence_transformers
+class SiameseDistanceMetric(Enum):
+ """The metric for the contrastive loss"""
+
+ EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) # noqa
+ MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) # noqa
+ COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) # noqa
+
+
+@register_loss_func(LossType.cosine_similarity)
+def cosine_similarity_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
+ cos_score_transformation = nn.Identity()
+ loss_fct = MSELoss()
+ sentence1, sentence2 = _parse_pair_sentence(outputs)
+ output = cos_score_transformation(torch.cosine_similarity(sentence1, sentence2))
+ return loss_fct(output, labels.to(output.dtype).view(-1))
+
+
+@register_loss_func(LossType.contrastive)
+def contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
+ sentence1, sentence2 = _parse_pair_sentence(outputs)
+ distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
+ distances = distance_metric(sentence1, sentence2)
+ margin = 0.5
+ labels = labels.to(sentence1.dtype)
+ losses = 0.5 * (labels * distances.pow(2) + (1 - labels) * F.relu(margin - distances).pow(2))
+ return losses.mean()
+
+
+def calculate_paired_metrics(embeddings, labels):
+ from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, \
+ paired_manhattan_distances
+ from scipy.stats import pearsonr, spearmanr
+
+ embeddings1, embeddings2 = _parse_pair_sentence(embeddings)
+ cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))
+ manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2)
+ euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2)
+ dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)]
+
+ eval_pearson_cosine, _ = pearsonr(labels, cosine_scores)
+ eval_spearman_cosine, _ = spearmanr(labels, cosine_scores)
+
+ eval_pearson_manhattan, _ = pearsonr(labels, manhattan_distances)
+ eval_spearman_manhattan, _ = spearmanr(labels, manhattan_distances)
+
+ eval_pearson_euclidean, _ = pearsonr(labels, euclidean_distances)
+ eval_spearman_euclidean, _ = spearmanr(labels, euclidean_distances)
+
+ eval_pearson_dot, _ = pearsonr(labels, dot_products)
+ eval_spearman_dot, _ = spearmanr(labels, dot_products)
+
+ return {
+ 'pearson_cosine': eval_pearson_cosine,
+ 'pearson_euclidean': eval_pearson_manhattan,
+ 'pearson_manhattan': eval_pearson_euclidean,
+ 'pearson_dot_product': eval_pearson_dot,
+ 'spearman_cosine': eval_spearman_cosine,
+ 'spearman_euclidean': eval_spearman_manhattan,
+ 'spearman_manhattan': eval_spearman_euclidean,
+ 'spearman_dot_product': eval_spearman_dot,
+ }
+
+
+def calculate_infonce_metrics(embeddings, labels):
+ from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, \
+ paired_manhattan_distances
+ from scipy.stats import pearsonr, spearmanr
+ hard_negatives = os.environ.get('INFONCE_HARD_NEGATIVES', None)
+ use_batch = strtobool(os.environ.get('INFONCE_USE_BATCH', 'True'))
+ split_tensors = _parse_multi_negative_sentences(torch.tensor(embeddings), torch.tensor(labels), hard_negatives)
+ split_tensors = [t.numpy() for t in split_tensors]
+ can_batched = hard_negatives is not None
+ if hard_negatives is None and len(set([s.shape[0] for s in split_tensors])) == 1:
+ can_batched = True
+ all_similarity_matrix = []
+ all_labels = []
+ pos_neg_margins = []
+ if not use_batch:
+ if can_batched:
+ sentences = np.stack(split_tensors, axis=0)
+ similarity_matrix = np.matmul(sentences[:, 0:1], sentences[:, 1:].transpose((0, 2, 1))).squeeze(1)
+ all_similarity_matrix.append(similarity_matrix)
+ labels = np.zeros_like(similarity_matrix)
+ labels[:, 0] = 1
+ all_labels.append(labels)
+ else:
+ for tensor in split_tensors:
+ similarity_matrix = np.matmul(tensor[0], tensor[1:].T)
+ all_similarity_matrix.append(similarity_matrix)
+ labels = np.zeros_like(similarity_matrix)
+ labels[0] = 1
+ all_labels.append(labels)
+ max_neg_scores = np.max(similarity_matrix[labels == 0], axis=-1)
+ pos_neg_margins.append(np.mean(similarity_matrix[labels == 1] - max_neg_scores).item())
+ else:
+ if can_batched:
+ sentences = np.stack(split_tensors, axis=0)
+ similarity_matrix = np.matmul(sentences[:, 0], sentences[:, 1:].reshape(-1, sentences.shape[2]).T)
+ all_similarity_matrix.append(similarity_matrix)
+ labels = np.zeros_like(similarity_matrix)
+ for row, col in enumerate(range(0, sentences.shape[0] * (sentences.shape[1] - 1), sentences.shape[1] - 1)):
+ labels[row, col] = 1
+ all_labels.append(labels)
+ else:
+ all_tensors = []
+ for tensor in split_tensors:
+ all_tensors.append(tensor[1:])
+ sentences = np.concatenate(all_tensors, axis=0)
+ length = 0
+ for idx, tensor in enumerate(split_tensors):
+ similarity_matrix = np.matmul(tensor[0], sentences.T)
+ all_similarity_matrix.append(similarity_matrix)
+ labels = np.zeros_like(similarity_matrix)
+ labels[length] = 1
+ all_labels.append(labels)
+ length += tensor.shape[0] - 1
+ max_neg_scores = np.max(similarity_matrix[labels == 0], axis=-1)
+ pos_neg_margins.append(np.mean(similarity_matrix[labels == 1] - max_neg_scores).item())
+
+ similarity_matrix = np.concatenate(all_similarity_matrix, axis=0)
+ labels = np.concatenate(all_labels, axis=0)
+ if can_batched:
+ pos_scores = similarity_matrix[labels == 1].reshape(similarity_matrix.shape[0], -1)
+ neg_scores = similarity_matrix[labels == 0].reshape(similarity_matrix.shape[0], -1)
+ max_neg_scores = np.max(neg_scores, axis=-1)
+ pos_neg_margin = np.mean(pos_scores - max_neg_scores).item()
+ else:
+ pos_scores = similarity_matrix[labels == 1]
+ neg_scores = similarity_matrix[labels == 0]
+ pos_neg_margin = np.mean(pos_neg_margins)
+
+ mean_neg = np.mean(neg_scores)
+ mean_pos = np.mean(pos_scores)
+ return {'margin': pos_neg_margin, 'mean_neg': mean_neg, 'mean_pos': mean_pos}
+
+
+def _parse_multi_negative_sentences(sentences, labels, hard_negatives=None):
+ split_indices = torch.nonzero(labels, as_tuple=False).squeeze().tolist()
+ if isinstance(split_indices, int):
+ split_indices = [split_indices]
+ split_indices.append(len(labels))
+ split_indices = np.array(split_indices) + np.array(list(range(len(split_indices))))
+ split_tensors = []
+
+ for i in range(len(split_indices) - 1):
+ start = split_indices[i]
+ end = split_indices[i + 1]
+ split_part = sentences[start:end]
+ if hard_negatives is not None:
+ negatives = len(split_part) - 2
+ assert negatives > 0
+ if negatives > hard_negatives:
+ split_part = split_part[:hard_negatives + 2]
+ elif negatives < hard_negatives:
+ selected = np.random.choice(list(range(negatives)), size=hard_negatives - negatives, replace=True)
+ selected += 1 # skip positive
+ split_part = torch.cat((split_part, split_part[selected]), dim=0)
+ split_tensors.append(split_part)
+ return split_tensors
+
+
+@register_loss_func(LossType.infonce)
+def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
+ temperature = float(os.environ.get('INFONCE_TEMPERATURE', '0.01')) # temperature
+ # calculate CE across the batch, meaning all samples will be negative except the matching positive
+ use_batch = strtobool(os.environ.get('INFONCE_USE_BATCH', 'True'))
+ hard_negatives = os.environ.get('INFONCE_HARD_NEGATIVES', None) # how many negative prompts kept in one sample
+ # mask out fake negatives
+ infonce_mask_fake_negative = strtobool(os.environ.get('INFONCE_MASK_FAKE_NEGATIVE', 'False'))
+ if hard_negatives is not None:
+ hard_negatives = int(hard_negatives)
+ from swift.utils import get_dist_setting
+ rank, _, world_size, _ = get_dist_setting()
+ # repeat of anchor(1)+positive(1)+negatives(n)
+ sentences = outputs['last_hidden_state']
+
+ if world_size > 1 and use_batch:
+ # gather all the sentences and labels across the gpus when calculate loss across all batches of all gpus
+ all_sentences = gather_object(sentences.unsqueeze(0))
+ labels = gather_object(labels)
+ # override the gathered one
+ all_sentences[rank] = sentences
+ for idx in range(len(all_sentences)):
+ if idx == rank:
+ continue
+ # we don't calculate grad from other gpus
+ all_sentences[idx] = all_sentences[idx].detach().to(sentences.device)
+ sentences = torch.cat(all_sentences, dim=0)
+ labels = [tensor.to(sentences.device) for tensor in labels]
+ labels = torch.stack(labels, dim=0)
+
+ # split tensors into single sample
+ # for example: batch_size=2 with tensor anchor(1)+positive(1)+negatives(3) + anchor(1)+positive(1)+negatives(2)
+ # labels will be [1,0,0,0,1,0,0], meaning 1 positive, 3 negatives, 1 positive, 2 negatives
+ split_tensors = _parse_multi_negative_sentences(sentences, labels, hard_negatives)
+ loss = 0
+ can_batched = hard_negatives is not None
+ if hard_negatives is None and len(set([s.shape[0] for s in split_tensors])) == 1:
+ # all tensors have the same batch size
+ can_batched = True
+ if not use_batch:
+ # only calculate loss inside one sample
+ if can_batched:
+ # negative numbers are equal
+ # [B, neg+2, D]
+ sentences = torch.stack(split_tensors, dim=0)
+ # [B, 1, D] * [B, neg+1, D]
+ similarity_matrix = torch.matmul(sentences[:, 0:1], sentences[:, 1:].transpose(1, 2)) / temperature
+ # The positive one is the first element
+ labels = torch.zeros(len(split_tensors), dtype=torch.int64).to(sentences.device)
+ loss = nn.CrossEntropyLoss()(similarity_matrix.squeeze(1), labels)
+ else:
+ # the negative numbers may be different, use for loop
+ for tensor in split_tensors:
+ # [D] * [neg+1, D]
+ similarity_matrix = torch.matmul(tensor[0], tensor[1:].T) / temperature
+ # The positive one is the first element
+ labels = torch.tensor(0).to(tensor.device)
+ loss += nn.CrossEntropyLoss()(similarity_matrix, labels)
+ # avg between all batches in one gpu
+ loss /= len(split_tensors)
+ else:
+
+ def mask_fake_negative(sim_matrix, sim_labels):
+ thresholds = sim_matrix[torch.arange(sim_matrix.size(0)), sim_labels].view(-1, 1) + 0.1
+ thresholds = thresholds.detach()
+ mask = sim_matrix > thresholds
+ sim_matrix[mask] = float('-inf')
+
+ if can_batched:
+ # [B, neg+2, D]
+ sentences = torch.stack(split_tensors, dim=0)
+ # [B, D] * [B*(neg+1), D]
+ similarity_matrix = torch.matmul(sentences[:, 0].squeeze(1), sentences[:,
+ 1:].reshape(-1, sentences.size(2)).T)
+ labels = torch.tensor(range(0,
+ sentences.size(0) * (sentences.size(1) - 1),
+ sentences.size(1) - 1)).view(-1).to(sentences.device)
+ if infonce_mask_fake_negative:
+ mask_fake_negative(similarity_matrix, labels)
+ similarity_matrix = similarity_matrix / temperature
+ # every neg+1 is positive start from 0
+ loss = nn.CrossEntropyLoss()(similarity_matrix, labels) / world_size # avoid duplicate
+ else:
+ all_tensors = []
+ for tensor in split_tensors:
+ all_tensors.append(tensor[1:])
+ # cat all neg+1 tensors
+ sentences = torch.cat(all_tensors, dim=0)
+ length = 0
+ for idx, tensor in enumerate(split_tensors):
+ # [D] * [B*(neg+1), D], neg numbers are different
+ similarity_matrix = torch.matmul(tensor[0], sentences.T) / temperature
+ labels = torch.tensor(length).to(tensor.device)
+ loss += nn.CrossEntropyLoss()(similarity_matrix, labels)
+ # next positive is neg+1
+ length += tensor.size(0) - 1
+ loss /= len(split_tensors)
+ loss /= world_size # avoid duplicate
+ return loss
+
+
+@register_loss_func(LossType.online_contrastive)
+def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
+ sentence1, sentence2 = _parse_pair_sentence(outputs)
+ distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
+ distance_matrix = distance_metric(sentence1, sentence2)
+ negs = distance_matrix[labels == 0]
+ poss = distance_matrix[labels == 1]
+
+ # select hard positive and hard negative pairs
+ negative_pairs = negs[negs < (poss.max() if len(poss) > 1 else negs.mean())]
+ positive_pairs = poss[poss > (negs.min() if len(negs) > 1 else poss.mean())]
+
+ positive_loss = positive_pairs.pow(2).sum()
+ margin = 0.5
+ negative_loss = F.relu(margin - negative_pairs).pow(2).sum()
+ loss = positive_loss + negative_loss
+ return loss
+
+
+def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]:
+ if loss_type is None:
+ return None
+ return LOSS_MAPPING[loss_type]['loss_func']
diff --git a/swift/plugin/loss_scale/__init__.py b/swift/plugin/loss_scale/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..579be3b98ca209fb7f868a601bda14b64bbf561c
--- /dev/null
+++ b/swift/plugin/loss_scale/__init__.py
@@ -0,0 +1 @@
+from .loss_scale import loss_scale_map
diff --git a/swift/plugin/loss_scale/config/agentflan.json b/swift/plugin/loss_scale/config/agentflan.json
new file mode 100644
index 0000000000000000000000000000000000000000..2751fea02b15587835f21577221d155417d129ea
--- /dev/null
+++ b/swift/plugin/loss_scale/config/agentflan.json
@@ -0,0 +1,22 @@
+{
+ "response":{
+ "Name:": [1.0, 3.0],
+ "Action:": [1.0, 3.0],
+ "ACTION:": [1.0,3.0],
+ "Tool:": [1.0, 3.0],
+ "Command": [1.0, 3.0],
+ "Arguments:": [1.0, 3.0],
+ "action input": [1.0, 3.0],
+ "ACTION_INPUT:":[1.0, 3.0],
+ "Action Input:": [1.0, 3.0],
+ "Thought:": [1.0, 1.0],
+ "Final Answer:": [1.0, 1.0],
+ "Observation:": [2.0, 0.0]
+ },
+ "query":{
+ "What is the tool you want to use": [3.0],
+ "What are the required parameter names": [3.0],
+ "What is the value of": [3.0],
+ "What are the required parameter names for this tool": [3.0]
+ }
+}
diff --git a/swift/plugin/loss_scale/config/alpha_umi.json b/swift/plugin/loss_scale/config/alpha_umi.json
new file mode 100644
index 0000000000000000000000000000000000000000..fcdcbcb185066da0b768263562729d8361ebaa01
--- /dev/null
+++ b/swift/plugin/loss_scale/config/alpha_umi.json
@@ -0,0 +1,8 @@
+{
+ "Action:": [2.0, 2.0],
+ "Action Input:": [2.0, 2.0],
+ "Thought:": [1.0, 1.0],
+ "Final Answer:": [1.0, 1.0],
+ "Observation:": [2.0, 0.0],
+ "Next:": [2,0, 2.0]
+}
diff --git a/swift/plugin/loss_scale/config/hermes.json b/swift/plugin/loss_scale/config/hermes.json
new file mode 100644
index 0000000000000000000000000000000000000000..e8bfee3fc5d6cd8aa79c99f0f9b4fcd15b623645
--- /dev/null
+++ b/swift/plugin/loss_scale/config/hermes.json
@@ -0,0 +1,3 @@
+{
+ ".+?": [2.0]
+}
diff --git a/swift/plugin/loss_scale/config/ignore_empty_think.json b/swift/plugin/loss_scale/config/ignore_empty_think.json
new file mode 100644
index 0000000000000000000000000000000000000000..c7c2395fbb78294a543f09072620895e76ef1ea9
--- /dev/null
+++ b/swift/plugin/loss_scale/config/ignore_empty_think.json
@@ -0,0 +1,3 @@
+{
+ "\n\n\n\n": [0.0]
+}
diff --git a/swift/plugin/loss_scale/config/qwen.json b/swift/plugin/loss_scale/config/qwen.json
new file mode 100644
index 0000000000000000000000000000000000000000..731ba5340387e8a3467831877fdfb1cdd19fdc90
--- /dev/null
+++ b/swift/plugin/loss_scale/config/qwen.json
@@ -0,0 +1,6 @@
+{
+ "✿FUNCTION✿:": [2.0, 2.0],
+ "✿ARGS✿:": [2.0, 2.0],
+ "✿RETURN✿:": [1.0, 1.0],
+ "✿RESULT✿:": [2.0, 0.0]
+}
diff --git a/swift/plugin/loss_scale/config/react.json b/swift/plugin/loss_scale/config/react.json
new file mode 100644
index 0000000000000000000000000000000000000000..006f92948e1a6de28a1825fa2ef256dc1b09de81
--- /dev/null
+++ b/swift/plugin/loss_scale/config/react.json
@@ -0,0 +1,7 @@
+{
+ "Action:": [2.0, 2.0],
+ "Action Input:": [2.0, 2.0],
+ "Thought:": [1.0, 1.0],
+ "Final Answer:": [1.0, 1.0],
+ "Observation:": [2.0, 0.0]
+}
diff --git a/swift/plugin/loss_scale/loss_scale.py b/swift/plugin/loss_scale/loss_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..1540169e00f3e14dba1c019536d50fa3f9536c6f
--- /dev/null
+++ b/swift/plugin/loss_scale/loss_scale.py
@@ -0,0 +1,136 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+from typing import List, Optional, Tuple
+
+import json
+
+from swift.llm import Messages
+from swift.llm.template.utils import ContextType
+from .utils import calculate_loss_scale
+
+
+class LossScale:
+ loss_scale_config = None # path
+
+ def __init__(self):
+ if self.loss_scale_config is not None:
+ path = os.path.dirname(os.path.abspath(__file__))
+ config_path = os.path.join(path, 'config', self.loss_scale_config)
+ with open(config_path, 'r', encoding='utf-8') as json_file:
+ self.loss_scale_map = json.load(json_file)
+ else:
+ self.loss_scale_map = None
+
+ def get_loss_scale(self,
+ context: str,
+ context_type: ContextType,
+ is_last_round: bool,
+ *,
+ query: Optional[str] = None) -> Tuple[List[str], List[float]]:
+ """Calculate loss scale
+
+ Args:
+ context: The input context
+ context_type: The type of this context, like response/suffix(eos token)/other(query/system, etc.)
+ is_last_round: If this is the last round of messages.
+ query: The query of this round.
+
+ Returns:
+ A tuple, list of context and list of loss_scales
+ """
+ if context_type in {ContextType.RESPONSE, ContextType.SUFFIX}:
+ loss_scale = 1.
+ else:
+ loss_scale = 0.
+ return [context], [loss_scale]
+
+ def __call__(self, context_list: List[str], context_types: List[ContextType], messages: Messages,
+ **kwargs) -> Tuple[List[str], List[float]]:
+ res_context_list = []
+ res_loss_scale = []
+ i = 0
+ n_round = len(messages) // 2
+ for context, context_type in zip(context_list, context_types):
+ is_last_round = i + 1 == n_round
+ if context_type == ContextType.RESPONSE:
+ query = messages[2 * i]['content']
+ assert context == messages[2 * i + 1]['content']
+ kwargs = {'query': query}
+ i += 1
+ new_context, loss_scale = self.get_loss_scale(context, context_type, is_last_round, **kwargs)
+ res_context_list += new_context
+ res_loss_scale += loss_scale
+ return res_context_list, res_loss_scale
+
+
+class LastRoundLossScale(LossScale):
+
+ def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs):
+ if context_type == ContextType.RESPONSE:
+ return [context], [float(is_last_round)]
+ return super().get_loss_scale(context, context_type, is_last_round)
+
+
+class AgentFlanLossScale(LossScale):
+ loss_scale_config = 'agentflan.json'
+
+ def get_loss_scale(self,
+ context: str,
+ context_type: ContextType,
+ is_last_round: bool,
+ *,
+ query: Optional[str] = None):
+ if context_type == ContextType.RESPONSE:
+ return calculate_loss_scale(query, context, self.loss_scale_map['response'], self.loss_scale_map['query'])
+ return super().get_loss_scale(context, context_type, is_last_round)
+
+
+class REACTLossScale(LossScale):
+ loss_scale_config = 'react.json'
+
+ def get_loss_scale(self,
+ context: str,
+ context_type: ContextType,
+ is_last_round: bool,
+ *,
+ query: Optional[str] = None):
+ if context_type == ContextType.RESPONSE:
+ return calculate_loss_scale(query, context, self.loss_scale_map)
+ return super().get_loss_scale(context, context_type, is_last_round)
+
+
+class QwenLossScale(REACTLossScale):
+ loss_scale_config = 'qwen.json'
+
+
+class HermesLossScale(REACTLossScale):
+ loss_scale_config = 'hermes.json'
+
+
+class AlphaUmiLossScale(REACTLossScale):
+ loss_scale_config = 'alpha_umi.json'
+
+
+class TrainAllLossScale(LossScale):
+
+ def get_loss_scale(self, context: str, context_type: ContextType, *args, **kwargs):
+ return [context], [1.]
+
+
+class IgnoreEmptyThink(REACTLossScale):
+ loss_scale_config = 'ignore_empty_think.json'
+
+
+# Add your loss scale here, use --loss_scale xxx to train
+loss_scale_map = {
+ 'last_round': LastRoundLossScale(),
+ 'default': LossScale(),
+ 'all': TrainAllLossScale(),
+ 'ignore_empty_think': IgnoreEmptyThink(),
+ # agent
+ 'react': REACTLossScale(),
+ 'hermes': HermesLossScale(),
+ 'qwen': QwenLossScale(),
+ 'agentflan': AgentFlanLossScale(),
+ 'alpha_umi': AlphaUmiLossScale(),
+}
diff --git a/swift/plugin/loss_scale/utils.py b/swift/plugin/loss_scale/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d60c592a5d025e689d2a232648fa54d19ca71ff0
--- /dev/null
+++ b/swift/plugin/loss_scale/utils.py
@@ -0,0 +1,58 @@
+from typing import Dict, List, Optional, Tuple
+
+from swift.llm.template import split_str_parts_by
+
+
+def calculate_loss_scale(query: str,
+ response: str,
+ response_loss_scale_map: Dict[str, list],
+ query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]:
+ """Calculate the loss scale by splitting the agent response.
+
+ This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf
+
+ Agent response format:
+
+ ```text
+ Thought: you should always think about what to do
+ Action: the action to take, should be one of the above tools[fire_recognition,
+ fire_alert, call_police, call_fireman]
+ Action Input: the input to the action
+ Observation: the result of the action
+ ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
+ Thought: I now know the final answer
+ Final Answer: the final answer to the original input question
+ ```
+ Returns:
+ A tuple of agent response parts and their weights.
+ """
+ # query loss scale map
+ if query_loss_scale_map is not None:
+ for key in query_loss_scale_map.keys():
+ if key in query:
+ if isinstance(query_loss_scale_map[key], (float, int)):
+ query_loss_scale_map[key] = [query_loss_scale_map[key]]
+ loss_scale_value = query_loss_scale_map[key][0]
+ return [response], [float(loss_scale_value)]
+ delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 2]
+ if delimiters:
+ agent_parts = split_str_parts_by(response, delimiters)
+ else:
+ regex_delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 1]
+ agent_parts = split_str_parts_by(response, regex_delimiters, regex_mode=True)
+ weights = []
+ agent_content = []
+ for c in agent_parts:
+ if c['key'] in response_loss_scale_map:
+ loss_scale = response_loss_scale_map[c['key']]
+ assert len(loss_scale) in {1, 2}, f'loss_scale: {loss_scale}'
+ if len(loss_scale) == 1:
+ weights += loss_scale
+ agent_content.append(c['content'])
+ else:
+ weights += loss_scale
+ agent_content += [c['key'], c['content']]
+ else:
+ weights.append(1.)
+ agent_content.append(c['content'])
+ return agent_content, weights
diff --git a/swift/plugin/metric.py b/swift/plugin/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..410449815c27d6591290b4e4458888d758721a14
--- /dev/null
+++ b/swift/plugin/metric.py
@@ -0,0 +1,189 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import time
+from abc import ABC, abstractmethod
+from typing import Dict, List, Literal
+
+import numpy as np
+import torch
+from transformers.trainer_utils import EvalPrediction
+
+from swift.utils import Serializer, get_logger
+
+logger = get_logger()
+
+
+class Metric(ABC):
+
+ def __init__(self):
+ self._default = {}
+ self._default_factory = {}
+
+ def add_state(self, name: str, default=None, default_factory=None) -> None:
+ if not hasattr(self, '_default'):
+ raise AttributeError('Please call super().__init__() first.')
+ if default is None:
+ self._default_factory[name] = default_factory
+ assert name not in self._default, f'self._default: {self._default}'
+ default = default_factory()
+ else:
+ self._default[name] = default
+ assert name not in self._default_factory, f'self._default_factory: {self._default_factory}'
+ setattr(self, name, default)
+
+ def reset(self):
+ for k, v in self._default.items():
+ setattr(self, k, v)
+ for k, v in self._default_factory.items():
+ setattr(self, k, v())
+
+ @abstractmethod
+ def update(self, *args, **kwargs):
+ pass
+
+ @abstractmethod
+ def compute(self):
+ pass
+
+
+class InferStats(Metric):
+
+ def __init__(self):
+ super().__init__()
+ self.add_state('start_runtime', default_factory=lambda: time.perf_counter())
+ self.add_state('num_prompt_tokens', default_factory=dict)
+ self.add_state('num_generated_tokens', default_factory=dict)
+
+ def update(self, output):
+ id_ = output.id
+ self.num_prompt_tokens[id_] = output.usage.prompt_tokens
+ self.num_generated_tokens[id_] = output.usage.completion_tokens
+
+ def compute(self):
+ runtime = time.perf_counter() - self.start_runtime
+ num_samples = len(self.num_generated_tokens)
+ num_generated_tokens = sum(self.num_generated_tokens.values())
+ return {
+ 'num_prompt_tokens': sum(self.num_prompt_tokens.values()),
+ 'num_generated_tokens': num_generated_tokens,
+ 'num_samples': num_samples,
+ 'runtime': runtime,
+ 'samples/s': num_samples / runtime,
+ 'tokens/s': num_generated_tokens / runtime,
+ }
+
+
+class MeanMetric(Metric):
+
+ def __init__(self, nan_value=0):
+ super().__init__()
+ self.nan_value = nan_value
+ self.add_state('state', default=0.)
+ self.add_state('count', default=0)
+
+ def update(self, state: torch.Tensor):
+ if isinstance(state, (torch.Tensor, np.ndarray)):
+ state = state.tolist()
+
+ if isinstance(state, (list, tuple)):
+ count = len(state)
+ state = sum(state)
+ else:
+ count = 1
+
+ self.state += state
+ self.count += count
+
+ def compute(self):
+ return {
+ 'value': self.state / self.count if self.count > 0 else self.nan_value,
+ }
+
+
+def compute_rouge_bleu(preds: List[str], labels: List[str]):
+ import jieba
+ from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
+ from rouge.rouge import Rouge
+ score_dict = {key: MeanMetric() for key in ['rouge-1', 'rouge-2', 'rouge-l', 'bleu-4']}
+
+ for pred, label in zip(preds, labels):
+ hypothesis = list(jieba.cut(pred))
+ reference = list(jieba.cut(label))
+ if not hypothesis or not reference:
+ continue
+ rouge = Rouge()
+ scores = rouge.get_scores(' '.join(hypothesis), ' '.join(reference))[0]
+ for k, v in scores.items():
+ score_dict[k].update(v['f'])
+ bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
+ score_dict['bleu-4'].update(bleu_score)
+
+ return {k: round(v.compute()['value'] * 100, 6) for k, v in score_dict.items()}
+
+
+def compute_nlg_metrics(prediction) -> Dict[str, float]:
+ preds, labels = prediction[0], prediction[1]
+ new_preds, new_labels = [], []
+ for i in range(preds.shape[0]):
+ new_preds.append(Serializer.from_tensor(preds[i]))
+ new_labels.append(Serializer.from_tensor(labels[i]))
+ return compute_rouge_bleu(new_preds, new_labels)
+
+
+def compute_acc(preds,
+ labels,
+ *,
+ acc_strategy: Literal['token', 'seq'] = 'token',
+ is_encoder_decoder: bool = False) -> Dict[str, List[float]]:
+
+ if isinstance(preds, torch.Tensor):
+ if torch.is_floating_point(labels):
+ return {}
+ preds = preds.cpu().numpy()
+ labels = labels.cpu().numpy()
+ if preds.ndim >= 2 and not is_encoder_decoder:
+ labels = labels[..., 1:]
+ preds = preds[..., :-1]
+ if np.issubdtype(labels.dtype, np.floating) or preds.shape != labels.shape:
+ return {}
+
+ masks = labels != -100
+ if acc_strategy == 'token' or preds.ndim == 1:
+ acc_list = (preds[masks] == labels[masks]).tolist()
+ else:
+ acc_list = []
+ for i, m in enumerate(masks):
+ acc_list.append(np.all(preds[i, m] == labels[i, m]))
+ return {f'{acc_strategy}_acc' if preds.ndim >= 2 else 'acc': acc_list}
+
+
+def compute_acc_metrics(eval_prediction: EvalPrediction,
+ *,
+ acc_strategy: Literal['token', 'seq'] = 'token',
+ is_encoder_decoder: bool = False) -> Dict[str, float]:
+
+ metric = compute_acc(
+ eval_prediction.predictions,
+ eval_prediction.label_ids,
+ acc_strategy=acc_strategy,
+ is_encoder_decoder=is_encoder_decoder)
+ if len(metric) == 0:
+ return {}
+ return {k: sum(v) / len(v) for k, v in metric.items()}
+
+
+def preprocess_logits_for_acc(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
+ if isinstance(logits, (list, tuple)):
+ logits = logits[0]
+ preds = logits.argmax(dim=-1)
+ return preds
+
+
+# Add your own metric calculation method here, use --metric xxx to train
+METRIC_MAPPING = {
+ 'acc': (compute_acc_metrics, preprocess_logits_for_acc),
+ 'nlg': (compute_nlg_metrics, None),
+}
+
+
+def get_metric(metric: str):
+ return METRIC_MAPPING[metric]
diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e9881892eaf26e4ee7c2b2ebd7702264f748f03
--- /dev/null
+++ b/swift/plugin/multi_turn.py
@@ -0,0 +1,42 @@
+def check_math_result_and_give_tips(inputs):
+ from .orm import MathAccuracy
+ acc = MathAccuracy()
+ # a trick
+ prompt = 'But wait... It seems I made a mistake,'
+ contents = [input['messages'][-1]['content'] for input in inputs]
+ rewards = acc(contents, [input['solution'] for input in inputs])
+ for reward, input in zip(rewards, inputs):
+ content = input['messages'][-1]['content']
+ if reward < 1 and prompt not in content:
+ if '' in content:
+ content = content[:content.index('')]
+ if '' in content:
+ content = content[:content.index('')]
+ content += prompt
+ input['messages'][-1]['content'] = content
+ input['finished'] = False
+ else:
+ input['finished'] = True
+ return inputs
+
+
+def check_math_result_and_give_tips_multi_turn(inputs):
+ from .orm import MathAccuracy
+ acc = MathAccuracy()
+ prompt = 'The answer is not correct, It seems You made a mistake, you need to recheck very carefully.'
+ contents = [input['messages'][-1]['content'] for input in inputs]
+ rewards = acc(contents, [input['solution'] for input in inputs])
+ for reward, input in zip(rewards, inputs):
+ content = input['messages'][-2]['content']
+ if reward < 1 and prompt not in content:
+ input['messages'].append({'role': 'user', 'content': prompt})
+ input['finished'] = False
+ else:
+ input['finished'] = True
+ return inputs
+
+
+multi_turns = {
+ 'math_tip_trick': check_math_result_and_give_tips,
+ 'math_tip_trick_multi_turn': check_math_result_and_give_tips_multi_turn,
+}
diff --git a/swift/plugin/optimizer.py b/swift/plugin/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a4b6ef8da78a9cf0662d04b887ca5f84aafb54
--- /dev/null
+++ b/swift/plugin/optimizer.py
@@ -0,0 +1,100 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import math
+import os
+import sys
+
+from transformers import Trainer
+
+from swift.trainers.optimizers.galore import create_optimizer_and_scheduler
+from swift.utils import get_dist_setting
+
+
+def calculate_max_steps(args: 'TrainArguments', dataset) -> int:
+ if args.max_steps and args.max_steps > 0:
+ max_steps = args.max_steps
+ else:
+ len_dataset = len(dataset)
+ _, _, world_size, _ = get_dist_setting()
+ total_train_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * world_size
+ num_update_steps_per_epoch = len_dataset // total_train_batch_size
+ num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
+ max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
+ return max_steps
+
+
+def create_galore_optimizer(args, model, dataset):
+ training_steps = calculate_max_steps(args, dataset)
+ optimizer, lr_scheduler = create_optimizer_and_scheduler(
+ model, args, args.galore_config, training_steps, lr=args.learning_rate, weight_decay=args.weight_decay)
+ # trainer cannot serialize galore_config
+ args.galore_config = None
+ return optimizer, lr_scheduler
+
+
+def create_lorap_optimizer(args, model, dataset):
+ optimizer_grouped_parameters = None
+ if hasattr(model, 'create_optimizer_param_groups'):
+ # Lora+ parameter groups
+ optimizer_grouped_parameters = model.create_optimizer_param_groups(
+ lr=args.learning_rate, weight_decay=args.weight_decay)
+
+ if optimizer_grouped_parameters is None:
+ # Default parameter groups
+ decay_parameters = Trainer.get_decay_parameter_names(None, model)
+ optimizer_grouped_parameters = [
+ {
+ 'params': [p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad)],
+ 'weight_decay': args.weight_decay,
+ },
+ {
+ 'params': [p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad)],
+ 'weight_decay': 0.0,
+ },
+ ]
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
+ return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs), None
+
+
+def create_muon_optimizer(args, model, dataset):
+ from swift.llm import git_clone_github, get_model_arch
+ if not args.local_repo_path:
+ args.local_repo_path = git_clone_github('https://github.com/MoonshotAI/Moonlight.git')
+ sys.path.append(os.path.join(args.local_repo_path, 'examples'))
+ from toy_train import Muon
+
+ # parse args.optim_args
+ optim_args = {}
+ if args.optim_args:
+ for mapping in args.optim_args.replace(' ', '').split(','):
+ key, value = mapping.split('=')
+ optim_args[key] = value
+
+ model_arch = get_model_arch(model.model_meta.model_arch)
+ embed_key = model_arch.embedding or 'embed_tokens'
+ lm_head_key = model_arch.lm_head or 'lm_head'
+ muon_params = [
+ p for n, p in model.named_parameters()
+ if p.requires_grad and p.ndim >= 2 and embed_key not in n and lm_head_key not in n
+ ]
+ adamw_params = [
+ p for n, p in model.named_parameters()
+ if p.requires_grad and not (p.ndim >= 2 and embed_key not in n and lm_head_key not in n)
+ ]
+
+ return Muon(
+ lr=args.learning_rate,
+ wd=args.weight_decay,
+ muon_params=muon_params,
+ adamw_params=adamw_params,
+ adamw_betas=(args.adam_beta1, args.adam_beta2),
+ adamw_eps=args.adam_epsilon,
+ **optim_args,
+ ), None
+
+
+# Add your own optimizers here, use --optimizer xxx to train
+optimizers_map = {
+ 'galore': create_galore_optimizer,
+ 'lorap': create_lorap_optimizer,
+ 'muon': create_muon_optimizer,
+}
diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5f1980f9067eab862bae2e01d09129d0d4fa750
--- /dev/null
+++ b/swift/plugin/orm.py
@@ -0,0 +1,406 @@
+import os
+import re
+from typing import Dict, List, Union
+
+import json
+
+from swift.llm import InferRequest
+
+
+class ORM:
+
+ def __call__(self, **kwargs) -> List[float]:
+ raise NotImplementedError
+
+
+class ReactORM(ORM):
+
+ @staticmethod
+ def evaluate_action_reward(action_pred: list, action_ref: list, cand_list: list, ref_list: list):
+ f1 = []
+ for i in range(len(action_pred)):
+ ref_action = action_ref[i]
+ pred_action = action_pred[i]
+
+ ref_input = ref_list[i]
+ cand_input = cand_list[i]
+
+ ref_is_json = False
+ try:
+ ref_input_json = json.loads(ref_input)
+ ref_is_json = True
+ except Exception:
+ ref_input_json = ref_input
+
+ cand_is_json = False
+ try:
+ cand_input_json = json.loads(cand_input)
+ cand_is_json = True
+ except Exception:
+ cand_input_json = cand_input
+
+ if ref_action != pred_action or (ref_is_json ^ cand_is_json):
+ f1.append(0)
+ elif not ref_is_json and not cand_is_json:
+ rougel = ReactORM.evaluate_rougel([ref_input_json], [cand_input_json])
+ if rougel is None or rougel < 10:
+ f1.append(0)
+ elif 10 <= rougel < 20:
+ f1.append(0.1)
+ else:
+ f1.append(1)
+ else:
+ if not isinstance(ref_input_json, dict) or not isinstance(cand_input_json, dict):
+ # This cannot be happen, but:
+ # line 62, in evaluate_action_reward
+ # for k, v in ref_input_json.items():
+ # AttributeError: 'str' object has no attribute 'items'
+ # print(f'>>>>>>ref_input_json: {ref_input_json}, cand_input_json: {cand_input_json}')
+ f1.append(0)
+ continue
+
+ half_match = 0
+ full_match = 0
+ if ref_input_json == {}:
+ if cand_input_json == {}:
+ f1.append(1)
+ else:
+ f1.append(0)
+ else:
+ for k, v in ref_input_json.items():
+ if k in cand_input_json.keys():
+ if cand_input_json[k] == v:
+ full_match += 1
+ else:
+ half_match += 1
+
+ recall = (0.5 * half_match + full_match) / (len(ref_input_json) + 1e-30)
+ precision = (0.5 * half_match + full_match) / (len(cand_input_json) + 1e-30)
+ try:
+ f1.append((2 * recall * precision) / (recall + precision))
+ except Exception:
+ f1.append(0.0)
+
+ if f1[0] == 1.0:
+ return True
+ else:
+ return False
+
+ @staticmethod
+ def parse_action(text):
+ if 'Action Input:' in text:
+ input_idx = text.rindex('Action Input:')
+ action_input = text[input_idx + len('Action Input:'):].strip()
+ else:
+ action_input = '{}'
+
+ if 'Action:' in text:
+ action_idx = text.rindex('Action:')
+ action = text[action_idx + len('Action:'):].strip()
+ if 'Action Input:' in action:
+ input_idx = action.index('Action Input:')
+ action = action[:input_idx].strip()
+ else:
+ action = 'none'
+ return action, action_input
+
+ @staticmethod
+ def parse_output(text):
+ action, action_input = ReactORM.parse_action(text)
+ return action, action_input
+
+ def __call__(self, infer_requests: List[Union[InferRequest, Dict]], solution: List[str], **kwargs) -> List[float]:
+ rewards = []
+ if not isinstance(infer_requests[0], str):
+ predictions = [request['messages'][-1]['content'] for request in infer_requests]
+ else:
+ predictions = infer_requests
+ for prediction, ground_truth in zip(predictions, solution):
+ if prediction.endswith('Observation:'):
+ prediction = prediction[:prediction.index('Observation:')].strip()
+ action_ref = []
+ action_input_ref = []
+ action_pred = []
+ action_input_pred = []
+ reference = ground_truth
+ prediction = prediction.replace('<|endoftext|>', '').replace('<|im_end|>', '').strip()
+ ref_action, ref_input = ReactORM.parse_output(reference)
+ pred_action, pred_input = ReactORM.parse_output(prediction)
+ action_ref.append(ref_action)
+ action_input_ref.append(ref_input)
+ if pred_action is None:
+ action_pred.append('none')
+ else:
+ action_pred.append(pred_action)
+
+ if pred_input is None:
+ action_input_pred.append('{}')
+ else:
+ action_input_pred.append(pred_input)
+
+ reward = ReactORM.evaluate_action_reward(action_pred, action_ref, action_input_pred, action_input_ref)
+ rewards.append(float(reward))
+ return rewards
+
+ @staticmethod
+ def evaluate_rougel(cand_list: list, ref_list: list):
+ if len(ref_list) == 0:
+ return None
+ try:
+ from rouge import Rouge
+ rouge = Rouge()
+ rouge_score = rouge.get_scores(hyps=cand_list, refs=ref_list, avg=True)
+ rougel = rouge_score['rouge-l']['f']
+ return rougel
+ except Exception:
+ return None
+
+
+class MathORM(ORM):
+
+ def __init__(self):
+ from transformers.utils import strtobool
+ self.use_opencompass = strtobool(os.environ.get('USE_OPENCOMPASS_EVALUATOR', 'False'))
+ if self.use_opencompass:
+ from opencompass.datasets.math import MATHEvaluator
+ self.evaluator = MATHEvaluator()
+
+ @staticmethod
+ def check_terminate(answers: Union[str, List[str]]) -> List[bool]:
+ if isinstance(answers, str):
+ answers = [answers]
+ results = []
+ for answer in answers:
+ results.append('\\boxed' in answer)
+ return results
+
+ @staticmethod
+ def extract_boxed_result(text):
+ pattern = r'\\boxed{([^}]*)}'
+ match = re.search(pattern, text)
+ if match:
+ return match.group(1).strip()
+ else:
+ return text
+
+ @staticmethod
+ def clean_latex(latex_str):
+ latex_str = re.sub(r'\\\(|\\\)|\\\[|\\]', '', latex_str)
+ latex_str = latex_str.replace('}}', '}').replace('{', '').replace('}', '')
+ return latex_str.strip()
+
+ @staticmethod
+ def parse_expression(latex_str):
+ from sympy import simplify
+ from sympy.parsing.latex import parse_latex
+ try:
+ expr = parse_latex(latex_str)
+ return simplify(expr)
+ except Exception:
+ return None
+
+ @staticmethod
+ def compare_consecutive(first, second):
+ cleaned_list = [MathORM.clean_latex(latex) for latex in [first, second]]
+ parsed_exprs = [MathORM.parse_expression(latex) for latex in cleaned_list]
+ if hasattr(parsed_exprs[0], 'equals') and hasattr(parsed_exprs[1], 'equals'):
+ value = parsed_exprs[0].equals(parsed_exprs[1])
+ else:
+ value = parsed_exprs[0] == parsed_exprs[1]
+ if value is None:
+ value = False
+ return value
+
+ def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str],
+ **kwargs) -> List[float]:
+ rewards = []
+ predictions = [request.messages[-1]['content'] for request in infer_requests]
+ for prediction, ground_truth in zip(predictions, ground_truths):
+ if '# Answer' in prediction:
+ prediction = prediction.split('# Answer')[1]
+ if '# Answer' in ground_truth:
+ ground_truth = ground_truth.split('# Answer')[1]
+ prediction = prediction.strip()
+ ground_truth = ground_truth.strip()
+ prediction = MathORM.extract_boxed_result(prediction)
+ ground_truth = MathORM.extract_boxed_result(ground_truth)
+ if self.use_opencompass:
+ reward = self.evaluator.is_equiv(prediction, ground_truth)
+ else:
+ reward = MathORM.compare_consecutive(prediction, ground_truth)
+ rewards.append(float(reward))
+ return rewards
+
+
+class MathAccuracy(ORM):
+
+ def __init__(self):
+ import importlib.util
+ assert importlib.util.find_spec('math_verify') is not None, (
+ "The math_verify package is required but not installed. Please install it using 'pip install math_verify'.")
+
+ def __call__(self, completions, solution, **kwargs) -> List[float]:
+ from latex2sympy2_extended import NormalizationConfig
+ from math_verify import LatexExtractionConfig, parse, verify
+ rewards = []
+ for content, sol in zip(completions, solution):
+ gold_parsed = parse(sol, extraction_mode='first_match')
+ if len(gold_parsed) != 0:
+ # We require the answer to be provided in correct latex (no malformed operators)
+ answer_parsed = parse(
+ content,
+ extraction_config=[
+ LatexExtractionConfig(
+ normalization_config=NormalizationConfig(
+ nits=False,
+ malformed_operators=False,
+ basic_latex=True,
+ equations=True,
+ boxed=True,
+ units=True,
+ ),
+ # Ensures that boxed is tried first
+ boxed_match_priority=0,
+ try_extract_without_anchor=False,
+ )
+ ],
+ extraction_mode='first_match',
+ )
+ # edge case
+ try:
+ reward = float(verify(gold_parsed, answer_parsed))
+ except Exception:
+ reward = 0.0
+ else:
+ # If the gold solution is not parseable, we reward 0 to skip this example
+ reward = 0.0
+ rewards.append(reward)
+ return rewards
+
+
+class Format(ORM):
+
+ def __call__(self, completions, **kwargs) -> List[float]:
+ """Reward function that checks if the completion has a specific format."""
+ pattern = r'^.*?\s*.*?(?![\s\S])'
+ matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
+ return [1.0 if match else 0.0 for match in matches]
+
+
+class ReActFormat(ORM):
+
+ def __call__(self, completions, **kwargs) -> List[float]:
+ """Reward function that checks if the completion has a specific format."""
+ pattern = r'^.*?\s*Action:.*?Action Input:.*?$'
+ matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
+ return [1.0 if match else 0.0 for match in matches]
+
+
+class CosineReward(ORM):
+ # https://arxiv.org/abs/2502.03373
+ def __init__(self,
+ tokenizer=None,
+ cosine_min_len_value_wrong: float = -0.5,
+ cosine_max_len_value_wrong: float = 0.0,
+ cosine_min_len_value_correct: float = 1.0,
+ cosine_max_len_value_correct: float = 0.5,
+ cosine_max_len: int = 1000,
+ accuracy_orm=None):
+ self.tokenizer = tokenizer
+ self.min_len_value_wrong = cosine_min_len_value_wrong
+ self.max_len_value_wrong = cosine_max_len_value_wrong
+ self.min_len_value_correct = cosine_min_len_value_correct
+ self.max_len_value_correct = cosine_max_len_value_correct
+ self.max_len = cosine_max_len
+ self.accuracy_orm = accuracy_orm or MathAccuracy()
+
+ @staticmethod
+ def cosfn(t, T, min_value, max_value):
+ import math
+ return max_value - (max_value - min_value) * (1 - math.cos(t * math.pi / T)) / 2
+
+ def __call__(self, completions, solution, **kwargs) -> List[float]:
+ acc_rewards = self.accuracy_orm(completions, solution, **kwargs)
+ rewards = []
+ for content, acc_reward in zip(completions, acc_rewards):
+ is_correct = acc_reward >= 1.
+ if is_correct:
+ # Swap min/max for correct answers
+ min_value = self.max_len_value_correct
+ max_value = self.min_len_value_correct
+ else:
+ min_value = self.max_len_value_wrong
+ max_value = self.min_len_value_wrong
+ gen_len = len(self.tokenizer.encode(content))
+ reward = self.cosfn(gen_len, self.max_len, min_value, max_value)
+ rewards.append(reward)
+ return rewards
+
+
+class RepetitionPenalty(ORM):
+ # https://arxiv.org/abs/2502.03373
+ def __init__(self, repetition_n_grams: int = 3, repetition_max_penalty: float = -1.0):
+ self.ngram_size = repetition_n_grams
+ self.max_penalty = repetition_max_penalty
+
+ @staticmethod
+ def zipngram(text: str, ngram_size: int):
+ words = text.lower().split()
+ return zip(*[words[i:] for i in range(ngram_size)])
+
+ def __call__(self, completions, **kwargs) -> List[float]:
+ """
+ reward function the penalizes repetitions
+
+ Args:
+ completions: List of model completions
+ """
+ rewards = []
+ for completion in completions:
+ if completion == '':
+ rewards.append(0.0)
+ continue
+ if len(completion.split()) < self.ngram_size:
+ rewards.append(0.0)
+ continue
+
+ ngrams = set()
+ total = 0
+ for ng in self.zipngram(completion, self.ngram_size):
+ ngrams.add(ng)
+ total += 1
+
+ scaling = 1 - len(ngrams) / total
+ reward = scaling * self.max_penalty
+ rewards.append(reward)
+ return rewards
+
+
+class SoftOverlong(ORM):
+
+ def __init__(self, tokenizer, soft_max_length, soft_cache_length):
+ self.tokenizer = tokenizer
+ assert soft_cache_length < soft_max_length
+ self.soft_max_length = soft_max_length
+ self.soft_cache_length = soft_cache_length
+
+ def __call__(self, completions, **kwargs) -> List[float]:
+ rewards = []
+ for completion in completions:
+ completion_length = len(self.tokenizer.encode(completion))
+ expected_len = self.soft_max_length - self.soft_cache_length
+ exceed_len = completion_length - expected_len
+ rewards.append(min(-exceed_len / self.soft_cache_length, 0))
+ return rewards
+
+
+orms = {
+ 'toolbench': ReactORM,
+ 'math': MathORM,
+ 'accuracy': MathAccuracy,
+ 'format': Format,
+ 'react_format': ReActFormat,
+ 'cosine': CosineReward,
+ 'repetition': RepetitionPenalty,
+ 'soft_overlong': SoftOverlong,
+}
diff --git a/swift/plugin/prm.py b/swift/plugin/prm.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f2b833128f4faefc18b4b4cddf204501fcd4a9a
--- /dev/null
+++ b/swift/plugin/prm.py
@@ -0,0 +1,154 @@
+import os
+from typing import Any, Dict, List, Union
+
+import json
+
+from swift.llm import InferRequest
+
+
+class PRM:
+
+ def __call__(self, **kwargs) -> List[Any]:
+ raise NotImplementedError
+
+
+SYSTEM = """
+You are a process reward model, give the reward value of the answer, you must follow the instructions below:
+
+1. Output a float reward value between -1.0 and 1.0, -1.0 means the worst answer, 1.0 means the best answer, please think step by step to give your reasons and thoughts, but the reward must appare at the end with this format: **Reward: your-reward-value**.
+
+2. The answer may be incomplete, you must give the reward by the existing part of the answer, taking into account semantic coherence, logical correctness, and clarity.
+
+3. A ground truth answer will be given to you, it may be not the best one, consider it as a reference example.
+
+Begin!
+""" # noqa
+
+QUERY = """
+The original question or the previous conversation:
+
+#query#
+
+Here is the ground truth as the reference:
+
+#ground_truth#
+
+Given the upper information, give your reward(-1.0~1.0) of the following answer:
+
+#response#
+"""
+
+
+class QwenMaxPRM(PRM):
+
+ def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str],
+ **kwargs) -> List[float]:
+ # TODO: check request_config
+ rewards = []
+
+ from openai import OpenAI
+
+ client = OpenAI(
+ api_key=os.getenv('DASHSCOPE_API_KEY'),
+ base_url='https://dashscope.aliyuncs.com/compatible-mode/v1',
+ )
+
+ for request, ground_truth in zip(infer_requests, ground_truths):
+ previous = request['messages'][:-1]
+ if previous[0]['role'] == 'system':
+ previous = previous[1:]
+
+ assert request['messages'][-1]['role'] == 'assistant'
+ query = QUERY.replace('#query#', json.dumps(previous))
+ query = query.replace('#ground_truth#', ground_truth)
+ query = query.replace('#response#', request['messages'][-1]['content'])
+ messages = [
+ {
+ 'role': 'system',
+ 'content': SYSTEM
+ },
+ {
+ 'role': 'user',
+ 'content': query
+ },
+ ]
+ completion = client.chat.completions.create(
+ model='qwen-max',
+ messages=messages,
+ )
+
+ content = completion.choices[0].message.content
+ if 'Reward:' not in content:
+ rewards.append(0.)
+ else:
+ try:
+ reward = float(content.split('Reward:')[1].strip().replace('*', ''))
+ rewards.append(reward)
+ except Exception:
+ rewards.append(0.)
+
+ return rewards
+
+
+class ClientPRM(PRM):
+
+ def __init__(self, api_key=None, base_url=None, model=None):
+ from swift.llm import InferClient
+ import os
+ if api_key is None:
+ api_key = os.getenv('DASHSCOPE_API_KEY')
+ if base_url is None:
+ base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
+ if model is None:
+ model = 'qwen-plus'
+ self.infer_engine = InferClient(base_url=base_url, api_key=api_key)
+ self.infer_engine.strict = False
+ self.infer_kwargs = {
+ 'model': model,
+ }
+
+ def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str],
+ **kwargs) -> List[float]:
+ prm_infer_requests = []
+ request_config = kwargs.get('request_config')
+ for request, ground_truth in zip(infer_requests, ground_truths):
+ previous = request['messages'][:-1]
+ if previous[0]['role'] == 'system':
+ previous = previous[1:]
+
+ assert request['messages'][-1]['role'] == 'assistant'
+ query = QUERY.replace('#query#', json.dumps(previous))
+ query = query.replace('#ground_truth#', ground_truth)
+ query = query.replace('#response#', request['messages'][-1]['content'])
+ messages = [
+ {
+ 'role': 'system',
+ 'content': SYSTEM
+ },
+ {
+ 'role': 'user',
+ 'content': query
+ },
+ ]
+
+ prm_infer_requests.append(InferRequest(messages=messages))
+
+ responses = self.infer_engine.infer(prm_infer_requests, request_config=request_config, **self.infer_kwargs)
+ rewards = []
+ for response in responses:
+ content = response.choices[0].message.content
+ if 'Reward:' not in content:
+ rewards.append(0.)
+ else:
+ try:
+ reward = float(content.split('Reward:')[1].strip().replace('*', ''))
+ rewards.append(reward)
+ except Exception:
+ rewards.append(0.)
+ return rewards
+
+
+prms = {
+ 'qwen_max': QwenMaxPRM,
+ 'client': ClientPRM,
+}
diff --git a/swift/plugin/rm_plugin.py b/swift/plugin/rm_plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..136223542992a01e574bd80418fec1e5bc8a505a
--- /dev/null
+++ b/swift/plugin/rm_plugin.py
@@ -0,0 +1,229 @@
+import re
+import textwrap
+from copy import deepcopy
+from typing import Dict, List
+
+import torch
+
+from swift.llm import PtEngine, RequestConfig, Template, to_device
+from swift.llm.infer.protocol import ChatCompletionResponse
+from swift.utils import get_logger
+
+logger = get_logger()
+
+
+class DefaultRMPlugin:
+ """
+ Default Reward Model Plugin
+
+ This class implements the default processing logic for reward models.
+ It assumes that `self.model` is a classification model with a value head(output dimmension 1).
+ The first logits value from the model's output is used as the reward score.
+ """
+
+ def __init__(self, model, template):
+ self.model = model
+ self.template: Template = template
+
+ def __call__(self, inputs):
+ batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs]
+ reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device)
+ reward_inputs.pop('labels')
+
+ with torch.inference_mode():
+ return self.model(**reward_inputs).logits[:, 0]
+
+
+class GenRMPlugin(DefaultRMPlugin):
+
+ def __init__(self, model, template):
+ """
+ Generative Reward Model Plugin Example.
+
+ This method sets up the reward model plugin by initializing the PtEngine for efficient inference,
+ configuring the request parameters, and defining the system prompt that guides the reward model in
+ evaluating responses.
+
+ Args:
+ model (torch.nn.Module): The generative reward model.
+ template (Template): The template used for encoding input data.
+ """
+
+ super().__init__(model, template)
+ # initilize PTEngine to infer
+ self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit
+ self.request_config = RequestConfig() # customise your request config here
+ self.system = textwrap.dedent("""
+ Based on the dialogue history, analyze in detail whether the model's response is accurate, complete, and relevant.
+ Assign a reward score between 0 and 1, where 0 indicates completely incorrect and 1 indicates fully correct.
+ Before finishing your response, please assign a reward using the following format:
+
+ Reward: {reward}
+
+ For example:
+ Reward: 0.85
+ """) # noqa
+
+ def __call__(self, inputs):
+ """
+ Compute reward scores for the provided inputs.
+
+ This method processes each input by converting dialogue messages into a query, sending the query to the
+ reward model for inference, and extracting the reward scores from the model's responses. The final reward
+ for each input is the average of all extracted scores.
+ Args:
+ inputs (List[Dict]): A list of input requests. Each input request is a dictionary containing:
+ - 'messages' (List[Dict]): messages from the training model. Each message dictionary includes:
+ - 'role' (str): The role of the speaker (e.g., 'user', 'assistant').
+ - 'content' (str): The content of the message.
+ - Additional dataset columns as key-value pairs (e.g., 'solutions', 'images').
+ Returns:
+ torch.Tensor: A tensor containing the average reward scores for each input. The tensor has a shape of (N,),
+ where N is the number of input requests.
+ """
+
+ rm_inputs = self.prepare_rm_inputs(inputs)
+ results = self.engine.infer(rm_inputs, self.request_config, use_tqdm=False)
+ rewards = self.compute_rewards(results)
+ return torch.tensor(rewards, dtype=torch.float32)
+
+ def prepare_rm_inputs(self, inputs: List[Dict]) -> List[Dict]:
+ """
+ Prepare inputs for the reward model by converting messages into queries.
+
+ Args:
+ inputs (List[Dict]): A list of input requests.
+
+ Returns:
+ List[Dict]: Processed inputs for the reward model.
+ """
+ rm_inputs = []
+ for idx, infer_request in enumerate(inputs):
+ # Deep copy to prevent modification of original input
+ rm_infer_request = deepcopy(infer_request)
+
+ # Extract and convert messages to a single query string
+ messages = rm_infer_request.get('messages')
+ query = self.messages_to_query(messages)
+
+ # Construct new messages tailored for the reward model
+ rm_messages = [{'role': 'system', 'content': self.system}, {'role': 'user', 'content': query}]
+
+ # Update the messages in the reward infer request
+ rm_infer_request['messages'] = rm_messages
+ rm_inputs.append(rm_infer_request)
+ return rm_inputs
+
+ @staticmethod
+ def extract_reward(model_output: str) -> float:
+ """
+ Extract the reward score from the model's output.
+
+ Args:
+ model_output (str): The model's output string, expected to follow the format "Reward: {reward}".
+
+ Returns:
+ float: The extracted reward score.
+
+ Raises:
+ ValueError: If the reward score cannot be extracted or the format is incorrect.
+ """
+ match = re.search(r'Reward:\s*([0-1](?:\.\d+)?)', model_output)
+ if match:
+ return float(match.group(1))
+ else:
+ logger.warning("Unable to extract reward score from the model's output, set reward to 0")
+ return None
+
+ @staticmethod
+ def messages_to_query(messages):
+ """
+ Compress a list of message dictionaries into a single query string.
+
+ Args:
+ messages (list[dict]): A list of message dictionaries, each containing:
+ - 'role' (str): The role of the speaker (e.g., 'user', 'assistant').
+ - 'content' (str): The content of the message.
+
+ Returns:
+ str: A single string that concatenates all messages in a formatted manner.
+
+ Example:
+ >>> messages = [
+ ... {'role': 'user', 'content': 'Hello, how are you?'},
+ ... {'role': 'assistant', 'content': 'I am fine, thank you! How can I assist you today?'},
+ ... {'role': 'user', 'content': 'Can you help me with my homework?'}
+ ... ]
+ >>> print(messages_to_query(messages))
+ User: Hello, how are you?
+ Assistant: I am fine, thank you! How can I assist you today?
+ User: Can you help me with my homework?
+ """
+ # Initialize an empty list to hold formatted messages
+ formatted_messages = []
+
+ # Define a mapping for role capitalization if needed
+ role_mapping = {
+ 'user': 'User',
+ 'assistant': 'Assistant',
+ 'system': 'System'
+ # Add more roles here as needed
+ }
+
+ for idx, message in enumerate(messages):
+ if not isinstance(message, dict):
+ raise TypeError(f'Each message must be a dictionary. Found {type(message)} at index {idx}.')
+
+ # Extract 'role' and 'content' from each message
+ role = message.get('role')
+ content = message.get('content')
+ if not content:
+ continue
+
+ # Capitalize the role using the mapping, default to capitalized original role
+ role_formatted = role_mapping.get(role.lower(), role.capitalize())
+
+ # Append the formatted message to the list
+ formatted_messages.append(f'{role_formatted}: {content}')
+
+ # Join all formatted messages with newline characters
+ query = '\n'.join(formatted_messages)
+
+ return query
+
+ def compute_rewards(self, results: List[ChatCompletionResponse]) -> List[float]:
+ """
+ Compute average reward scores from the reward model's outputs.
+
+ Args:
+ results (List[ChatCompletionResponse]): A list of results from the reward model.
+
+ Returns:
+ List[float]: A list of average reward scores.
+ """
+ rewards = []
+ for idx, output in enumerate(results):
+ try:
+ cur_rewards = []
+ for choice in output.choices:
+ response = choice.message.content
+ reward = self.extract_reward(response)
+ cur_rewards.append(reward)
+ cur_rewards = [r for r in cur_rewards if r is not None]
+ if cur_rewards:
+ average_reward = sum(cur_rewards) / len(cur_rewards)
+ else:
+ average_reward = 0.0
+ logger.warning('No valid rewards extracted. Assigning reward score of 0.0.')
+
+ rewards.append(average_reward)
+ except Exception as e:
+ logger.error(f'Error computing reward: {e}')
+ rewards.append(0.0) # Assign default reward score on failure
+ return rewards
+
+
+rm_plugins = {
+ 'default': DefaultRMPlugin,
+ 'genrm': GenRMPlugin,
+}
diff --git a/swift/plugin/tuner.py b/swift/plugin/tuner.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8cb44d5251d92749f8aeef189df4a3f572b506e
--- /dev/null
+++ b/swift/plugin/tuner.py
@@ -0,0 +1,92 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import Optional
+
+import torch
+from peft import IA3Config, PeftModel, get_peft_model
+
+from swift.llm import MODEL_ARCH_MAPPING, ModelKeys
+from swift.utils import find_all_linears
+
+
+class Tuner:
+
+ @staticmethod
+ def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
+ """Prepare a new model with a tuner
+
+ Args:
+ args: The training arguments
+ model: The model instance
+
+ Returns:
+ The wrapped model
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def save_pretrained(
+ model: torch.nn.Module,
+ save_directory: str,
+ state_dict: Optional[dict] = None,
+ safe_serialization: bool = True,
+ **kwargs,
+ ) -> None:
+ """Save when save_steps reaches
+
+ Args:
+ model: The wrapped model by `prepare_model`
+ save_directory: The directory to save
+ safe_serialization: Use safetensors or not
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module:
+ """Load the ckpt_dir
+
+ Args:
+ model: The original model instance.
+ model_id: The model id or ckpt_dir to load
+ Returns:
+ The wrapped model instance
+ """
+ raise NotImplementedError
+
+
+class PeftTuner(Tuner):
+
+ @staticmethod
+ def save_pretrained(
+ model: torch.nn.Module,
+ save_directory: str,
+ state_dict: Optional[dict] = None,
+ safe_serialization: bool = True,
+ **kwargs,
+ ) -> None:
+ model.save_pretrained(save_directory, safe_serialization=safe_serialization, **kwargs)
+
+ @staticmethod
+ def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module:
+ return PeftModel.from_pretrained(model, model_id, **kwargs)
+
+
+# Here gives a simple example of IA3
+class IA3(PeftTuner):
+
+ @staticmethod
+ def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
+ model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch]
+ ia3_config = IA3Config(
+ target_modules=find_all_linears(model), feedforward_modules='.*' + model_arch.mlp.split('{}.')[1] + '.*')
+ return get_peft_model(model, ia3_config)
+
+
+class DummyTuner(PeftTuner):
+
+ @staticmethod
+ def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
+ return model
+
+
+# Add your own tuner here, use --train_type xxx to begin
+extra_tuners = {'ia3': IA3, 'dummy': DummyTuner}
diff --git a/swift/trainers/__init__.py b/swift/trainers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..16ae3dfe72c7ad9b0041e25932103e3495f60019
--- /dev/null
+++ b/swift/trainers/__init__.py
@@ -0,0 +1,49 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from transformers.trainer_callback import TrainerCallback
+from transformers.trainer_utils import (EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy,
+ SchedulerType)
+
+from swift.utils.import_utils import _LazyModule
+from . import callback
+
+try:
+ # https://github.com/huggingface/transformers/pull/25702
+ from transformers.trainer_utils import ShardedDDPOption
+except ImportError:
+ ShardedDDPOption = None
+
+if TYPE_CHECKING:
+ from .arguments import Seq2SeqTrainingArguments, TrainingArguments
+ from .rlhf_trainer import (CPOTrainer, DPOTrainer, KTOTrainer, ORPOTrainer, RLHFTrainerMixin, PPOTrainer,
+ RewardTrainer, GRPOTrainer)
+ from .rlhf_arguments import DPOConfig, CPOConfig, KTOConfig, ORPOConfig, PPOConfig, RewardConfig
+ from .trainer_factory import TrainerFactory
+ from .trainers import Seq2SeqTrainer, Trainer, EmbeddingTrainer
+ from .mixin import SwiftMixin
+
+else:
+ _extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
+ _import_structure = {
+ 'arguments': ['Seq2SeqTrainingArguments', 'TrainingArguments'],
+ 'rlhf_arguments':
+ ['DPOConfig', 'CPOConfig', 'KTOConfig', 'ORPOConfig', 'PPOConfig', 'RewardConfig', 'GRPOConfig'],
+ 'rlhf_trainer': [
+ 'CPOTrainer', 'DPOTrainer', 'KTOTrainer', 'ORPOTrainer', 'RLHFTrainerMixin', 'PPOTrainer', 'RewardTrainer',
+ 'GRPOTrainer'
+ ],
+ 'trainer_factory': ['TrainerFactory'],
+ 'trainers': ['Seq2SeqTrainer', 'Trainer', 'EmbeddingTrainer'],
+ 'mixin': ['SwiftMixin'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects=_extra_objects,
+ )
diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..14c98b5c1a7a14b6cd361565e3382688aeeddcb1
--- /dev/null
+++ b/swift/trainers/arguments.py
@@ -0,0 +1,214 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import math
+import os
+import platform
+from dataclasses import dataclass, field
+from functools import wraps
+from typing import List, Literal, Optional, Union
+
+import torch
+import torch.utils.checkpoint
+from transformers.training_args import TrainingArguments as HfTrainingArguments
+from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments
+
+from swift.utils import get_dist_setting, get_logger, is_liger_available, use_torchacc
+from .optimizers.galore import GaLoreConfig
+
+logger = get_logger()
+
+
+@dataclass
+class TrainArgumentsMixin:
+ """
+ check_model (bool): Flag to check the model is latest. Default is True.
+ acc_strategy (Literal['token', 'seq']): Strategy for accumulation. Default is 'token'.
+ """
+ per_device_train_batch_size: int = 1
+ per_device_eval_batch_size: int = 1
+ gradient_accumulation_steps: Optional[int] = None
+
+ gradient_checkpointing: bool = True
+ gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None
+ logging_first_step: bool = True
+ logging_steps: int = 5
+
+ weight_decay: float = 0.1
+ adam_beta2: float = 0.95
+ lr_scheduler_type: str = 'cosine'
+ lr_scheduler_kwargs: Optional[Union[dict, str]] = None
+ report_to: List[str] = field(default_factory=lambda: ['tensorboard'])
+ dataloader_num_workers: Optional[int] = None
+ dataloader_prefetch_factor: Optional[int] = None
+ use_liger_kernel: bool = False
+
+ # extra
+ check_model: bool = True
+ acc_strategy: Literal['token', 'seq'] = 'token'
+ train_dataloader_shuffle: bool = True
+ max_epochs: Optional[int] = None
+
+ # torchacc
+ metric_warmup_step: Optional[float] = 0
+ fsdp_num: int = 1
+ acc_steps: int = 1
+
+ # train-eval loop args
+ eval_use_evalscope: bool = False
+ eval_datasets: List[str] = field(default_factory=list)
+ eval_limit: Optional[int] = None
+ eval_datasets_args: Optional[Union[str, dict]] = None
+ eval_generation_config: Optional[Union[str, dict]] = None
+
+ def _fix_gradient_checkpointing(self):
+ # fix use_reentrant
+ if hasattr(torch.utils.checkpoint, '_old_checkpoint'): # avoid double patching
+ return
+ # Consistent with the default behavior of transformers.
+ use_reentrant_ = (
+ self.gradient_checkpointing_kwargs.get('use_reentrant', True)
+ if self.gradient_checkpointing_kwargs else True)
+ _old_checkpoint = torch.utils.checkpoint.checkpoint
+
+ @wraps(_old_checkpoint)
+ def _new_checkpoint(*args, use_reentrant=None, **kwargs):
+ return _old_checkpoint(*args, use_reentrant=use_reentrant_, **kwargs)
+
+ torch.utils.checkpoint._old_checkpoint = _old_checkpoint
+ torch.utils.checkpoint.checkpoint = _new_checkpoint
+ try:
+ # Fix the old version of transformers.
+ import transformers.modeling_utils
+ transformers.modeling_utils.checkpoint = _new_checkpoint
+ except (ImportError, AttributeError):
+ pass
+
+ def _init_liger(self):
+ if self.use_liger_kernel:
+ assert is_liger_available(), 'use_liger_kernel requires liger_kernels, try `pip install liger-kernel`'
+
+ def __post_init__(self):
+ from swift.llm.argument.base_args.model_args import ModelArguments
+ if use_torchacc():
+ self.dataloader_drop_last = True
+ if self.gradient_accumulation_steps is None:
+ world_size = get_dist_setting()[2]
+ self.gradient_accumulation_steps = max(1, math.ceil(16 / self.per_device_train_batch_size / world_size))
+ logger.info(f'Setting args.gradient_accumulation_steps: {self.gradient_accumulation_steps}')
+ if self.lr_scheduler_kwargs:
+ self.lr_scheduler_kwargs = ModelArguments.parse_to_dict(self.lr_scheduler_kwargs)
+ if self.gradient_checkpointing_kwargs:
+ self.gradient_checkpointing_kwargs = ModelArguments.parse_to_dict(self.gradient_checkpointing_kwargs)
+ self._fix_gradient_checkpointing()
+ self._init_liger()
+ if self.dataloader_num_workers is None:
+ if platform.system() == 'Windows':
+ self.dataloader_num_workers = 0
+ else:
+ self.dataloader_num_workers = 1
+ logger.info(f'Setting args.dataloader_num_workers: {self.dataloader_num_workers}')
+ if self.dataloader_prefetch_factor is None and self.dataloader_num_workers > 0:
+ self.dataloader_prefetch_factor = 10
+ if self.eval_use_evalscope:
+ try:
+ import evalscope
+ except ImportError:
+ raise ImportError('evalscope is not installed, please install it by `pip install evalscope`')
+ self.eval_datasets_args = ModelArguments.parse_to_dict(self.eval_datasets_args)
+ self.eval_generation_config = ModelArguments.parse_to_dict(self.eval_generation_config)
+
+ super().__post_init__()
+
+
+@dataclass
+class SwiftArgumentsMixin(TrainArgumentsMixin):
+ # Value copied from TrainArguments
+ train_type: Optional[str] = None
+ optimizer: Optional[str] = None
+ local_repo_path: Optional[str] = None
+ galore_config: Optional[GaLoreConfig] = None
+
+ def __post_init__(self):
+ if hasattr(self, 'output_dir'):
+ self.output_dir = os.path.abspath(os.path.expanduser(self.output_dir))
+ super().__post_init__()
+
+ @property
+ def place_model_on_device(self):
+ return False if use_torchacc() else super().place_model_on_device
+
+
+@dataclass
+class GRPOArgumentsMixin:
+ epsilon: float = 0.2
+ epsilon_high: Optional[float] = None
+ top_k: int = 50
+ top_p: float = 0.9
+ repetition_penalty: float = 1.
+ num_infer_workers: int = 1
+ # vllm
+ vllm_device: List[str] = field(default_factory=lambda: ['auto'])
+ vllm_gpu_memory_utilization: float = 0.9
+ vllm_max_model_len: Optional[int] = None
+ vllm_max_num_seqs: int = 256
+ vllm_enforce_eager: bool = False
+ vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}'
+ vllm_enable_prefix_caching: bool = True
+ # reward function args, see details in swift/plugin/orm.py
+ # cosine reward, https://arxiv.org/abs/2502.03373
+ cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length.
+ cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length.
+ cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length.
+ cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length.
+ cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length
+ # repetition penalty, https://arxiv.org/abs/2502.03373
+ repetition_n_grams: int = 3
+ repetition_max_penalty: float = -1.0
+
+ reward_model: Optional[List[str]] = None
+ reward_model_plugin: Optional[List[str]] = None
+ # LMDeploy in GRPO
+ use_lmdeploy: bool = False
+ lmdeploy_device: Optional[str] = 'auto'
+ lmdeploy_session_len: Optional[int] = None
+ lmdeploy_cache_max_entry_count: float = 0.8
+
+ async_generate: bool = False
+ tensor_parallel_size: int = 1
+ sleep_level: int = 0
+ move_model_batches: Optional[int] = None
+ offload_optimizer: bool = False
+ offload_model: bool = False
+ gc_collect_after_offload: bool = False
+ multi_turn_func: Optional[str] = None
+
+ # DAPO, https://arxiv.org/abs/2503.14476
+ dynamic_sample: bool = False
+ max_resample_times: int = 3
+ overlong_filter: bool = False
+ soft_max_length: Optional[int] = None
+ soft_cache_length: Optional[int] = None
+
+ # Dr. GRPO, https://arxiv.org/abs/2503.20783
+ scale_rewards: bool = True
+
+ # compatible with trl main branch(0.17.0.dev0)
+ wandb_log_unique_prompts: Optional[bool] = None
+
+ # external vllm
+ vllm_server_host: Optional[str] = None
+ vllm_server_port: int = 8000
+ vllm_server_timeout: float = 240.0
+ vllm_client = None
+
+ # dataset
+ dataset_shuffle: Optional[bool] = True
+
+
+@dataclass
+class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments):
+ pass
+
+
+@dataclass
+class Seq2SeqTrainingArguments(SwiftArgumentsMixin, HfSeq2SeqTrainingArguments):
+ pass
diff --git a/swift/trainers/callback.py b/swift/trainers/callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d0343d88fb9e59ef7e91d4e50e3494e4652cb23
--- /dev/null
+++ b/swift/trainers/callback.py
@@ -0,0 +1,124 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import math
+import os
+import time
+
+from tqdm import tqdm
+from transformers import trainer
+from transformers.trainer_callback import (DefaultFlowCallback, PrinterCallback, ProgressCallback, TrainerControl,
+ TrainerState)
+from transformers.trainer_utils import IntervalStrategy, has_length, speed_metrics
+
+from swift.utils import append_to_jsonl, is_pai_training_job, use_torchacc
+from ..utils.utils import format_time
+from .arguments import TrainingArguments
+
+
+def add_train_message(logs, state, start_time) -> None:
+ logs['global_step/max_steps'] = f'{state.global_step}/{state.max_steps}'
+ train_percentage = state.global_step / state.max_steps if state.max_steps else 0.
+ logs['percentage'] = f'{train_percentage * 100:.2f}%'
+ elapsed = time.time() - start_time
+ logs['elapsed_time'] = format_time(elapsed)
+ if train_percentage != 0:
+ logs['remaining_time'] = format_time(elapsed / train_percentage - elapsed)
+ for k, v in logs.items():
+ if isinstance(v, float):
+ logs[k] = round(logs[k], 8)
+
+
+class ProgressCallbackNew(ProgressCallback):
+
+ def on_train_begin(self, args, state, control, **kwargs):
+ if state.is_world_process_zero:
+ self.training_bar = tqdm(desc='Train', total=state.max_steps, dynamic_ncols=True)
+ self.current_step = 0
+ self.start_time = time.time()
+ if use_torchacc():
+ self.warmup_start_time = 0
+ self.warmup_metric = None
+ self.metric_warmup_step = int(args.metric_warmup_step
+ * state.max_steps) if args.metric_warmup_step < 1 else args.metric_warmup_step
+
+ def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader=None, **kwargs):
+ if state.is_world_process_zero and has_length(eval_dataloader):
+ if self.prediction_bar is None:
+ if self.training_bar is not None:
+ self.training_bar.fp.write('\n')
+ self.prediction_bar = tqdm(
+ desc='Val', total=len(eval_dataloader), leave=True, dynamic_ncols=True, position=0)
+ self.prediction_bar.update()
+
+ def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs):
+
+ if use_torchacc():
+ if state.global_step >= self.metric_warmup_step and self.warmup_start_time == 0:
+ self.warmup_start_time = time.time()
+ self.metric_warmup_step = state.global_step
+ if state.max_steps == state.global_step and self.warmup_metric is None:
+ num_steps = state.max_steps - self.metric_warmup_step
+ num_total_samples = args.train_dataset_sample
+ num_after_warmup_samples = int(num_total_samples / state.max_steps * num_steps)
+ self.warmup_metric = speed_metrics('warmup_train', self.warmup_start_time, num_after_warmup_samples,
+ num_steps)
+ self.warmup_metric['num_total_samples'] = num_total_samples
+ self.warmup_metric['num_after_warmup_samples'] = num_after_warmup_samples
+ if 'train_samples_per_second' in logs:
+ logs.update(self.warmup_metric)
+ state.log_history[-1] = logs
+
+ add_train_message(logs, state, self.start_time)
+ if not is_pai_training_job() and state.is_world_process_zero:
+ jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
+ append_to_jsonl(jsonl_path, logs)
+ super().on_log(args, state, control, logs, **kwargs)
+ if state.is_world_process_zero and self.training_bar is not None:
+ self.training_bar.refresh()
+
+
+class DefaultFlowCallbackNew(DefaultFlowCallback):
+
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ control = super().on_step_end(args, state, control, **kwargs)
+ # save the last ckpt
+ evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy
+ if state.global_step == state.max_steps:
+ if evaluation_strategy != IntervalStrategy.NO:
+ control.should_evaluate = True
+ if args.save_strategy != IntervalStrategy.NO:
+ control.should_save = True
+ return control
+
+ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ control = super().on_epoch_end(args, state, control, **kwargs)
+ evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy
+ if args.max_epochs is not None and args.max_epochs <= math.ceil(state.epoch):
+ if evaluation_strategy != IntervalStrategy.NO:
+ control.should_evaluate = True
+ if args.save_strategy != IntervalStrategy.NO:
+ control.should_save = True
+ control.should_training_stop = True
+ return control
+
+
+class PrinterCallbackNew(PrinterCallback):
+
+ def on_train_begin(self, args, state, control, **kwargs):
+ self.start_time = time.time()
+ return super().on_train_begin(args, state, control, **kwargs)
+
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ add_train_message(logs, state, self.start_time)
+ if not is_pai_training_job() and state.is_world_process_zero:
+ jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
+ append_to_jsonl(jsonl_path, logs)
+
+ _ = logs.pop('total_flos', None)
+ if state.is_world_process_zero:
+ print(logs, flush=True)
+
+
+# monkey patching
+trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew
+trainer.DEFAULT_CALLBACKS = [DefaultFlowCallbackNew]
+trainer.PrinterCallback = PrinterCallbackNew
diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbd382d99f394e16eb362ecb58da969eccef066c
--- /dev/null
+++ b/swift/trainers/mixin.py
@@ -0,0 +1,516 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Part of the implementation is borrowed from huggingface/transformers.
+import inspect
+import os
+import shutil
+import time
+from contextlib import contextmanager
+from copy import copy
+from functools import partial
+from types import MethodType
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import safetensors
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import transformers
+from datasets import Dataset as HfDataset
+from modelscope import check_local_model_is_latest
+from packaging import version
+from peft import PeftModel
+from torch.nn import Module
+from torch.utils.data import DataLoader
+from transformers import PreTrainedModel
+from transformers.data.data_collator import DataCollator
+from transformers.integrations import is_deepspeed_zero3_enabled
+from transformers.modeling_utils import unwrap_model
+from transformers.trainer import TrainerCallback
+from transformers.trainer_utils import EvalPrediction, IntervalStrategy
+from transformers.utils import is_torch_npu_available
+
+from swift.hub import get_hub
+from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, Template
+from swift.plugin import MeanMetric, compute_acc, extra_tuners
+from swift.tuners import SwiftModel
+from swift.utils import get_logger, is_mp_ddp, use_torchacc
+from swift.utils.torchacc_utils import ta_trim_graph
+from ..utils.torch_utils import get_device_count
+from .arguments import TrainingArguments
+from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model
+
+try:
+ from trl import AutoModelForCausalLMWithValueHead
+except (ImportError, RuntimeError):
+ AutoModelForCausalLMWithValueHead = None
+
+logger = get_logger()
+
+
+class SwiftMixin:
+
+ def __init__(self,
+ model: Union[PreTrainedModel, Module] = None,
+ args: TrainingArguments = None,
+ data_collator: Optional[DataCollator] = None,
+ train_dataset: Optional[HfDataset] = None,
+ eval_dataset: Optional[Union[HfDataset, Dict[str, HfDataset]]] = None,
+ template: Optional[Template] = None,
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
+ compute_loss_func: Optional[Callable] = None,
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
+ callbacks: Optional[List[TrainerCallback]] = None,
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+ **kwargs) -> None:
+ if not hasattr(train_dataset, '__len__') and args.dataloader_num_workers > 1:
+ args.dataloader_num_workers = 1
+ logger.warning('Using IterableDataset, setting args.dataloader_num_workers to 1.')
+
+ if args.check_model and hasattr(model, 'model_dir'):
+ from swift.utils.logger import ms_logger_ignore_error
+ with ms_logger_ignore_error():
+ check_local_model_is_latest(
+ model.model_dir, user_agent={
+ 'invoked_by': 'local_trainer',
+ 'third_party': 'swift',
+ })
+ if eval_dataset is None and args:
+ args.evaluation_strategy = IntervalStrategy.NO
+ args.eval_strategy = IntervalStrategy.NO
+
+ self._custom_metrics = {}
+ self.template = template
+ self.max_memory = 0
+ self.hub = get_hub()
+
+ self.model_meta = model.model_meta
+ with self.hub.patch_hub():
+ super().__init__(
+ model=model,
+ args=args,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ tokenizer=template.tokenizer,
+ model_init=model_init,
+ compute_metrics=compute_metrics,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
+ **kwargs)
+
+ self.compute_loss_func = compute_loss_func
+ if get_function(model.__class__.forward) is not get_function(model.forward):
+ self.label_names = find_labels(model)
+ self.can_return_loss = can_return_loss(model)
+ self.label_names = self.label_names or ['labels']
+ self.start_time = time.time()
+ if self.template.sequence_parallel_size > 1:
+ from swift.trainers.sequence_parallel import sequence_parallel
+ sequence_parallel.prepare_trainer(self)
+
+ def _save_initial_model(self, output_dir):
+ # pissa/olora/lora-ga
+ model = unwrap_model(self.model)
+ if isinstance(model, PeftModel):
+ config = model.peft_config.get('default')
+ init_lora_weights = getattr(config, 'init_lora_weights', None)
+ if (isinstance(init_lora_weights, str)
+ and any(s in init_lora_weights for s in ('pissa', 'olora', 'lora-ga'))):
+ config.init_lora_weights = True
+ model.save_pretrained(os.path.join(output_dir, 'initial_model'))
+ config.init_lora_weights = init_lora_weights
+
+ def _save_converted_model(self, output_dir):
+ # pissa/olora/lora-ga
+ model = unwrap_model(self.model)
+ if isinstance(model, PeftModel):
+ config = model.peft_config.get('default')
+ init_lora_weights = getattr(config, 'init_lora_weights', None)
+ if isinstance(init_lora_weights, str):
+ config = copy(config)
+ os.makedirs(os.path.join(output_dir, 'converted'), exist_ok=True)
+ if 'lora-ga' in init_lora_weights:
+ try:
+ from lora_ga.entrypoint import LoraGAContext
+ with LoraGAContext(model):
+ model.save_pretrained(
+ os.path.join(output_dir, 'converted', 'default'),
+ path_initial_model_for_weight_conversion=os.path.join(
+ os.path.dirname(output_dir), 'initial_model'),
+ )
+ model.peft_config['default'] = config
+ except ImportError as e:
+ error_message = """
+ Since 'LoRA-GA' is not implemented by PEFT, you will need to install it directly from GitHub.
+ Command: 'pip install git+https://github.com/lxline/LoRA-GA.git'.
+ """
+ logger.info(error_message)
+ raise RuntimeError(error_message) from e
+ elif 'pissa' in init_lora_weights or 'olora' in init_lora_weights:
+ model.save_pretrained(
+ os.path.join(output_dir, 'converted', 'default'),
+ path_initial_model_for_weight_conversion=os.path.join(
+ os.path.dirname(output_dir), 'initial_model'),
+ )
+ model.peft_config['default'] = config
+
+ def _load_optimizer_and_scheduler(self, *args, **kwargs):
+ super()._load_optimizer_and_scheduler(*args, **kwargs)
+ if is_mp_ddp():
+ # fix mp+ddp adamw
+ for v in self.optimizer.state.values():
+ if 'step' in v:
+ # not on the same device
+ device_set = set([t.device for t in v.values()]) - {v['step'].device, torch.device('cpu')}
+ if len(device_set) >= 1:
+ v['step'] = v['step'].to('cpu')
+
+ def _save_model(self, output_dir: Optional[str] = None, state_dict=None):
+ # model
+ supported_classes = (SwiftModel, PreTrainedModel, PeftModel)
+ supported_names = ('SentenceTransformer')
+ if AutoModelForCausalLMWithValueHead is not None:
+ supported_classes = supported_classes + (AutoModelForCausalLMWithValueHead, )
+ save_safetensors = self.args.save_safetensors
+ if not isinstance(self.model, supported_classes) and self.model.__class__.__name__ not in supported_names:
+ if state_dict is None:
+ state_dict = self.model.state_dict()
+
+ _unwrap_model = unwrap_model(self.model)
+ if isinstance(_unwrap_model, supported_classes):
+ _unwrap_model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
+ else:
+ logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.')
+ if save_safetensors:
+ safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors'))
+ else:
+ torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin'))
+ elif AutoModelForCausalLMWithValueHead and isinstance(self.model, AutoModelForCausalLMWithValueHead):
+ # save reward model
+ state_dict = self.model.state_dict()
+ decoder_state_dict, v_head_state_dict = {}, {}
+ for name, param in state_dict.items():
+ if name.startswith('v_head.'):
+ v_head_state_dict[name] = param
+ else:
+ decoder_state_dict[name.replace('pretrained_model.', '', 1)] = param
+ self.model.pretrained_model.save_pretrained(
+ output_dir, state_dict=decoder_state_dict or None, safe_serialization=save_safetensors)
+ if save_safetensors:
+ from safetensors.torch import save_file
+ save_file(
+ v_head_state_dict, os.path.join(output_dir, 'value_head.safetensors'), metadata={'format': 'pt'})
+ else:
+ torch.save(v_head_state_dict, os.path.join(output_dir, 'value_head.bin'))
+ elif is_instance_of_ms_model(self.model):
+ PreTrainedModel.save_pretrained(
+ self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
+ elif self.args.train_type in extra_tuners:
+ extra_tuners[self.args.train_type].save_pretrained(
+ self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
+ else:
+ if self.model.__class__.__name__ != 'SentenceTransformer':
+ self.model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
+ else:
+
+ @contextmanager
+ def save_context():
+ save_pretrained = self.model[0].auto_model.save_pretrained
+ _state_dict = {
+ key[len('0.auto_model.'):] if 'auto_model' in key else key: value
+ for key, value in state_dict.items()
+ }
+ self.model[0].auto_model.save_pretrained = partial(
+ self.model[0].auto_model.save_pretrained, state_dict=_state_dict)
+ yield
+ self.model[0].auto_model.save_pretrained = save_pretrained
+
+ with save_context():
+ self.model.save_pretrained(output_dir, safe_serialization=save_safetensors)
+ # copy sentencetransformers files
+ from swift.utils import copy_files_by_pattern
+ copy_files_by_pattern(self.model.model_dir, output_dir, '*.py')
+ copy_files_by_pattern(self.model.model_dir, output_dir, '*.json')
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ """Compatible with swift and peft"""
+ # If we are executing this function, we are the process zero, so we don't check for that.
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+ self._save_model(output_dir, state_dict)
+ # training_args.bin
+ torch.save(self.args, os.path.join(output_dir, 'training_args.bin'))
+ self._save_converted_model(output_dir)
+ # args.json
+ args_path = os.path.join(os.path.dirname(output_dir), 'args.json')
+ if os.path.exists(args_path):
+ shutil.copy(args_path, os.path.join(output_dir, 'args.json'))
+ # predict.jsonl
+ predict_jsonl = os.path.join(os.path.dirname(output_dir), 'predict.jsonl')
+ if os.path.exists(predict_jsonl):
+ shutil.move(predict_jsonl, os.path.join(output_dir, 'predict.jsonl'))
+
+ is_adapter = isinstance(self.model, (SwiftModel, PeftModel))
+ # tokenizer
+ if not is_adapter:
+ from swift.llm import save_checkpoint
+ additional_saved_files = self.model_meta.additional_saved_files
+ save_checkpoint(
+ None,
+ self.template.processor,
+ output_dir,
+ model_dirs=[self.model.model_dir],
+ additional_saved_files=additional_saved_files)
+ if getattr(self.model, 'origin_generation_config', None):
+ self.model.origin_generation_config.save_pretrained(output_dir)
+
+ def _fix_zero3_gather_all_parameters(self) -> None:
+ if is_deepspeed_zero3_enabled() and not hasattr(self.deepspeed, '_zero3_consolidated_16bit_state_dict_origin'):
+ parameters = inspect.signature(self.deepspeed._zero3_consolidated_16bit_state_dict).parameters
+ if 'exclude_frozen_parameters' in parameters:
+
+ def _zero3_consolidated_16bit_state_dict(model, exclude_frozen_parameters=False):
+ unwrapped = unwrap_model(model)
+ exclude_frozen_parameters = False
+ if isinstance(unwrapped, SwiftModel) and unwrapped.has_additional_modules:
+ exclude_frozen_parameters = True
+ if isinstance(unwrapped, PeftModel):
+ exclude_frozen_parameters = True
+ return model._zero3_consolidated_16bit_state_dict_origin(exclude_frozen_parameters)
+
+ self.deepspeed._zero3_consolidated_16bit_state_dict_origin = (
+ self.deepspeed._zero3_consolidated_16bit_state_dict)
+ self.deepspeed._zero3_consolidated_16bit_state_dict = MethodType(_zero3_consolidated_16bit_state_dict,
+ self.deepspeed)
+
+ def _save_checkpoint(self, *args, **kwargs):
+ self.state.last_model_checkpoint = os.path.join(self.args.output_dir, f'checkpoint-{self.state.global_step}')
+ self._fix_zero3_gather_all_parameters()
+ result = super()._save_checkpoint(*args, **kwargs)
+ logger.info(f'Saving model checkpoint to {self.state.last_model_checkpoint}')
+ return result
+
+ @staticmethod
+ @contextmanager
+ def _fix_grad_norm_nan():
+ from accelerate import Accelerator
+ origin_clip_grad_norm_ = Accelerator.clip_grad_norm_
+
+ def clip_grad_norm_(self, parameters, *args, **kwargs):
+ # If NaN occurs, ignore weight updates.
+ parameters = list(parameters)
+ grad_norm = origin_clip_grad_norm_(self, parameters, *args, **kwargs)
+ if isinstance(grad_norm, torch.Tensor) and grad_norm.isnan().item():
+ for p in parameters:
+ p.grad = None
+ return grad_norm
+
+ Accelerator.clip_grad_norm_ = clip_grad_norm_
+ try:
+ yield
+ finally:
+ Accelerator.clip_grad_norm_ = origin_clip_grad_norm_
+
+ def train(self, *args, **kwargs):
+ if self.model_meta.is_multimodal:
+ models = []
+ for model_name in ['model', 'ref_model', 'value_model']:
+ model = getattr(self, model_name, None)
+ if isinstance(model, nn.Module):
+ models.append(model)
+
+ reward_model = getattr(self, 'reward_model', None)
+ if reward_model is not None:
+ if isinstance(reward_model, list):
+ models.extend([m for m in reward_model if isinstance(m, nn.Module)])
+ elif isinstance(reward_model, nn.Module):
+ models.append(reward_model)
+
+ models = list(set(models)) # Deduplicate
+ self.template.register_post_encode_hook(models)
+ logger.info(f'Successfully registered post_encode hook: {[model.__class__.__name__ for model in models]}.')
+ self._save_initial_model(self.args.output_dir)
+ with self.hub.patch_hub(), self._fix_grad_norm_nan():
+ res = super().train(*args, **kwargs)
+ self.template.remove_post_encode_hook()
+ return res
+
+ def push_to_hub(self, *args, **kwargs):
+ with self.hub.patch_hub():
+ return super().push_to_hub(*args, **kwargs)
+
+ def get_max_cuda_memory(self, device: Optional[Union[torch.device, int]] = None) -> float:
+ if device is None:
+ mems = [torch.cuda.max_memory_reserved(device=device) for device in range(get_device_count())]
+ else:
+ mems = [torch.cuda.max_memory_reserved(device=device)]
+ mem = sum(mems) / 1024**3
+ self.max_memory = max(self.max_memory, mem)
+ return mem
+
+ def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
+ if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
+ self.control.should_log = False
+
+ # all_gather + mean() to get average loss over all processes
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
+ loss = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
+ logs: Dict[str, float] = {'loss': loss} # loss first
+
+ for k, metric in self._custom_metrics.items():
+ value = metric.compute()
+ if len(value) == 1:
+ val = list(value.values())[0]
+ logs[k] = val
+ else:
+ for k_suffix, val in value.items():
+ new_k = f'{k}_{k_suffix}'
+ logs[new_k] = val
+ metric.reset()
+
+ if version.parse(transformers.__version__) >= version.parse('4.38'):
+ grad_norm = args[0]
+ if grad_norm is not None:
+ logs['grad_norm'] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
+ logs['learning_rate'] = self._get_learning_rate()
+ if not is_torch_npu_available():
+ logs['memory(GiB)'] = round(self.get_max_cuda_memory(), 2)
+
+ elapse_time = time.time() - self.start_time
+ logs['train_speed(iter/s)'] = round(self.state.global_step / elapse_time, 6)
+ for k in list(logs.keys()):
+ if logs[k] is None:
+ logs.pop(k)
+ tr_loss -= tr_loss
+ self._total_loss_scalar += tr_loss_scalar
+ self._globalstep_last_logged = self.state.global_step
+ self.store_flos()
+ self.log(logs)
+
+ if self.args.eval_use_evalscope and self.control.should_evaluate:
+ self._evalscope_eval()
+ super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs)
+
+ def create_optimizer_and_scheduler(self, num_training_steps: int):
+ if self.args.optimizer is not None:
+ from swift.plugin import optimizers_map
+ optimizer_callback = optimizers_map[self.args.optimizer]
+ self.optimizer, self.lr_scheduler = optimizer_callback(self.args, self.model, self.train_dataset)
+ if self.optimizer is None:
+ self.create_optimizer()
+ if self.lr_scheduler is None:
+ self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
+ else:
+ super().create_optimizer_and_scheduler(num_training_steps=num_training_steps)
+
+ def _compute_acc(self, outputs, labels) -> None:
+ args = self.args
+ acc_steps = args.acc_steps
+ preds = outputs.logits.argmax(dim=-1)
+ if self.state.global_step % acc_steps == 0:
+ if use_torchacc():
+ ta_trim_graph()
+ preds = preds.to('cpu')
+ labels = labels.to('cpu')
+ metrics = compute_acc(
+ preds, labels, acc_strategy=args.acc_strategy, is_encoder_decoder=self.template.is_encoder_decoder)
+ for k, v in metrics.items():
+ if k not in self._custom_metrics:
+ self._custom_metrics[k] = MeanMetric(nan_value=None)
+ self._custom_metrics[k].update(v)
+
+ @torch.no_grad()
+ def _evalscope_eval(self):
+ from ..llm.eval.utils import EvalModel
+ from evalscope import TaskConfig, run_task
+ from evalscope.constants import EvalType
+
+ self.model.eval()
+ max_batch_size = self.args.per_device_eval_batch_size
+ custom_model = EvalModel(
+ self.model, self.template, max_batch_size=max_batch_size, model_name=f'model-step{self.state.global_step}')
+ task_config = TaskConfig(
+ model=custom_model,
+ eval_type=EvalType.CUSTOM,
+ datasets=self.args.eval_datasets,
+ dataset_args=self.args.eval_datasets_args,
+ limit=self.args.eval_limit,
+ work_dir=os.path.join(self.args.output_dir, 'eval'),
+ eval_batch_size=max_batch_size,
+ generation_config=self.args.eval_generation_config or {'max_tokens': 512},
+ )
+ # start evaluation
+ eval_report = run_task(task_config)
+ # convert to dict
+ eval_dict = {f'test_{k}': v.score for k, v in eval_report.items()}
+ self.log(eval_dict)
+
+ self.model.train()
+ return eval_dict
+
+ def get_batch_samples(self, *args, **kwargs):
+ res = super().get_batch_samples(*args, **kwargs)
+ if self.template.sequence_parallel_size == 1:
+ return res
+
+ batch_samples, num_items_in_batch = res
+ if num_items_in_batch is None:
+ num_items_in_batch = torch.tensor(0).to(args[2])
+ from swift.trainers.sequence_parallel import sequence_parallel
+ dist.all_reduce(num_items_in_batch, dist.ReduceOp.SUM, sequence_parallel.sp_group)
+ return batch_samples, num_items_in_batch
+
+
+class DataLoaderMixin:
+
+ def get_train_dataloader(self):
+ dataloader = None
+ if self.template.sequence_parallel_size > 1:
+ from swift.trainers.sequence_parallel import sequence_parallel
+ dataloader = sequence_parallel.get_dataloader(self, self.train_dataset, self._train_batch_size)
+ if dataloader is None:
+ # Higher efficiency
+ if self.train_dataset is None:
+ raise ValueError('Trainer: training requires a train_dataset.')
+ args = self.args
+ train_dataset = self.train_dataset
+
+ dataloader_params = {
+ 'collate_fn': self.data_collator,
+ 'num_workers': args.dataloader_num_workers,
+ 'pin_memory': args.dataloader_pin_memory,
+ 'persistent_workers': args.dataloader_persistent_workers,
+ 'prefetch_factor': args.dataloader_prefetch_factor
+ }
+ batch_sampler_params = {
+ 'drop_last': args.dataloader_drop_last,
+ 'shuffle': args.train_dataloader_shuffle,
+ 'data_seed': args.data_seed,
+ }
+
+ if hasattr(train_dataset, '__len__'):
+ batch_sampler = BatchSamplerShard(
+ len(train_dataset), batch_size=self._train_batch_size, **batch_sampler_params)
+ dataloader = DataLoaderShard(train_dataset, batch_sampler, **dataloader_params)
+ else:
+ # IterableDataset
+ if dist.is_initialized() and dataloader_params['prefetch_factor']:
+ dataloader_params['prefetch_factor'] = dataloader_params['prefetch_factor'] * dist.get_world_size()
+ dataloader = DataLoader(train_dataset, batch_size=self._train_batch_size, **dataloader_params)
+ dataloader = DataLoaderDispatcher(dataloader)
+
+ return dataloader
+
+ def get_eval_dataloader(self, eval_dataset=None):
+ dataloader = None
+ if self.template.sequence_parallel_size > 1:
+ from swift.trainers.sequence_parallel import sequence_parallel
+ if eval_dataset is None and self.eval_dataset is None:
+ raise ValueError('Trainer: evaluation requires an eval_dataset.')
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
+ dataloader = sequence_parallel.get_dataloader(self, eval_dataset, self.args.eval_batch_size)
+ if dataloader is None:
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
+ return dataloader
diff --git a/swift/trainers/optimizers/__init__.py b/swift/trainers/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b937315b6e719ae8289fee2908aa486222eb76c5
--- /dev/null
+++ b/swift/trainers/optimizers/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
diff --git a/swift/trainers/optimizers/galore/__init__.py b/swift/trainers/optimizers/galore/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..822853cd8c7f8a585138c45fbc9e5a44f749efb5
--- /dev/null
+++ b/swift/trainers/optimizers/galore/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+from typing import TYPE_CHECKING
+
+from swift.utils.import_utils import _LazyModule
+
+if TYPE_CHECKING:
+ from .utils import create_optimizer_and_scheduler, GaLoreConfig
+ from .adafactor import GaLoreAdafactor
+ from .adamw8bit import GaLoreAdamW8bit
+ from .adamw import GaLoreAdamW
+else:
+ _import_structure = {
+ 'utils': ['GaLoreConfig', 'create_optimizer_and_scheduler'],
+ 'adafactor': ['GaLoreAdafactor'],
+ 'adamw8bit': ['GaLoreAdamW8bit'],
+ 'adamw': ['GaLoreAdamW'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/swift/trainers/optimizers/galore/adafactor.py b/swift/trainers/optimizers/galore/adafactor.py
new file mode 100644
index 0000000000000000000000000000000000000000..98ab26477ad4d53ad1dc7de19324794cf24ae001
--- /dev/null
+++ b/swift/trainers/optimizers/galore/adafactor.py
@@ -0,0 +1,272 @@
+# copy dependencies from transformers/optimization.py
+# code borrowed from https://github.com/jiaweizzhao/GaLore
+import math
+
+import torch
+from torch.optim import Optimizer
+from transformers.utils.versions import require_version
+
+from .galore_projector import GaLoreProjector
+
+
+class Adafactor(Optimizer):
+ """
+ AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
+ https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
+
+ Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
+ this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
+ `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
+ `relative_step=False`.
+
+ Arguments:
+ params (`Iterable[nn.parameter.Parameter]`):
+ Iterable of parameters to optimize or dictionaries defining parameter groups.
+ lr (`float`, *optional*):
+ The external learning rate.
+ eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
+ Regularization constants for square gradient and parameter scale respectively
+ clip_threshold (`float`, *optional*, defaults to 1.0):
+ Threshold of root mean square of final gradient update
+ decay_rate (`float`, *optional*, defaults to -0.8):
+ Coefficient used to compute running averages of square
+ beta1 (`float`, *optional*):
+ Coefficient used for computing running averages of gradient
+ weight_decay (`float`, *optional*, defaults to 0.0):
+ Weight decay (L2 penalty)
+ scale_parameter (`bool`, *optional*, defaults to `True`):
+ If True, learning rate is scaled by root mean square
+ relative_step (`bool`, *optional*, defaults to `True`):
+ If True, time-dependent learning rate is computed instead of external learning rate
+ warmup_init (`bool`, *optional*, defaults to `False`):
+ Time-dependent learning rate computation depends on whether warm-up initialization is being used
+
+ This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
+
+ Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
+
+ - Training without LR warmup or clip_threshold is not recommended.
+
+ - use scheduled LR warm-up to fixed LR
+ - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
+ - Disable relative updates
+ - Use scale_parameter=False
+ - Additional optimizer operations like gradient clipping should not be used alongside Adafactor
+
+ Example:
+
+ ```python
+ Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
+ ```
+
+ Others reported the following combination to work well:
+
+ ```python
+ Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
+ ```
+
+ When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
+ scheduler as following:
+
+ ```python
+ from transformers.optimization import Adafactor, AdafactorSchedule
+
+ optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
+ lr_scheduler = AdafactorSchedule(optimizer)
+ trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
+ ```
+
+ Usage:
+
+ ```python
+ # replace AdamW with Adafactor
+ optimizer = Adafactor(
+ model.parameters(),
+ lr=1e-3,
+ eps=(1e-30, 1e-3),
+ clip_threshold=1.0,
+ decay_rate=-0.8,
+ beta1=None,
+ weight_decay=0.0,
+ relative_step=False,
+ scale_parameter=False,
+ warmup_init=False,
+ )
+ ```"""
+
+ def __init__(
+ self,
+ params,
+ lr=None,
+ eps=(1e-30, 1e-3),
+ clip_threshold=1.0,
+ decay_rate=-0.8,
+ beta1=None,
+ weight_decay=0.0,
+ scale_parameter=True,
+ relative_step=True,
+ warmup_init=False,
+ ):
+ require_version('torch>=1.5.0') # add_ with alpha
+ if lr is not None and relative_step:
+ raise ValueError('Cannot combine manual `lr` and `relative_step=True` options')
+ if warmup_init and not relative_step:
+ raise ValueError('`warmup_init=True` requires `relative_step=True`')
+
+ defaults = {
+ 'lr': lr,
+ 'eps': eps,
+ 'clip_threshold': clip_threshold,
+ 'decay_rate': decay_rate,
+ 'beta1': beta1,
+ 'weight_decay': weight_decay,
+ 'scale_parameter': scale_parameter,
+ 'relative_step': relative_step,
+ 'warmup_init': warmup_init,
+ }
+ super().__init__(params, defaults)
+
+ @staticmethod
+ def _get_lr(param_group, param_state):
+ rel_step_sz = param_group['lr']
+ if param_group['relative_step']:
+ min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
+ rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state['step']))
+ param_scale = 1.0
+ if param_group['scale_parameter']:
+ param_scale = max(param_group['eps'][1], param_state['RMS'])
+ return param_scale * rel_step_sz
+
+ @staticmethod
+ def _get_options(param_group, param_shape):
+ factored = len(param_shape) >= 2
+ use_first_moment = param_group['beta1'] is not None
+ return factored, use_first_moment
+
+ @staticmethod
+ def _rms(tensor):
+ return tensor.norm(2) / (tensor.numel()**0.5)
+
+ @staticmethod
+ def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
+ # copy from fairseq's adafactor implementation:
+ # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
+ r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
+ c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
+ return torch.mul(r_factor, c_factor)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """
+ Performs a single optimization step
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ if grad.dtype in {torch.float16, torch.bfloat16}:
+ grad = grad.float()
+ if grad.is_sparse:
+ raise RuntimeError('Adafactor does not support sparse gradients.')
+
+ state = self.state[p]
+
+ if 'step' not in state:
+ state['step'] = 0
+
+ # GaLore Projection
+ if 'rank' in group:
+ if 'projector' not in state:
+ state['projector'] = GaLoreProjector(
+ group['rank'],
+ update_proj_gap=group['update_proj_gap'],
+ scale=group['scale'],
+ proj_type=group['proj_type'])
+
+ grad = state['projector'].project(grad, state['step'])
+
+ grad_shape = grad.shape
+
+ factored, use_first_moment = self._get_options(group, grad_shape)
+ # State Initialization
+ if 'RMS' not in state:
+ state['step'] = 0
+
+ if use_first_moment:
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(grad)
+ if factored:
+ state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
+ state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
+ else:
+ state['exp_avg_sq'] = torch.zeros_like(grad)
+
+ state['RMS'] = 0
+ else:
+ if use_first_moment:
+ state['exp_avg'] = state['exp_avg'].to(grad)
+ if factored:
+ state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
+ state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
+ else:
+ state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
+
+ p_data_fp32 = p
+ if p.dtype in {torch.float16, torch.bfloat16}:
+ p_data_fp32 = p_data_fp32.float()
+
+ state['step'] += 1
+ state['RMS'] = self._rms(p_data_fp32)
+ lr = self._get_lr(group, state)
+
+ beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
+ update = (grad**2) + group['eps'][0]
+ if factored:
+ exp_avg_sq_row = state['exp_avg_sq_row']
+ exp_avg_sq_col = state['exp_avg_sq_col']
+
+ exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
+ exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
+
+ # Approximation of exponential moving average of square of gradient
+ update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
+ update.mul_(grad)
+ else:
+ exp_avg_sq = state['exp_avg_sq']
+
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
+ update = exp_avg_sq.rsqrt().mul_(grad)
+
+ update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
+ update.mul_(lr)
+
+ if use_first_moment:
+ exp_avg = state['exp_avg']
+ exp_avg.mul_(group['beta1']).add_(update, alpha=(1 - group['beta1']))
+ update = exp_avg
+
+ # GaLore Projection Back
+ if 'rank' in group:
+ update = state['projector'].project_back(update)
+
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(p_data_fp32, alpha=(-group['weight_decay'] * lr))
+
+ p_data_fp32.add_(-update)
+
+ if p.dtype in {torch.float16, torch.bfloat16}:
+ p.copy_(p_data_fp32)
+
+ return loss
+
+
+GaLoreAdafactor = Adafactor
diff --git a/swift/trainers/optimizers/galore/adamw.py b/swift/trainers/optimizers/galore/adamw.py
new file mode 100644
index 0000000000000000000000000000000000000000..7396334a32d974a3631e30862a384f908a6816f4
--- /dev/null
+++ b/swift/trainers/optimizers/galore/adamw.py
@@ -0,0 +1,141 @@
+# copy dependencies from transformers/optimization.py
+# code borrowed from https://github.com/jiaweizzhao/GaLore
+import math
+from typing import Callable, Iterable, Tuple
+
+import torch
+from torch import nn
+from torch.optim import Optimizer
+from transformers.utils.versions import require_version
+
+from .galore_projector import GaLoreProjector
+
+
+class AdamW(Optimizer):
+ """
+ Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
+ Regularization](https://arxiv.org/abs/1711.05101).
+
+ Parameters:
+ params (`Iterable[nn.parameter.Parameter]`):
+ Iterable of parameters to optimize or dictionaries defining parameter groups.
+ lr (`float`, *optional*, defaults to 0.001):
+ The learning rate to use.
+ betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
+ Adam's betas parameters (b1, b2).
+ eps (`float`, *optional*, defaults to 1e-06):
+ Adam's epsilon for numerical stability.
+ weight_decay (`float`, *optional*, defaults to 0.0):
+ Decoupled weight decay to apply.
+ correct_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
+ no_deprecation_warning (`bool`, *optional*, defaults to `False`):
+ A flag used to disable the deprecation warning (set to `True` to disable the warning).
+ """
+
+ def __init__(
+ self,
+ params: Iterable[nn.parameter.Parameter],
+ lr: float = 1e-3,
+ betas: Tuple[float, float] = (0.9, 0.999),
+ eps: float = 1e-6,
+ weight_decay: float = 0.0,
+ correct_bias: bool = True,
+ no_deprecation_warning: bool = False,
+ ):
+ require_version('torch>=1.5.0') # add_ with alpha
+ if lr < 0.0:
+ raise ValueError(f'Invalid learning rate: {lr} - should be >= 0.0')
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError(f'Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)')
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError(f'Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)')
+ if not 0.0 <= eps:
+ raise ValueError(f'Invalid epsilon value: {eps} - should be >= 0.0')
+ defaults = {'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay, 'correct_bias': correct_bias}
+ super().__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self, closure: Callable = None):
+ """
+ Performs a single optimization step.
+
+ Arguments:
+ closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
+
+ state = self.state[p]
+
+ if 'step' not in state:
+ state['step'] = 0
+
+ # GaLore Projection
+ if 'rank' in group:
+ if 'projector' not in state:
+ state['projector'] = GaLoreProjector(
+ group['rank'],
+ update_proj_gap=group['update_proj_gap'],
+ scale=group['scale'],
+ proj_type=group['proj_type'])
+
+ grad = state['projector'].project(grad, state['step'])
+
+ # State initialization
+ if 'exp_avg' not in state:
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(grad)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(grad)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ state['step'] += 1
+
+ # Decay the first and second moment running average coefficient
+ # In-place operations to update the averages at the same time
+ exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+
+ step_size = group['lr']
+ if group['correct_bias']: # No bias correction for Bert
+ bias_correction1 = 1.0 - beta1**state['step']
+ bias_correction2 = 1.0 - beta2**state['step']
+ step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
+
+ # compute norm gradient
+ norm_grad = exp_avg / denom
+
+ # GaLore Projection Back
+ if 'rank' in group:
+ norm_grad = state['projector'].project_back(norm_grad)
+
+ p.add_(norm_grad, alpha=-step_size)
+
+ # Just adding the square of the weights to the loss function is *not*
+ # the correct way of using L2 regularization/weight decay with Adam,
+ # since that will interact with the m and v parameters in strange ways.
+ #
+ # Instead we want to decay the weights in a manner that doesn't interact
+ # with the m/v parameters. This is equivalent to adding the square
+ # of the weights to the loss with plain (non-momentum) SGD.
+ # Add weight decay at the end (fixed version)
+ if group['weight_decay'] > 0.0:
+ p.add_(p, alpha=(-group['lr'] * group['weight_decay']))
+
+ return loss
+
+
+GaLoreAdamW = AdamW
diff --git a/swift/trainers/optimizers/galore/adamw8bit.py b/swift/trainers/optimizers/galore/adamw8bit.py
new file mode 100644
index 0000000000000000000000000000000000000000..66b0c5b621369ec16577729df5251848a8796e90
--- /dev/null
+++ b/swift/trainers/optimizers/galore/adamw8bit.py
@@ -0,0 +1,112 @@
+# code borrowed from https://github.com/jiaweizzhao/GaLore
+import torch
+from bitsandbytes.optim.optimizer import Optimizer2State
+
+from .galore_projector import GaLoreProjector
+
+
+class AdamW8bit(Optimizer2State):
+
+ def __init__(self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=1e-2,
+ amsgrad=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ is_paged=False):
+ super().__init__(
+ 'adam',
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ is_paged=is_paged)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ if not self.initialized:
+ self.check_overrides()
+ self.to_gpu() # needed for fairseq pure fp16 training
+ self.initialized = True
+
+ # if self.is_paged: self.page_mng.prefetch_all()
+ for gindex, group in enumerate(self.param_groups):
+ for pindex, p in enumerate(group['params']):
+ if p.grad is None:
+ continue
+ state = self.state[p]
+
+ if 'step' not in state:
+ state['step'] = 0
+
+ # GaLore Projection
+ if 'rank' in group:
+ if 'projector' not in state:
+ state['projector'] = GaLoreProjector(
+ group['rank'],
+ update_proj_gap=group['update_proj_gap'],
+ scale=group['scale'],
+ proj_type=group['proj_type'])
+
+ if 'weight_decay' in group and group['weight_decay'] > 0:
+ # ensure that the weight decay is not applied to the norm grad
+ group['weight_decay_saved'] = group['weight_decay']
+ group['weight_decay'] = 0
+
+ grad = state['projector'].project(p.grad, state['step'])
+
+ # suboptimal implementation
+ p.saved_data = p.data.clone()
+ p.data = grad.clone().to(p.data.dtype).to(p.data.device)
+ p.data.zero_()
+ p.grad = grad
+
+ if 'state1' not in state:
+ self.init_state(group, p, gindex, pindex)
+
+ self.prefetch_state(p)
+ self.update_step(group, p, gindex, pindex)
+ torch.cuda.synchronize()
+
+ # GaLore Projection Back
+ if 'rank' in group:
+ p.data = p.saved_data.add_(state['projector'].project_back(p.data))
+
+ # apply weight decay
+ if 'weight_decay_saved' in group:
+ p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay_saved'])
+ group['weight_decay'] = group['weight_decay_saved']
+ del group['weight_decay_saved']
+
+ if self.is_paged:
+ # all paged operation are asynchronous, we need
+ # to sync to make sure all tensors are in the right state
+ torch.cuda.synchronize()
+
+ return loss
+
+
+GaLoreAdamW8bit = AdamW8bit
diff --git a/swift/trainers/optimizers/galore/galore_projector.py b/swift/trainers/optimizers/galore/galore_projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..52fa1f0f3a3abcb92cc029f29ce390a3760667cf
--- /dev/null
+++ b/swift/trainers/optimizers/galore/galore_projector.py
@@ -0,0 +1,109 @@
+# code borrowed from https://github.com/jiaweizzhao/GaLore
+
+import torch
+
+
+class GaLoreProjector:
+
+ def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'):
+ self.rank = rank
+ self.verbose = verbose
+ self.update_proj_gap = update_proj_gap
+ self.scale = scale
+ self.ortho_matrix = None
+ self.proj_type = proj_type
+
+ def project(self, full_rank_grad, iter):
+
+ if self.proj_type == 'std':
+ if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
+ low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
+ else:
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
+ low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
+ elif self.proj_type == 'reverse_std':
+ if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
+ low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
+ else:
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
+ low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
+ elif self.proj_type == 'right':
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
+ low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
+ elif self.proj_type == 'left':
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
+ low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
+ elif self.proj_type == 'full':
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full')
+ low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
+
+ return low_rank_grad
+
+ def project_back(self, low_rank_grad):
+
+ if self.proj_type == 'std':
+ if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
+ full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
+ else:
+ full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
+ elif self.proj_type == 'reverse_std':
+ if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
+ full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
+ else:
+ full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
+ elif self.proj_type == 'right':
+ full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
+ elif self.proj_type == 'left':
+ full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
+ elif self.proj_type == 'full':
+ full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
+
+ return full_rank_grad * self.scale
+
+ # svd decomposition
+ def get_orthogonal_matrix(self, weights, rank, type):
+ module_params = weights
+
+ if module_params.data.dtype != torch.float:
+ float_data = False
+ original_type = module_params.data.dtype
+ original_device = module_params.data.device
+ matrix = module_params.data.float()
+ else:
+ float_data = True
+ matrix = module_params.data
+
+ U, s, Vh = torch.linalg.svd(matrix, full_matrices=False)
+
+ # make the smaller matrix always to be orthogonal matrix
+ if type == 'right':
+ A = U[:, :rank] @ torch.diag(s[:rank])
+ B = Vh[:rank, :]
+
+ if not float_data:
+ B = B.to(original_device).type(original_type)
+ return B
+ elif type == 'left':
+ A = U[:, :rank]
+ B = torch.diag(s[:rank]) @ Vh[:rank, :]
+ if not float_data:
+ A = A.to(original_device).type(original_type)
+ return A
+ elif type == 'full':
+ A = U[:, :rank]
+ B = Vh[:rank, :]
+ if not float_data:
+ A = A.to(original_device).type(original_type)
+ B = B.to(original_device).type(original_type)
+ return [A, B]
+ else:
+ raise ValueError('type should be left, right or full')
diff --git a/swift/trainers/optimizers/galore/utils.py b/swift/trainers/optimizers/galore/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9f243f8cba23547e5a0147d9b236c13cf7dfdc
--- /dev/null
+++ b/swift/trainers/optimizers/galore/utils.py
@@ -0,0 +1,214 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import importlib
+from dataclasses import dataclass
+from typing import Any, Dict, List, Tuple, Union
+
+import torch
+from torch import nn
+from torch.optim import Optimizer
+from transformers import Trainer, TrainingArguments, get_scheduler
+
+from swift.utils import get_logger
+
+try:
+ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+except ImportError:
+ from torch.optim.lr_scheduler import LRScheduler
+
+logger = get_logger()
+
+
+@dataclass
+class GaLoreConfig:
+ """
+ The configuration class for the Galore module.
+
+
+ See https://arxiv.org/abs/2403.03507
+
+ Args:
+ rank (`int`): The galore rank
+ target_modules (`Union[str, List[str]]`): The target modules to use, if `None`,
+ will use all attn and mlp linears
+ update_proj_gap(`int`): The projection update interval for galore
+ proj_type(`str`) The project type of Galore, valid values are `std`,
+ `reverse_std`, `right`, `left`, `full`
+ galore_scale(float): the scale of gradient
+ optim_per_parameter(bool): Gives one optimizer per parameter
+ """
+ rank: int = 128
+ target_modules: Union[str, List[str]] = None
+ update_proj_gap: int = 50
+ galore_scale: float = 1.0
+ proj_type: str = 'std'
+ optim_per_parameter: bool = False
+ quantize: bool = False
+ proj_quant: bool = False
+ proj_bits: int = 4
+ proj_group_size: int = 256
+ cos_threshold: float = 0.4
+ gamma_proj: int = 2
+ queue_size: int = 5
+
+
+class GaloreOptimizerWrapper(Optimizer):
+
+ def __init__(self, optimizers: Dict[Any, Optimizer]):
+ self.optimizers = optimizers
+ super().__init__([torch.tensor([1., 2., 3.])], {'lr': 1.})
+
+ def zero_grad(self, *args, **kwargs) -> None:
+ for optim in self.optimizers.values():
+ optim.zero_grad(*args, **kwargs)
+
+ def step(self, *args, **kwargs) -> None:
+ for optim in self.optimizers.values():
+ optim.step(*args, **kwargs)
+
+
+class GaloreSchedulerWrapper(LRScheduler):
+
+ def __init__(self, lr_schedulers: Dict[Any, LRScheduler]):
+ self.lr_schedulers = lr_schedulers
+
+ def step(self, *args, **kwargs) -> None:
+ for lr_scheduler in self.lr_schedulers.values():
+ lr_scheduler.step(*args, **kwargs)
+ self._last_lr = lr_scheduler.get_last_lr()
+
+
+def create_optimizer_and_scheduler(model: nn.Module, args: TrainingArguments, config: GaLoreConfig, max_steps,
+ **defaults):
+ galore_params = []
+ for module_name, module in model.named_modules():
+ if not isinstance(module, (nn.Linear, nn.Embedding)) or \
+ not any(target_key in module_name for target_key in config.target_modules):
+ continue
+
+ if not module.weight.requires_grad:
+ continue
+
+ logger.info(f'Enable GaLore for weights in module: {module_name}')
+ galore_params.append(module.weight)
+
+ id_galore_params = [id(p) for p in galore_params]
+ galore_defaults = {
+ 'rank': config.rank,
+ 'update_proj_gap': config.update_proj_gap,
+ 'scale': config.galore_scale,
+ 'proj_type': config.proj_type,
+ **defaults
+ }
+ if config.quantize:
+ galore_defaults['quant'] = config.proj_quant
+ galore_defaults['quant_n_bit'] = config.proj_bits
+ galore_defaults['quant_group_size'] = config.proj_group_size
+ galore_defaults['cos_threshold'] = config.cos_threshold
+ galore_defaults['gamma_proj'] = config.gamma_proj
+ galore_defaults['queue_size'] = config.queue_size
+ optim_cls, optim_kwargs = get_optimizer(args, config)
+
+ if config.optim_per_parameter and not config.quantize:
+ # q-galore does not support optim_per_parameter
+ optimizer_dict = {}
+ galore_defaults['update_proj_gap'] = galore_defaults['update_proj_gap'] * 2
+ for p in model.parameters():
+ if p.requires_grad:
+ if id(p) in id_galore_params:
+ optimizer_dict[p] = optim_cls([{'params': [p], **galore_defaults}], **optim_kwargs)
+ else:
+ optimizer_dict[p] = optim_cls([{'params': [p], **defaults}], **optim_kwargs)
+
+ # get scheduler dict
+ scheduler_dict = {}
+ for p in model.parameters():
+ if p.requires_grad:
+ scheduler_dict[p] = get_scheduler(
+ optimizer=optimizer_dict[p],
+ name=args.lr_scheduler_type,
+ num_training_steps=max_steps * 2,
+ num_warmup_steps=args.warmup_steps * 2,
+ scheduler_specific_kwargs=args.lr_scheduler_kwargs,
+ )
+
+ return GaloreOptimizerWrapper(optimizer_dict), GaloreSchedulerWrapper(scheduler_dict)
+ else:
+ decay_parameters = Trainer.get_decay_parameter_names(Trainer, model)
+ param_groups = [{
+ 'params': galore_params,
+ **galore_defaults,
+ }]
+ param_groups.extend([
+ {
+ 'params': [
+ p for n, p in model.named_parameters()
+ if (n in decay_parameters and id(p) not in id_galore_params and p.requires_grad)
+ ],
+ 'weight_decay':
+ defaults['weight_decay'],
+ },
+ {
+ 'params': [
+ p for n, p in model.named_parameters()
+ if (n not in decay_parameters and id(p) not in id_galore_params and p.requires_grad)
+ ],
+ 'weight_decay':
+ 0.0,
+ },
+ ])
+ optim = optim_cls(param_groups, **optim_kwargs)
+ scheduler = get_scheduler(
+ optimizer=optim,
+ name=args.lr_scheduler_type,
+ num_training_steps=max_steps,
+ num_warmup_steps=args.warmup_steps,
+ scheduler_specific_kwargs=args.lr_scheduler_kwargs,
+ )
+ return optim, scheduler
+
+
+def get_optimizer(args: TrainingArguments, config: GaLoreConfig) -> Tuple[Any, Any]:
+ # parse args.optim_args
+ optim_args = {}
+ if args.optim_args:
+ for mapping in args.optim_args.replace(' ', '').split(','):
+ key, value = mapping.split('=')
+ optim_args[key] = value
+
+ optimizer_kwargs = {'lr': args.learning_rate}
+
+ adam_kwargs = {
+ 'betas': (args.adam_beta1, args.adam_beta2),
+ 'eps': args.adam_epsilon,
+ }
+ if args.optim == 'adafactor':
+ from .adafactor import GaLoreAdafactor
+ optimizer_cls = GaLoreAdafactor
+ optimizer_kwargs.update({'scale_parameter': False, 'relative_step': False})
+ elif args.optim in ('adamw_hf', 'adamw_torch'):
+ if config.quantize:
+ assert importlib.util.find_spec('q_galore_torch') is not None, \
+ 'Please install q-galore by `pip install q_galore_torch`'
+ logger.info('If you encounter `absmax2` error, please downgrade your bitsandbytes to 0.40.0')
+ from swift.utils import get_dist_setting
+ _, _, world_size, _ = get_dist_setting()
+ if world_size > 1:
+ # from q_galore_torch import QGaLoreAdamW8bit_simulate as GaLoreAdamW
+ from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW
+ else:
+ from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW
+ else:
+ from .adamw import GaLoreAdamW
+ optimizer_cls = GaLoreAdamW
+ optimizer_kwargs.update(adam_kwargs)
+ elif 'adamw' in args.optim and '8bit' in args.optim:
+ try:
+ from .adamw8bit import GaLoreAdamW8bit
+ optimizer_cls = GaLoreAdamW8bit
+ optimizer_kwargs.update(adam_kwargs)
+ optimizer_kwargs.update({'optim_bits': 8, 'is_paged': 'paged' in args.optim})
+ except ImportError:
+ raise ValueError('Trainer tried to instantiate bnb optimizer but bnb is not installed!')
+ else:
+ raise ValueError(f'Galore not supported for optimizer type: {args.optim}')
+ return optimizer_cls, optimizer_kwargs
diff --git a/swift/trainers/rlhf_arguments.py b/swift/trainers/rlhf_arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..268bca7aad8cfca2e57a589db6ec60b9d3f8feef
--- /dev/null
+++ b/swift/trainers/rlhf_arguments.py
@@ -0,0 +1,63 @@
+from dataclasses import dataclass, field
+from typing import List
+
+from trl import CPOConfig as HfCPOConfig
+from trl import DPOConfig as HfDPOConfig
+from trl import GRPOConfig as HfGRPOConfig
+from trl import KTOConfig as HfKTOConfig
+from trl import ORPOConfig as HfORPOConfig
+from trl import PPOConfig as HfPPOConfig
+from trl import RewardConfig as HfRewardConfig
+
+from .arguments import GRPOArgumentsMixin, SwiftArgumentsMixin
+
+
+@dataclass
+class DPOConfig(SwiftArgumentsMixin, HfDPOConfig):
+ pass
+
+
+@dataclass
+class CPOConfig(SwiftArgumentsMixin, HfCPOConfig):
+ pass
+
+
+@dataclass
+class ORPOConfig(SwiftArgumentsMixin, HfORPOConfig):
+ pass
+
+
+@dataclass
+class KTOConfig(SwiftArgumentsMixin, HfKTOConfig):
+ pass
+
+
+@dataclass
+class RewardConfig(SwiftArgumentsMixin, HfRewardConfig):
+ pass
+
+
+@dataclass
+class PPOConfig(SwiftArgumentsMixin, HfPPOConfig):
+ pass
+
+
+@dataclass
+class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig):
+ stop_words: List[str] = field(default_factory=list)
+
+ def __post_init__(self):
+ from swift.llm.argument.base_args.model_args import ModelArguments
+ super().__post_init__()
+ if self.cosine_max_len is None:
+ self.cosine_max_len = self.max_completion_length
+ self.vllm_limit_mm_per_prompt = ModelArguments.parse_to_dict(self.vllm_limit_mm_per_prompt)
+
+ if self.deepspeed and 'zero_optimization' in self.deepspeed and self.deepspeed['zero_optimization'][
+ 'stage'] == 3:
+ # https://github.com/modelscope/ms-swift/issues/3237
+ self.deepspeed['zero_optimization']['stage3_prefetch_bucket_size'] = 0
+ self.deepspeed_plugin.hf_ds_config.config['zero_optimization']['stage3_prefetch_bucket_size'] = 0
+
+ # https://github.com/modelscope/ms-swift/issues/3863
+ self.dataloader_drop_last = True
diff --git a/swift/trainers/rlhf_trainer/__init__.py b/swift/trainers/rlhf_trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b6d6a7fa3c254acb5ab1ae855de18b0c70ceaaa
--- /dev/null
+++ b/swift/trainers/rlhf_trainer/__init__.py
@@ -0,0 +1,37 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from swift.utils.import_utils import _LazyModule
+
+if TYPE_CHECKING:
+ from .cpo_trainer import CPOTrainer
+ from .dpo_trainer import DPOTrainer
+ from .grpo_trainer import GRPOTrainer
+ from .kto_trainer import KTOTrainer
+ from .orpo_trainer import ORPOTrainer
+ from .ppo_trainer import PPOTrainer
+ from .reward_trainer import RewardTrainer
+ from .rlhf_mixin import RLHFTrainerMixin
+ from .utils import _split_into_mini_batches, patch_lora_merge, patch_lora_unmerge, round_robin
+else:
+ _import_structure = {
+ 'cpo_trainer': ['CPOTrainer'],
+ 'dpo_trainer': ['DPOTrainer'],
+ 'grpo_trainer': ['GRPOTrainer'],
+ 'kto_trainer': ['KTOTrainer'],
+ 'orpo_trainer': ['ORPOTrainer'],
+ 'ppo_trainer': ['PPOTrainer'],
+ 'reward_trainer': ['RewardTrainer'],
+ 'rlhf_mixin': ['RLHFTrainerMixin'],
+ 'utils': ['_split_into_mini_batches', 'patch_lora_merge', 'patch_lora_unmerge', 'round_robin'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/swift/trainers/rlhf_trainer/cpo_trainer.py b/swift/trainers/rlhf_trainer/cpo_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..25e4c93578d7d732e581ddfac46420bf5ffe6548
--- /dev/null
+++ b/swift/trainers/rlhf_trainer/cpo_trainer.py
@@ -0,0 +1,32 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import warnings
+from typing import Optional, Union
+
+import torch.nn as nn
+from transformers import PreTrainedModel
+from trl import CPOTrainer as HFCPOTrainer
+
+from ..mixin import SwiftMixin
+from .rlhf_mixin import RLHFTrainerMixin
+
+del HFCPOTrainer.__init__
+
+
+class CPOTrainer(RLHFTrainerMixin, SwiftMixin, HFCPOTrainer):
+
+ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs):
+ ref_model = kwargs.get('ref_model')
+ assert ref_model is None, 'CPO/SimPO does not require a ref_model.'
+
+ args = kwargs['args']
+ self.label_smoothing = args.label_smoothing
+ self.loss_type = args.loss_type
+ self.cpo_alpha = args.cpo_alpha
+ if args.loss_type == 'simpo':
+ self.simpo_gamma = args.simpo_gamma
+ if self.cpo_alpha > 0:
+ warnings.warn('You are using CPO-SimPO method because you set a non-zero cpo_alpha. '
+ 'This will result in the CPO-SimPO method '
+ '(https://github.com/fe1ixxu/CPO_SIMPO/tree/main). '
+ 'If you want to use a pure SimPO method, please set cpo_alpha to 0.')
+ super().__init__(model, *_args, **kwargs)
diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f03af82120fe16d29424383b3c68765d8e90355
--- /dev/null
+++ b/swift/trainers/rlhf_trainer/dpo_trainer.py
@@ -0,0 +1,129 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from peft import PeftModel
+from transformers import PreTrainedModel
+from trl import DPOTrainer as HFDPOTrainer
+
+from ..mixin import DataLoaderMixin, SwiftMixin
+from .rlhf_mixin import RLHFTrainerMixin
+
+del HFDPOTrainer.__init__
+
+
+class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer):
+
+ def __init__(self,
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
+ *_args,
+ **kwargs):
+ from trl.trainer import FDivergenceConstants
+ args = kwargs['args']
+ self.label_smoothing = args.label_smoothing
+ self.loss_type = args.loss_type
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
+ self.f_divergence_type = args.f_divergence_type
+ self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}
+ self.is_peft_model = isinstance(model, PeftModel)
+
+ self.ref_adapter_name = args.ref_adapter_name
+ self.reference_free = args.reference_free
+ self.use_weighting = False
+
+ super().__init__(model, ref_model, *_args, **kwargs)
+
+ def get_nll_loss(self, logits, labels):
+ if not self.is_encoder_decoder:
+ # Shift so that tokens < n predict n
+ logits = logits[..., :-1, :].contiguous()
+ labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_token_id)
+ logits = logits.view(-1, logits.shape[-1])
+ labels = labels.view(-1)
+ # Enable model parallelism
+ labels = labels.to(logits.device)
+ return loss_fct(logits, labels)
+
+ def concatenated_forward(
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+ batch = batch.copy()
+ num_examples = batch['labels'].shape[0] // 2
+ labels = batch.pop('labels', None)
+ if self.is_encoder_decoder:
+ batch['labels'] = labels
+
+ if self.aux_loss_enabled:
+ batch['output_router_logits'] = True
+ outputs = model(**batch, use_cache=False)
+ batch['labels'] = labels
+ if outputs.logits.shape[1] != labels.shape[1]:
+ # for llava, the model returns logits for the entire sequence, including the image tokens
+ # (placed before the text tokens)
+ outputs.logits = outputs.logits[:, -labels.shape[1]:]
+ for key in ['input_ids', 'attention_mask', 'labels']:
+ batch[f'concatenated_{key}'] = batch.pop(key, None)
+ if self.__class__.__name__ == 'ORPOTrainer': # Pass-through labels
+ batch['concatenated_input_ids'] = batch['concatenated_labels']
+
+ all_logits = outputs.logits
+
+ if all_logits.shape[:2] != batch['concatenated_labels'].shape[:2]:
+ # for llava, the model returns logits for the entire sequence,
+ # including the image tokens (placed before the text tokens)
+ seq_len = batch['concatenated_labels'].shape[1]
+ all_logits = all_logits[:, -seq_len:]
+
+ all_logps, size_completion = self.get_batch_logps(
+ all_logits,
+ batch['concatenated_labels'],
+ is_encoder_decoder=self.is_encoder_decoder,
+ label_pad_token_id=self.label_pad_token_id,
+ )
+
+ output = {}
+
+ if self.args.rpo_alpha is not None:
+ labels = batch['concatenated_labels'].clone()
+ output['nll_loss'] = self.get_nll_loss(all_logits[:num_examples], labels[:num_examples])
+
+ if self.loss_type == 'ipo':
+ all_logps = all_logps / size_completion
+
+ output['chosen_logps'] = all_logps[:num_examples]
+ output['rejected_logps'] = all_logps[num_examples:]
+ output['mean_chosen_logits'] = all_logits[:num_examples].mean()
+ output['mean_rejected_logits'] = all_logits[num_examples:].mean()
+
+ if self.aux_loss_enabled:
+ output['aux_loss'] = outputs.aux_loss
+
+ return output
+
+ @staticmethod
+ def get_batch_logps(
+ logits: torch.FloatTensor,
+ labels: torch.LongTensor,
+ label_pad_token_id: int = -100,
+ is_encoder_decoder: bool = False,
+ ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
+ if logits.shape[:-1] != labels.shape:
+ raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}'
+ 'and labels must have the same shape {labels.shape}')
+ if not is_encoder_decoder:
+ labels = labels[:, 1:].clone()
+ logits = logits[:, :-1, :]
+ else:
+ labels = labels.clone()
+
+ loss_mask = labels != label_pad_token_id
+
+ labels[labels == label_pad_token_id] = 0
+
+ per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
+
+ return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
diff --git a/swift/trainers/rlhf_trainer/kto_trainer.py b/swift/trainers/rlhf_trainer/kto_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f56d0fd6056fe3eb1001bc862bc1f807621264aa
--- /dev/null
+++ b/swift/trainers/rlhf_trainer/kto_trainer.py
@@ -0,0 +1,69 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from peft import PeftModel
+from transformers import PreTrainedModel
+from trl import KTOTrainer as HFKTOTrainer
+
+from swift.utils import get_logger
+from ..mixin import SwiftMixin
+from .rlhf_mixin import RLHFTrainerMixin
+
+logger = get_logger()
+
+del HFKTOTrainer.__init__
+
+
+class KTOTrainer(RLHFTrainerMixin, SwiftMixin, HFKTOTrainer):
+
+ def __init__(self,
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
+ *_args,
+ **kwargs):
+ args = kwargs['args']
+ args.disable_dropout = True
+ self.desirable_weight = args.desirable_weight
+ self.undesirable_weight = args.undesirable_weight
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
+ self.is_peft_model = isinstance(model, PeftModel)
+ if hasattr(args, 'loss_type'):
+ self.loss_type = args.loss_type
+ else:
+ self.loss_type = 'kto'
+
+ self.ref_adapter_name = None
+ # Not all losses require a KL calculation
+ self.calculate_KL = True
+ if self.loss_type in ['apo_zero_unpaired']:
+ self.calculate_KL = False
+ super().__init__(model, ref_model, *_args, **kwargs)
+
+ def forward(
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+ is_kl = True
+
+ def _add_data_hook(model, args, kwargs):
+ nonlocal is_kl
+ if is_kl:
+ kwargs = {k[len('KL_completion_'):]: v for k, v in batch.items() if k.startswith('KL_completion_')}
+ else:
+ kwargs = {k[len('completion_'):]: v for k, v in batch.items() if k.startswith('completion_')}
+ is_kl = not is_kl
+ return (), kwargs
+
+ @contextmanager
+ def _patch_model_call():
+ handle = model.register_forward_pre_hook(_add_data_hook, with_kwargs=True, prepend=True)
+
+ try:
+ yield
+ finally:
+ handle.remove()
+
+ with _patch_model_call():
+ return super().forward(model, batch)
diff --git a/swift/trainers/sequence_parallel/xtuner.py b/swift/trainers/sequence_parallel/xtuner.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e43b6bb65aeeee18b6ba40fb42e44db9c4394d
--- /dev/null
+++ b/swift/trainers/sequence_parallel/xtuner.py
@@ -0,0 +1,127 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import Any
+
+import datasets
+import torch
+import torch.distributed as dist
+from datasets import Dataset
+from torch.utils.data import DataLoader
+from transformers.trainer_utils import seed_worker
+
+from .base import SequenceParallel
+
+
+class XTuner(SequenceParallel):
+
+ @staticmethod
+ def assert_xtuner_runtime_condition():
+ from swift.utils import is_xtuner_available
+ assert is_xtuner_available(), \
+ ('Please install XTuner first to pack dataset to `max_length`.'
+ '`pip install -U \'xtuner[deepspeed]\'`')
+ assert dist.is_initialized(), 'pack_to_max_length is only available with distributed training.'
+
+ def pack_dataset_xtuner(self, dataset: Dataset, args: Any) -> Any:
+ self.assert_xtuner_runtime_condition()
+ if dist.get_rank() == 0:
+ ds = [i[0] for i in dataset.data]
+ train_dataset = Dataset.from_list(ds)
+ from xtuner.dataset.huggingface import pack_dataset
+ train_dataset = pack_dataset(
+ train_dataset,
+ max_length=args.max_length,
+ use_varlen_attn=False,
+ shuffle_before_pack=True,
+ map_num_proc=16)
+ objects = [train_dataset]
+ train_dataset.save_to_disk('alpaca_pack')
+ else:
+ objects = [None]
+ dist.broadcast_object_list(objects, src=0)
+ train_dataset = objects[0]
+ return train_dataset
+
+ @property
+ def sp_group(self):
+ from xtuner.parallel.sequence import get_sequence_parallel_group
+ return get_sequence_parallel_group()
+
+ def init_sequence_parallel(self, size):
+ self.assert_xtuner_runtime_condition()
+ from xtuner.parallel.sequence import init_sequence_parallel
+ init_sequence_parallel(size)
+
+ def prepare_model(self, model, tokenizer, split_in_forward):
+ self.assert_xtuner_runtime_condition()
+ from xtuner.model.modules.dispatch import dispatch_modules
+ dispatch_modules(model)
+
+ def pad_and_split_inputs(self,
+ tokenizer,
+ input_ids,
+ input_embeds,
+ labels,
+ position_ids,
+ attention_mask,
+ loss_scale,
+ embed_tokens=None):
+ self.assert_xtuner_runtime_condition()
+ from xtuner.parallel.sequence import (pad_for_sequence_parallel, split_for_sequence_parallel,
+ get_sequence_parallel_group)
+ input_ids = pad_for_sequence_parallel(input_ids, padding_value=tokenizer.pad_token_id, dim=-1)
+ labels = pad_for_sequence_parallel(labels, padding_value=-100, dim=-1)
+ position_ids = pad_for_sequence_parallel(position_ids, padding_value=0, dim=-1)
+ if attention_mask is not None:
+ attention_mask = pad_for_sequence_parallel(attention_mask, padding_value=0, dim=-1)
+
+ sp_group = get_sequence_parallel_group()
+ input_ids = split_for_sequence_parallel(input_ids, dim=1, sp_group=sp_group)
+ labels = split_for_sequence_parallel(labels, dim=1, sp_group=sp_group)
+ position_ids = split_for_sequence_parallel(position_ids, dim=1, sp_group=sp_group)
+ if attention_mask is not None:
+ attention_mask = split_for_sequence_parallel(attention_mask, dim=-1, sp_group=sp_group)
+ if loss_scale is not None:
+ loss_scale = pad_for_sequence_parallel(loss_scale, padding_value=0., dim=-1)
+ loss_scale = split_for_sequence_parallel(loss_scale, dim=1, sp_group=sp_group)
+
+ return input_ids, None, labels, position_ids, attention_mask, loss_scale
+
+ def reduce_outputs(self, loss, labels):
+ from xtuner.parallel.sequence import (reduce_sequence_parallel_loss, get_sequence_parallel_group)
+ # reduce loss for logging correctly
+ num_tokens = (labels != -100).sum()
+ return reduce_sequence_parallel_loss(loss, num_tokens, get_sequence_parallel_group())
+
+ def world_size(self):
+ self.assert_xtuner_runtime_condition()
+ from xtuner.parallel.sequence import get_sequence_parallel_world_size
+ return get_sequence_parallel_world_size()
+
+ def prepare_trainer(self, trainer):
+ pass
+
+ def get_dataloader(self, trainer, dataset, batch_size):
+ # modified from HFTrainer.get_train_dataloader
+ # RandomSampler -> SequenceParallelSampler
+ self.assert_xtuner_runtime_condition()
+ data_collator = trainer.data_collator
+ if isinstance(dataset, datasets.Dataset):
+ dataset = trainer._remove_unused_columns(dataset, description='training')
+ else:
+ data_collator = trainer._get_collator_with_removed_columns(data_collator, description='training')
+
+ dataloader_params = {
+ 'batch_size': batch_size,
+ 'collate_fn': data_collator,
+ 'num_workers': trainer.args.dataloader_num_workers,
+ 'pin_memory': trainer.args.dataloader_pin_memory,
+ 'persistent_workers': trainer.args.dataloader_persistent_workers,
+ }
+
+ if not isinstance(dataset, torch.utils.data.IterableDataset):
+ from xtuner.parallel import SequenceParallelSampler
+ dataloader_params['sampler'] = SequenceParallelSampler(dataset, seed=1024)
+ dataloader_params['drop_last'] = trainer.args.dataloader_drop_last
+ dataloader_params['worker_init_fn'] = seed_worker
+
+ return DataLoader(dataset, **dataloader_params)
diff --git a/swift/trainers/torchacc_mixin.py b/swift/trainers/torchacc_mixin.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cb373794be9040aa4d0bd56b96d9a1fccf14812
--- /dev/null
+++ b/swift/trainers/torchacc_mixin.py
@@ -0,0 +1,156 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import shutil
+from typing import Optional
+
+from transformers import PreTrainedModel, is_datasets_available
+
+from swift.utils import use_torchacc
+from swift.utils.torchacc_utils import (patch_clip_grad_norm, save_ta_ddp_checkpoint, save_ta_fsdp_checkpoint,
+ ta_eval_dataloader, ta_load_optimizer_and_scheduler,
+ ta_save_optimizer_and_scheduler, ta_test_dataloader, ta_train_dataloader,
+ ta_trim_graph)
+
+
+class TorchAccMixin:
+
+ def __init__(self, *args, **kwargs):
+ if use_torchacc():
+ patch_clip_grad_norm(self.accelerator)
+ super().__init__(*args, **kwargs)
+
+ def get_train_dataloader(self):
+ if not use_torchacc():
+ return super().get_train_dataloader()
+
+ if is_datasets_available():
+ import datasets
+
+ if self.train_dataset is None:
+ raise ValueError('Trainer: training requires a train_dataset.')
+
+ train_dataset = self.train_dataset
+ data_collator = self.data_collator
+
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
+ train_dataset = self._remove_unused_columns(train_dataset, description='training')
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description='training')
+
+ return ta_train_dataloader(train_dataset, data_collator, self._get_train_sampler(), self.args,
+ self._train_batch_size)
+
+ def get_eval_dataloader(self, eval_dataset=None):
+
+ if not use_torchacc():
+ return super().get_eval_dataloader(eval_dataset)
+
+ if is_datasets_available():
+ import datasets
+
+ if eval_dataset is None and self.eval_dataset is None:
+ raise ValueError('Trainer: evaluation requires an eval_dataset.')
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
+ data_collator = self.data_collator
+
+ if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
+ eval_dataset = self._remove_unused_columns(eval_dataset, description='evaluation')
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description='evaluation')
+
+ return ta_eval_dataloader(eval_dataset, data_collator, self._get_eval_sampler(eval_dataset), self.args)
+
+ def get_test_dataloader(self, test_dataset):
+
+ if not use_torchacc():
+ return super().get_test_dataloader(test_dataset)
+
+ if is_datasets_available():
+ import datasets
+
+ data_collator = self.data_collator
+
+ if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
+ test_dataset = self._remove_unused_columns(test_dataset, description='test')
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description='test')
+
+ return ta_test_dataloader(test_dataset, data_collator, self._get_eval_sampler(test_dataset), self.args)
+
+ def _save_tpu(self, output_dir: Optional[str] = None):
+
+ if not use_torchacc():
+ return super()._save_tpu(output_dir)
+
+ import torch_xla.core.xla_model as xm
+
+ # Compatible with swift and peft
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
+
+ if xm.is_master_ordinal(local=False):
+ os.makedirs(output_dir, exist_ok=True)
+ # configuration.json
+ model_dir = getattr(self.model, 'model_dir', None)
+ if model_dir is not None:
+ src_path = os.path.join(model_dir, 'configuration.json')
+ dst_path = os.path.join(output_dir, 'configuration.json')
+ if os.path.exists(src_path):
+ shutil.copy(src_path, dst_path)
+ else:
+ self._create_configuration_file(self.model, output_dir)
+ self._save_sft_args(output_dir)
+ # generation_config
+ generation_config = getattr(self.args, 'generation_config', None)
+ if generation_config is not None:
+ generation_config.save_pretrained(output_dir)
+
+ # model
+ if self.args.fsdp_num > 1:
+ save_ta_fsdp_checkpoint(self.model, self.tokenizer, self.args, output_dir)
+ else:
+ save_ta_ddp_checkpoint(self.model, self.tokenizer, self.args, output_dir)
+
+ # additional files
+ if xm.is_master_ordinal(local=False):
+ if self.args is not None and self.args.sft_type == 'full':
+ additional_files = getattr(self.args, 'additional_saved_files',
+ None) or [] + ['preprocessor_config.json']
+ if model_dir is not None:
+ for file in additional_files:
+ src_path = os.path.join(model_dir, file)
+ dst_path = os.path.join(output_dir, file)
+ if os.path.isfile(src_path):
+ shutil.copy(src_path, dst_path)
+ elif os.path.isdir(src_path):
+ shutil.copytree(src_path, dst_path)
+
+ def _load_optimizer_and_scheduler(self, checkpoint):
+
+ if not use_torchacc() or self.args.fsdp_num == 1:
+ return super()._load_optimizer_and_scheduler(checkpoint)
+
+ self.optimizer, self.lr_scheduler = ta_load_optimizer_and_scheduler(self.optimizer, self.lr_scheduler,
+ checkpoint, self.args.device)
+
+ def _save_optimizer_and_scheduler(self, output_dir):
+ if not use_torchacc() or not self.args.fsdp_num == 1:
+ return super()._save_optimizer_and_scheduler(output_dir)
+
+ return ta_save_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, output_dir)
+
+ def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
+ if use_torchacc() and self.control.should_log:
+ ta_trim_graph()
+ super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs)
+
+ def _load_from_checkpoint(self, resume_from_checkpoint: str, model=None) -> None:
+ if use_torchacc():
+ if model is None:
+ model = self.model
+ # Loading checkpoint of TorchAcc has been done in tuner.py when
+ # sft_type is 'full'.
+ if self.args.fsdp_num > 1:
+ model = model._get_underlay_model().module.module
+ if isinstance(model, PreTrainedModel):
+ return
+ return super()._load_from_checkpoint(resume_from_checkpoint, model)
diff --git a/swift/trainers/trainer_factory.py b/swift/trainers/trainer_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..87657d45d41d4606535549af69da3a9962865b6f
--- /dev/null
+++ b/swift/trainers/trainer_factory.py
@@ -0,0 +1,64 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import importlib.util
+import inspect
+from dataclasses import asdict
+from typing import Dict
+
+from swift.utils import get_logger
+
+logger = get_logger()
+
+
+class TrainerFactory:
+ TRAINER_MAPPING = {
+ 'causal_lm': 'swift.trainers.Seq2SeqTrainer',
+ 'seq_cls': 'swift.trainers.Trainer',
+ 'embedding': 'swift.trainers.EmbeddingTrainer',
+ 'dpo': 'swift.trainers.DPOTrainer',
+ 'orpo': 'swift.trainers.ORPOTrainer',
+ 'kto': 'swift.trainers.KTOTrainer',
+ 'cpo': 'swift.trainers.CPOTrainer',
+ 'rm': 'swift.trainers.RewardTrainer',
+ 'ppo': 'swift.trainers.PPOTrainer',
+ 'grpo': 'swift.trainers.GRPOTrainer'
+ }
+
+ TRAINING_ARGS_MAPPING = {
+ 'causal_lm': 'swift.trainers.Seq2SeqTrainingArguments',
+ 'seq_cls': 'swift.trainers.TrainingArguments',
+ 'embedding': 'swift.trainers.TrainingArguments',
+ 'dpo': 'swift.trainers.DPOConfig',
+ 'orpo': 'swift.trainers.ORPOConfig',
+ 'kto': 'swift.trainers.KTOConfig',
+ 'cpo': 'swift.trainers.CPOConfig',
+ 'rm': 'swift.trainers.RewardConfig',
+ 'ppo': 'swift.trainers.PPOConfig',
+ 'grpo': 'swift.trainers.GRPOConfig',
+ }
+
+ @staticmethod
+ def get_cls(args, mapping: Dict[str, str]):
+ if hasattr(args, 'rlhf_type'):
+ train_method = args.rlhf_type
+ else:
+ train_method = args.task_type
+ module_path, class_name = mapping[train_method].rsplit('.', 1)
+ module = importlib.import_module(module_path)
+ return getattr(module, class_name)
+
+ @classmethod
+ def get_trainer_cls(cls, args):
+ return cls.get_cls(args, cls.TRAINER_MAPPING)
+
+ @classmethod
+ def get_training_args(cls, args):
+ training_args_cls = cls.get_cls(args, cls.TRAINING_ARGS_MAPPING)
+ args_dict = asdict(args)
+ parameters = inspect.signature(training_args_cls).parameters
+
+ for k in list(args_dict.keys()):
+ if k not in parameters:
+ args_dict.pop(k)
+
+ args._prepare_training_args(args_dict)
+ return training_args_cls(**args_dict)
diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py
new file mode 100644
index 0000000000000000000000000000000000000000..24bd3e42826cab35f8953daecae37c515c766845
--- /dev/null
+++ b/swift/trainers/trainers.py
@@ -0,0 +1,208 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Part of the implementation is borrowed from huggingface/transformers.
+import os
+from contextlib import contextmanager, nullcontext
+from functools import wraps
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from peft import PeftModel
+from torch import nn
+from torch.nn.utils.rnn import pad_sequence
+from transformers import EvalPrediction
+from transformers import Seq2SeqTrainer as HfSeq2SeqTrainer
+from transformers import Trainer as HfTrainer
+from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
+from transformers.utils import is_peft_available
+
+from swift.utils import JsonlWriter, Serializer, gc_collect
+from .arguments import Seq2SeqTrainingArguments, TrainingArguments
+from .mixin import DataLoaderMixin, SwiftMixin
+
+
+class Trainer(SwiftMixin, HfTrainer):
+ args: TrainingArguments
+
+ @contextmanager
+ def _patch_loss_function(self):
+ model = self.model
+ if isinstance(model, PeftModel):
+ model = model.model
+ model_cls = model.__class__
+ if not hasattr(model_cls, 'loss_function'):
+ yield
+ return
+
+ loss_function = model.loss_function
+ _old_loss_function = model_cls.loss_function
+
+ @staticmethod
+ @wraps(loss_function)
+ def new_loss_function(logits, labels, **kwargs):
+ labels = labels.to(logits.device) # fix device_map
+ return loss_function(logits=logits, labels=labels, **kwargs)
+
+ model_cls.loss_function = new_loss_function
+ try:
+ yield
+ finally:
+ model_cls.loss_function = _old_loss_function
+
+ def train(self, *args, **kwargs):
+ with self._patch_loss_function():
+ return super().train(*args, **kwargs)
+
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
+ loss, outputs = super().compute_loss(model, inputs, return_outputs=True)
+ if inputs.get('labels') is not None:
+ self._compute_acc(outputs, inputs['labels'])
+ if num_items_in_batch is not None and self.model_accepts_loss_kwargs:
+ loss /= self.args.gradient_accumulation_steps
+ return (loss, outputs) if return_outputs else loss
+
+
+class EmbeddingTrainer(Trainer):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.compute_metrics = self.calculate_metric
+ self.preprocess_logits_for_metrics = None
+ self.label_names = ['labels']
+
+ def calculate_metric(self, eval_prediction: EvalPrediction) -> Dict[str, float]:
+ from swift.plugin.loss import infonce_loss, calculate_paired_metrics, calculate_infonce_metrics
+ if self.compute_loss_func is infonce_loss:
+ return calculate_infonce_metrics(eval_prediction.predictions, eval_prediction.label_ids)
+ else:
+ return calculate_paired_metrics(eval_prediction.predictions, eval_prediction.label_ids)
+
+
+class Seq2SeqTrainer(SwiftMixin, DataLoaderMixin, HfSeq2SeqTrainer):
+ args: Seq2SeqTrainingArguments
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.model_accepts_loss_kwargs = True # fix transformers>=4.46.2
+ if self.args.predict_with_generate:
+ from swift.llm import PtEngine
+ self.infer_engine = PtEngine.from_model_template(
+ self.model, self.template, max_batch_size=self.args.per_device_eval_batch_size)
+ self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'predict.jsonl'))
+
+ @staticmethod
+ def _predict_data_collator(batch):
+ return {'_data': batch}
+
+ @contextmanager
+ def _patch_predict_with_generate(self):
+ origin_mode = self.template.mode
+ self.template.set_mode('pt')
+ is_multimodal = self.model.model_meta.is_multimodal
+ origin_data_collator = self.data_collator
+
+ if is_multimodal:
+ models = self.template.remove_post_encode_hook()
+ self.data_collator = self._predict_data_collator
+ try:
+ yield
+ finally:
+ if is_multimodal:
+ self.template.register_post_encode_hook(models)
+ self.data_collator = origin_data_collator
+ self.template.set_mode(origin_mode)
+
+ def evaluate(self, *args, **kwargs):
+ context = self._patch_predict_with_generate() if self.args.predict_with_generate else nullcontext()
+ with context:
+ res = super().evaluate(*args, **kwargs)
+ gc_collect()
+ return res
+
+ def prediction_step(
+ self,
+ model: nn.Module,
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[List[str]] = None,
+ **gen_kwargs,
+ ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
+ if not self.args.predict_with_generate or prediction_loss_only:
+ return super().prediction_step(
+ model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys)
+ from swift.llm import RequestConfig, InferRequest
+ data_list = inputs['_data']
+ labels_list = [InferRequest.remove_response(data['messages']) for data in data_list]
+ resp_list = self.infer_engine.infer(
+ data_list,
+ RequestConfig(max_tokens=self.model.generation_config.max_new_tokens),
+ use_tqdm=False,
+ template=self.template)
+
+ response_list = []
+ jsonl_cache = []
+ device = self.args.device
+ for data, resp, labels in zip(data_list, resp_list, labels_list):
+ response = resp.choices[0].message.content
+ jsonl_cache.append({'response': response, 'labels': labels, **data})
+ response_list.append(Serializer.to_tensor(resp.choices[0].message.content).to(device=device))
+ self.jsonl_writer.append(jsonl_cache, gather_obj=True)
+ labels_list = [Serializer.to_tensor(labels).to(device=device) for labels in labels_list]
+ response_list = pad_sequence(response_list, batch_first=True, padding_value=0)
+ labels_list = pad_sequence(labels_list, batch_first=True, padding_value=0)
+ return None, response_list, labels_list
+
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
+ loss_kwargs = {}
+ labels = None
+ if (self.label_smoother is not None or self.compute_loss_func is not None) and 'labels' in inputs:
+ labels = inputs.pop('labels')
+
+ loss_scale = inputs.pop('loss_scale', None)
+ if loss_scale is not None:
+ loss_kwargs['loss_scale'] = loss_scale
+
+ with self.template.compute_loss_context(self.model, inputs):
+ outputs = model(**inputs)
+ # Save past state if it exists
+ # TODO: this needs to be fixed and made cleaner later.
+ if self.args.past_index >= 0:
+ self._past = outputs[self.args.past_index]
+
+ if labels is None:
+ labels = inputs['labels']
+ outputs.loss = outputs.loss.to(labels.device)
+ # fix https://github.com/huggingface/transformers/issues/34263
+ if num_items_in_batch is not None:
+ outputs.loss = outputs.loss * (labels[:, 1:] != -100).sum() / num_items_in_batch
+
+ if isinstance(outputs, dict) and 'loss' not in outputs:
+ raise ValueError(
+ 'The model did not return a loss from the inputs, only the following keys: '
+ f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}.")
+ # We don't use .loss here since the model may return tuples instead of ModelOutput.
+ loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0]
+ else:
+ unwrapped_model = self.accelerator.unwrap_model(model)
+ if is_peft_available() and isinstance(unwrapped_model, PeftModel):
+ model_name = unwrapped_model.model._get_name()
+ else:
+ model_name = unwrapped_model._get_name()
+ # User-defined compute_loss function
+ if self.compute_loss_func is not None:
+ loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch, **loss_kwargs)
+ elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
+ loss = self.label_smoother(outputs, labels, shift_labels=True)
+ else:
+ loss = self.label_smoother(outputs, labels)
+
+ if self.template.sequence_parallel_size > 1:
+ from swift.trainers.sequence_parallel import sequence_parallel
+ loss = sequence_parallel.reduce_outputs(loss, labels)
+
+ if getattr(self.args, 'average_tokens_across_devices', False) and self.model_accepts_loss_kwargs:
+ loss *= self.accelerator.num_processes
+
+ if outputs.logits is not None and labels is not None:
+ # Liger does not have logits
+ self._compute_acc(outputs, labels)
+ return (loss, outputs) if return_outputs else loss
diff --git a/swift/trainers/utils.py b/swift/trainers/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5540f9f13062a1e974d0c2ed12b71caa2d659d1f
--- /dev/null
+++ b/swift/trainers/utils.py
@@ -0,0 +1,53 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Part of the implementation is borrowed from huggingface/transformers.
+import inspect
+from types import FunctionType, MethodType
+from typing import List, Union
+
+from peft import PeftModel
+from torch.nn import Module
+
+from swift.utils import get_logger
+
+logger = get_logger()
+
+
+def can_return_loss(model: Module) -> bool:
+ """Check if a given model can return loss."""
+ if isinstance(model, PeftModel):
+ signature = inspect.signature(model.model.forward)
+ else:
+ signature = inspect.signature(model.forward)
+ for p in signature.parameters:
+ if p == 'return_loss' and signature.parameters[p].default is True:
+ return True
+ return False
+
+
+def find_labels(model: Module) -> List[str]:
+ """Find the labels used by a given model."""
+ model_name = model.__class__.__name__
+ if isinstance(model, PeftModel):
+ signature = inspect.signature(model.model.forward)
+ else:
+ signature = inspect.signature(model.forward)
+ if 'QuestionAnswering' in model_name:
+ return [p for p in signature.parameters if 'label' in p or p in ('start_positions', 'end_positions')]
+ else:
+ return [p for p in signature.parameters if 'label' in p]
+
+
+def get_function(method_or_function: Union[MethodType, FunctionType]) -> FunctionType:
+ if isinstance(method_or_function, MethodType):
+ method_or_function = method_or_function.__func__
+ return method_or_function
+
+
+def is_instance_of_ms_model(model: Module) -> bool:
+ """avoid import modelscope: circular dependency problem"""
+ for m_cls in model.__class__.__mro__:
+ cls_name = m_cls.__name__
+ cls_module = m_cls.__module__
+ if cls_name == 'Model' and cls_module.startswith('modelscope'):
+ return True
+ return False
diff --git a/swift/tuners/__init__.py b/swift/tuners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..35eb48aa897aaeb6426fd28a94cbe561927210d8
--- /dev/null
+++ b/swift/tuners/__init__.py
@@ -0,0 +1,57 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from swift.utils.import_utils import _LazyModule
+
+if TYPE_CHECKING:
+ from .adapter import Adapter, AdapterConfig, AdapterModule
+ from .base import SwiftModel, Swift
+ from .lora import LoRA, LoRAConfig
+ from .mapping import SWIFT_MAPPING, SwiftTuners
+ from .side import Side, SideConfig, SideModule
+ from .neftune import NEFTune, NEFTuneConfig
+ from .longlora.longlora import LongLoRAModelType, LongLoRAConfig, LongLoRA
+ from .restuning import ResTuning, ResTuningConfig, ResTuningBypassModule
+ from .reft import Reft, ReftConfig
+ from .llamapro import LLaMAPro, LLaMAProConfig
+ from .peft import (AdaLoraConfig, LoftQConfig, LoHaConfig, LoKrConfig, LoraConfig, VeraConfig, BOFTConfig,
+ OFTConfig, PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM,
+ PeftModelForSequenceClassification, PeftModelForTokenClassification, PrefixTuningConfig,
+ PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, get_peft_config, get_peft_model,
+ get_peft_model_state_dict)
+ from .prompt import Prompt, PromptConfig, PromptModule
+ from .scetuning.scetuning import SCETuning, SCETuningConfig
+ from .utils import SwiftConfig, SwiftOutput, swift_to_peft_format
+else:
+ _import_structure = {
+ 'adapter': ['Adapter', 'AdapterConfig', 'AdapterModule'],
+ 'base': ['SwiftModel', 'Swift'],
+ 'lora': ['LoRA', 'LoRAConfig'],
+ 'longlora.longlora': ['LongLoRAModelType', 'LongLoRAConfig', 'LongLoRA'],
+ 'mapping': ['SWIFT_MAPPING', 'SwiftTuners'],
+ 'side': ['Side', 'SideConfig', 'SideModule'],
+ 'reft': ['Reft', 'ReftConfig'],
+ 'llamapro': ['LLaMAPro', 'LLaMAProConfig'],
+ 'neftune': ['NEFTune', 'NEFTuneConfig'],
+ 'restuning': ['ResTuning', 'ResTuningConfig', 'ResTuningBypassModule'],
+ 'peft': [
+ 'AdaLoraConfig', 'LoftQConfig', 'LoHaConfig', 'LoKrConfig', 'LoraConfig', 'VeraConfig', 'BOFTConfig',
+ 'OFTConfig', 'PeftConfig', 'PeftModel', 'PeftModelForCausalLM', 'PeftModelForSeq2SeqLM',
+ 'PeftModelForSequenceClassification', 'PeftModelForTokenClassification', 'PrefixTuningConfig',
+ 'PromptEncoderConfig', 'PromptLearningConfig', 'PromptTuningConfig', 'get_peft_config', 'get_peft_model',
+ 'get_peft_model_state_dict'
+ ],
+ 'prompt': ['Prompt', 'PromptConfig', 'PromptModule'],
+ 'scetuning': ['SCETuning', 'SCETuningConfig'],
+ 'utils': ['SwiftConfig', 'SwiftOutput', 'swift_to_peft_format'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/swift/tuners/adapter.py b/swift/tuners/adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..290040b551b5e969eeb7b59bcc7dfd63536b57e3
--- /dev/null
+++ b/swift/tuners/adapter.py
@@ -0,0 +1,189 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import inspect
+import re
+import types
+from dataclasses import dataclass, field
+from typing import List, Union
+
+import torch
+from torch import nn
+from transformers.activations import ACT2CLS
+
+from swift.utils.torch_utils import find_sub_module, get_logger
+from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
+
+logger = get_logger()
+
+
+@dataclass
+class AdapterConfig(SwiftConfig):
+ """
+ The configuration class for the adapter module.
+
+ Adapters project input tokens by an MLP layer.
+ 'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019)
+ See http://arxiv.org/abs/1902.00751
+
+ Args:
+ dim(`int`): The dimension of the hidden states
+ target_modules(`Union[str, List[str]]`): The feedforward module to be replaced.
+ in regex format if this argument is str, else will match with `end with` if List[str].
+ hidden_pos(`Union[str, int]`): The position of the hidden state to be passed into the adapter,
+ can be int (args) or str (kwargs)
+ method_name(`str`): The method to be replaced, default is `forward`
+ adapter_length: The length of the adapter length (intermediate length)
+ act_layer: The activation layer of the adapter
+ """
+
+ dim: int = field(default=None, metadata={'help': 'The dimension of the hidden states'})
+
+ target_modules: Union[str, List[str]] = field(
+ default=None,
+ metadata={
+ 'help':
+ 'The feedforward module to be replaced. in regex format if this argument is str, '
+ 'else will match with `end with` if List[str].'
+ })
+
+ hidden_pos: Union[str, int] = field(
+ default=None,
+ metadata={
+ 'help': 'The position of the hidden state to be passed into the adapter, can be int (args) or str (kwargs)'
+ })
+
+ method_name: str = field(default='forward', metadata={'help': 'The method to be replaced, default is `forward`'})
+
+ adapter_length: int = field(
+ default=128, metadata={'help': 'The length of the adapter length (intermediate length)'})
+
+ act_layer: str = field(default='gelu', metadata={'help': 'The activation layer of the adapter'})
+
+ def __post_init__(self):
+ from .mapping import SwiftTuners
+ self.swift_type = SwiftTuners.ADAPTER
+
+
+class Adapter(SwiftAdapter):
+
+ @staticmethod
+ def prepare_model(model: nn.Module, config: AdapterConfig, adapter_name: str) -> SwiftOutput:
+ """Prepare a model with `AdapterConfig`"""
+ module_keys = [key for key, _ in model.named_modules()]
+
+ for module_key in module_keys:
+ if isinstance(config.target_modules, str):
+ target_module_found = re.fullmatch(config.target_modules, module_key)
+ else:
+ target_module_found = any(module_key.endswith(target_key) for target_key in config.target_modules)
+
+ if target_module_found: # noqa
+ module = model.get_submodule(module_key)
+
+ def _forward(self, *args, **kwargs):
+ args = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
+ if isinstance(args, (tuple, list, dict)):
+ if isinstance(config.hidden_pos, int):
+ _type = type(args)
+ args = list(args)
+ args[config.hidden_pos] = getattr(self, f'adapter_{adapter_name}')(args[config.hidden_pos])
+ args = _type(args)
+ else:
+ args[config.hidden_pos] = getattr(self, f'adapter_{adapter_name}')(args[config.hidden_pos])
+ elif isinstance(args, torch.Tensor):
+ args = getattr(self, f'adapter_{adapter_name}')(args)
+ return args
+
+ def _feed_forward_chunk(self, attention_output):
+ return _forward(self, attention_output)
+
+ # TODO The `config.method_name` method should not be replaced twice.
+
+ setattr(module, f'forward_origin_{adapter_name}', getattr(module, config.method_name))
+ num_args_in_forward_chunk_fn = len(
+ inspect.signature(getattr(module, f'forward_origin_{adapter_name}')).parameters)
+ if config.method_name == 'feed_forward_chunk' and num_args_in_forward_chunk_fn == 1:
+ setattr(module, config.method_name, types.MethodType(_feed_forward_chunk, module))
+ else:
+ setattr(module, config.method_name, types.MethodType(_forward, module))
+ adapter_module = AdapterModule(config.dim, adapter_name, module_key, config.adapter_length,
+ ACT2CLS[config.act_layer])
+ setattr(module, f'adapter_{adapter_name}', adapter_module)
+ logger.info(f'Adapter modules(module_key): {module_key}.adapter_{adapter_name}')
+
+ def state_dict_callback(state_dict, adapter_name: str, **kwargs):
+ return {key: value for key, value in state_dict.items() if f'adapter_{adapter_name}' in key}
+
+ def mark_trainable_callback(model):
+ return
+
+ return SwiftOutput(
+ config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
+
+ @staticmethod
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
+ modules = find_sub_module(module, f'adapter_{adapter_name}')
+ for _module in modules:
+ _module: ActivationMixin
+ _module: nn.Module
+ _module.set_activation(adapter_name, activate)
+ SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
+
+
+class AdapterModule(nn.Module, ActivationMixin):
+ """The implementation of adapter tuning method.
+
+ Adapters project input tokens by an MLP layer.
+ 'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019)
+ See http://arxiv.org/abs/1902.00751
+
+ Args:
+ dim: An integer indicating the embedding dimension.
+ adapter_length: An integer indicating the length of adapter tuning.
+ """
+
+ def __init__(
+ self,
+ dim,
+ adapter_name,
+ module_key,
+ adapter_length=None,
+ act_layer=nn.GELU,
+ ):
+ super(AdapterModule, self).__init__()
+ super(nn.Module, self).__init__(module_key)
+ self.dim = dim
+ self.adapter_name = adapter_name
+ self.adapter_length = adapter_length
+ self.linear1 = nn.Linear(dim, adapter_length)
+ self.act = act_layer()
+ self.linear2 = nn.Linear(adapter_length, dim)
+ self.init_weights()
+ self._prepared = False
+ self.mark_all_sub_modules_as_plugin()
+
+ def init_weights(self):
+
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ nn.init.normal_(m.bias, std=1e-6)
+
+ self.apply(_init_weights)
+
+ def forward(self, x, identity=None):
+ if not self.is_activated(self.adapter_name):
+ return x
+ if not self._prepared:
+ self.linear1.to(x.device)
+ self.act.to(x.device)
+ self.linear2.to(x.device)
+ self._prepared = True
+
+ x_dtype = x.dtype
+ x = x.to(self.linear1.weight.dtype)
+ out = self.linear2(self.act(self.linear1(x)))
+ if identity is None:
+ identity = x
+ identity = identity.to(out.dtype)
+ out = identity + out
+ return out.to(x_dtype)
diff --git a/swift/tuners/base.py b/swift/tuners/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..fafc0883abce55d975055352ade4d9f5b3cbdd58
--- /dev/null
+++ b/swift/tuners/base.py
@@ -0,0 +1,926 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Copyright 2023-present the HuggingFace Inc. team.
+import os
+import re
+import shutil
+import tempfile
+from contextlib import contextmanager
+from copy import copy
+from functools import partial
+from inspect import Parameter, Signature, signature
+from types import MethodType
+from typing import Dict, List, Literal, Optional, Union
+
+import json
+import torch
+from modelscope import snapshot_download
+from peft.utils import CONFIG_NAME
+from peft.utils.other import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
+from torch import nn
+from transformers import Trainer
+
+from swift.utils.constants import DEFAULT_ADAPTER, SWIFT_TYPE_KEY
+from swift.utils.logger import get_logger
+from ..utils.torch_utils import get_device_count
+from .mapping import SwiftTuners
+from .peft import PeftConfig, PeftModel, get_peft_model
+from .utils import SwiftConfig, SwiftOutput
+
+logger = get_logger()
+
+
+class SwiftModel(nn.Module):
+ """The Swift wrapper model.
+
+ Args:
+ model (`Union[nn.Module, 'SwiftModel']`) A module to be tuned by Swift.
+ config (`Union[SwiftConfig, Dict[str, SwiftConfig]]`) A config or a dict of {adapter_name: SwiftConfig}.
+ If it's a config class, the adapter_name will be `default`
+ extra_state_keys (`List[str]`, `optional`) A list of regex to match the extra state keys to be saved.
+ inference_mode (bool, `optional`): Load model at inference mode, default False.
+ """
+
+ EXTRA_STATE_DIR = 'extra_states'
+
+ def __init__(self,
+ model: Union[nn.Module, 'SwiftModel'],
+ config: Union[SwiftConfig, Dict[str, SwiftConfig]],
+ extra_state_keys: List[str] = None,
+ inference_mode: bool = False,
+ **kwargs):
+ super().__init__()
+ self.adapters = {}
+ self.active_adapters = set()
+ if isinstance(model, SwiftModel):
+ self.adapters = model.adapters
+ extra_state_keys = extra_state_keys or []
+ extra_state_keys.extend(model.extra_state_keys)
+ self.active_adapters = model.active_adapters
+ model = model.base_model
+
+ self.base_model = model
+ new_adapters = []
+ if isinstance(config, SwiftConfig):
+ if DEFAULT_ADAPTER not in self.adapters:
+ all_parts = self._deactivate_all_parts()
+ self.adapters[DEFAULT_ADAPTER] = self._prepare_model(model, config, DEFAULT_ADAPTER)
+ for part in all_parts:
+ self.activate_adapter(part)
+ new_adapters.append(DEFAULT_ADAPTER)
+ if self.adapters[DEFAULT_ADAPTER].model is not None:
+ self.base_model = self.adapters[DEFAULT_ADAPTER].model
+ else:
+ logger.warn(f'Adapter {DEFAULT_ADAPTER} has been patched, skip.')
+ elif isinstance(config, dict):
+ assert (all(isinstance(c, SwiftConfig) for c in config.values()))
+ for adapter_name, _config in config.items():
+ if adapter_name not in self.adapters:
+ all_parts = self._deactivate_all_parts()
+ self.adapters[adapter_name] = self._prepare_model(model, _config, adapter_name)
+ for part in all_parts:
+ self.activate_adapter(part)
+ new_adapters.append(adapter_name)
+ if self.adapters[adapter_name].model is not None:
+ self.base_model = self.adapters[adapter_name].model
+ else:
+ logger.warn(f'Adapter {adapter_name} has been patched, skip.')
+
+ self.extra_state_keys = extra_state_keys or []
+ self.has_additional_modules = any([c.config.has_additional_modules for c in self.adapters.values()])
+
+ def forward(self, *args, **kwargs):
+ return self.base_model(*args, **kwargs)
+
+ _parameters = [Parameter('self', Parameter.POSITIONAL_ONLY)]
+ _parameters += list(signature(self.base_model.forward).parameters.values())
+ forward.__signature__ = Signature(_parameters)
+ self.forward = MethodType(forward, self)
+ for adapter_name in new_adapters:
+ self.activate_adapter(adapter_name)
+
+ if inference_mode:
+ self.eval()
+ else:
+ for key, output in self.adapters.items():
+ if key in new_adapters:
+ output.mark_trainable_callback(model)
+ if self.extra_state_keys:
+ for n, p in model.named_parameters():
+ if any(re.fullmatch(extra_key, n) for extra_key in self.extra_state_keys):
+ p.requires_grad = True
+
+ @property
+ def model(self):
+ return self.base_model
+
+ def _deactivate_all_parts(self):
+ deactivated = []
+ for adapter in self.active_adapters:
+ output = self.adapters[adapter]
+ if output.config.swift_type == SwiftTuners.PART:
+ deactivated.append(adapter)
+ self.deactivate_adapter(adapter)
+ return deactivated
+
+ def load_state_dict(self, state_dict, strict=True, adapter_name: str = None):
+ if adapter_name is not None:
+ output: SwiftOutput = self.adapters[adapter_name]
+ if getattr(output.config, 'modules_to_save', None):
+ for key, value in copy(state_dict).items():
+ for module_name in output.config.modules_to_save:
+ if module_name in key:
+ state_dict.pop(key)
+ key = key.replace(module_name, f'{module_name}.modules_to_save.{adapter_name}')
+ break
+ state_dict[key] = value
+
+ for key, value in copy(state_dict).items():
+ if key.startswith('base_model.model.'):
+ state_dict.pop(key, None)
+ key = key[len('base_model.model.'):]
+ if f'lora_A.{adapter_name}.' not in key and 'lora_A' in key:
+ state_dict.pop(key, None)
+ key = key.replace('lora_A.', f'lora_A.{adapter_name}.')
+ if f'lora_B.{adapter_name}.' not in key and 'lora_B' in key:
+ state_dict.pop(key, None)
+ key = key.replace('lora_B.', f'lora_B.{adapter_name}.')
+ if f'lora_embedding_A.{adapter_name}.' not in key and 'lora_embedding_A' in key:
+ state_dict.pop(key, None)
+ key = key.replace('lora_embedding_A.', f'lora_embedding_A.{adapter_name}.')
+ if f'lora_embedding_B.{adapter_name}.' not in key and 'lora_embedding_B' in key:
+ state_dict.pop(key, None)
+ key = key.replace('lora_embedding_B.', f'lora_embedding_B.{adapter_name}.')
+ state_dict[key] = value
+
+ if output.load_state_dict_callback:
+ state_dict = output.load_state_dict_callback(self.base_model, adapter_name, state_dict)
+
+ incompatible_keys = self.base_model.load_state_dict(state_dict, False)
+ if incompatible_keys and len(incompatible_keys[1]) > 0:
+ logger.error(f'Load state dict with unexpected keys: {incompatible_keys[1]}')
+
+ def state_dict(self,
+ *args,
+ destination=None,
+ prefix='',
+ keep_vars=False,
+ adapter_name: str = None,
+ peft_format: bool = False,
+ **kwargs):
+ """
+ Args:
+ destination (`dict`, `optional`): If provided, the state of module will
+ be updated into the dict and the same object is returned.
+ Otherwise, an ``OrderedDict`` will be created and returned.
+ Default: ``None``.
+ prefix (`str`, `optional`): a prefix added to parameter and buffer
+ names to compose the keys in state_dict. Default: ``''``.
+ keep_vars (`bool`, `optional`): by default the :class:`~torch.Tensor` s
+ returned in the state dict are detached from autograd. If it's
+ set to ``True``, detaching will not be performed.
+ Default: ``False``.
+ adapter_name (`str`, `optional`): The name of the adapter's parameters to be saved,
+ `None` input will save all adapters.
+ peft_format (`bool`, `optional`): Save with peft format (extra `base_model.model.` prefix)
+ **kwargs:
+ save_adapter(`bool`): Save adapters or not, default True
+ save_extra_states(`bool`): Save extra states or not, default True
+ Returns:
+ The state dict to be saved.
+ """
+ state_dict = kwargs.get('state_dict')
+ if state_dict is None:
+ state_dict = self.base_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
+ state_dict = {
+ key[len('base_model.'):] if key.startswith('base_model.') else key: value
+ for key, value in state_dict.items()
+ }
+ if not self.has_additional_modules:
+ return state_dict
+
+ state_dicts = {}
+ if kwargs.get('save_adapter', True):
+ for name, output in self.adapters.items():
+ if (adapter_name == name or adapter_name is None) and output.config.has_additional_modules: # noqa
+ state_dicts.update(output.state_dict_callback(state_dict, name))
+ modules_to_save_names = [
+ sub_name for sub_name, _ in self.base_model.named_parameters()
+ if f'modules_to_save.{name}' in sub_name
+ ]
+ for module_name in modules_to_save_names:
+ if f'modules_to_save.{name}' in module_name:
+ state_dicts[module_name.replace(f'modules_to_save.{name}.', '')] = state_dict[module_name]
+ if kwargs.get('save_extra_states', True):
+ state_dicts.update({
+ k: v
+ for k, v in state_dict.items() if any(
+ re.fullmatch(extra_key, k) for extra_key in self.extra_state_keys)
+ })
+ if peft_format:
+ new_state_dict = {}
+ for key, value in state_dicts.items():
+ if not key.startswith('base_model.model.'):
+ key = 'base_model.model.' + key
+ key = key.replace(f'lora_A.{adapter_name}.', 'lora_A.')
+ key = key.replace(f'lora_B.{adapter_name}.', 'lora_B.')
+ key = key.replace(f'lora_embedding_A.{adapter_name}.', 'lora_embedding_A.')
+ key = key.replace(f'lora_embedding_B.{adapter_name}.', 'lora_embedding_B.')
+ new_state_dict[key] = value
+ state_dicts = new_state_dict
+ return state_dicts
+
+ def __getattr__(self, key: str):
+ """Forward missing attributes to the wrapped module."""
+ try:
+ return super().__getattr__(key)
+ except AttributeError:
+ if 'base_model' in dir(self):
+ return getattr(self.base_model, key)
+ raise
+
+ @staticmethod
+ def load_state_file(path, device: Optional[str] = None):
+ """Load a state dict file by the input path.
+
+ Args:
+ path: The local dir to load the state file.
+
+ Returns:
+ The state dict.
+ """
+ if device is None:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)):
+ filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
+ from safetensors.torch import load_file as safe_load_file
+ return safe_load_file(filename, device=device)
+ elif os.path.exists(os.path.join(path, WEIGHTS_NAME)):
+ filename = os.path.join(path, WEIGHTS_NAME)
+ return torch.load(filename, map_location=device)
+ return None
+
+ def create_optimizer_param_groups(self, **defaults):
+ all_param_names = set()
+ param_groups = []
+ for output in self.adapters.values():
+ if output.optimizer_group_callback:
+ param_names, param_group = output.optimizer_group_callback(self.model, **defaults)
+ if param_names and all_param_names & param_names:
+ raise ValueError('Cannot set one parameter to different param groups')
+ if param_names and param_group:
+ all_param_names.update(param_names)
+ param_groups.extend(param_group)
+
+ decay_parameters = Trainer.get_decay_parameter_names(None, self.model)
+ param_groups.extend([
+ {
+ 'params': [
+ p for n, p in self.model.named_parameters()
+ if (n in decay_parameters and n not in all_param_names and p.requires_grad)
+ ],
+ 'weight_decay':
+ defaults['weight_decay'],
+ },
+ {
+ 'params': [
+ p for n, p in self.model.named_parameters()
+ if (n not in decay_parameters and n not in all_param_names and p.requires_grad)
+ ],
+ 'weight_decay':
+ 0.0,
+ },
+ ])
+
+ return param_groups
+
+ @classmethod
+ def from_pretrained(cls,
+ model: Union[nn.Module, 'SwiftModel'],
+ model_id: str = None,
+ adapter_name: Union[str, List[str], Dict[str, str]] = None,
+ inference_mode: bool = True,
+ revision: str = None,
+ **kwargs):
+ """Load a set of tuners and corresponding weights by a model_id.
+
+ Args:
+ model (`Union[torch.nn.Module, 'SwiftModel']`): The model to be tuned,
+ if the model is already a `SwiftModel` it will be un-wrapped and re-wrapped..
+ model_id (`str`): The model_id or a local model dir of tuners to use to tune the model.
+ adapter_name (`Union[str, List[str], Dict[str, str]]`): The adapter_names saved in the model repo to load.
+ Default `None`, means load all tuners saved in the model_id
+ inference_mode (`bool`): Use in the inference mode or not.
+ revision (`str`): The model revision to use.
+ **kwargs:
+ extra_state_keys (`List[str]`, `optional`) A list of regex to match the extra state keys to be saved.
+ Other parameters will be passed to the device_map.
+ Returns:
+ The `SwiftModel` instance.
+ """
+ adapters = {}
+ model_dir = model_id
+ if not os.path.exists(model_dir):
+ model_dir = snapshot_download(model_dir, revision=revision)
+ if os.path.isfile(model_dir):
+ raise ValueError(f'Please pass in a local dir or a model id, not a local file: {model_dir}')
+ extra_state_keys = kwargs.pop('extra_state_keys', None)
+ if extra_state_keys is None and os.path.isfile(os.path.join(model_dir, cls.EXTRA_STATE_DIR, CONFIG_NAME)):
+ with open(os.path.join(model_dir, cls.EXTRA_STATE_DIR, CONFIG_NAME), 'r', encoding='utf-8') as file:
+ _json = json.load(file)
+ extra_state_keys = _json.get('extra_state_keys')
+ if adapter_name is None:
+ adapter_name = [
+ sub_dir for sub_dir in os.listdir(model_dir)
+ if os.path.isfile(os.path.join(model_dir, sub_dir, CONFIG_NAME)) and sub_dir != cls.EXTRA_STATE_DIR
+ ]
+ for _name in adapter_name if isinstance(adapter_name,
+ list) else [adapter_name] \
+ if isinstance(adapter_name, str) else adapter_name.keys():
+ sub_folder = os.path.join(model_dir, _name)
+ config_file = os.path.join(sub_folder, CONFIG_NAME)
+
+ if not os.path.isfile(config_file):
+ logger.warning(f'{_name} is not a valid tuner')
+ continue
+
+ with open(config_file, 'r', encoding='utf-8') as file:
+ json_object = json.load(file)
+
+ if SWIFT_TYPE_KEY not in json_object:
+ raise ValueError('Mixed using with peft is not allowed now.')
+ else:
+ key = _name if not isinstance(adapter_name, dict) else adapter_name[_name]
+ adapters[key] = SwiftConfig.from_pretrained(sub_folder)
+
+ self = SwiftModel(model, adapters, extra_state_keys, inference_mode, **kwargs)
+ for _name in adapter_name if isinstance(adapter_name,
+ list) else [adapter_name] \
+ if isinstance(adapter_name, str) else adapter_name.keys():
+ _adapter = _name if not isinstance(adapter_name, dict) else adapter_name[_name]
+ output: SwiftOutput = self.adapters[_adapter]
+ sub_folder = os.path.join(model_dir, _name)
+ if output.load_callback:
+ output.load_callback(self, sub_folder, _adapter)
+ continue
+ state_dict = cls.load_state_file(sub_folder)
+ if state_dict is not None:
+ if isinstance(adapter_name, dict):
+ # TODO this logic is fragile! replace `_name` may cause other parts replaced
+ state_dict = {key.replace(_name, adapter_name[_name]): value for key, value in state_dict.items()}
+ self.load_state_dict(state_dict, adapter_name=_adapter)
+ state_dict = cls.load_state_file(os.path.join(model_dir, self.EXTRA_STATE_DIR))
+ if state_dict is not None:
+ self.load_state_dict(state_dict)
+ return self
+
+ @classmethod
+ def _prepare_model(
+ cls,
+ model: nn.Module,
+ config: SwiftConfig,
+ adapter_name: str,
+ ):
+ assert (hasattr(config, SWIFT_TYPE_KEY))
+ from .mapping import SWIFT_MAPPING
+
+ adapter_cls = SWIFT_MAPPING[config.swift_type][1]
+ if adapter_cls.has_additional_modules() and not getattr(model, 'model_frozen', False):
+ for _, p in model.named_parameters():
+ p.requires_grad = False
+ model.model_frozen = True
+ config.has_additional_modules = adapter_cls.has_additional_modules()
+ return adapter_cls.prepare_model(model, config, adapter_name)
+
+ def create_or_update_model_card(self, output_dir: str):
+ """
+ Updates or create the model card.
+ """
+ if not os.path.exists(os.path.join(output_dir, 'README.md')):
+ lines = []
+ else:
+ with open(os.path.join(output_dir, 'README.md'), 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+
+ quantization_config = None
+ if hasattr(self.base_model, 'config') and hasattr(self.base_model.config, 'quantization_config'):
+ if hasattr(self.base_model.config.quantization_config, 'to_dict'):
+ quantization_config = self.base_model.config.quantization_config.to_dict()
+ training_config_text = ''
+ # Adds quantization information if it was used
+ if quantization_config is not None:
+ training_config_text += '\nThe following `bitsandbytes` quantization config was used during training:\n'
+ training_config_text += '\n'.join([f'- {name}: {value}' for name, value in quantization_config.items()])
+ training_config_text += '\n'
+
+ training_procedure_heading = '## Training procedure\n'
+ if training_procedure_heading in lines:
+ lines.insert(lines.index(training_procedure_heading) + 2, training_config_text)
+ else:
+ lines.append(f'{training_procedure_heading}\n{training_config_text}')
+
+ framework_block_heading = '### Framework versions\n'
+ from swift.version import __version__
+ if framework_block_heading in lines:
+ lines.insert(lines.index(framework_block_heading) + 2, f'- SWIFT {__version__}\n')
+ else:
+ lines.append(f'{framework_block_heading}\n\n- SWIFT {__version__}\n')
+
+ base_model_heading = '### Base model information\n'
+ lines.append(f'{base_model_heading}\n\n- BaseModel Class {self.base_model.__class__.__name__}\n')
+
+ # write the lines back to README.md
+ with open(os.path.join(output_dir, 'README.md'), 'w', encoding='utf-8') as f:
+ f.writelines(lines)
+
+ def add_weighted_adapter(
+ self,
+ adapters,
+ weights,
+ adapter_name,
+ combination_type='svd',
+ svd_rank=None,
+ svd_clamp=None,
+ svd_full_matrices=True,
+ svd_driver=None,
+ density=None,
+ majority_sign_method: Literal['total', 'frequency'] = 'total',
+ ):
+ """
+ This method adds a new adapter by merging the given adapters with the given weights.
+
+ When using the `cat` combination_type you should be aware that rank of the resulting adapter will be equal to
+ the sum of all adapters ranks. So it's possible that the mixed adapter may become too big and result in OOM
+ errors.
+
+ Args:
+ adapters (`list`):
+ List of adapter names to be merged.
+ weights (`list`):
+ List of weights for each adapter.
+ adapter_name (`str`):
+ Name of the new adapter.
+ combination_type (`str`):
+ The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`,
+ `dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat`
+ combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the
+ mixed adapter may be too big and result in OOM errors).
+ svd_rank (`int`, *optional*):
+ Rank of output adapter for svd. If None provided, will use max rank of merging adapters.
+ svd_clamp (`float`, *optional*):
+ A quantile threshold for clamping SVD decomposition output. If None is provided, do not perform
+ clamping. Defaults to None.
+ svd_full_matrices (`bool`, *optional*):
+ Controls whether to compute the full or reduced SVD, and consequently, the shape of the returned
+ tensors U and Vh. Defaults to True.
+ svd_driver (`str`, *optional*):
+ Name of the cuSOLVER method to be used. This keyword argument only works when merging on CUDA. Can be
+ one of [None, `gesvd`, `gesvdj`, `gesvda`]. For more info please refer to `torch.linalg.svd`
+ documentation. Defaults to None.
+ density (`float`, *optional*):
+ Value between 0 and 1. 0 means all values are pruned and 1 means no values are pruned. Should be used
+ with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`,
+ `magnintude_prune`, `magnitude_prune_svd`]
+ majority_sign_method (`str`):
+ The method, should be one of ["total", "frequency"], to use to get the magnitude of the sign values.
+ Should be used with [`ties`, `ties_svd`, `dare_ties`, `dare_ties_svd`]
+ """
+ from swift.tuners.lora import LoraModel
+ lora_model = LoraModel(self.model, None, '')
+ lora_model.peft_config = {key: value.config for key, value in self.adapters.items()}
+ from peft.tuners.lora import LoraLayer
+ lora_model.targeted_module_names = [
+ key for key, value in self.model.named_modules() if isinstance(value, LoraLayer)
+ ]
+ lora_model.active_adapter = self.active_adapters
+ lora_model.add_weighted_adapter(
+ adapters=adapters,
+ weights=weights,
+ adapter_name=adapter_name,
+ combination_type=combination_type,
+ svd_rank=svd_rank,
+ svd_clamp=svd_clamp,
+ svd_full_matrices=svd_full_matrices,
+ svd_driver=svd_driver,
+ density=density,
+ majority_sign_method=majority_sign_method,
+ )
+
+ def state_dict_callback(state_dict, adapter_name, cfg):
+ from swift.tuners.lora_layers import lora_state_dict
+ return lora_state_dict(state_dict, adapter_name, cfg.bias)
+
+ def mark_trainable_callback(model, cfg):
+ from swift.tuners.lora_layers import mark_lora_as_trainable
+ mark_lora_as_trainable(model, adapter_name, cfg.bias)
+
+ cfg = lora_model.peft_config[adapter_name]
+ cfg.has_additional_modules = True
+ self.adapters[adapter_name] = SwiftOutput(
+ config=cfg,
+ state_dict_callback=partial(state_dict_callback, cfg=cfg),
+ mark_trainable_callback=partial(mark_trainable_callback, cfg=cfg),
+ optimizer_group_callback=None,
+ )
+
+ self.set_active_adapters(adapter_name)
+
+ def save_pretrained(self,
+ save_directory: str,
+ safe_serialization: bool = False,
+ adapter_name: Union[str, List[str]] = None,
+ **kwargs):
+ """Save the adapters to a local directory.
+
+ Args:
+ save_directory (`str`): The directory to use.
+ safe_serialization (`bool`): Use safe tensors to save the weights, default False.
+ adapter_name(`Union[str, List[str]]`): The adapters to be saved, default is `None` to save all.
+ """
+ peft_format = kwargs.pop('peft_format', False)
+ if os.path.isfile(save_directory):
+ raise ValueError(f'Provided path ({save_directory}) should be a directory, not a file')
+ os.makedirs(save_directory, exist_ok=True)
+ if not self.has_additional_modules:
+ if hasattr(self.base_model, 'save_pretrained'):
+ self.base_model.save_pretrained(save_directory, safe_serialization=safe_serialization)
+ else:
+ self._save_state_dict(self.base_model.state_dict(), save_directory, safe_serialization)
+ self.create_or_update_model_card(save_directory)
+ else:
+ self.create_or_update_model_card(save_directory)
+
+ adapter_names = adapter_name if isinstance(adapter_name, list) or adapter_name is None else [adapter_name]
+
+ state_dict_kwargs = {}
+ state_dict = kwargs.get('state_dict')
+ if state_dict is not None:
+ state_dict_kwargs['state_dict'] = kwargs['state_dict']
+ for adapter_name, output in self.adapters.items():
+ if adapter_names is not None and adapter_name not in adapter_names:
+ continue
+
+ save_to_peft = peft_format and output.config.swift_type == SwiftTuners.LORA
+ save_to_peft = save_to_peft and output.config.can_be_saved_to_peft()
+ if peft_format and not save_to_peft:
+ logger.error('You are using additional lora parameters, which is not compatible with peft,'
+ 'which is unable to save to peft format.')
+ output_dir = os.path.join(save_directory,
+ adapter_name) if adapter_name != 'default' or not save_to_peft else save_directory
+
+ if save_to_peft:
+ config = output.config.to_peft_config()
+ config.save_pretrained(output_dir)
+ else:
+ output.config.save_pretrained(output_dir)
+
+ if output.save_callback:
+ output.save_callback(self, output_dir, adapter_name)
+ continue
+
+ # save only the trainable weights
+ output_state_dict = self.state_dict(
+ adapter_name=adapter_name, save_extra_states=False, peft_format=save_to_peft, **state_dict_kwargs)
+ os.makedirs(output_dir, exist_ok=True)
+ if output_state_dict and output.config.has_additional_modules:
+ self._save_state_dict(output_state_dict, output_dir, safe_serialization)
+
+ output_state_dict = self.state_dict(save_extra_states=True, save_adapter=False, **state_dict_kwargs)
+ if len(output_state_dict) > 0:
+ if self.has_additional_modules:
+ os.makedirs(os.path.join(save_directory, self.EXTRA_STATE_DIR), exist_ok=True)
+ self._save_state_dict(output_state_dict, os.path.join(save_directory, self.EXTRA_STATE_DIR),
+ safe_serialization)
+ with open(
+ os.path.join(save_directory, self.EXTRA_STATE_DIR, CONFIG_NAME), 'w', encoding='utf-8') as file:
+ json.dump({'extra_state_keys': self.extra_state_keys}, file)
+ else:
+ logger.error('Full parameter training, save_extra_states will be ignored')
+
+ if not os.path.exists(os.path.join(save_directory, 'configuration.json')):
+ with open(os.path.join(save_directory, 'configuration.json'), 'w', encoding='utf-8') as f:
+ f.write('{}')
+
+ @staticmethod
+ def _save_state_dict(output_state_dict, save_directory, safe_serialization):
+ if safe_serialization:
+ from safetensors.torch import save_file as safe_save_file
+ safe_save_file(
+ output_state_dict, os.path.join(save_directory, SAFETENSORS_WEIGHTS_NAME), metadata={'format': 'pt'})
+ else:
+ torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME))
+
+ @contextmanager
+ def disable_adapter(self):
+ try:
+ self.set_active_adapters(adapter_names=[])
+ yield
+ finally:
+ self.set_active_adapters(adapter_names=self.adapters.keys())
+
+ def set_active_adapters(self, adapter_names: Union[List[str], str], offload: str = None):
+ """Set activated adapters
+
+ Args:
+ adapter_names(`Union[List[str], str]`): The adapters needed to be activated
+ offload(`str`): Whether to offload the deactivated ones to `cpu` or `meta` device
+ """
+ if not adapter_names:
+ adapter_names = []
+
+ if isinstance(adapter_names, str):
+ adapter_names = [adapter_names]
+
+ adapter_names = set(adapter_names)
+ for adapter_name in (adapter_names & set(self.adapters.keys())):
+ self.activate_adapter(adapter_name)
+
+ for adapter_name in (set(self.adapters.keys()) - adapter_names):
+ self.deactivate_adapter(adapter_name, offload)
+
+ self.active_adapters = (adapter_names & set(self.adapters.keys()))
+
+ def activate_adapter(self, adapter_name: str):
+ """Activate one adapter
+
+ Args:
+ adapter_name(`str`): The adapter needed to be activated
+ """
+ if adapter_name not in self.adapters:
+ logger.warning(f'{adapter_name} not in adapters: {self.adapters.keys()}')
+ return
+
+ from .mapping import SWIFT_MAPPING
+ SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\
+ .activate_adapter(self.base_model, adapter_name, True)
+ self.active_adapters = self.active_adapters | {adapter_name}
+
+ def deactivate_adapter(self, adapter_name: str, offload: str = None):
+ """Deactivate one adapter
+
+ Args:
+ adapter_name(`str`): The adapter needed to be activated
+ offload(`str`): Whether to offload to `cpu` or `meta` device
+ """
+ if adapter_name not in self.adapters:
+ logger.warning(f'{adapter_name} not in adapters: {self.adapters.keys()}')
+ return
+
+ from .mapping import SWIFT_MAPPING
+ SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\
+ .activate_adapter(self.base_model, adapter_name, False, offload=offload)
+ self.active_adapters = self.active_adapters - {adapter_name}
+
+ def get_trainable_parameters(self):
+ """
+ Get the content of trainable parameters in the model.
+ """
+ trainable_params = 0
+ all_param = 0
+ for _, param in self.base_model.named_parameters():
+ num_params = param.numel()
+ # if using DS Zero 3 and the weights are initialized empty
+ if num_params == 0 and hasattr(param, 'ds_numel'):
+ num_params = param.ds_numel
+
+ all_param += num_params
+ if param.requires_grad:
+ trainable_params += num_params
+ return f'trainable params: {trainable_params:,d} || all params: {all_param:,d} ' \
+ f'|| trainable%: {100 * trainable_params / all_param:.4f}' \
+ '|| cuda memory: ' \
+ f'{sum([torch.cuda.memory_allocated(i) for i in range(get_device_count())])/1024/1024/1024:.2f}' \
+ 'GiB.'
+
+
+class Swift:
+ """The Wrapper to use both Peft and Swift tuners."""
+
+ @staticmethod
+ def prepare_model(model: Union[nn.Module, SwiftModel], config: Union[SwiftConfig, PeftConfig,
+ Dict[str, SwiftConfig]], **kwargs):
+ """Prepare a model by the input config.
+
+ Args:
+ model(`Union[nn.Module, 'SwiftModel']`): The model to be tuned.
+ config(`Union[SwiftConfig, PeftConfig, Dict[str, SwiftConfig]]`): The config or config dict, can be either
+ SwiftConfigs or PeftConfigs
+ **kwargs:
+ Extra kwargs needed by SwiftModel or PeftModel.
+ Returns:
+ The model wrapped by SwiftModel or PeftModel.
+ """
+
+ if isinstance(config, (SwiftConfig, dict)):
+ return SwiftModel(model, config, **kwargs)
+ else:
+ return get_peft_model(model, config, **kwargs)
+
+ @staticmethod
+ def merge_and_unload(model: Union[PeftModel, SwiftModel], **kwargs):
+ """Merge tuners into the base model and unload them.
+
+ Args:
+ model(`Union[PeftModel, SwiftModel]`): The model instance with tuners
+ kwargs:
+ adapter_name(`Union[str, List[str]]`): The adapter_name to unload, only supported in swift tuners.
+
+ """
+ from peft import PeftModel as _PeftModel
+ if isinstance(model, _PeftModel):
+ model.merge_and_unload()
+ elif isinstance(model, SwiftModel):
+ from swift import LoRAConfig
+ from swift.tuners import LoRA
+ adapter_name = kwargs.get('adapter_name', None)
+ if isinstance(adapter_name, str):
+ adapter_name = [adapter_name]
+ for adapter, output in model.adapters.items():
+ if isinstance(output.config, LoRAConfig) and (adapter_name is None or adapter in adapter_name):
+ LoRA.unpatch_lora(model, output.config, adapter)
+
+ @staticmethod
+ @contextmanager
+ def grpo_context(model: Union[SwiftModel, torch.nn.Module], processor):
+ # Save the model and temporarily modify model.model_dir.
+ if not isinstance(model, SwiftModel):
+ yield
+ return
+ else:
+ assert len(model.adapters) == 1
+ adapter = list(model.adapters.values())[0]
+ if adapter.config.swift_type == SwiftTuners.LLAMAPRO:
+ from modelscope.hub.utils.utils import get_cache_dir
+ temp_dir = tempfile.mkdtemp(dir=get_cache_dir())
+ model_dir = model.model_dir
+ from transformers.integrations import is_deepspeed_zero3_enabled
+ if is_deepspeed_zero3_enabled():
+ raise ValueError('DeepSpeed ZeRO3 not supported for LLaMAPro&GRPO currently.')
+ model.base_model.save_pretrained(temp_dir)
+ processor.save_pretrained(temp_dir)
+ model.model_dir = temp_dir
+ yield
+ if adapter.config.swift_type == SwiftTuners.LLAMAPRO:
+ model.model_dir = model_dir
+ shutil.rmtree(temp_dir)
+
+ @staticmethod
+ def merge(model: Union[PeftModel, SwiftModel], **kwargs):
+ """Merge tuners into the base model, will not unload them.
+
+ Args:
+ model(`Union[PeftModel, SwiftModel]`): The model instance with tuners
+ """
+ from .lora_layers import LoraLayer, LoRALayer
+ for sub_module in model.modules():
+ if isinstance(sub_module, (LoraLayer, LoRALayer)):
+ sub_module.merge(**kwargs)
+
+ @staticmethod
+ def unmerge(model: Union[PeftModel, SwiftModel], **kwargs):
+ """Unmerge tuners from the base model
+
+ Args:
+ model(`Union[PeftModel, SwiftModel]`): The model instance with tuners
+ """
+ from .lora_layers import LoraLayer, LoRALayer
+ for sub_module in model.modules():
+ if isinstance(sub_module, (LoraLayer, LoRALayer)):
+ sub_module.unmerge(**kwargs)
+
+ @staticmethod
+ def save_to_peft_format(ckpt_dir: str, output_dir: str) -> None:
+ """Save swift format to peft format
+
+ Args:
+ ckpt_dir(`str`): Original swift output dir
+ output_dir(`str`): Converted peft format dir
+ """
+ assert ckpt_dir and output_dir, 'Please pass in valid ckpt_dir and output_dir.'
+ assert os.path.exists(ckpt_dir), f'ckpt_dir: {ckpt_dir} must exists in local disk.'
+ if os.path.exists(os.path.join(ckpt_dir, SwiftModel.EXTRA_STATE_DIR)):
+ raise AssertionError('Cannot transfer to peft format, because you are additional state dicts.')
+
+ adapter_names = [
+ sub_dir for sub_dir in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, sub_dir, CONFIG_NAME))
+ ]
+
+ def has_custom_content(_json):
+ if _json.get('swift_type', _json.get('peft_type')) != SwiftTuners.LORA:
+ logger.warn('Only LoRA can be converted to peft format')
+ return True
+
+ from swift import LoRAConfig
+ return not LoRAConfig(**_json).can_be_saved_to_peft()
+
+ for adapter in adapter_names:
+ with open(os.path.join(ckpt_dir, adapter, CONFIG_NAME), encoding='utf-8') as f:
+ _json = json.load(f)
+ if has_custom_content(_json):
+ raise AssertionError('Cannot transfer to peft format, '
+ 'because you have special parameters or adapter types.')
+
+ os.makedirs(output_dir, exist_ok=True)
+ if ckpt_dir != output_dir:
+ shutil.copytree(ckpt_dir, output_dir, dirs_exist_ok=True)
+
+ for adapter in adapter_names:
+ safe_serialization = os.path.isfile(os.path.join(output_dir, adapter, SAFETENSORS_WEIGHTS_NAME))
+ state_dict = SwiftModel.load_state_file(os.path.join(output_dir, adapter))
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if not key.startswith('base_model.model.'):
+ key = 'base_model.model.' + key
+ key = key.replace(f'lora_A.{adapter}.', 'lora_A.')
+ key = key.replace(f'lora_B.{adapter}.', 'lora_B.')
+ key = key.replace(f'lora_embedding_A.{adapter}.', 'lora_embedding_A.')
+ key = key.replace(f'lora_embedding_B.{adapter}.', 'lora_embedding_B.')
+ key = key.replace(f'lora_magnitude_vector.{adapter}', 'lora_magnitude_vector')
+ new_state_dict[key] = value
+ state_dict = new_state_dict
+ SwiftModel._save_state_dict(state_dict, os.path.join(output_dir, adapter), safe_serialization)
+ from swift import LoRAConfig
+ with open(os.path.join(output_dir, adapter, CONFIG_NAME), encoding='utf-8') as f:
+ _json = json.load(f)
+ peft_config = LoRAConfig(**_json).to_peft_config()
+ peft_config.save_pretrained(os.path.join(output_dir, adapter))
+
+ if 'default' in adapter_names:
+ shutil.move(os.path.join(output_dir, 'default', CONFIG_NAME), os.path.join(output_dir, CONFIG_NAME))
+ state_dict = SwiftModel.load_state_file(os.path.join(output_dir, 'default'))
+ safe_serialization = os.path.isfile(os.path.join(output_dir, 'default', SAFETENSORS_WEIGHTS_NAME))
+ SwiftModel._save_state_dict(state_dict, output_dir, safe_serialization)
+ shutil.rmtree(os.path.join(output_dir, 'default'))
+
+ @staticmethod
+ def from_pretrained(model: Union[nn.Module, SwiftModel, PeftModel],
+ model_id: str = None,
+ adapter_name: Union[str, List[str], Dict[str, str]] = None,
+ revision: str = None,
+ **kwargs):
+ """Prepare a model by a model_id in the ModelScope hub or a local dir.
+
+ Args:
+ model(`Union[nn.Module, 'SwiftModel']`): The model to be tuned.
+ model_id(`str`): The model id of the modelhub or a local dir containing the configs/weights.
+ adapter_name(`str`, `optional`): The adapter_name to use.
+ revision(`str`, `optional`): The model revision if the model_id is a model id of the modelhub.
+ **kwargs:
+ Extra kwargs needed by ``SwiftModel.from_pretrained`` or ``PeftModel.from_pretrained``.
+ Returns:
+ The model wrapped by SwiftModel or PeftModel.
+ """
+ if not os.path.exists(model_id):
+ model_id = snapshot_download(model_id, revision=revision)
+ is_peft_model = False
+ if os.path.exists(os.path.join(model_id, CONFIG_NAME)):
+ with open(os.path.join(model_id, CONFIG_NAME), 'r', encoding='utf-8') as f:
+ _json = json.load(f)
+ is_peft_model = SWIFT_TYPE_KEY not in _json
+
+ _name = adapter_name if isinstance(
+ adapter_name, str) or adapter_name is None else adapter_name[0] \
+ if isinstance(adapter_name, list) else list(adapter_name.keys())[0]
+ _name = _name or ''
+ if os.path.exists(os.path.join(model_id, _name, CONFIG_NAME)):
+ with open(os.path.join(model_id, _name, CONFIG_NAME), 'r', encoding='utf-8') as f:
+ _json = json.load(f)
+ is_peft_model = SWIFT_TYPE_KEY not in _json and 'extra_state_keys' not in _json
+ if is_peft_model:
+
+ def load_peft_model(_model, _adapter_name, _new_name=None):
+ if not _new_name:
+ _new_name = _adapter_name
+ import peft
+ if not isinstance(_model, peft.PeftModel):
+ return PeftModel.from_pretrained(
+ _model,
+ os.path.join(model_id, _adapter_name) if _adapter_name != 'default'
+ and os.path.exists(os.path.join(model_id, _adapter_name)) else model_id,
+ revision=revision,
+ adapter_name=_new_name,
+ **kwargs)
+ else:
+ _model.load_adapter(
+ os.path.join(model_id, _adapter_name) if _adapter_name != 'default'
+ and os.path.exists(os.path.join(model_id, _adapter_name)) else model_id, _new_name)
+ return _model
+
+ if not adapter_name:
+ peft_model = load_peft_model(model, 'default')
+ for _dir in os.listdir(model_id):
+ if os.path.isdir(os.path.join(model_id, _dir)) and \
+ os.path.exists(os.path.join(model_id, _dir, CONFIG_NAME)):
+ peft_model = load_peft_model(peft_model, _dir)
+ elif isinstance(adapter_name, str):
+ return load_peft_model(model, adapter_name)
+ elif isinstance(adapter_name, list):
+ peft_model = model
+ for name in adapter_name:
+ peft_model = load_peft_model(peft_model, name)
+ else:
+ peft_model = model
+ for key, value in adapter_name.items():
+ peft_model = load_peft_model(peft_model, key, value)
+ return peft_model
+ else:
+ return SwiftModel.from_pretrained(model, model_id, revision=revision, adapter_name=adapter_name, **kwargs)
diff --git a/swift/tuners/llamapro.py b/swift/tuners/llamapro.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ec6d254fd743750d1e7914d00a08e6ea5fc63be
--- /dev/null
+++ b/swift/tuners/llamapro.py
@@ -0,0 +1,233 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from copy import deepcopy
+from dataclasses import dataclass, field, fields
+from typing import Optional
+
+import torch
+from torch import nn
+
+from swift.llm import MODEL_ARCH_MAPPING, HfConfigFactory, ModelKeys
+from swift.utils.logger import get_logger
+from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
+
+logger = get_logger()
+
+
+@dataclass
+class LLaMAProConfig(SwiftConfig):
+ """
+ The configuration class for the LLaMAPro module.
+
+ See https://arxiv.org/abs/2401.02415
+
+ Args:
+ model_type(`str`): LLaMAPro only support parts of the LLM models because of the variables need to be manually
+ modified.
+ num_new_blocks(`int`): How many new blocks need to be added
+ num_groups(`int`): The groups of new blocks are split to. Default equals to `num_new_blocks` which means each
+ single layer will be inserted into every `num_hidden_layers/num_new_blocks` original layers.
+ """
+ model_type: str = field(
+ default=None, metadata={
+ 'choices': list(MODEL_ARCH_MAPPING.keys()),
+ })
+
+ num_new_blocks: int = None
+
+ num_groups: Optional[int] = None
+
+ def __post_init__(self):
+ from .mapping import SwiftTuners
+ self.swift_type = SwiftTuners.LLAMAPRO
+
+
+class LLaMAPro(SwiftAdapter):
+
+ @staticmethod
+ def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -> SwiftOutput:
+ """Prepare a model with `LLaMAProConfig`"""
+ num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_hidden_layers')
+ if num_hidden_layers is None:
+ num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_layers')
+ assert num_hidden_layers is not None, 'Cannot find num of layers config'
+ assert num_hidden_layers % config.num_new_blocks == 0, f'Model layers {num_hidden_layers} ' \
+ f'should be divided by {config.num_new_blocks}'
+ if config.num_groups is None:
+ config.num_groups = config.num_new_blocks
+
+ # the except block will change the model_type, this will cause `model not found` error
+ # when using internvl
+ origin_model_type = config.model_type
+ model_type = origin_model_type
+ num_stride = num_hidden_layers // config.num_groups
+ try:
+ module_list = LLaMAPro._find_module_list(config, model)
+ except AssertionError as e:
+ model_type = LLaMAPro.search_correct_model_type(model)
+ if model_type is None:
+ language_model_name = SwiftAdapter.get_model_key_mapping(config.model_type, config).language_model
+ if language_model_name:
+ if isinstance(language_model_name, str):
+ language_model_name = [language_model_name]
+ language_model = model.get_submodule(language_model_name[0])
+ model_type = LLaMAPro.search_correct_model_type(language_model)
+ if model_type:
+ model = language_model
+
+ if model_type:
+ config.model_type = model_type
+ module_list = LLaMAPro._find_module_list(config, model)
+ else:
+ raise e
+
+ new_module_list = nn.ModuleList()
+ new_module_idx = []
+ for idx, module in enumerate(module_list):
+ new_module_list.append(module)
+ if (idx + 1) % num_stride == 0:
+ new_module = deepcopy(module)
+ ActivationMixin.mark_all_sub_modules_as_plugin(new_module)
+ new_module_list.append(new_module)
+ new_module_idx.append(idx + 1 + len(new_module_idx))
+
+ LLaMAPro._update_module_weight(config, new_module_list, new_module_idx)
+ LLaMAPro._update_module_attr(config, new_module_list)
+ model.config.num_hidden_layers = len(new_module_list)
+ LLaMAPro._set_module_list(config, model, new_module_list)
+
+ def activate_module(activate: bool):
+ if activate:
+ LLaMAPro._update_module_attr(config, new_module_list)
+ LLaMAPro._set_module_list(config, model, new_module_list)
+ else:
+ LLaMAPro._update_module_attr(config, module_list)
+ LLaMAPro._set_module_list(config, model, module_list)
+
+ def state_dict_callback(state_dict, adapter_name, **kwargs):
+ model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
+ new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx]
+ return {
+ key: value
+ for key, value in state_dict.items() if any([m_part in key for m_part in new_module_list])
+ }
+
+ def mark_trainable_callback(model):
+ model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
+ new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx]
+ for name, parameter in model.named_parameters():
+ parameter: nn.Parameter
+ if any([m_part in name for m_part in new_module_list]):
+ parameter.requires_grad = True
+
+ config.model_type = origin_model_type
+ model.activate_module = activate_module
+ return SwiftOutput(
+ config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
+
+ @staticmethod
+ def _update_module_attr(config: LLaMAProConfig, module_list):
+ model_type = config.model_type
+ model_key_mapping = LLaMAPro.get_model_key_mapping(model_type, config)
+ attention = model_key_mapping.attention
+ attention = attention.split('{}.')[1]
+ if model_type == 'phi3-small':
+ raise ValueError('phi3-small does not support llamapro currently')
+ if model_type in ('llama', 'mistral', 'qwen2', 'yi', 'gemma', 'deepseek', 'openbuddy', 'xverse', 'orion',
+ 'bluelm', 'ziya', 'skywork', 'deepseek-v2', 'minicpm', 'phi3', 'internlm2'):
+ for idx, module in enumerate(module_list):
+ try:
+ getattr(module, attention).layer_idx = idx
+ except AttributeError:
+ getattr(module, 'cross_attn').layer_idx = idx
+ elif model_type in ('chatglm', 'glm4'):
+ for idx, module in enumerate(module_list):
+ getattr(module, attention).layer_number = idx
+ elif model_type in ('phi2', ):
+ for idx, module in enumerate(module_list):
+ getattr(module, attention).block_idx = idx
+ else:
+ for idx, module in enumerate(module_list):
+ attrs = [
+ attr for attr in dir(getattr(module_list[0], attention))
+ if attr in ('layer_idx', 'layer_number', 'block_idx')
+ ]
+ assert len(attrs) <= 1
+ if attrs:
+ setattr(getattr(module, attention), attrs[0], idx)
+ else:
+ logger.warn(f'model_type: {model_type} seems has no layer_idx, if you encountered anything wrong,'
+ f'please give us a feedback.')
+
+ @classmethod
+ def get_model_key_mapping(cls, model_type, config) -> ModelKeys:
+
+ model_key_mapping = SwiftAdapter.get_model_key_mapping(model_type, config)
+ assert model_key_mapping.o_proj is not None and model_key_mapping.down_proj is not None, \
+ 'LLaMAPro only support models with o_proj and down_proj components.'
+ return model_key_mapping
+
+ @classmethod
+ def search_correct_model_type(cls, module: nn.Module):
+ for arch_name, arch_type in MODEL_ARCH_MAPPING.items():
+ arch_type: ModelKeys
+ if getattr(arch_type, 'module_list') is None:
+ # Need to be a LLM arch
+ continue
+
+ matched = True
+ for f in fields(arch_type):
+ arch_str = getattr(arch_type, f.name)
+ if f.name == 'arch_name' or arch_str is None:
+ continue
+
+ arch_str = arch_str.replace('{}', '0')
+ try:
+ sub_module = module.get_submodule(arch_str)
+ if sub_module is None:
+ matched = False
+ except AttributeError:
+ matched = False
+
+ if not matched:
+ break
+
+ if matched:
+ return arch_name
+
+ @staticmethod
+ def _update_module_weight(config: LLaMAProConfig, module_list, new_module_idx):
+ model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
+ o_proj = model_key_mapping.o_proj.split('{}.')[1]
+ down_proj = model_key_mapping.down_proj.split('{}.')[1]
+
+ for idx, module in enumerate(module_list):
+ if idx not in new_module_idx:
+ continue
+ _o_proj: nn.Linear = module.get_submodule(o_proj)
+ _down_proj: nn.Linear = module.get_submodule(down_proj)
+ _o_proj.weight.data = torch.zeros_like(_o_proj.weight.data)
+ _down_proj.weight.data = torch.zeros_like(_down_proj.weight.data)
+ if hasattr(_o_proj, 'bias') and _o_proj.bias is not None:
+ _o_proj.bias.data = torch.zeros_like(_o_proj.bias)
+ if hasattr(_down_proj, 'bias') and _down_proj.bias is not None:
+ _down_proj.bias.data = torch.zeros_like(_down_proj.bias)
+
+ @staticmethod
+ def _set_module_list(config, module: nn.Module, module_list: nn.ModuleList):
+ model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
+ idx = model_key_mapping.module_list.rfind('.')
+ parent = module.get_submodule(model_key_mapping.module_list[:idx])
+ setattr(parent, model_key_mapping.module_list[idx + 1:], module_list)
+
+ @staticmethod
+ def _find_module_list(config, module: nn.Module) -> nn.ModuleList:
+ model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
+ return module.get_submodule(model_key_mapping.module_list)
+
+ @staticmethod
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
+ module.activate_module(activate)
+
+ @staticmethod
+ def has_additional_modules():
+ return True
diff --git a/swift/tuners/lora.py b/swift/tuners/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..b36e5df392d41c24a2e99f426a062ef018412dec
--- /dev/null
+++ b/swift/tuners/lora.py
@@ -0,0 +1,193 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+from dataclasses import asdict, dataclass, field
+from functools import reduce
+
+import peft
+import torch
+from packaging import version
+from transformers import Trainer
+
+from .lora_layers import * # noqa
+from .utils import SwiftAdapter, SwiftConfig, SwiftOutput, set_adapter
+
+logger = get_logger()
+
+
+@dataclass
+class LoRAConfig(LoraConfig, SwiftConfig):
+ """
+ The configuration class for the loRA module.
+
+ Args:
+ use_qa_lora(bool): Use
+ QA-LoRA:[Quantization-Aware Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2309.14717)
+ instead of LoRA. QA-LoRA only supports AutoGPTQ quantized models.
+ Deprecated, do not use this argument.
+ lora_dtype(str): The dtype for all lora modules, supported values are `fp32`, `fp16`, `bf16`.
+ Default value is `None`, which means follow the dtype of original module's weight.
+ lorap_lr_ratio(float): The lr_ratio argument for [LoRA+](https://arxiv.org/abs/2402.12354)
+ """
+
+ use_qa_lora: bool = field(
+ default=False, metadata={'help': 'Use [qa-lora](https://github.com/yuhuixu1993/qa-lora) or not'})
+
+ use_merged_linear: bool = field(default=False, metadata={'help': 'Use merged Linear'})
+
+ enable_lora: List[bool] = field(
+ default=None, metadata={'help': 'The modules need to be turned on when using the merged linear layer'})
+
+ lora_dtype: Optional[str] = field(
+ default=None, metadata={'help': 'The lora dtype, default None means following the original layer\'s dtype'})
+
+ lorap_lr_ratio: float = field(default=2.0**4, metadata={'help': 'The lr ratio of lora_B in lora+'})
+
+ lorap_emb_lr: float = field(default=1e-6, metadata={'help': 'The lr for embedding in lora+'})
+
+ def __post_init__(self):
+ super().__post_init__()
+ from .mapping import SwiftTuners
+ self.swift_type = SwiftTuners.LORA
+
+ def can_be_saved_to_peft(self) -> bool:
+ if self.use_qa_lora or self.use_merged_linear:
+ logger.warn('QA-LoRA and MergedLinear cannot be saved to peft format')
+ return False
+ return True
+
+ def to_peft_config(self) -> LoraConfig:
+ _dict = asdict(self)
+ _dict.pop('use_qa_lora', None)
+ _dict.pop('enable_lora', None)
+ _dict.pop('lora_dtype', None)
+ _dict.pop('use_merged_linear', None)
+ _dict['peft_type'] = _dict['swift_type']
+ _dict.pop('swift_type', None)
+ _dict.pop('lr_ratio', None)
+ _dict.pop('model_key_mapping', None)
+ return LoraConfig(**_dict)
+
+ def save_pretrained(self, save_directory: str, **kwargs) -> None:
+ super(peft.LoraConfig, self).save_pretrained(save_directory, **kwargs)
+
+
+class LoRA(SwiftAdapter):
+
+ @staticmethod
+ def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str):
+ assert not config.use_qa_lora, 'Do not use qa-lora'
+ if config.use_qa_lora:
+ auto_gptq_config = get_quantization_config(model, method='gptq')
+ if auto_gptq_config:
+ config.group_size = getattr(auto_gptq_config, 'group_size', None)
+ LoraModel(model, config, adapter_name)
+
+ def state_dict_callback(state_dict, adapter_name, cfg=None, **kwargs):
+ return lora_state_dict(state_dict, adapter_name, cfg.bias if cfg else config.bias)
+
+ def mark_trainable_callback(model, cfg=None):
+ mark_lora_as_trainable(model, adapter_name, cfg.bias if cfg else config.bias)
+
+ def optimizer_group_callback(model, **defaults):
+ if config.lorap_lr_ratio is None:
+ return None, None
+
+ def get_module(name):
+ parent_idx = 2 if 'lora' in name else 1
+ module_names = name.split(sep='.')[:-parent_idx]
+ module = reduce(getattr, module_names, model)
+ return module
+
+ all_params = set()
+ param_groups = {
+ 'groupA': {},
+ 'groupB': {},
+ 'groupB_no_decay': {},
+ 'embedding': {},
+ }
+
+ decay_parameters = Trainer.get_decay_parameter_names(None, model)
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue
+ module = get_module(name)
+ if isinstance(module, Embedding):
+ param_groups['embedding'][name] = param
+ elif 'lora_B' in name or param.ndim == 1:
+ if name in decay_parameters:
+ param_groups['groupB'][name] = param
+ else:
+ param_groups['groupB_no_decay'][name] = param
+ else:
+ param_groups['groupA'][name] = param
+ all_params.add(name)
+
+ lr = defaults['lr']
+ weight_decay = defaults.get('weight_decay', 0.0)
+
+ param_groups = [
+ {
+ 'params': list(param_groups['groupA'].values()),
+ 'weight_decay': weight_decay,
+ 'lr': lr,
+ },
+ {
+ 'params': list(param_groups['embedding'].values()),
+ 'weight_decay': weight_decay,
+ 'lr': config.lorap_emb_lr,
+ },
+ {
+ 'params': list(param_groups['groupB'].values()),
+ 'weight_decay': weight_decay,
+ 'lr': lr * config.lorap_lr_ratio,
+ },
+ {
+ 'params': list(param_groups['groupB_no_decay'].values()),
+ 'weight_decay': 0.0,
+ 'lr': lr * config.lorap_lr_ratio,
+ },
+ ]
+ return all_params, param_groups
+
+ return SwiftOutput(
+ config=config,
+ state_dict_callback=state_dict_callback,
+ mark_trainable_callback=mark_trainable_callback,
+ optimizer_group_callback=optimizer_group_callback)
+
+ @staticmethod
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
+ set_adapter(module, adapter_name, activate, offload)
+ for sub_module in module.modules():
+ if isinstance(sub_module, (LoraLayer, LoRALayer)):
+ sub_module.set_activation(adapter_name, activate)
+ if hasattr(sub_module, 'save_memory'):
+ sub_module.save_memory(adapter_name, activate, offload)
+
+ @staticmethod
+ def unpatch_lora(model, config: LoRAConfig, adapter_name: str):
+ """Unpatch lora modules and merge the weights to original modules.
+
+ LoRA constructs an additional layer with low-rank decomposition matrices of the weights in the network.
+ 'LoRA: Low-Rank Adaptation of Large Language Models' by Hu et al.(2021)
+ See https://arxiv.org/abs/2106.09685
+
+ Args:
+ model(`torch.nn.Module`): The model called with `tune` function.
+ config(`LoRAConfig`): The `LoRAConfig` to use. Deprecated
+ adapter_name(`str`): The adapter name
+ """
+ if not config.use_merged_linear:
+ if version.parse(peft.__version__) < version.parse('0.6.3'):
+ logger.info('All adapters will be merged.')
+ LoraModel(model, None, '').merge_and_unload()
+ else:
+ LoraModel(model, None, '').merge_and_unload(adapter_names=[adapter_name])
+ else:
+ for name, sub_module in model.named_modules():
+ if isinstance(sub_module, MergedLinear):
+ sub_module.merge()
+ parent = model.get_submodule('.'.join(name.split('.')[:-1]))
+ target_name = name.split('.')[-1]
+ setattr(parent, target_name, sub_module.base_layer)
diff --git a/swift/tuners/lora_layers.py b/swift/tuners/lora_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f681644a3829fbbf8961fe02819a567165fbaad4
--- /dev/null
+++ b/swift/tuners/lora_layers.py
@@ -0,0 +1,673 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+import math
+import re
+import warnings
+from itertools import chain
+from typing import Dict, List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from peft.import_utils import is_bnb_4bit_available, is_bnb_available
+from peft.tuners.lora import Conv2d as _Conv2d
+from peft.tuners.lora import Embedding as _Embedding
+from peft.tuners.lora import Linear as _Linear
+from peft.tuners.lora import LoraLayer
+from peft.tuners.lora import LoraModel as _LoraModel
+from peft.tuners.lora.tp_layer import LoraParallelLinear as _LoraParallelLinear
+from peft.tuners.tuners_utils import BaseTunerLayer
+from peft.utils import _get_submodules, get_quantization_config
+from transformers import Conv1D
+
+from swift.utils import get_logger
+from .peft import LoraConfig
+from .utils import ActivationMixin, ModulesToSaveWrapper, SwiftAdapter
+
+logger = get_logger()
+dispatchers = []
+
+
+class LoRAActivationMixin(ActivationMixin):
+
+ @property
+ def active_adapters(self):
+ return self.get_activated_adapters()
+
+ @property
+ def active_adapter(self) -> str:
+ return self.get_activated_adapters()
+
+ def set_adapter(self, adapter_names, offload=None):
+ if isinstance(adapter_names, str):
+ adapter_names = [adapter_names]
+
+ # Deactivate grads on the inactive adapter and activate grads on the active adapter
+ for layer_name in self.adapter_layer_names:
+ module_dict = getattr(self, layer_name)
+ for key, layer in module_dict.items():
+ if key in adapter_names:
+ self.set_activation(key, True)
+ layer.requires_grad_(True)
+ SwiftAdapter.save_memory(layer, key, self.module_key, True)
+ else:
+ self.set_activation(key, False)
+ layer.requires_grad_(False)
+ SwiftAdapter.save_memory(layer, key, self.module_key, False, offload=offload)
+
+ def save_memory(self, adapter_name, activate, offload=None):
+ for layer_name in self.adapter_layer_names:
+ module_dict = getattr(self, layer_name)
+ for key, layer in module_dict.items():
+ if key == adapter_name:
+ if activate:
+ SwiftAdapter.save_memory(layer, layer_name + '.' + key, self.module_key, True)
+ else:
+ SwiftAdapter.save_memory(layer, layer_name + '.' + key, self.module_key, False, offload=offload)
+
+ def merge(self, *args, **kwargs):
+ if not self.unique_thread:
+ raise AssertionError('Merge is unsupported in multiple thread, '
+ 'please set `USE_UNIQUE_THREAD=1` in env variable to merge LoRA.')
+ return super().merge(*args, **kwargs)
+
+
+if is_bnb_available():
+ import bitsandbytes as bnb
+ from peft.tuners.lora.bnb import Linear8bitLt as _Linear8bitLt
+
+ class Linear8bitLt(LoRAActivationMixin, _Linear8bitLt):
+
+ def __init__(
+ self,
+ *args,
+ module_key: str,
+ **kwargs,
+ ):
+ super(Linear8bitLt, self).__init__(module_key)
+ self.set_activation(args[1], True)
+ super(ActivationMixin, self).__init__(*args, **kwargs)
+
+ def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, module_key: str, **kwargs):
+ new_module = None
+
+ if isinstance(target, BaseTunerLayer):
+ target_base_layer = target.get_base_layer()
+ else:
+ target_base_layer = target
+
+ loaded_in_8bit = kwargs.get('loaded_in_8bit', False)
+ if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
+ eightbit_kwargs = kwargs.copy()
+ eightbit_kwargs.update({
+ 'has_fp16_weights': target.state.has_fp16_weights,
+ 'threshold': target.state.threshold,
+ 'index': target.index,
+ })
+ new_module = Linear8bitLt(target, adapter_name, module_key=module_key, **eightbit_kwargs)
+
+ return new_module
+
+ dispatchers.append(dispatch_bnb_8bit)
+
+if is_bnb_4bit_available():
+ from peft.tuners.lora.bnb import Linear4bit as _Linear4bit
+
+ class Linear4bit(LoRAActivationMixin, _Linear4bit):
+
+ def __init__(
+ self,
+ *args,
+ module_key: str,
+ **kwargs,
+ ):
+ super(Linear4bit, self).__init__(module_key)
+ self.set_activation(args[1], True)
+ super(ActivationMixin, self).__init__(*args, **kwargs)
+
+ def dispatch_bnb_4bit(target: torch.nn.Module, adapter_name: str, module_key: str, **kwargs):
+ new_module = None
+
+ if isinstance(target, BaseTunerLayer):
+ target_base_layer = target.get_base_layer()
+ else:
+ target_base_layer = target
+
+ loaded_in_4bit = kwargs.get('loaded_in_4bit', False)
+ if loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
+ fourbit_kwargs = kwargs.copy()
+ fourbit_kwargs.update({
+ 'compute_dtype': target_base_layer.compute_dtype,
+ 'compress_statistics': target_base_layer.weight.compress_statistics,
+ 'quant_type': target_base_layer.weight.quant_type,
+ })
+ new_module = Linear4bit(target, adapter_name, module_key=module_key, **fourbit_kwargs)
+
+ return new_module
+
+ dispatchers.append(dispatch_bnb_4bit)
+
+
+def dispatch_default(
+ target: torch.nn.Module,
+ adapter_name: str,
+ lora_config: LoraConfig,
+ module_key: str,
+ **kwargs,
+) -> Optional[torch.nn.Module]:
+ new_module = None
+
+ if isinstance(target, BaseTunerLayer):
+ target_base_layer = target.get_base_layer()
+ else:
+ target_base_layer = target
+
+ if isinstance(target_base_layer, torch.nn.Embedding):
+ embedding_kwargs = kwargs.copy()
+ embedding_kwargs.pop('fan_in_fan_out', None)
+ embedding_kwargs.update(lora_config.loftq_config)
+ new_module = Embedding(target, adapter_name, module_key=module_key, **embedding_kwargs)
+ elif isinstance(target_base_layer, torch.nn.Conv2d):
+ kwargs.update(lora_config.loftq_config)
+ new_module = Conv2d(target, adapter_name, module_key=module_key, **kwargs)
+ elif isinstance(target_base_layer, torch.nn.Linear):
+ if target_base_layer.__class__.__name__ == 'NonDynamicallyQuantizableLinear':
+ # Fix issue: https://github.com/modelscope/swift/issues/342
+ return None
+ if kwargs['fan_in_fan_out']:
+ warnings.warn('fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. '
+ 'Setting fan_in_fan_out to False.')
+ kwargs['fan_in_fan_out'] = lora_config.fan_in_fan_out = False
+ kwargs.update(lora_config.loftq_config)
+ new_module = Linear(target, adapter_name, module_key=module_key, **kwargs)
+ elif isinstance(target_base_layer, Conv1D):
+ if not kwargs['fan_in_fan_out']:
+ warnings.warn('fan_in_fan_out is set to False but the target module is `Conv1D`. '
+ 'Setting fan_in_fan_out to True.')
+ kwargs['fan_in_fan_out'] = lora_config.fan_in_fan_out = True
+ kwargs.update(lora_config.loftq_config)
+ new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, module_key=module_key, **kwargs)
+
+ return new_module
+
+
+dispatchers.append(dispatch_default)
+
+
+class Embedding(LoRAActivationMixin, _Embedding):
+
+ def __init__(
+ self,
+ *args,
+ module_key: str,
+ **kwargs,
+ ) -> None:
+ super(Embedding, self).__init__(module_key)
+ self.set_activation(args[1], True)
+ super(ActivationMixin, self).__init__(*args, **kwargs)
+
+
+class Linear(LoRAActivationMixin, _Linear):
+
+ def __init__(self, *args, module_key: str, **kwargs):
+ super(Linear, self).__init__(module_key)
+ self.set_activation(args[1], True)
+ super(ActivationMixin, self).__init__(*args, **kwargs)
+
+
+class Conv2d(LoRAActivationMixin, _Conv2d):
+
+ def __init__(self, *args, module_key: str, **kwargs):
+ super(Conv2d, self).__init__(module_key)
+ self.set_activation(args[1], True)
+ super(ActivationMixin, self).__init__(*args, **kwargs)
+
+
+class LoraParallelLinear(LoRAActivationMixin, _LoraParallelLinear):
+
+ def __init__(self, *args, module_key: str, **kwargs):
+ super(LoraParallelLinear, self).__init__(module_key)
+ self.set_activation(args[1], True)
+ super(ActivationMixin, self).__init__(*args, **kwargs)
+
+
+class LoraModel(_LoraModel):
+
+ prefix: str = 'lora_'
+
+ def __init__(self, model, config, adapter_name):
+ if config is not None:
+ super().__init__(model, config, adapter_name)
+ else:
+ nn.Module.__init__(self)
+ self.model = model
+
+ def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
+ for active_adapter in self.active_adapters:
+ bias = self.peft_config[active_adapter].bias
+ if bias == 'none':
+ continue
+
+ if bias == 'all':
+ for n, p in model.named_parameters():
+ if 'bias' in n:
+ p.requires_grad = True
+ elif bias == 'lora_only':
+ for m in model.modules():
+ if isinstance(m, LoraLayer) and hasattr(m, 'bias') and m.bias is not None:
+ m.bias.requires_grad = True
+ else:
+ raise NotImplementedError(f'Requested bias: {bias}, is not implemented.')
+
+ def inject_adapter(self,
+ model: nn.Module,
+ adapter_name: str,
+ autocast_adapter_dtype: bool = True,
+ low_cpu_mem_usage: bool = False):
+ r"""
+ Override code:
+ 1. ModulesToSaveWrapper construction method: add module_key=key argument to offload to cpu
+ """
+ peft_config = self.peft_config[adapter_name]
+ # Note: If possible, all checks should be performed *at the start of this method*.
+ # This way, we can raise early if something goes wrong, without leaving the model
+ # in a bad (half-initialized) state.
+ self._check_new_adapter_config(peft_config)
+
+ is_target_modules_in_base_model = False
+ key_list = [key for key, _ in model.named_modules()]
+
+ _check_for_modules_to_save = getattr(peft_config, 'modules_to_save', None) is not None
+ _has_modules_to_save = False
+
+ model_config = getattr(model, 'config', {'model_type': 'custom'})
+ if hasattr(model_config, 'to_dict'):
+ model_config = model_config.to_dict()
+
+ peft_config = self._prepare_adapter_config(peft_config, model_config)
+
+ from peft.tuners.tuners_utils import _maybe_include_all_linear_layers
+ try:
+ from peft.utils.constants import DUMMY_TARGET_MODULES
+ except ImportError: # compat with peft==0.11.*
+ DUMMY_TARGET_MODULES = 'dummy-target-modules'
+ if getattr(peft_config, 'target_modules', None) == DUMMY_TARGET_MODULES:
+ # dummy adapter, we allow not matching any module
+ key_list = []
+ is_target_modules_in_base_model = True
+ # update peft_config.target_modules if required
+ peft_config = _maybe_include_all_linear_layers(peft_config, model)
+ self._prepare_model(peft_config, model)
+
+ for key in key_list:
+ if '_part_' in key or not key:
+ # Avoid lora conflict with part tuner
+ continue
+ # Check for modules_to_save in case
+ if _check_for_modules_to_save and any(
+ key.endswith(f'{module_to_save}') for module_to_save in peft_config.modules_to_save):
+ # Optionally set the modules to save
+ parent, target, target_name = _get_submodules(model, key)
+
+ if not isinstance(target, ModulesToSaveWrapper):
+ new_module = ModulesToSaveWrapper(target, adapter_name=adapter_name, module_key=key)
+ setattr(parent, target_name, new_module)
+ else:
+ target.update(adapter_name)
+
+ _has_modules_to_save = True
+ continue
+
+ if not self._check_target_module_exists(peft_config, key):
+ continue
+
+ self.targeted_module_names.append(key)
+ is_target_modules_in_base_model = True
+ parent, target, target_name = _get_submodules(model, key)
+ self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
+
+ if not is_target_modules_in_base_model and hasattr(peft_config, 'target_modules'):
+ raise ValueError(f'Target modules {peft_config.target_modules} not found in the base model. '
+ f'Please check the target modules and try again.')
+
+ self._mark_only_adapters_as_trainable(self.model)
+
+ if self.peft_config[adapter_name].inference_mode:
+ for n, p in self.model.named_parameters():
+ if adapter_name in n:
+ p.requires_grad = False
+
+ if _has_modules_to_save:
+ if not hasattr(model, 'modules_to_save'):
+ model.modules_to_save = set(peft_config.modules_to_save)
+ else:
+ model.modules_to_save.update(set(peft_config.modules_to_save))
+
+ def _convert_dtype(self, target: nn.Module, lora_dtype: str):
+ if lora_dtype == 'float32':
+ torch_dtype = torch.float32
+ elif lora_dtype == 'float16':
+ torch_dtype = torch.float16
+ elif lora_dtype == 'bfloat16':
+ torch_dtype = torch.bfloat16
+ else:
+ torch_dtype = None
+
+ if torch_dtype is not None:
+ if hasattr(target, 'lora_A'):
+ target.lora_A.to(torch_dtype)
+ target.lora_B.to(torch_dtype)
+ if hasattr(target, 'lora_embedding_A'):
+ target.lora_embedding_A.to(torch_dtype)
+ target.lora_embedding_B.to(torch_dtype)
+
+ def _create_and_replace(
+ self,
+ lora_config,
+ adapter_name,
+ target,
+ target_name,
+ parent,
+ current_key,
+ **optional_kwargs,
+ ):
+ """
+ Override code:
+ 1. Import bnb from upper code
+ 2. Support dtype converting
+ 3. Support skipping NonDynamicallyQuantizableLinear
+ 4. Add current_key argument to _create_new_module
+ 5. Use Class type defined here
+ 6. Allow new_module being None
+ """
+ if current_key is None:
+ raise ValueError("Current Key shouldn't be `None`")
+
+ # Regexp matching - Find key which matches current target_name in patterns provided
+ pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys()))
+ target_name_key = next(filter(lambda key: re.match(rf'.*\.{key}$', current_key), pattern_keys), current_key)
+ r = lora_config.rank_pattern.get(target_name_key, lora_config.r)
+ alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha)
+
+ kwargs = {
+ 'r': r,
+ 'lora_alpha': alpha,
+ 'lora_dropout': lora_config.lora_dropout,
+ 'fan_in_fan_out': lora_config.fan_in_fan_out,
+ 'init_lora_weights': lora_config.init_lora_weights,
+ 'use_rslora': lora_config.use_rslora,
+ 'use_dora': lora_config.use_dora,
+ 'loaded_in_8bit': getattr(self.model, 'is_loaded_in_8bit', False),
+ 'loaded_in_4bit': getattr(self.model, 'is_loaded_in_4bit', False),
+ }
+ # compat with peft==0.11.*
+ if hasattr(lora_config, 'runtime_config'):
+ kwargs['ephemeral_gpu_offload'] = lora_config.runtime_config.ephemeral_gpu_offload
+
+ quant_methods = ['gptq', 'aqlm', 'awq']
+ for quant_method in quant_methods:
+ quantization_config = get_quantization_config(self.model, method=quant_method)
+ if quantization_config is not None:
+ kwargs[f'{quant_method}_quantization_config'] = quantization_config
+
+ # note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it
+ from peft.tuners.adalora import AdaLoraLayer
+
+ if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer):
+ if target.__class__.__name__ == 'NonDynamicallyQuantizableLinear':
+ # Fix issue: https://github.com/modelscope/swift/issues/342
+ return
+ target.update_layer(
+ adapter_name,
+ r,
+ lora_alpha=alpha,
+ lora_dropout=lora_config.lora_dropout,
+ init_lora_weights=lora_config.init_lora_weights,
+ use_rslora=lora_config.use_rslora,
+ use_dora=lora_config.use_dora,
+ )
+ self._convert_dtype(target, lora_config.lora_dtype)
+ ActivationMixin.mark_all_sub_modules_as_plugin(target)
+ else:
+ new_module = self._create_new_module(lora_config, adapter_name, target, current_key=current_key, **kwargs)
+ if new_module is not None:
+ ActivationMixin.mark_all_sub_modules_as_plugin(new_module)
+ if adapter_name not in self.active_adapters:
+ # adding an additional adapter: it is not automatically trainable
+ new_module.requires_grad_(False)
+ self._replace_module(parent, target_name, new_module, target)
+ self._convert_dtype(new_module, lora_config.lora_dtype)
+
+ def _replace_module(self, parent, child_name, new_module, child):
+ setattr(parent, child_name, new_module)
+ # It's not necessary to set requires_grad here, as that is handled by
+ # _mark_only_adapters_as_trainable
+
+ # child layer wraps the original module, unpack it
+ if hasattr(child, 'base_layer'):
+ child = child.base_layer
+
+ if not hasattr(new_module, 'base_layer'):
+ if hasattr(new_module, 'W_q'): # HQQ
+ new_module.W_q = child.W_q
+ else:
+ new_module.weight = child.weight
+ if hasattr(child, 'bias'):
+ new_module.bias = child.bias
+
+ if getattr(child, 'state', None) is not None:
+ if hasattr(new_module, 'base_layer'):
+ new_module.base_layer.state = child.state
+ else:
+ new_module.state = child.state
+ new_module.to(child.weight.device)
+
+ meta = torch.device('meta')
+ # dispatch to correct device
+ for name, module in new_module.named_modules():
+ if (self.prefix in name) or ('ranknum' in name):
+ weight = (
+ child.qweight if hasattr(child, 'qweight') else child.W_q if hasattr(child, 'W_q') else
+ child.weight if hasattr(child, 'weight') else next(child.parameters()))
+ if not any(p.device == meta for p in module.parameters()):
+ module.to(weight.device)
+
+ @staticmethod
+ def _create_new_module(lora_config, adapter_name, target, **kwargs):
+ """
+ Override code:
+ 1. Support current_key argument
+ 2. Support MergedLinear
+ 3. Support skipping NonDynamicallyQuantizableLinear(Move to dispatcher)
+ 4. Use Class type defined here(Move to dispatcher)
+ 5. return None instead of raising error when target type not found
+ """
+ # Collect dispatcher functions to decide what backend to use for the replaced LoRA layer. The order matters,
+ # because the first match is always used. Therefore, the default layers should be checked last.
+ current_key = kwargs.pop('current_key')
+ new_module = None
+ if lora_config.use_qa_lora:
+ kwargs['use_qa_lora'] = True
+ kwargs['group_size'] = lora_config.group_size
+ if lora_config.use_merged_linear:
+ bias = kwargs.pop('bias', False)
+ new_module = MergedLinear(
+ adapter_name, current_key, target, bias=bias, enable_lora=lora_config.enable_lora, **kwargs)
+ else:
+ for dispatcher in dispatchers:
+ new_module = dispatcher(target, adapter_name, lora_config=lora_config, module_key=current_key, **kwargs)
+ if new_module is not None: # first match wins
+ break
+
+ if new_module is None:
+ # no module could be matched
+ logger.debug(
+ f'Target module {target} is not supported. Currently, only the following modules are supported: '
+ '`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`.')
+ new_module = None
+
+ return new_module
+
+
+class LoRALayer(ActivationMixin):
+
+ def __init__(
+ self,
+ adapter_name: str,
+ module_key: str,
+ r: int,
+ lora_alpha: int,
+ lora_dropout: float,
+ merge_weights: bool,
+ ):
+ super().__init__(module_key)
+ self.adapter_name = adapter_name
+ self.r = r
+ self.lora_alpha = lora_alpha
+ # Optional dropout
+ if lora_dropout > 0.:
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
+ else:
+ self.lora_dropout = lambda x: x
+ # Mark the weight as unmerged
+ self.merged = False
+ self.merge_weights = merge_weights
+ if not self._unique_thread:
+ self.merge_weights = False
+
+
+class MergedLinear(nn.Linear, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(self,
+ adapter_name: str,
+ module_key: str,
+ base_layer: nn.Linear,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ enable_lora: List[bool] = [False],
+ fan_in_fan_out: bool = False,
+ merge_weights: bool = True,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ **kwargs):
+ nn.Linear.__init__(self, base_layer.in_features, base_layer.out_features, bias=bias, device=device, dtype=dtype)
+ LoRALayer.__init__(
+ self,
+ adapter_name,
+ module_key,
+ r=r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ merge_weights=merge_weights)
+ assert base_layer.out_features % len(enable_lora) == 0, \
+ 'The length of enable_lora must divide out_features'
+ self.enable_lora = enable_lora
+ self.fan_in_fan_out = fan_in_fan_out
+ self.base_layer = base_layer
+ # Actual trainable parameters
+ if r > 0 and any(enable_lora):
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r * sum(enable_lora), base_layer.in_features)))
+ self.lora_B = nn.Parameter(
+ self.weight.new_zeros((base_layer.out_features // len(enable_lora) * sum(enable_lora),
+ r))) # weights for Conv1D with groups=sum(enable_lora)
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ # Compute the indices
+ self.lora_ind = self.weight.new_zeros((base_layer.out_features, ),
+ dtype=torch.bool).view(len(enable_lora), -1)
+ self.lora_ind[enable_lora, :] = True
+ self.lora_ind = self.lora_ind.view(-1)
+ self.reset_parameters()
+ self.weight = self.base_layer.weight
+ if getattr(self.base_layer, 'bias', None) is not None:
+ self.bias = self.base_layer.bias
+ if fan_in_fan_out:
+ self.weight.data = self.weight.data.transpose(0, 1)
+
+ def reset_parameters(self):
+ nn.Linear.reset_parameters(self)
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ def zero_pad(self, x):
+ result = x.new_zeros((len(self.lora_ind), *x.shape[1:]))
+ result[self.lora_ind] = x
+ return result
+
+ def merge_AB(self):
+
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+
+ delta_w = F.conv1d(self.lora_A.unsqueeze(0), self.lora_B.unsqueeze(-1), groups=sum(self.enable_lora)).squeeze(0)
+ return T(self.zero_pad(delta_w))
+
+ def merge(self, **kwargs):
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0 and any(self.enable_lora):
+ self.weight.data += self.merge_AB() * self.scaling
+
+ def unmerge(self, **kwargs):
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ if self.r > 0 and any(self.enable_lora):
+ self.weight.data -= self.merge_AB() * self.scaling
+ self.merged = False
+
+ def forward(self, x: torch.Tensor, **kwargs):
+
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+
+ if self.merged or not self.is_activated(self.adapter_name):
+ return F.linear(x, T(self.weight), bias=self.bias)
+ else:
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ if self.r > 0:
+ x_dtype = x.dtype
+ x = x.to(self.lora_A.dtype)
+ result += self.lora_dropout(x) @ T(self.merge_AB().T) * self.scaling
+ result = result.to(x_dtype)
+ return result
+
+
+def mark_lora_as_trainable(model: nn.Module, adapter_name: str, bias: str = 'none') -> None:
+ if bias == 'none':
+ return
+ elif bias == 'all':
+ for n, p in model.named_parameters():
+ if 'bias' in n:
+ p.requires_grad = True
+ elif bias == 'lora_only':
+ for n, m in model.named_modules():
+ if 'lora_' in n and f'.{adapter_name}' in n and \
+ hasattr(m, 'bias') and \
+ m.bias is not None:
+ m.bias.requires_grad = True
+ else:
+ raise NotImplementedError
+
+
+def lora_state_dict(state_dict, adapter_name: str, bias: str = 'none') -> Dict[str, torch.Tensor]:
+ if bias == 'none':
+ to_return = {k: state_dict[k] for k in state_dict if 'lora_' in k}
+ elif bias == 'all':
+ to_return = {k: state_dict[k] for k in state_dict if 'lora_' in k or 'bias' in k}
+ elif bias == 'lora_only':
+ to_return = {}
+ for k in state_dict:
+ if 'lora_' in k:
+ to_return[k] = state_dict[k]
+ bias_name = k.split('lora_')[0] + 'bias'
+ if bias_name in state_dict:
+ to_return[bias_name] = state_dict[bias_name]
+ else:
+ raise NotImplementedError
+ return {k: v for k, v in to_return.items() if (('lora_' in k and f'.{adapter_name}' in k) or ('bias' in k))}
diff --git a/swift/tuners/mapping.py b/swift/tuners/mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa17ef89e6af7fca7af3d53aa54958d1a4ee4f94
--- /dev/null
+++ b/swift/tuners/mapping.py
@@ -0,0 +1,42 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+from .adapter import Adapter, AdapterConfig
+from .llamapro import LLaMAPro, LLaMAProConfig
+from .longlora.longlora import LongLoRA, LongLoRAConfig
+from .lora import LoRA, LoRAConfig
+from .neftune import NEFTune, NEFTuneConfig
+from .part import Part, PartConfig
+from .prompt import Prompt, PromptConfig
+from .reft import Reft, ReftConfig
+from .restuning import ResTuning, ResTuningConfig
+from .scetuning.scetuning import SCETuning, SCETuningConfig
+from .side import Side, SideConfig
+
+
+class SwiftTuners:
+ ADAPTER = 'ADAPTER'
+ PROMPT = 'PROMPT'
+ LORA = 'LORA'
+ SIDE = 'SIDE'
+ RESTUNING = 'RESTUNING'
+ LONGLORA = 'longlora'
+ NEFTUNE = 'neftune'
+ LLAMAPRO = 'LLAMAPRO'
+ SCETUNING = 'SCETuning'
+ PART = 'part'
+ REFT = 'reft'
+
+
+SWIFT_MAPPING = {
+ SwiftTuners.ADAPTER: (AdapterConfig, Adapter),
+ SwiftTuners.PROMPT: (PromptConfig, Prompt),
+ SwiftTuners.LORA: (LoRAConfig, LoRA),
+ SwiftTuners.SIDE: (SideConfig, Side),
+ SwiftTuners.RESTUNING: (ResTuningConfig, ResTuning),
+ SwiftTuners.LONGLORA: (LongLoRAConfig, LongLoRA),
+ SwiftTuners.NEFTUNE: (NEFTuneConfig, NEFTune),
+ SwiftTuners.SCETUNING: (SCETuningConfig, SCETuning),
+ SwiftTuners.LLAMAPRO: (LLaMAProConfig, LLaMAPro),
+ SwiftTuners.PART: (PartConfig, Part),
+ SwiftTuners.REFT: (ReftConfig, Reft),
+}
diff --git a/swift/tuners/neftune.py b/swift/tuners/neftune.py
new file mode 100644
index 0000000000000000000000000000000000000000..6476283e5d2348e24823fbef0cd34abb06675308
--- /dev/null
+++ b/swift/tuners/neftune.py
@@ -0,0 +1,73 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from dataclasses import dataclass, field
+
+import torch
+from torch import nn
+
+from swift.utils.logger import get_logger
+from .utils import SwiftAdapter, SwiftConfig, SwiftOutput
+
+logger = get_logger()
+
+
+@dataclass
+class NEFTuneConfig(SwiftConfig):
+ """
+ The configuration class for the NEFTune module.
+
+ NEFTune adds slightly noises to embedding outputs.
+ See https://arxiv.org/abs/2310.05914
+
+ Args:
+ noise_alpha(`float`): The noise alpha value used for the NEFTune, default 5.0
+ """
+ noise_alpha: float = field(default=5.0, metadata={'help': 'The noise alpha value used for the NEFTune'})
+
+ def __post_init__(self):
+ from .mapping import SwiftTuners
+ self.swift_type = SwiftTuners.NEFTUNE
+
+
+class NEFTune(SwiftAdapter):
+
+ @staticmethod
+ def prepare_model(model: nn.Module, config: NEFTuneConfig, adapter_name: str) -> SwiftOutput:
+ """Prepare a model with `NEFTuneConfig`"""
+ for sub_module in model.modules():
+ if isinstance(sub_module, torch.nn.Embedding):
+
+ def neftune_hook(module, args, output):
+ if module.training and getattr(module, 'nef_activated'):
+ dims = torch.tensor(output.size(-1) * output.size(-2))
+ mag_norm = config.noise_alpha / torch.sqrt(dims)
+ output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
+ return output
+
+ if hasattr(sub_module, 'nef_activated'):
+ raise ValueError('NEFTune does not support a second tuner.')
+
+ sub_module.register_forward_hook(neftune_hook)
+ sub_module.nef_activated = True
+
+ def state_dict_callback(state_dict, adapter_name, **kwargs):
+ return state_dict
+
+ def mark_trainable_callback(model):
+ return
+
+ return SwiftOutput(
+ config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
+
+ @staticmethod
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
+ for sub_module in module.modules():
+ if isinstance(sub_module, torch.nn.Embedding):
+ sub_module.nef_activated = activate
+
+ @staticmethod
+ def freeze_model():
+ return False
+
+ @staticmethod
+ def has_additional_modules():
+ return False
diff --git a/swift/tuners/part.py b/swift/tuners/part.py
new file mode 100644
index 0000000000000000000000000000000000000000..e398986f91e3726c7da42594f598cb57dc16fc90
--- /dev/null
+++ b/swift/tuners/part.py
@@ -0,0 +1,119 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import re
+from copy import deepcopy
+from dataclasses import dataclass
+from types import MethodType
+from typing import Dict, Optional
+
+import torch
+from torch import nn
+
+from swift.utils import get_logger
+from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
+
+logger = get_logger()
+
+
+@dataclass
+class PartConfig(SwiftConfig):
+ """
+ Freeze the model and train a part of it.
+
+ Args:
+ target_modules(`Optional[str]`): The target modules to be trained in regex format
+ """
+
+ target_modules: Optional[str] = None
+
+ def __post_init__(self):
+ from .mapping import SwiftTuners
+ self.swift_type = SwiftTuners.PART
+
+
+class Part(SwiftAdapter):
+
+ @staticmethod
+ def target_module_matched(module_key: str, config: PartConfig):
+ return re.fullmatch(config.target_modules, module_key)
+
+ @staticmethod
+ def prepare_model(model: nn.Module, config: PartConfig, adapter_name: str):
+ name_list = [name for name, _ in model.named_modules(remove_duplicate=False)]
+ for name in name_list:
+ module: nn.Module = model.get_submodule(name)
+ if Part.target_module_matched(name, config) and not getattr(module, 'plugin', False):
+ if hasattr(module, 'base_layer'):
+ module = module.base_layer
+
+ def _forward(self, *args, **kwargs):
+ child_list = [
+ sub_module for name, sub_module in self.named_modules(remove_duplicate=False)
+ if '_part_' in name
+ ]
+ sub_modules = [child for child in child_list if getattr(child, 'activated', False)]
+ assert len(sub_modules) <= 1
+ if len(sub_modules) == 1:
+ return sub_modules[0].forward(*args, **kwargs)
+ else:
+ return self.forward_origin(*args, **kwargs)
+
+ if not hasattr(module, 'forward_origin'):
+ module.forward_origin = module.forward
+ module.forward = MethodType(_forward, module)
+
+ new_module = deepcopy(module)
+ for attr in dir(new_module):
+ if '_part_' in attr:
+ delattr(new_module, attr)
+ new_module.part_name = adapter_name
+ ActivationMixin.mark_all_sub_modules_as_plugin(new_module)
+ setattr(module, f'_part_{adapter_name}', new_module)
+ new_module.requires_grad_(True)
+
+ def state_dict_callback(state_dict, adapter_name, **kwargs):
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if f'_part_{adapter_name}.' in key:
+ if kwargs.get('replace_key', True):
+ new_key = key.replace(f'_part_{adapter_name}.', '').replace('base_layer.', '')
+ else:
+ new_key = key
+ new_state_dict[new_key] = value
+
+ return new_state_dict
+
+ def mark_trainable_callback(model: nn.Module):
+ pass
+
+ def load_state_dict_callback(model: nn.Module, adapter_name: str, state_dict: Dict[str, torch.Tensor]):
+ new_state_dict = {}
+ for name, module in model.named_modules(remove_duplicate=False):
+ module: nn.Module
+ if Part.target_module_matched(name, config):
+ for param_name in state_dict:
+ if param_name.startswith(name):
+ end = param_name[len(name):]
+ if '_part_' not in param_name:
+ if hasattr(module, 'base_layer'):
+ new_state_dict[name + f'.base_layer._part_{adapter_name}'
+ + end] = state_dict[param_name]
+ else:
+ new_state_dict[name + f'._part_{adapter_name}' + end] = state_dict[param_name]
+ else:
+ new_state_dict[param_name] = state_dict[param_name]
+ return new_state_dict
+
+ return SwiftOutput(
+ config=config,
+ state_dict_callback=state_dict_callback,
+ mark_trainable_callback=mark_trainable_callback,
+ load_state_dict_callback=load_state_dict_callback)
+
+ @staticmethod
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
+ name_list = [name for name, _ in module.named_modules(remove_duplicate=False)]
+ for name in name_list:
+ sub_module: nn.Module = module.get_submodule(name)
+ if re.fullmatch(f'.*_part_{adapter_name}$', name):
+ sub_module.activated = activate
+ SwiftAdapter.save_memory(sub_module, adapter_name, name, activate, offload)
diff --git a/swift/tuners/peft.py b/swift/tuners/peft.py
new file mode 100644
index 0000000000000000000000000000000000000000..f561db4fc049d167f87c56bfae28b201dc967b6d
--- /dev/null
+++ b/swift/tuners/peft.py
@@ -0,0 +1,392 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Copyright 2023-present the HuggingFace Inc. team.
+import os.path
+from dataclasses import asdict, dataclass, field
+from functools import partial, reduce
+from types import MethodType
+from typing import Dict, Optional
+
+import json
+import peft
+import torch
+import torch.nn
+import transformers
+from modelscope import snapshot_download
+from peft import (AdaLoraConfig, BOFTConfig, BOFTModel, LoftQConfig, LoHaConfig, LoKrConfig, LoraModel, OFTConfig,
+ PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM,
+ PeftModelForSequenceClassification, PeftModelForTokenClassification, PrefixTuningConfig,
+ PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, VeraConfig, VeraModel, get_peft_config,
+ get_peft_model, get_peft_model_state_dict)
+from peft.config import PeftConfigMixin
+from peft.tuners import lora
+from peft.tuners.adalora import AdaLoraModel, RankAllocator
+from peft.tuners.lora import Embedding
+from transformers import Trainer
+
+from swift.utils import get_logger
+
+try:
+ from peft import FourierFTModel
+except ImportError:
+ FourierFTModel = None
+
+try:
+ from peft import BoneModel
+except ImportError:
+ BoneModel = None
+
+logger = get_logger()
+dispatchers = []
+
+
+@dataclass
+class LoraConfig(peft.LoraConfig):
+ lora_dtype: Optional[str] = field(
+ default=None, metadata={'help': 'The lora dtype, default None means following the original layer\'s dtype'})
+
+ lorap_lr_ratio: Optional[float] = field(default=None, metadata={'help': 'The lr ratio of lora_B in lora+'})
+
+ lorap_emb_lr: float = field(default=1e-6, metadata={'help': 'The lr for embedding in lora+'})
+
+ def to_peft_config(self) -> peft.LoraConfig:
+ _dict = asdict(self)
+ _dict.pop('lora_dtype')
+ _dict.pop('lorap_lr_ratio')
+ _dict.pop('lorap_emb_lr')
+ return peft.LoraConfig(**_dict)
+
+ def save_pretrained(self, save_directory: str, **kwargs) -> None:
+ self.to_peft_config().save_pretrained(save_directory, **kwargs)
+ additional_args = {
+ 'lora_dtype': self.lora_dtype,
+ 'lorap_lr_ratio': self.lorap_lr_ratio,
+ 'lorap_emb_lr': self.lorap_emb_lr,
+ }
+ with open(os.path.join(save_directory, 'additional_config.json'), 'w', encoding='utf-8') as f:
+ json.dump(additional_args, f)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs):
+ if hasattr(PeftConfigMixin, 'from_pretrained_origin'):
+ self = PeftConfigMixin.from_pretrained_origin(pretrained_model_name_or_path, subfolder, **kwargs)
+ else:
+ self = super(LoraConfig, cls).from_pretrained(pretrained_model_name_or_path, subfolder, **kwargs)
+
+ if type(self) == peft.LoraConfig:
+ self = LoraConfig(**self.to_dict())
+
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, 'additional_config.json')):
+ with open(
+ os.path.join(pretrained_model_name_or_path, 'additional_config.json'), 'r', encoding='utf-8') as f:
+ _json = json.load(f)
+ for key, value in _json.items():
+ setattr(self, key, value)
+
+ return self
+
+
+def _create_and_replace_hook(self, peft_config, adapter_name, target, *args, **kwargs):
+ all_supported_names = ('linear', )
+ all_supported_types = (torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D, lora.Linear)
+ target_modules = getattr(peft_config, 'target_modules', None)
+ if target is None:
+ return
+
+ if isinstance(target_modules, str) and not any(
+ [name in target.__class__.__name__.lower()
+ for name in all_supported_names]) and not any([isinstance(target, type_) for type_ in all_supported_types]):
+ return
+
+ if target.__class__.__name__ == 'NonDynamicallyQuantizableLinear':
+ return
+
+ return self._create_and_replace_origin(peft_config, adapter_name, target, *args, **kwargs)
+
+
+def _convert_dtype(target: torch.nn.Module, adapter_name: str, lora_dtype: str):
+ if lora_dtype is not None:
+ torch_dtype = eval(f'torch.{lora_dtype}')
+ if hasattr(target, 'lora_A') and adapter_name in target.lora_A:
+ target.lora_A[adapter_name].to(torch_dtype)
+ target.lora_B[adapter_name].to(torch_dtype)
+ if hasattr(target, 'lora_embedding_A') and adapter_name in target.lora_embedding_A:
+ target.lora_embedding_A[adapter_name].to(torch_dtype)
+ target.lora_embedding_B[adapter_name].to(torch_dtype)
+
+
+def create_optimizer_param_groups(self: PeftModel, **defaults):
+ if not isinstance(self.peft_config[self.active_adapter],
+ LoraConfig) or self.peft_config[self.active_adapter].lorap_lr_ratio is None:
+ return None
+
+ def get_module(name):
+ parent_idx = 2 if 'lora' in name else 1
+ module_names = name.split(sep='.')[:-parent_idx]
+ module = reduce(getattr, module_names, self.base_model)
+ return module
+
+ param_groups = {
+ 'groupA': {},
+ 'groupB': {},
+ 'groupB_no_decay': {},
+ 'embedding': {},
+ }
+
+ decay_parameters = Trainer.get_decay_parameter_names(None, self.base_model)
+ for name, param in self.base_model.named_parameters():
+ if not param.requires_grad:
+ continue
+
+ module = get_module(name)
+ if isinstance(module, Embedding):
+ param_groups['embedding'][name] = param
+ elif 'lora_B' in name or param.ndim == 1:
+ if name in decay_parameters:
+ param_groups['groupB'][name] = param
+ else:
+ param_groups['groupB_no_decay'][name] = param
+ else:
+ param_groups['groupA'][name] = param
+
+ lr = defaults['lr']
+ weight_decay = defaults.get('weight_decay', 0.0)
+
+ param_groups = [
+ {
+ 'params': list(param_groups['groupA'].values()),
+ 'weight_decay': weight_decay,
+ 'lr': lr,
+ },
+ {
+ 'params': list(param_groups['embedding'].values()),
+ 'weight_decay': weight_decay,
+ 'lr': self.peft_config[self.active_adapter].lorap_emb_lr,
+ },
+ {
+ 'params': list(param_groups['groupB'].values()),
+ 'weight_decay': weight_decay,
+ 'lr': lr * self.peft_config[self.active_adapter].lorap_lr_ratio,
+ },
+ {
+ 'params': list(param_groups['groupB_no_decay'].values()),
+ 'weight_decay': 0.0,
+ 'lr': lr * self.peft_config[self.active_adapter].lorap_lr_ratio,
+ },
+ ]
+ return param_groups
+
+
+def adalora_forward(self, *args, **kwargs):
+ from peft.utils.integrations import gather_params_ctx
+ outputs = self.model.forward(*args, **kwargs)
+
+ if (getattr(outputs, 'loss', None) is not None) and isinstance(outputs.loss, torch.Tensor):
+ # Calculate the orthogonal regularization
+ orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight
+
+ if orth_reg_weight <= 0:
+ raise ValueError('orth_reg_weight should be greater than 0. ')
+
+ regu_loss = 0
+ num_param = 0
+ for n, p in self.model.named_parameters():
+ if ('lora_A' in n or 'lora_B' in n) and self.trainable_adapter_name in n:
+ if p.shape == torch.Size([0]):
+ with gather_params_ctx(p, fwd_module=self):
+ para_cov = p @ p.T if 'lora_A' in n else p.T @ p
+ else:
+ para_cov = p @ p.T if 'lora_A' in n else p.T @ p
+ I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov)) # noqa: E741
+ I.requires_grad = False
+ num_param += 1
+ if isinstance(regu_loss, torch.Tensor):
+ regu_loss = regu_loss.to(para_cov.device)
+ regu_loss += torch.norm(para_cov - I, p='fro')
+ if num_param > 0:
+ regu_loss = regu_loss / num_param
+ else:
+ regu_loss = 0
+ if isinstance(regu_loss, torch.Tensor) and isinstance(outputs.loss, torch.Tensor):
+ regu_loss = regu_loss.to(outputs.loss.device)
+ outputs.loss += orth_reg_weight * regu_loss
+ return outputs
+
+
+def adalora_mask_to_budget(self, model, budget):
+ value_ipt = {}
+ vector_ipt = {}
+ triplet_ipt = {}
+ # Get the importance score for A, E, B
+ for n, p in model.named_parameters():
+ if f'lora_A.{self.adapter_name}' in n:
+ entry_ipt = self._element_score(n)
+ comb_ipt = torch.mean(entry_ipt, dim=1, keepdim=True)
+ name_m = n.replace('lora_A', '%s')
+ if name_m not in vector_ipt:
+ vector_ipt[name_m] = [comb_ipt]
+ else:
+ vector_ipt[name_m].append(comb_ipt)
+ if f'lora_B.{self.adapter_name}' in n:
+ entry_ipt = self._element_score(n)
+ comb_ipt = torch.mean(entry_ipt, dim=0, keepdim=False).view(-1, 1)
+ name_m = n.replace('lora_B', '%s')
+ if name_m not in vector_ipt:
+ vector_ipt[name_m] = [comb_ipt]
+ else:
+ vector_ipt[name_m].append(comb_ipt)
+ if f'lora_E.{self.adapter_name}' in n:
+ entry_ipt = self._element_score(n)
+ name_m = n.replace('lora_E', '%s')
+ value_ipt[name_m] = entry_ipt
+
+ all_score = []
+ # Calculate the score for each triplet
+ for name_m in vector_ipt:
+ ipt_E = value_ipt[name_m]
+ ipt_AB = torch.cat(vector_ipt[name_m], dim=1)
+ sum_ipt = self._combine_ipt(ipt_E, ipt_AB)
+ name_E = name_m % 'lora_E'
+ triplet_ipt[name_E] = sum_ipt.view(-1, 1)
+ sum_ipt = sum_ipt.view(-1)
+ if all_score:
+ sum_ipt = sum_ipt.to(all_score[0].device)
+ all_score.append(sum_ipt)
+
+ # Get the threshold by ranking ipt
+ mask_threshold = torch.kthvalue(
+ torch.cat(all_score),
+ k=self.init_bgt - budget,
+ )[0].item()
+
+ rank_pattern = {}
+ # Mask the unimportant triplets
+ with torch.no_grad():
+ for n, p in model.named_parameters():
+ if f'lora_E.{self.adapter_name}' in n:
+ p.masked_fill_(triplet_ipt[n] <= mask_threshold, 0.0)
+ rank_pattern[n] = (~(triplet_ipt[n] <= mask_threshold)).view(-1).tolist()
+ return rank_pattern
+
+
+def keep_device_forward(self, *args, **kwargs):
+ x = args[0]
+ if self.weight.device != x.device:
+ return self.forward_origin(x.to(self.weight.device), *args[1:], **kwargs)
+ else:
+ return self.forward_origin(*args, **kwargs)
+
+
+def hot_patch_peft_module():
+ from peft.tuners.lora import LoraLayer
+ if hasattr('LoraModel', '_create_and_replace_origin'):
+ return
+
+ # Fix Lora does not support NonDynamicallyQuantizableLinear
+ LoraModel._create_and_replace_origin = LoraModel._create_and_replace
+ LoraModel._create_and_replace = _create_and_replace_hook
+ AdaLoraModel._create_and_replace_origin = AdaLoraModel._create_and_replace
+ AdaLoraModel._create_and_replace = _create_and_replace_hook
+ VeraModel._create_and_replace_origin = VeraModel._create_and_replace
+ VeraModel._create_and_replace = _create_and_replace_hook
+ BOFTModel._create_and_replace_origin = BOFTModel._create_and_replace
+ BOFTModel._create_and_replace = _create_and_replace_hook
+ if FourierFTModel is not None:
+ FourierFTModel._create_and_replace_origin = FourierFTModel._create_and_replace
+ FourierFTModel._create_and_replace = _create_and_replace_hook
+ if BoneModel is not None:
+ BoneModel._create_and_replace_origin = BoneModel._create_and_replace
+ BoneModel._create_and_replace = _create_and_replace_hook
+
+ # Support type conversion
+ def __new_init__(self, model: torch.nn.Module, config: Dict[str, LoraConfig], adapter_name: str):
+
+ self.__init_origin__(model, config, adapter_name)
+ active_adapters = self.active_adapter
+ if isinstance(active_adapters, str):
+ active_adapters = [active_adapters]
+ for active_adapter in active_adapters:
+ active_config = config[active_adapter] if isinstance(config, dict) else config
+ if hasattr(active_config, 'lora_dtype'):
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ _convert_dtype(module, active_adapter, active_config.lora_dtype)
+ for lora in list(module.lora_A.values()) + list(module.lora_B.values()):
+ if not hasattr(lora, 'forward_origin'):
+ lora.forward_origin = lora.forward
+ lora.forward = MethodType(keep_device_forward, lora)
+
+ LoraModel.__init_origin__ = LoraModel.__init__
+ LoraModel.__init__ = __new_init__
+
+ # Support LoRA+
+ PeftModel.create_optimizer_param_groups = create_optimizer_param_groups
+
+ PeftConfigMixin.from_pretrained_origin = PeftConfigMixin.from_pretrained
+ PeftConfigMixin.from_pretrained = LoraConfig.from_pretrained
+
+ # Compatible with SwiftModel
+ def dummy_function(*args, **kwargs):
+ logger.warn(f'The function {kwargs["func"]} has no effects, consider using other functions.')
+
+ PeftModel.activate_adapter = PeftModel.set_adapter
+ PeftModel.deactivate_adapter = partial(dummy_function, func='deactivate_adapter')
+ PeftModel.set_active_adapters = partial(dummy_function, func='set_active_adapters')
+
+ # Fix adalora does not support device_map
+ AdaLoraModel.forward = adalora_forward
+ RankAllocator.mask_to_budget = adalora_mask_to_budget
+
+
+def get_wrapped_class(module_class):
+ """Get a custom wrapper class for peft classes to download the models from the ModelScope hub
+
+ Args:
+ module_class: The actual module class
+
+ Returns:
+ The wrapper
+ """
+
+ class PeftWrapper(module_class):
+
+ @classmethod
+ def from_pretrained(cls, model, model_id, *args, revision: Optional[str] = None, **kwargs):
+ if not os.path.exists(model_id):
+ model_id = snapshot_download(model_id, revision=revision)
+ return module_class.from_pretrained(model, model_id, *args, **kwargs)
+
+ PeftWrapper.__name__ = module_class.__name__
+ PeftWrapper.__qualname__ = module_class.__qualname__
+ return PeftWrapper
+
+
+def wrap_module(module):
+ if not hasattr(module, 'from_pretrained'):
+ return module
+
+ return get_wrapped_class(module)
+
+
+hot_patch_peft_module()
+PeftModel = wrap_module(PeftModel)
+PeftConfig = wrap_module(PeftConfig)
+PeftModelForSeq2SeqLM = wrap_module(PeftModelForSeq2SeqLM)
+PeftModelForSequenceClassification = wrap_module(PeftModelForSequenceClassification)
+PeftModelForTokenClassification = wrap_module(PeftModelForTokenClassification)
+PeftModelForCausalLM = wrap_module(PeftModelForCausalLM)
+PromptEncoderConfig = wrap_module(PromptEncoderConfig)
+PromptTuningConfig = wrap_module(PromptTuningConfig)
+PrefixTuningConfig = wrap_module(PrefixTuningConfig)
+PromptLearningConfig = wrap_module(PromptLearningConfig)
+LoraConfig = wrap_module(LoraConfig)
+AdaLoraConfig = wrap_module(AdaLoraConfig)
+LoHaConfig = wrap_module(LoHaConfig)
+LoKrConfig = wrap_module(LoKrConfig)
+LoftQConfig = wrap_module(LoftQConfig)
+OFTConfig = wrap_module(OFTConfig)
+BOFTConfig = wrap_module(BOFTConfig)
+VeraConfig = wrap_module(VeraConfig)
+OFTConfig = wrap_module(OFTConfig)
+get_peft_config = get_peft_config
+get_peft_model_state_dict = get_peft_model_state_dict
+get_peft_model = get_peft_model
diff --git a/swift/tuners/prompt.py b/swift/tuners/prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b3d1ab4e80eadbd9f6fb176e672d73a316da2cb
--- /dev/null
+++ b/swift/tuners/prompt.py
@@ -0,0 +1,205 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import re
+import types
+from dataclasses import dataclass, field
+from typing import List, Union
+
+import torch
+from torch import nn
+
+from swift.utils import get_logger
+from swift.utils.torch_utils import find_sub_module
+from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
+
+logger = get_logger()
+
+
+@dataclass
+class PromptConfig(SwiftConfig):
+ """
+ The configuration class for the prompt module.
+
+ Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens
+ and prepend to the original tokens in the first layer or multiple layers.
+ 'Visual Prompt Tuning' by Jia et al.(2022)
+ See https://arxiv.org/abs/2203.12119
+
+ Here we apply the VPT to other fields.
+
+ Args:
+ dim(`Union[int, List[int]]`): The dimension of the hidden states, use list if there are up-sample blocks
+ or down-sample blocks
+ target_modules(str): The layer module to be replaced, in regex format
+ embedding_pos(Union[str, int]): The position of the embedding tensor
+ attention_mask_pos(Union[str, int]): The position of the attention mask
+ attention_mask_value(Union[float, int, bool]): The value to pad to the attention mask
+ prompt_length(int): The length of the prompt tokens
+ attach_front(bool): When set to True, prompt is attached in front of the embedding
+ extract_embedding(bool): Whether the embedding is extracted at final stage to keep the same dims with inputs
+ """
+
+ dim: Union[int, List[int]] = field(default=None, metadata={'help': 'The dimension of the hidden states'})
+
+ target_modules: str = field(default=None, metadata={'help': 'The layer module to be replaced, in regex format'})
+
+ embedding_pos: Union[str, int] = field(default=None, metadata={'help': 'The position of the embedding tensor'})
+
+ attention_mask_pos: Union[str, int] = field(default=None, metadata={'help': 'The position of the attention mask'})
+
+ attention_mask_value: Union[float, int, bool] = field(
+ default=0., metadata={'help': 'The value to pad to the attention mask'})
+
+ prompt_length: int = field(default=16, metadata={'help': 'The length of the prompt tokens'})
+
+ attach_front: bool = field(
+ default=True, metadata={'help': 'When set to True, prompt is attached in front of the embedding'})
+
+ extract_embedding: bool = field(
+ default=False,
+ metadata={'help': 'Whether the embedding is extracted at final stage to keep the same dims with inputs'})
+
+ def __post_init__(self):
+ from .mapping import SwiftTuners
+ self.swift_type = SwiftTuners.PROMPT
+
+
+class Prompt(SwiftAdapter):
+
+ @staticmethod
+ def prepare_model(model: nn.Module, config: PromptConfig, adapter_name: str):
+ module_keys = [key for key, _ in model.named_modules()]
+ match_module_keys = []
+ for module_key in module_keys:
+ if isinstance(config.target_modules, str):
+ target_module_found = re.fullmatch(config.target_modules, module_key)
+ else:
+ target_module_found = any(module_key.endswith(target_key) for target_key in config.target_modules)
+ if target_module_found: # noqa
+ module = model.get_submodule(module_key)
+
+ def _forward(self, *args, **kwargs):
+ if isinstance(config.embedding_pos, int):
+ input_embedding = args[config.embedding_pos]
+ else:
+ input_embedding = kwargs[config.embedding_pos]
+
+ input_embedding = getattr(self, f'prompt_{adapter_name}').forward(input_embedding)
+ if isinstance(config.embedding_pos, int):
+ args = type(args)(
+ args[0:config.embedding_pos] + (input_embedding, ) + args[config.embedding_pos + 1:])
+ else:
+ kwargs[config.embedding_pos] = input_embedding
+
+ if config.attention_mask_pos:
+ attention_mask = None
+ if isinstance(config.attention_mask_pos, int):
+ attention_mask = args[config.attention_mask_pos]
+ elif isinstance(config.attention_mask_pos, str):
+ attention_mask = kwargs[config.attention_mask_pos]
+
+ if attention_mask is not None:
+ attention_mask = getattr(self,
+ f'prompt_{adapter_name}').patch_attention_mask(attention_mask)
+ if isinstance(config.attention_mask_pos, int):
+ args = type(args)(
+ args[0:config.attention_mask_pos] + (attention_mask, )
+ + args[config.attention_mask_pos + 1:])
+ else:
+ kwargs[config.attention_mask_pos] = attention_mask
+
+ forward_output = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
+ if config.extract_embedding:
+ forward_output = getattr(self, f'prompt_{adapter_name}').extract(forward_output)
+
+ return forward_output
+
+ setattr(module, f'forward_origin_{adapter_name}', module.forward)
+ module.forward = types.MethodType(_forward, module)
+ if isinstance(config.dim, list):
+ input_dim = config.dim[len(match_module_keys)]
+ else:
+ input_dim = config.dim
+ prompt_module = PromptModule(input_dim, int(module_key.rsplit('.')[-1]), adapter_name, module_key,
+ config.prompt_length, config.attention_mask_value, config.attach_front)
+ setattr(module, f'prompt_{adapter_name}', prompt_module)
+ logger.info(f'Prompt modules(module_key): {module_key}.prompt_{adapter_name}')
+ match_module_keys.append(module_key)
+
+ def state_dict_callback(state_dict, adapter_name, **kwargs):
+ return {key: value for key, value in state_dict.items() if f'prompt_{adapter_name}' in key}
+
+ def mark_trainable_callback(model):
+ return
+
+ return SwiftOutput(
+ config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
+
+ @staticmethod
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
+ modules = find_sub_module(module, f'prompt_{adapter_name}')
+ for _module in modules:
+ _module: ActivationMixin
+ _module: nn.Module
+ _module.set_activation(adapter_name, activate)
+ SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
+
+
+class PromptModule(nn.Module, ActivationMixin):
+ """The implementation of vision prompt tuning method.
+
+ Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens
+ and prepend to the original tokens in the first layer or multiple layers.
+ 'Visual Prompt Tuning' by Jia et al.(2022)
+ See https://arxiv.org/abs/2203.12119
+
+ Args:
+ dim: An integer indicating the embedding dimension.
+ layer_num: An integer indicating number of layers.
+ prompt_length: An integer indicating the length of vision prompt tuning.
+ """
+
+ def __init__(self, dim, layer_num, adapter_name, module_key, prompt_length=None, mask_values=0., attach_front=True):
+ super(PromptModule, self).__init__()
+ super(nn.Module, self).__init__(module_key)
+ self.dim = dim
+ self.layer_num = layer_num
+ self.adapter_name = adapter_name
+ self.prompt_length = prompt_length
+ self.mask_values = mask_values
+ self.attach_front = attach_front
+ self.prompt_token = nn.Parameter(torch.zeros(1, prompt_length, dim))
+ nn.init.xavier_uniform_(self.prompt_token)
+ self.mark_all_sub_modules_as_plugin()
+
+ def forward(self, x):
+ if not self.is_activated(self.adapter_name):
+ return x
+ prompt_token = self.prompt_token.expand(x.shape[0], -1, -1).to(x.device, x.dtype)
+
+ if self.layer_num == 0:
+ if self.attach_front:
+ x = torch.cat((prompt_token, x), dim=1)
+ else:
+ x = torch.cat((x, prompt_token), dim=1)
+ else:
+ if self.attach_front:
+ x = torch.cat((prompt_token, x[:, self.prompt_length:, :]), dim=1)
+ else:
+ x = torch.cat((x[:, :-self.prompt_length, :], prompt_token), dim=1)
+ return x
+
+ def patch_attention_mask(self, m):
+ if not self.is_activated(self.adapter_name):
+ return m
+ prefix_attention_mask = torch.full((*m.shape[:-1], self.prompt_length), self.mask_values).to(m.device)
+ if self.attach_front:
+ return torch.cat((prefix_attention_mask, m), dim=-1)
+ else:
+ return torch.cat((m, prefix_attention_mask), dim=-1)
+
+ def extract(self, x):
+ if self.attach_front:
+ return x[:, self.prompt_length:, :]
+ else:
+ return x[:, :-self.prompt_length, :]
diff --git a/swift/tuners/reft.py b/swift/tuners/reft.py
new file mode 100644
index 0000000000000000000000000000000000000000..8179b61ccda8b81241cd583ec039c70665e4077a
--- /dev/null
+++ b/swift/tuners/reft.py
@@ -0,0 +1,215 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from dataclasses import dataclass
+from types import MethodType
+from typing import List, Literal, Optional
+
+import json
+import torch
+from torch import nn
+
+from swift.utils import get_logger, patch_getattr
+from .utils import SwiftAdapter, SwiftConfig, SwiftOutput
+
+logger = get_logger()
+
+
+@dataclass
+class ReftConfig(SwiftConfig):
+ """
+ Train a model with Reft.
+ Paper: https://arxiv.org/pdf/2404.03592
+
+ Args:
+ model_type(`Optional[str]`): The model_type to find down_proj/layers.
+ layer_key(`Optional[str]`): Manually specify the layer key, for example `language_model.layers`.
+ layers (`Optional[List[int]]`): The layer number to inject.
+ r(`int`): The rank of Reft.
+ intervention_type (`Literal['NoreftIntervention', 'LoreftIntervention',
+ 'ConsreftIntervention', 'LobireftIntervention',
+ 'DireftIntervention', 'NodireftIntervention']`): The intervention type,
+ default LoreftIntervention
+ args (`Optional[str]`): Other reft_args in json-string format
+ """
+
+ model_type: Optional[str] = None
+ layer_key: Optional[str] = None
+ layers: Optional[List[int]] = None
+ r: int = 4
+ intervention_type: Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention',
+ 'LobireftIntervention', 'DireftIntervention',
+ 'NodireftIntervention'] = 'LoreftIntervention'
+ args: Optional[str] = None
+
+ def __post_init__(self):
+ from .mapping import SwiftTuners
+ self.swift_type = SwiftTuners.REFT
+ if self.args:
+ self.args = json.loads(self.args)
+ else:
+ self.args = {}
+
+
+class Reft(SwiftAdapter):
+
+ @staticmethod
+ def prepare_model(model: nn.Module, config: ReftConfig, adapter_name: str):
+ from swift.utils.import_utils import is_pyreft_available
+ if not is_pyreft_available():
+ raise ImportError('Please install pyreft before using ReFT: ' '`pip install pyreft`')
+
+ import pyreft
+ from pyreft import ReftModel
+ from pyreft.interventions import LowRankRotateLayer
+ from pyreft import (
+ NoreftIntervention,
+ LoreftIntervention,
+ ConsreftIntervention,
+ LobireftIntervention,
+ DireftIntervention,
+ NodireftIntervention,
+ )
+
+ intervention_mapping = {
+ 'NoreftIntervention': NoreftIntervention,
+ 'LoreftIntervention': LoreftIntervention,
+ 'ConsreftIntervention': ConsreftIntervention,
+ 'LobireftIntervention': LobireftIntervention,
+ 'DireftIntervention': DireftIntervention,
+ 'NodireftIntervention': NodireftIntervention,
+ }
+
+ patch_getattr(ReftModel, 'model')
+
+ def forward(self, x):
+ self.to(x.device)
+ return self.forward_origin(x)
+
+ def forward2(self, base, source=None, subspaces=None):
+ self.to(base.device)
+ return self.forward_origin(base, source, subspaces)
+
+ if not hasattr(LowRankRotateLayer, 'forward_origin'):
+ LowRankRotateLayer.forward_origin = LowRankRotateLayer.forward
+ LowRankRotateLayer.forward = forward
+ NoreftIntervention.forward_origin = NoreftIntervention.forward
+ NoreftIntervention.forward = forward2
+ LoreftIntervention.forward_origin = LoreftIntervention.forward
+ LoreftIntervention.forward = forward2
+ ConsreftIntervention.forward_origin = ConsreftIntervention.forward
+ ConsreftIntervention.forward = forward2
+ LobireftIntervention.forward_origin = LobireftIntervention.forward
+ LobireftIntervention.forward = forward2
+ DireftIntervention.forward_origin = DireftIntervention.forward
+ DireftIntervention.forward = forward2
+ NodireftIntervention.forward_origin = NodireftIntervention.forward
+ NodireftIntervention.forward = forward2
+
+ module_list_key = config.layer_key
+ if module_list_key is None:
+ model_key_mapping = Reft.get_model_key_mapping(config.model_type, config)
+ module_list_key = model_key_mapping.module_list
+ logger.info(f'Applying Reft to module: {module_list_key}')
+ module_list: nn.ModuleList = model.get_submodule(module_list_key)
+ representations = []
+ for idx, layer in enumerate(module_list):
+ if config.layers and idx not in config.layers:
+ continue
+ intervention_config = {
+ 'layer':
+ idx,
+ 'component':
+ module_list_key + f'[{idx}].output',
+ 'low_rank_dimension':
+ config.r,
+ 'intervention':
+ intervention_mapping[config.intervention_type](
+ embed_dim=model.config.hidden_size, low_rank_dimension=config.r, **config.args)
+ }
+ representations.append(intervention_config)
+
+ reft_config = pyreft.ReftConfig(representations=representations)
+ reft_model = pyreft.get_reft_model(model, reft_config, set_device=False)
+ reft_model.reft_config = reft_model.config
+ reft_model.config = reft_model.model.config
+
+ def _pre_forward_hook(module, args, kwargs):
+ if 'base' in kwargs:
+ return args, kwargs
+
+ if 'input_ids' not in kwargs:
+ raise ValueError('Input does not contain `input_ids`, maybe the model does not support ReFT.')
+ # run intervened forward pass
+ unit_locations = None
+ if 'intervention_locations' in kwargs:
+ if kwargs['intervention_locations'].dim() == 3:
+ unit_locations = {
+ 'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist())
+ }
+ else:
+ # this is dummy for lora only baseline
+ unit_locations = {'sources->base': (None, 0)}
+ kwargs = {
+ 'base': {
+ 'input_ids': kwargs['input_ids'],
+ 'attention_mask': kwargs['attention_mask']
+ },
+ 'unit_locations': unit_locations,
+ 'labels': kwargs['labels'],
+ 'subspaces': kwargs['subspaces'].permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None
+ }
+ return args, kwargs
+
+ def _post_forward_hook(module, args, kwargs, outputs):
+ return outputs[1]
+
+ def _generate(self, **kwargs):
+ # run intervened forward pass
+ unit_locations = None
+ if 'intervention_locations' in kwargs:
+ if kwargs['intervention_locations'].dim() == 3:
+ unit_locations = {
+ 'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist())
+ }
+ else:
+ # this is dummy for lora only baseline
+ unit_locations = {'sources->base': (None, 0)}
+
+ _kwargs = {
+ 'base': {
+ 'input_ids': kwargs.pop('input_ids'),
+ 'attention_mask': kwargs.pop('attention_mask')
+ },
+ 'unit_locations': unit_locations,
+ 'subspaces': kwargs.pop('subspaces').permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None
+ }
+ _kwargs = {**_kwargs, **kwargs}
+ return self.generate_origin(**_kwargs)[1]
+
+ reft_model.generate_origin = reft_model.generate
+ reft_model.generate = MethodType(_generate, reft_model)
+ reft_model.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True)
+ reft_model.register_forward_hook(_post_forward_hook, with_kwargs=True)
+
+ def save_callback(swift_model, model_dir, adapter_name):
+ reft_model.save_intervention(save_directory=model_dir, include_model=False)
+
+ def mark_trainable_callback(model):
+ return
+
+ def load_callback(swift_model, model_dir, adapter_name):
+ reft_model.load_intervention(model_dir, include_model=False)
+
+ return SwiftOutput(
+ model=reft_model,
+ config=config,
+ mark_trainable_callback=mark_trainable_callback,
+ save_callback=save_callback,
+ load_callback=load_callback)
+
+ @staticmethod
+ def has_additional_modules():
+ return True
+
+ @staticmethod
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
+ assert activate, 'ReFT does not support deactivate'
diff --git a/swift/tuners/restuning.py b/swift/tuners/restuning.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a9def230a9c2e228d306b7304c4e006680c40ad
--- /dev/null
+++ b/swift/tuners/restuning.py
@@ -0,0 +1,327 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import copy
+import re
+import types
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from swift.utils import get_logger
+from swift.utils.torch_utils import find_sub_module
+from .restuning_components import ResTuner, detach_tensors, probe_input_pre_hook, probe_output_hook
+from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
+
+logger = get_logger()
+
+
+@dataclass
+class ResTuningConfig(SwiftConfig):
+ """
+ The configuration class for the ResTuning module.
+
+ ResTuning is a flexible parameter-efficient and memory-efficient tuning paradigm framework.
+ 'Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone'
+ by Jiang et al.(2023)
+ See
+
+ Args:
+ dims(`Union[List[int], int]`): The dimensions of the hidden states
+ root_modules(`str`): The root module to be replaced, can a regex string
+ root_modules_hook(`str`): The hook type of root modules, can be "input" or "output"
+ stem_modules(`Union[List[str], str]`): The stem modules to be replaced,
+ can a regex string or name list of full match format
+ stem_modules_hook(`Union[List[str], str]`): The hook type of stem modules, can be "input" or "output"
+ target_modules(`str`): The target module to be replaced, can a regex string
+ target_modules_hook(`str`): The hook type of target modules, can be "input" or "output"
+ tuner_cfg(`Union[List[Dict], Dict, str]`): The configuration of the tuning module,
+ can a string or customized config
+ use_upsample(bool): Whether to use auxiliary upsample module
+ upsample_out_channels(List[int]): The channels if `use_upsample`
+ zero_init_last(bool): Use zero to initialize the last Linear in every sub tuner.
+
+ """
+
+ dims: Optional[Union[List[int], int]] = field(
+ default=None, metadata={'help': 'The dimensions of the hidden states'})
+
+ root_modules: str = field(
+ default=None,
+ metadata={
+ 'help':
+ 'The root module to be replaced, can a regex string (use the first matching module) or full match format'
+ })
+
+ root_modules_hook: str = field(
+ default='input', metadata={'help': 'The hook type of root modules, can be "input" or "output"'})
+
+ stem_modules: Optional[Union[List[str], str]] = field(
+ default=None,
+ metadata={'help': 'The stem modules to be replaced, can a regex string or name list of full match format'})
+
+ stem_modules_hook: str = field(
+ default='output', metadata={'help': 'The hook type of stem modules, can be "input" or "output"'})
+
+ target_modules: str = field(
+ default=None,
+ metadata={
+ 'help':
+ 'The target module to be replaced, can a regex string (use the first matching module) or full match format'
+ })
+
+ target_modules_hook: str = field(
+ default='input', metadata={'help': 'The hook type of target modules, can be "input" or "output"'})
+
+ target_hidden_pos: Union[int, str] = field(
+ default=None, metadata={'help': 'The position of the hidden state for target modules output'})
+
+ tuner_cfg: Optional[Union[List[Dict], Dict, str]] = field(
+ default=None, metadata={'help': 'The configuration of the tuning module, can a string or customized config'})
+
+ use_upsample: bool = field(default=False, metadata={'help': 'Whether to use auxiliary upsample module'})
+
+ upsample_out_channels: List[int] = field(
+ default=None, metadata={'help': 'The number of output channels when "use_upsample" is set to "True"'})
+
+ zero_init_last: bool = field(default=False, metadata={'help': 'Zero init last weight'})
+
+ use_bypass: bool = field(default=True, metadata={'help': 'Whether to use bypass'})
+
+ def __post_init__(self):
+ from .mapping import SwiftTuners
+ self.swift_type = SwiftTuners.RESTUNING
+ self.target_hidden_pos = 0 if self.target_hidden_pos is None else self.target_hidden_pos
+
+
+class ResTuning(SwiftAdapter):
+
+ @staticmethod
+ def prepare_model(model: nn.Module, config: ResTuningConfig, adapter_name: str) -> SwiftOutput:
+ """Prepare a model with `ResTuningConfig`"""
+
+ def _forward_seq(self, input, *args, **kwargs):
+ for idx, module in enumerate(self):
+ if idx >= len(self.origin_module_keys):
+ continue
+ input = module(input)
+ return input
+
+ def _forward_target(self, *args, **kwargs):
+ if self.target_modules_hook == 'input':
+ if isinstance(self.target_hidden_pos, int):
+ args = list(args)
+ _arg = args[self.target_hidden_pos]
+ else:
+ _arg = kwargs[self.target_hidden_pos]
+ args_main = _forward_restuning(self, _arg)
+ if isinstance(self.target_hidden_pos, int):
+ args[self.target_hidden_pos] = args_main
+ else:
+ kwargs[self.target_hidden_pos] = args_main
+ args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
+ else:
+ _args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
+ _arg = _args_main[self.target_hidden_pos] if isinstance(_args_main, (tuple, list, dict)) else _args_main
+ args_main = _forward_restuning(self, _arg)
+ if type(_args_main) != type(args_main):
+ _args_main[self.target_hidden_pos] = args_main
+ args_main = _args_main
+ return args_main
+
+ def _forward_restuning(self, origin_arg):
+ probe_results = []
+ root_module_ins = self.root_module_ins_list[0]
+ stem_module_ins_list = self.stem_module_ins_list
+ top_module = model.get_submodule('')
+ if root_module_ins:
+ if root_module_ins.root_modules_hook == 'input':
+ probe_results.append(root_module_ins.probe_input_data)
+ else:
+ probe_results.append(root_module_ins.probe_output_data)
+ for i, st_mod in enumerate(stem_module_ins_list):
+ if i == 0 and root_module_ins is None:
+ probe_results.append(st_mod.probe_input_data)
+ if st_mod.stem_modules_hook == 'input':
+ probe_results.append(st_mod.probe_input_data)
+ else:
+ probe_results.append(st_mod.probe_output_data)
+ args_main = getattr(top_module, f'restuning_{adapter_name}')(probe_results, origin_arg)
+ return args_main
+
+ # 1. Matching the root module
+ module_keys = [key for key, _ in model.named_modules()]
+ root_module_ins_list = []
+ if config.root_modules:
+ for module_key in module_keys:
+ if re.fullmatch(config.root_modules, module_key):
+ root_module = model.get_submodule(module_key)
+ logger.info(f'Matching root module [{module_key}] of type {type(root_module)}')
+ if isinstance(root_module, (nn.ModuleList, nn.ModuleDict)):
+ logger.warning(
+ f'Type of {type(root_module)} may not be supported because of its customized forward')
+ if config.root_modules_hook == 'input':
+ root_module.register_forward_pre_hook(probe_input_pre_hook)
+ else:
+ root_module.register_forward_hook(probe_output_hook)
+ root_module.root_modules_hook = config.root_modules_hook
+ root_module_ins_list.append(root_module)
+ break
+ if len(root_module_ins_list) == 0:
+ logger.error('Cannot match root modules')
+
+ # 2. Matching the stem module
+ stem_module_ins_list = []
+ stem_module_ins_index = []
+ for module_key in module_keys:
+ if (isinstance(config.stem_modules, str) and re.fullmatch(config.stem_modules, module_key)) or \
+ (isinstance(config.stem_modules, list) and module_key in config.stem_modules):
+ stem_module = model.get_submodule(module_key)
+ if isinstance(config.stem_modules, list):
+ stem_module_ins_index.append(config.stem_modules.index(module_key))
+ logger.info(f'Matching stem module [{module_key}] of type {type(stem_module)}')
+ if isinstance(stem_module, (nn.ModuleList, nn.ModuleDict)):
+ logger.warning(
+ f'Type of {type(stem_module)} may not be supported because of its customized forward')
+ if len(root_module_ins_list) == 0 and len(stem_module_ins_list) == 0:
+ stem_module.register_forward_pre_hook(probe_input_pre_hook)
+ if config.stem_modules_hook == 'input':
+ stem_module.register_forward_pre_hook(probe_input_pre_hook)
+ else:
+ stem_module.register_forward_hook(probe_output_hook)
+ stem_module.stem_modules_hook = config.stem_modules_hook
+ stem_module_ins_list.append(stem_module)
+ if isinstance(config.stem_modules, list):
+ stem_module_ins_list = [
+ stem_module_ins_list[stem_module_ins_index.index(i)] for i in range(len(stem_module_ins_index))
+ ]
+ depth = len(stem_module_ins_list)
+ if len(stem_module_ins_list) == 0:
+ raise Exception('Cannot match source modules')
+
+ # 3. Init restuning module
+ if len(stem_module_ins_list) != 0:
+ top_module = model.get_submodule('')
+ restuning_module = ResTuningBypassModule(config.dims, depth, adapter_name, config.use_upsample,
+ config.upsample_out_channels, config.zero_init_last,
+ config.tuner_cfg)
+ setattr(top_module, f'restuning_{adapter_name}', restuning_module)
+
+ # 4. Matching the target module
+ target_module_ins = None
+ for module_key in module_keys:
+ if re.fullmatch(config.target_modules, module_key):
+ tgt_module = model.get_submodule(module_key)
+ logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}')
+ if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)):
+ raise Exception(
+ f'Type of {type(tgt_module)} may not be supported because of its customized forward')
+
+ tgt_module.target_modules_hook = config.target_modules_hook
+ tgt_module.target_hidden_pos = config.target_hidden_pos
+ tgt_module.root_module_ins_list = root_module_ins_list
+ tgt_module.stem_module_ins_list = stem_module_ins_list
+ target_module_ins = tgt_module
+
+ if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'origin_module_keys'):
+ tgt_module.origin_module_keys = copy.deepcopy(list(tgt_module._modules.keys()))
+
+ setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(_forward_seq, tgt_module))
+ else:
+ setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward)
+ tgt_module.forward = types.MethodType(_forward_target, tgt_module)
+ if target_module_ins is None:
+ raise Exception('Cannot match target modules')
+
+ def state_dict_callback(state_dict, adapter_name, **kwargs):
+ return {key: value for key, value in state_dict.items() if f'restuning_{adapter_name}' in key}
+
+ def mark_trainable_callback(model):
+ return
+
+ return SwiftOutput(
+ config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
+
+ @staticmethod
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
+ modules = find_sub_module(module, f'restuning_{adapter_name}')
+ for _module in modules:
+ _module: ActivationMixin
+ _module: nn.Module
+ _module.set_activation(adapter_name, activate)
+ SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
+
+
+class ResTuningBypassModule(nn.Module, ActivationMixin):
+ """The implementation of ResTuningBypass method.
+ """
+
+ def __init__(
+ self,
+ dims,
+ depth,
+ adapter_name,
+ use_upsample=False,
+ upsample_out_channels=None,
+ zero_init_last=False,
+ tuner_cfg=None,
+ ):
+ super(ResTuningBypassModule, self).__init__()
+ super(nn.Module, self).__init__('')
+ self.adapter_name = adapter_name
+
+ self.bypass_blocks = nn.Sequential(*[
+ ResTunerBypassBlock(
+ dim=dims[i] if isinstance(dims, list) else dims,
+ layer_num=i,
+ depth=depth,
+ use_upsample=use_upsample,
+ upsample_out_channels=upsample_out_channels[i] if isinstance(upsample_out_channels, list
+ ) else upsample_out_channels,
+ zero_init_last=zero_init_last,
+ tuner_cfg=tuner_cfg[i] if isinstance(tuner_cfg, list) else tuner_cfg) for i in range(depth)
+ ])
+ self.mark_all_sub_modules_as_plugin()
+
+ def forward(self, x_list, origin_arg, **kwargs):
+ if not self.is_activated(self.adapter_name):
+ return origin_arg
+ x_bypass = detach_tensors(x_list.pop(0))
+ x_bypass = x_bypass[0] if isinstance(x_bypass, (list, tuple)) else x_bypass
+ x_list = detach_tensors(x_list)
+ x_list = [_x[0] if isinstance(_x, (list, tuple)) else _x for _x in x_list]
+ for i, (bp_blk, x_stem) in enumerate(zip(self.bypass_blocks, x_list)):
+ target_size = x_list[i + 1].shape[2:] if i < len(x_list) - 1 else None
+ x_bypass = bp_blk(x_stem, x_bypass, target_size, **kwargs)
+ return x_bypass
+
+
+class ResTunerBypassBlock(nn.Module):
+
+ def __init__(self, dim, layer_num=-1, depth=-1, use_upsample=False, zero_init_last=False, tuner_cfg=None, **kwargs):
+ super().__init__()
+ self.layer_num = layer_num
+ self.depth = depth
+
+ if isinstance(tuner_cfg, str):
+ lateral_cfg = tuner_cfg
+ vertical_cfg = tuner_cfg
+ aux_cfg = 'upsample' if use_upsample and layer_num != depth - 1 else None
+ elif isinstance(tuner_cfg, dict):
+ lateral_cfg = tuner_cfg['lateral_cfg'] if 'lateral_cfg' in tuner_cfg else None
+ vertical_cfg = tuner_cfg['vertical_cfg'] if 'vertical_cfg' in tuner_cfg else None
+ aux_cfg = tuner_cfg['aux_cfg'] if 'aux_cfg' in tuner_cfg else None
+
+ self.lateral_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'lateral', lateral_cfg, **kwargs)
+ self.vertical_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'vertical', vertical_cfg, **kwargs)
+ if aux_cfg and len(aux_cfg) != 0:
+ self.aux_tuner = ResTuner(dim, layer_num, depth, zero_init_last, 'aux', aux_cfg, **kwargs)
+
+ def forward(self, x_stem, x_bypass, target_size=None, **kwargs):
+ x_lateral = self.lateral_tuner(x_stem)
+ x_vertical = self.vertical_tuner(x_bypass)
+
+ x_bypass_out = x_lateral + x_vertical
+ if hasattr(self, 'aux_tuner'):
+ x_bypass_out = self.aux_tuner(x_bypass_out, target_size)
+ return x_bypass_out
diff --git a/swift/tuners/side.py b/swift/tuners/side.py
new file mode 100644
index 0000000000000000000000000000000000000000..a315bcd3a9527c38d96ac34a9da59cf04e01c91c
--- /dev/null
+++ b/swift/tuners/side.py
@@ -0,0 +1,245 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import copy
+import re
+import types
+from collections import OrderedDict
+from dataclasses import dataclass, field
+from functools import partial
+from itertools import repeat
+from typing import Union
+
+import torch
+from torch import nn
+
+from swift.utils.logger import get_logger
+from swift.utils.torch_utils import find_sub_module
+from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
+
+logger = get_logger()
+
+
+@dataclass
+class SideConfig(SwiftConfig):
+ """
+ The configuration class for the side module.
+
+ Side-Tuning only needs to train one side network and
+ weights the output of pre-trained model and side network.
+ 'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks'
+ by Zhang et al.(2019)
+ See https://arxiv.org/abs/1912.13503
+
+ Args:
+ target_modules: The feedforward module to be replaced, in regex format
+ """
+
+ dim: int = field(default=None, metadata={'help': 'The dimension of the hidden states'})
+
+ target_modules: str = field(
+ default=None, metadata={'help': 'The target module to be replaced, in full match format'})
+
+ side_module_name: str = field(default='fcn4', metadata={'help': 'The name of the additive side networks'})
+
+ source_hidden_pos: Union[str, int] = field(
+ default=0,
+ metadata={
+ 'help': 'The position of the hidden state input to the target module, can be int (args) or str (kwargs)'
+ })
+
+ target_hidden_pos: Union[str, int] = field(
+ default=0,
+ metadata={
+ 'help': 'The position of the hidden state output from the target module, can be int (args) or str (kwargs)'
+ })
+
+ def __post_init__(self):
+ from .mapping import SwiftTuners
+ self.swift_type = SwiftTuners.SIDE
+
+
+class Side(SwiftAdapter):
+
+ @staticmethod
+ def prepare_model(model: nn.Module, config: SideConfig, adapter_name: str) -> SwiftOutput:
+ """Prepare a model with `SideConfig`"""
+ module_keys = [key for key, _ in model.named_modules()]
+
+ for module_key in module_keys:
+ if re.fullmatch(config.target_modules, module_key): # noqa
+ tgt_module = model.get_submodule(module_key)
+ logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}')
+ if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)):
+ raise Exception(
+ f'Type of {type(tgt_module)} may not be supported because of its customized forward')
+
+ def _forward(self, *args, **kwargs):
+ args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
+
+ if isinstance(config.source_hidden_pos, int):
+ x = args[config.source_hidden_pos]
+ else:
+ x = kwargs[config.source_hidden_pos]
+
+ x_main = args_main[config.target_hidden_pos] \
+ if isinstance(args_main, (tuple, list, dict)) else args_main
+ out = getattr(self, f'side_{adapter_name}')(x, x_main)
+ if isinstance(args_main, (tuple, list, dict)):
+ args_main[config.target_hidden_pos] = out
+ else:
+ args_main = out
+ return args_main
+
+ if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'tgt_module_keys'):
+ tgt_module.tgt_module_keys = copy.deepcopy(list(tgt_module._modules.keys()))
+
+ def forward_seq(self, input, *args, **kwargs):
+ for idx, module in enumerate(self):
+ if idx >= len(tgt_module.tgt_module_keys):
+ continue
+ input = module(input)
+ return input
+
+ setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(forward_seq, tgt_module))
+ else:
+ setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward)
+ tgt_module.forward = types.MethodType(_forward, tgt_module)
+ side_module = SideModule(config.dim, adapter_name, module_key, config.side_module_name)
+ setattr(tgt_module, f'side_{adapter_name}', side_module)
+ logger.info(f'Side modules(module_key): {module_key}.side_{adapter_name}')
+
+ def state_dict_callback(state_dict, adapter_name, **kwargs):
+ return {key: value for key, value in state_dict.items() if f'side_{adapter_name}' in key}
+
+ def mark_trainable_callback(model):
+ return
+
+ return SwiftOutput(
+ config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
+
+ @staticmethod
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
+ modules = find_sub_module(module, f'side_{adapter_name}')
+ for _module in modules:
+ _module: ActivationMixin
+ _module: nn.Module
+ _module.set_activation(adapter_name, activate)
+ SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
+
+
+class SideModule(nn.Module, ActivationMixin):
+ """The implementation of vision side-tuning method.
+
+ Side-Tuning only needs to train one side network and
+ weights the output of pre-trained model and side network.
+ 'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks'
+ by Zhang et al.(2019)
+ See https://arxiv.org/abs/1912.13503
+
+ Args:
+ side_module_name: The name of the additive side networks.
+ """
+
+ def __init__(self, dim, adapter_name, module_key, side_module_name='fcn4'):
+ super(SideModule, self).__init__()
+ super(nn.Module, self).__init__(module_key)
+ self.adapter_name = adapter_name
+
+ side_module_name = side_module_name.lower()
+ if side_module_name == 'fcn4':
+ self.side_net = FCN4(out_dims=dim)
+ elif side_module_name == 'mlp':
+ self.side_net = Mlp(dim)
+ elif side_module_name == 'alexnet':
+ import torchvision
+ mm = torchvision.models.alexnet(pretrained=True)
+ self.side_net = nn.Sequential(
+ OrderedDict([('features', mm.features), ('avgpool', mm.avgpool), ('flatten', nn.Flatten()),
+ ('fc', nn.Linear(9216, dim, bias=False))]))
+ else:
+ raise ValueError(f'Unsupported side_module_name: {side_module_name}')
+ self.alpha = nn.Parameter(torch.tensor(0.0))
+ self.mark_all_sub_modules_as_plugin()
+
+ def forward(self, x, x_main):
+ if not self.is_activated(self.adapter_name):
+ return x_main
+ alpha_squashed = torch.sigmoid(self.alpha)
+ x_side = self.side_net(x)
+ x_out = alpha_squashed * x_main + (1 - alpha_squashed) * x_side
+ return x_out
+
+
+class FCN4(nn.Module):
+ """The implementation of simple FCN4 network for side network.
+ """
+
+ def __init__(self, out_dims=-1, **kwargs):
+ super(FCN4, self).__init__(**kwargs)
+
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dilation=1), nn.GroupNorm(2, 16),
+ nn.ReLU())
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 16),
+ nn.ReLU())
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 32),
+ nn.ReLU())
+ self.conv4 = nn.Sequential(
+ nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 64),
+ nn.ReLU())
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
+ if out_dims > 0:
+ self.fc = nn.Linear(64, out_dims)
+ else:
+ self.fc = None
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ x = self.conv4(x)
+ x = self.pool(x)
+ x = x.view(x.size(0), -1)
+ if self.fc is not None:
+ x = self.fc(x)
+ return x
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer.
+ """
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = tuple(repeat(bias, 2))
+ drop_probs = tuple(repeat(drop, 2))
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.norm(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x