# Copyright (c) Alibaba, Inc. and its affiliates. from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy from megatron.training import get_args, global_vars, initialize, training from swift.utils import get_logger logger = get_logger() def patch_megatron_tokenizer(tokenizer): def build_tokenizer(args): args.extra_vocab_size = args.padded_vocab_size - tokenizer.vocab_size return tokenizer global_vars.build_tokenizer = build_tokenizer def patch_torch_dist_shard(thread_count): __init__ = TorchDistSaveShardedStrategy.__init__ def __new_init__(*args, **kwargs): kwargs['thread_count'] = thread_count return __init__(*args, **kwargs) TorchDistSaveShardedStrategy.__init__ = __new_init__