Spaces:
Running
Running
| import re | |
| import torch | |
| import random | |
| from tqdm import tqdm | |
| from unidecode import unidecode | |
| from torch.utils.data import Dataset | |
| from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel | |
| from samplings import top_p_sampling, top_k_sampling, temperature_sampling | |
| from utils import PATCH_SIZE, PATCH_LENGTH, PATCH_SAMPLING_BATCH_SIZE | |
| class Patchilizer: | |
| """ | |
| A class for converting music bars to patches and vice versa. | |
| """ | |
| def __init__(self): | |
| self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"] | |
| self.regexPattern = f"({'|'.join(map(re.escape, self.delimiters))})" | |
| self.pad_token_id = 0 | |
| self.bos_token_id = 1 | |
| self.eos_token_id = 2 | |
| def split_bars(self, body): | |
| """ | |
| Split a body of music into individual bars. | |
| """ | |
| bars = re.split(self.regexPattern, "".join(body)) | |
| bars = list(filter(None, bars)) | |
| # remove empty strings | |
| if bars[0] in self.delimiters: | |
| bars[1] = bars[0] + bars[1] | |
| bars = bars[1:] | |
| bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)] | |
| return bars | |
| def bar2patch(self, bar, patch_size=PATCH_SIZE): | |
| """ | |
| Convert a bar into a patch of specified length. | |
| """ | |
| patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id] | |
| patch = patch[:patch_size] | |
| patch += [self.pad_token_id] * (patch_size - len(patch)) | |
| return patch | |
| def patch2bar(self, patch): | |
| """ | |
| Convert a patch into a bar. | |
| """ | |
| return "".join( | |
| chr(idx) if idx > self.eos_token_id else "" | |
| for idx in patch | |
| if idx != self.eos_token_id | |
| ) | |
| def encode( | |
| self, | |
| abc_code, | |
| patch_length=PATCH_LENGTH, | |
| patch_size=PATCH_SIZE, | |
| add_special_patches=False, | |
| ): | |
| """ | |
| Encode music into patches of specified length. | |
| """ | |
| lines = unidecode(abc_code).split("\n") | |
| lines = list(filter(None, lines)) # remove empty lines | |
| body = "" | |
| patches = [] | |
| for line in lines: | |
| if len(line) > 1 and ( | |
| (line[0].isalpha() and line[1] == ":") or line.startswith("%%score") | |
| ): | |
| if body: | |
| bars = self.split_bars(body) | |
| patches.extend( | |
| self.bar2patch( | |
| bar + "\n" if idx == len(bars) - 1 else bar, patch_size | |
| ) | |
| for idx, bar in enumerate(bars) | |
| ) | |
| body = "" | |
| patches.append(self.bar2patch(line + "\n", patch_size)) | |
| else: | |
| body += line + "\n" | |
| if body: | |
| patches.extend( | |
| self.bar2patch(bar, patch_size) for bar in self.split_bars(body) | |
| ) | |
| if add_special_patches: | |
| bos_patch = [self.bos_token_id] * (patch_size - 1) + [self.eos_token_id] | |
| eos_patch = [self.bos_token_id] + [self.eos_token_id] * (patch_size - 1) | |
| patches = [bos_patch] + patches + [eos_patch] | |
| return patches[:patch_length] | |
| def decode(self, patches): | |
| """ | |
| Decode patches into music. | |
| """ | |
| return "".join(self.patch2bar(patch) for patch in patches) | |
| class PatchLevelDecoder(PreTrainedModel): | |
| """ | |
| An Patch-level Decoder model for generating patch features in an auto-regressive manner. | |
| It inherits PreTrainedModel from transformers. | |
| """ | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd) | |
| torch.nn.init.normal_(self.patch_embedding.weight, std=0.02) | |
| self.base = GPT2Model(config) | |
| def forward(self, patches: torch.Tensor) -> torch.Tensor: | |
| """ | |
| The forward pass of the patch-level decoder model. | |
| :param patches: the patches to be encoded | |
| :return: the encoded patches | |
| """ | |
| patches = torch.nn.functional.one_hot(patches, num_classes=128).float() | |
| patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128) | |
| patches = self.patch_embedding(patches.to(self.device)) | |
| return self.base(inputs_embeds=patches) | |
| class CharLevelDecoder(PreTrainedModel): | |
| """ | |
| A Char-level Decoder model for generating the characters within each bar patch sequentially. | |
| It inherits PreTrainedModel from transformers. | |
| """ | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.pad_token_id = 0 | |
| self.bos_token_id = 1 | |
| self.eos_token_id = 2 | |
| self.base = GPT2LMHeadModel(config) | |
| def forward( | |
| self, | |
| encoded_patches: torch.Tensor, | |
| target_patches: torch.Tensor, | |
| patch_sampling_batch_size: int, | |
| ): | |
| """ | |
| The forward pass of the char-level decoder model. | |
| :param encoded_patches: the encoded patches | |
| :param target_patches: the target patches | |
| :return: the decoded patches | |
| """ | |
| # preparing the labels for model training | |
| target_masks = target_patches == self.pad_token_id | |
| labels = target_patches.clone().masked_fill_(target_masks, -100) | |
| # masking the labels for model training | |
| target_masks = torch.ones_like(labels) | |
| target_masks = target_masks.masked_fill_(labels == -100, 0) | |
| # select patches | |
| if ( | |
| patch_sampling_batch_size != 0 | |
| and patch_sampling_batch_size < target_patches.shape[0] | |
| ): | |
| indices = list(range(len(target_patches))) | |
| random.shuffle(indices) | |
| selected_indices = sorted(indices[:patch_sampling_batch_size]) | |
| target_patches = target_patches[selected_indices, :] | |
| target_masks = target_masks[selected_indices, :] | |
| encoded_patches = encoded_patches[selected_indices, :] | |
| labels = labels[selected_indices, :] | |
| # get input embeddings | |
| inputs_embeds = torch.nn.functional.embedding( | |
| target_patches, self.base.transformer.wte.weight | |
| ) | |
| # concatenate the encoded patches with the input embeddings | |
| inputs_embeds = torch.cat( | |
| (encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]), dim=1 | |
| ) | |
| return self.base( | |
| inputs_embeds=inputs_embeds, attention_mask=target_masks, labels=labels | |
| ) | |
| def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor): | |
| """ | |
| The generate function for generating a patch based on the encoded patch and already generated tokens. | |
| :param encoded_patch: the encoded patch | |
| :param tokens: already generated tokens in the patch | |
| :return: the probability distribution of next token | |
| """ | |
| encoded_patch = encoded_patch.reshape(1, 1, -1) | |
| tokens = tokens.reshape(1, -1) | |
| # Get input embeddings | |
| tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight) | |
| # Concatenate the encoded patch with the input embeddings | |
| tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1) | |
| # Get output from model | |
| outputs = self.base(inputs_embeds=tokens) | |
| # Get probabilities of next token | |
| return torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1) | |
| class TunesFormer(PreTrainedModel): | |
| """ | |
| TunesFormer is a hierarchical music generation model based on bar patching. | |
| It includes a patch-level decoder and a character-level decoder. | |
| It inherits PreTrainedModel from transformers. | |
| """ | |
| def __init__(self, encoder_config, decoder_config, share_weights=False): | |
| super().__init__(encoder_config) | |
| self.pad_token_id = 0 | |
| self.bos_token_id = 1 | |
| self.eos_token_id = 2 | |
| if share_weights: | |
| max_layers = max( | |
| encoder_config.num_hidden_layers, decoder_config.num_hidden_layers | |
| ) | |
| max_context_size = max(encoder_config.max_length, decoder_config.max_length) | |
| max_position_embeddings = max( | |
| encoder_config.max_position_embeddings, | |
| decoder_config.max_position_embeddings, | |
| ) | |
| encoder_config.num_hidden_layers = max_layers | |
| encoder_config.max_length = max_context_size | |
| encoder_config.max_position_embeddings = max_position_embeddings | |
| decoder_config.num_hidden_layers = max_layers | |
| decoder_config.max_length = max_context_size | |
| decoder_config.max_position_embeddings = max_position_embeddings | |
| self.patch_level_decoder = PatchLevelDecoder(encoder_config) | |
| self.char_level_decoder = CharLevelDecoder(decoder_config) | |
| if share_weights: | |
| self.patch_level_decoder.base = self.char_level_decoder.base.transformer | |
| def forward( | |
| self, | |
| patches: torch.Tensor, | |
| patch_sampling_batch_size: int = PATCH_SAMPLING_BATCH_SIZE, | |
| ): | |
| """ | |
| The forward pass of the TunesFormer model. | |
| :param patches: the patches to be both encoded and decoded | |
| :return: the decoded patches | |
| """ | |
| patches = patches.reshape(len(patches), -1, PATCH_SIZE) | |
| encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] | |
| return self.char_level_decoder( | |
| encoded_patches.squeeze(0)[:-1, :], | |
| patches.squeeze(0)[1:, :], | |
| patch_sampling_batch_size, | |
| ) | |
| def generate( | |
| self, | |
| patches: torch.Tensor, | |
| tokens: torch.Tensor, | |
| top_p: float = 1, | |
| top_k: int = 0, | |
| temperature: float = 1, | |
| seed: int = None, | |
| ): | |
| """ | |
| The generate function for generating patches based on patches. | |
| :param patches: the patches to be encoded | |
| :return: the generated patches | |
| """ | |
| patches = patches.reshape(len(patches), -1, PATCH_SIZE) | |
| encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] | |
| if tokens == None: | |
| tokens = torch.tensor([self.bos_token_id], device=self.device) | |
| generated_patch = [] | |
| random.seed(seed) | |
| while True: | |
| if seed != None: | |
| n_seed = random.randint(0, 1000000) | |
| random.seed(n_seed) | |
| else: | |
| n_seed = None | |
| prob = ( | |
| self.char_level_decoder.generate(encoded_patches[0][-1], tokens) | |
| .cpu() | |
| .detach() | |
| .numpy() | |
| ) | |
| prob = top_p_sampling(prob, top_p=top_p, return_probs=True) | |
| prob = top_k_sampling(prob, top_k=top_k, return_probs=True) | |
| token = temperature_sampling(prob, temperature=temperature, seed=n_seed) | |
| generated_patch.append(token) | |
| if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1: | |
| break | |
| else: | |
| tokens = torch.cat( | |
| (tokens, torch.tensor([token], device=self.device)), dim=0 | |
| ) | |
| return generated_patch, n_seed | |
| class PatchilizedData(Dataset): | |
| def __init__(self, items, patchilizer): | |
| self.texts = [] | |
| for item in tqdm(items): | |
| text = item["control code"] + "\n".join( | |
| item["abc notation"].split("\n")[1:] | |
| ) | |
| input_patch = patchilizer.encode(text, add_special_patches=True) | |
| input_patch = torch.tensor(input_patch) | |
| if torch.sum(input_patch) != 0: | |
| self.texts.append(input_patch) | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| return self.texts[idx] | |