Create README.md
Browse files
    	
        README.md
    ADDED
    
    | 
         @@ -0,0 +1,272 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            license: apache-2.0
         
     | 
| 3 | 
         
            +
            language:
         
     | 
| 4 | 
         
            +
            - en
         
     | 
| 5 | 
         
            +
            pipeline_tag: text-generation
         
     | 
| 6 | 
         
            +
            tags:
         
     | 
| 7 | 
         
            +
            - music
         
     | 
| 8 | 
         
            +
            - art
         
     | 
| 9 | 
         
            +
            ---
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            <div align="center">
         
     | 
| 12 | 
         
            +
                <img src="Yi_logo.svg" width="150px" style="display: inline-block;">
         
     | 
| 13 | 
         
            +
                <img src="m-a-p.png" width="150px" style="display: inline-block;">
         
     | 
| 14 | 
         
            +
            </div>
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            ## MuPT: Symbolic Music Generative Pre-trained Transformer
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            MuPT is a series of pre-trained models for symbolic music generation. It was trained on a large-scale dataset of symbolic music, including millions of monophonic and polyphonic pieces from different genres and styles. The models are trained with the LLama2 architecture, and can be further used for downstream music generation tasks such as melody generation, accompaniment generation, and multi-track music generation. 
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            - 09/01/2024: a series of pre-trained MuPT models are released, with parameters ranging from 110M to 1.3B.
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            ## Model architecture
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            The details of model architecture of MuPT-v1 are listed below:
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            | Name | Parameters | Training Data(Music Pieces) | Seq Length | Hidden Size | Layers | Heads |
         
     | 
| 27 | 
         
            +
            | :--- | :---: | :---: | :---: | :---: | :---: | :---: |
         
     | 
| 28 | 
         
            +
            | MuPT-v1-8192-110M | 110M | 7M x 8 epochs | 8192 | 768 | 12 | 12 |
         
     | 
| 29 | 
         
            +
            | MuPT-v1-8192-345M | 345M | 7M x 6 epochs | 8192 | 1024 | 24 | 16 |
         
     | 
| 30 | 
         
            +
            | MuPT-v1-8192-770M | 770M | 7M x 5 epochs | 8192 | 1280 | 36 | 20 |
         
     | 
| 31 | 
         
            +
            | MuPT-v1-8192-1.3B | 1.3B | 7M x 8 epochs | 8192 | 1536 | 48 | 24 |
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            ## Model Usage
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            #### Huggingface
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            ##### Inference
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            ```python
         
     | 
| 40 | 
         
            +
            from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            tokenizer = AutoTokenizer.from_pretrained("m-a-p/MuPT_v1_8192_770M",
         
     | 
| 43 | 
         
            +
                                                        trust_remote_code=True,
         
     | 
| 44 | 
         
            +
                                                        use_fast=False)
         
     | 
| 45 | 
         
            +
            model = AutoModelForCausalLM.from_pretrained("m-a-p/MuPT_v1_8192_770M").eval().half().cuda()
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            prefix = "X:1<n>L:1/8<n>Q:1/8=200<n>M:4/4<n>K:Gmin<n>|:\"Gm\" BGdB" # replace "\n" with "<n>" for all the MuPT-8192 models, but not for MuPT-4096 models
         
     | 
| 48 | 
         
            +
            inputs = tokenizer(prefix, return_tensors="pt").to(model.device)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            max_length = 256
         
     | 
| 51 | 
         
            +
            outputs = model.generate(
         
     | 
| 52 | 
         
            +
                inputs.input_ids,
         
     | 
| 53 | 
         
            +
                max_length=max_length
         
     | 
| 54 | 
         
            +
            )
         
     | 
| 55 | 
         
            +
            outputs = tokenizer.decode(outputs[0])
         
     | 
| 56 | 
         
            +
            print(outputs)
         
     | 
| 57 | 
         
            +
            ```
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            ##### Post-processing
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            Since we merged multiple tracks into one track during training, we need to separate the outputs into standard ABC notation sequences. The post-processing code is as follows:
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            ```python
         
     | 
| 64 | 
         
            +
            import re
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            SEPARATORS = ['|', '|]', '||', '[|', '|:', ':|', '::']
         
     | 
| 67 | 
         
            +
            SEP_DICT = {}
         
     | 
| 68 | 
         
            +
            for i, sep in enumerate(SEPARATORS, start=1):
         
     | 
| 69 | 
         
            +
                # E.g. ' | ': ' <1>'
         
     | 
| 70 | 
         
            +
                SEP_DICT[' '+sep+' '] = f' <{i}>'
         
     | 
| 71 | 
         
            +
            NEWSEP = '<|>'
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            def sep2tok(row):
         
     | 
| 74 | 
         
            +
                for sep, tok in SEP_DICT.items():
         
     | 
| 75 | 
         
            +
                    row = row.replace(sep, tok+'<=> ')
         
     | 
| 76 | 
         
            +
                return row
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            def tok2sep(bar):
         
     | 
| 79 | 
         
            +
                for sep, tok in SEP_DICT.items():
         
     | 
| 80 | 
         
            +
                    bar = bar.replace(tok, sep)
         
     | 
| 81 | 
         
            +
                return bar
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            def spacing(row):
         
     | 
| 85 | 
         
            +
                
         
     | 
| 86 | 
         
            +
                for sep in SEPARATORS:
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    def subfunc(match):
         
     | 
| 89 | 
         
            +
                        symbol = [':', '|', ']']
         
     | 
| 90 | 
         
            +
                        if match.group(1) is None:
         
     | 
| 91 | 
         
            +
                            return f' {sep}'
         
     | 
| 92 | 
         
            +
                        elif match.group(1) in symbol:
         
     | 
| 93 | 
         
            +
                            return f' {sep}{match.group(1)}'
         
     | 
| 94 | 
         
            +
                        else:
         
     | 
| 95 | 
         
            +
                            return ' '+sep+' '+match.group(1)
         
     | 
| 96 | 
         
            +
                            
         
     | 
| 97 | 
         
            +
                    pattern = r' ' + re.escape(sep) + r'(.{1})'
         
     | 
| 98 | 
         
            +
                    row = re.sub(pattern, subfunc, row)
         
     | 
| 99 | 
         
            +
                    row = row.replace('\n'+sep+'"', '\n '+sep+' "') # B \n|"A -> B \n | "A
         
     | 
| 100 | 
         
            +
                    row = row.replace(' '+sep+'\n', ' '+sep+' \n')  # B |\n -> B | \n
         
     | 
| 101 | 
         
            +
                return row
         
     | 
| 102 | 
         
            +
              
         
     | 
| 103 | 
         
            +
             def decode(piece):
         
     | 
| 104 | 
         
            +
                dec_piece = ''
         
     | 
| 105 | 
         
            +
                idx = piece.find(' '+NEWSEP+' ')
         
     | 
| 106 | 
         
            +
                heads = piece[:idx]
         
     | 
| 107 | 
         
            +
                scores = piece[idx:]
         
     | 
| 108 | 
         
            +
                scores_lst = re.split(' <\|>', scores)
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                all_bar_lst = []
         
     | 
| 111 | 
         
            +
                for bar in scores_lst:
         
     | 
| 112 | 
         
            +
                    if bar == '':
         
     | 
| 113 | 
         
            +
                        continue
         
     | 
| 114 | 
         
            +
                    bar = sep2tok(bar)
         
     | 
| 115 | 
         
            +
                    bar_lst = re.split('<=>', bar)
         
     | 
| 116 | 
         
            +
                    bar_lst = list(map(tok2sep, bar_lst))
         
     | 
| 117 | 
         
            +
                    if len(all_bar_lst) == 0:
         
     | 
| 118 | 
         
            +
                        all_bar_lst = [[] for _ in range(len(bar_lst))]
         
     | 
| 119 | 
         
            +
                    for i in range(len(bar_lst)):
         
     | 
| 120 | 
         
            +
                        all_bar_lst[i].append(bar_lst[i])
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                if len(all_bar_lst) > 1:
         
     | 
| 123 | 
         
            +
                    # There might be the bar number like %30 at the end 
         
     | 
| 124 | 
         
            +
                    # which need to be specially handled.
         
     | 
| 125 | 
         
            +
                    if len(all_bar_lst[0]) > len(all_bar_lst[1]):
         
     | 
| 126 | 
         
            +
                        last_bar_lst = all_bar_lst[0][-1].split()
         
     | 
| 127 | 
         
            +
                        all_bar_lst[0].pop()
         
     | 
| 128 | 
         
            +
                        for i in range(len(all_bar_lst)):
         
     | 
| 129 | 
         
            +
                            all_bar_lst[i].append(last_bar_lst[i])
         
     | 
| 130 | 
         
            +
                            # Add the remaining symbols to the last row.
         
     | 
| 131 | 
         
            +
                            if i == len(all_bar_lst) - 1:
         
     | 
| 132 | 
         
            +
                                for j in range(i+1, len(last_bar_lst)):
         
     | 
| 133 | 
         
            +
                                    all_bar_lst[i][-1] += ' ' + last_bar_lst[j]
         
     | 
| 134 | 
         
            +
                    # Ensure the lengths are consistent. 
         
     | 
| 135 | 
         
            +
                    length = len(all_bar_lst[0])
         
     | 
| 136 | 
         
            +
                    for lst in all_bar_lst[1:]:
         
     | 
| 137 | 
         
            +
                        # assert len(lst) == length       
         
     | 
| 138 | 
         
            +
                        pass
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                dec_piece += heads
         
     | 
| 141 | 
         
            +
                for i in range(len(all_bar_lst)):
         
     | 
| 142 | 
         
            +
                    if len(all_bar_lst) > 1:
         
     | 
| 143 | 
         
            +
                        dec_piece += f'V:{i+1}\n'
         
     | 
| 144 | 
         
            +
                    dec_piece += ''.join(all_bar_lst[i])
         
     | 
| 145 | 
         
            +
                    dec_piece += '\n'
         
     | 
| 146 | 
         
            +
                # Remove redundant spaces.
         
     | 
| 147 | 
         
            +
                dec_piece = re.sub(' {2,}', ' ', dec_piece)
         
     | 
| 148 | 
         
            +
                
         
     | 
| 149 | 
         
            +
                return dec_piece
         
     | 
| 150 | 
         
            +
            ```
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            Processed Output:
         
     | 
| 153 | 
         
            +
            ```shell
         
     | 
| 154 | 
         
            +
            X:1
         
     | 
| 155 | 
         
            +
            L:1/8
         
     | 
| 156 | 
         
            +
            Q:1/8=200
         
     | 
| 157 | 
         
            +
            M:4/4<n>K:Gmin
         
     | 
| 158 | 
         
            +
            |:\"Gm\" BGdB fdBG |\"F\" AFcF dFcF |\"Gm\" BGdG gFBF |\"F\" AFAG AF F2 |\"Gm\" BGBd fffd |\"F\" cdcB cdeg |
         
     | 
| 159 | 
         
            +
            \"Gm\" fdcB\"Eb\" AFcA |1 BGFG\"F\" AFGc :|2 BGFG\"F\" AF F2 ||
         
     | 
| 160 | 
         
            +
            ```
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
            Once you encode the post-processed ABC notation into audio, you will hear the following music.
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
            <audio controls src="https://cdn-uploads.huggingface.co/production/uploads/640701cb4dc5f2846c91d4eb/gnBULaFjcUyXYzzIwXLZq.mpga"></audio>
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
            #### Megatron-LM
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
            We now the provide usage based on [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main).
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
            Before starting, make sure you have setup the relevant environment and codebase. 
         
     | 
| 171 | 
         
            +
             
         
     | 
| 172 | 
         
            +
            ```shell
         
     | 
| 173 | 
         
            +
            # pull Megatron-LM codebase
         
     | 
| 174 | 
         
            +
            mkdir -p /path/to/workspace && cd /path/to/workspace
         
     | 
| 175 | 
         
            +
            git clone https://github.com/NVIDIA/Megatron-LM.git
         
     | 
| 176 | 
         
            +
            # download the pre-trained MuPT models checkpoint and vocab files from Huggingface page
         
     | 
| 177 | 
         
            +
            mkdir -p /models/MuPT_v0_8192_1.3B && cd /models/MuPT_v0_8192_1.3B
         
     | 
| 178 | 
         
            +
            wget -O model_optim_rng.pt https://huggingface.co/m-a-p/MuPT_v0_8192_1.3B/resolve/main/model_optim_rng.pt?download=true
         
     | 
| 179 | 
         
            +
            wget -O newline.vocab https://huggingface.co/m-a-p/MuPT_v0_8192_1.3B/resolve/main/newline.vocab?download=true
         
     | 
| 180 | 
         
            +
            wget -O newline.txt https://huggingface.co/m-a-p/MuPT_v0_8192_1.3B/resolve/main/newline.txt?download=true
         
     | 
| 181 | 
         
            +
            ```
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
            We recommend using the latest version of [NGC's PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) for MuPT inference. See more details in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
            ```shell
         
     | 
| 186 | 
         
            +
            # pull the latest NGC's PyTorch container, mount the workspace directory and enter the container
         
     | 
| 187 | 
         
            +
            docker run --gpus all -it --name megatron --shm-size=16g -v $PWD:/workspace -p 5000:5000 nvcr.io/nvidia/pytorch:23.11-py3 /bin/bash
         
     | 
| 188 | 
         
            +
            ```
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
            Once you enter the container, you can start a REST server for inference. 
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
            <details>
         
     | 
| 193 | 
         
            +
                <summary>Click to expand the example script</summary>
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                #!/bin/bash
         
     | 
| 196 | 
         
            +
                # This example will start serving the 1.3B model.
         
     | 
| 197 | 
         
            +
                export CUDA_DEVICE_MAX_CONNECTIONS=1
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                DISTRIBUTED_ARGS="--nproc_per_node 1 \
         
     | 
| 200 | 
         
            +
                                --nnodes 1 \
         
     | 
| 201 | 
         
            +
                                --node_rank 0 \
         
     | 
| 202 | 
         
            +
                                --master_addr localhost \
         
     | 
| 203 | 
         
            +
                                --master_port 6000"
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                CHECKPOINT=/path/to/model/checkpoint/folder
         
     | 
| 206 | 
         
            +
                VOCAB_FILE=/path/to/vocab/file
         
     | 
| 207 | 
         
            +
                MERGE_FILE=/path/to/merge/file
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                MODEL_SIZE="1.3B"
         
     | 
| 210 | 
         
            +
                if   [[ ${MODEL_SIZE} == "110M" ]];   then HIDDEN_SIZE=768;  NUM_HEAD=12; NUM_QUERY_GROUP=12; NUM_LAYERS=12; FFN_HIDDEN_SIZE=3072; NORM_EPS=1e-5;
         
     | 
| 211 | 
         
            +
                elif [[ ${MODEL_SIZE} == "345M" ]];   then HIDDEN_SIZE=1024;  NUM_HEAD=16; NUM_QUERY_GROUP=16; NUM_LAYERS=24; FFN_HIDDEN_SIZE=4096; NORM_EPS=1e-5;
         
     | 
| 212 | 
         
            +
                elif [[ ${MODEL_SIZE} == "770M" ]];   then HIDDEN_SIZE=1280;  NUM_HEAD=20; NUM_QUERY_GROUP=20; NUM_LAYERS=36; FFN_HIDDEN_SIZE=5120; NORM_EPS=1e-5;
         
     | 
| 213 | 
         
            +
                elif [[ ${MODEL_SIZE} == "1.3B" ]];   then HIDDEN_SIZE=1536;  NUM_HEAD=24; NUM_QUERY_GROUP=24; NUM_LAYERS=48; FFN_HIDDEN_SIZE=6144; NORM_EPS=1e-5;
         
     | 
| 214 | 
         
            +
                else echo "invalid MODEL_SIZE: ${MODEL_SIZE}"; exit 1
         
     | 
| 215 | 
         
            +
                fi
         
     | 
| 216 | 
         
            +
                MAX_SEQ_LEN=8192
         
     | 
| 217 | 
         
            +
                MAX_POSITION_EMBEDDINGS=8192
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                pip install flask-restful
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py   \
         
     | 
| 222 | 
         
            +
                    --tensor-model-parallel-size 1  \
         
     | 
| 223 | 
         
            +
                    --pipeline-model-parallel-size 1  \
         
     | 
| 224 | 
         
            +
                    --num-layers ${NUM_LAYERS}  \
         
     | 
| 225 | 
         
            +
                    --hidden-size ${HIDDEN_SIZE}  \
         
     | 
| 226 | 
         
            +
                    --ffn-hidden-size ${FFN_HIDDEN_SIZE} \
         
     | 
| 227 | 
         
            +
                    --load ${CHECKPOINT}  \
         
     | 
| 228 | 
         
            +
                    --group-query-attention \
         
     | 
| 229 | 
         
            +
                    --num-query-groups ${NUM_QUERY_GROUP} \
         
     | 
| 230 | 
         
            +
                    --position-embedding-type rope \
         
     | 
| 231 | 
         
            +
                    --num-attention-heads ${NUM_HEAD}  \
         
     | 
| 232 | 
         
            +
                    --max-position-embeddings ${MAX_POSITION_EMBEDDINGS}  \
         
     | 
| 233 | 
         
            +
                    --tokenizer-type GPT2BPETokenizer  \
         
     | 
| 234 | 
         
            +
                    --normalization RMSNorm \
         
     | 
| 235 | 
         
            +
                    --norm-epsilon ${NORM_EPS} \
         
     | 
| 236 | 
         
            +
                    --make-vocab-size-divisible-by 1 \
         
     | 
| 237 | 
         
            +
                    --swiglu \
         
     | 
| 238 | 
         
            +
                    --use-flash-attn \
         
     | 
| 239 | 
         
            +
                    --bf16  \
         
     | 
| 240 | 
         
            +
                    --micro-batch-size 1  \
         
     | 
| 241 | 
         
            +
                    --disable-bias-linear \
         
     | 
| 242 | 
         
            +
                    --no-bias-gelu-fusion \
         
     | 
| 243 | 
         
            +
                    --untie-embeddings-and-output-weights \
         
     | 
| 244 | 
         
            +
                    --seq-length ${MAX_SEQ_LEN}  \
         
     | 
| 245 | 
         
            +
                    --vocab-file $VOCAB_FILE  \
         
     | 
| 246 | 
         
            +
                    --merge-file $MERGE_FILE  \
         
     | 
| 247 | 
         
            +
                    --attention-dropout 0.0 \
         
     | 
| 248 | 
         
            +
                    --hidden-dropout 0.0 \
         
     | 
| 249 | 
         
            +
                    --weight-decay 1e-1 \
         
     | 
| 250 | 
         
            +
                    --clip-grad 1.0 \
         
     | 
| 251 | 
         
            +
                    --adam-beta1 0.9 \
         
     | 
| 252 | 
         
            +
                    --adam-beta2 0.95 \
         
     | 
| 253 | 
         
            +
                    --adam-eps 1e-8 \
         
     | 
| 254 | 
         
            +
                    --seed 42
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
            </details>
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
            Use CURL to query the server directly, note that the newline token `\n` is represented by `<n>` in the vocabulary, so we need to replace the newline token with `<n>` in both the prompt and the generated tokens. 
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
            ```shell
         
     | 
| 262 | 
         
            +
            curl 'http://localhost:6000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8'  -d '{"prompts":["X:1<n>L:1/8<n>Q:1/8=200<n>M:4/4<n>K:Gmin<n>|:\"Gm\" BGdB"], "tokens_to_generate":4096}'
         
     | 
| 263 | 
         
            +
            ```
         
     | 
| 264 | 
         
            +
            Processed Output:
         
     | 
| 265 | 
         
            +
            ```shell
         
     | 
| 266 | 
         
            +
            X:1
         
     | 
| 267 | 
         
            +
            L:1/8
         
     | 
| 268 | 
         
            +
            Q:1/8=200
         
     | 
| 269 | 
         
            +
            M:4/4<n>K:Gmin
         
     | 
| 270 | 
         
            +
            |:\"Gm\" BGdB fdBG |\"F\" AFcF dFcF |\"Gm\" BGdG gFBF |\"F\" AFAG AF F2 |\"Gm\" BGBd fffd |\"F\" cdcB cdeg |
         
     | 
| 271 | 
         
            +
            \"Gm\" fdcB\"Eb\" AFcA |1 BGFG\"F\" AFGc :|2 BGFG\"F\" AF F2 ||
         
     | 
| 272 | 
         
            +
            ```
         
     |