File size: 785 Bytes
a1652f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Copyright (c) Alibaba, Inc. and its affiliates.
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
from megatron.training import get_args, global_vars, initialize, training

from swift.utils import get_logger

logger = get_logger()


def patch_megatron_tokenizer(tokenizer):

    def build_tokenizer(args):
        args.extra_vocab_size = args.padded_vocab_size - tokenizer.vocab_size
        return tokenizer

    global_vars.build_tokenizer = build_tokenizer


def patch_torch_dist_shard(thread_count):
    __init__ = TorchDistSaveShardedStrategy.__init__

    def __new_init__(*args, **kwargs):
        kwargs['thread_count'] = thread_count
        return __init__(*args, **kwargs)

    TorchDistSaveShardedStrategy.__init__ = __new_init__