| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import math |
| | import os |
| | from x_transformers import Encoder, Decoder |
| | from transformers import AutoTokenizer |
| |
|
| | |
| | try: |
| | if os.path.exists("tokenizer_config.json"): |
| | tokenizer = AutoTokenizer.from_pretrained(".") |
| | else: |
| | tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en") |
| | except Exception as e: |
| | print(f"Warning: Tokenizer load failed: {e}") |
| | |
| |
|
| |
|
| | class RoPETransformer(nn.Module): |
| | def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout): |
| | super().__init__() |
| | self.d_model = d_model |
| | self.embedding = nn.Embedding(vocab_size, d_model) |
| |
|
| | |
| | self.dropout_layer = nn.Dropout(dropout) |
| |
|
| | |
| | self.encoder = Encoder( |
| | dim = d_model, |
| | depth = num_encoder_layers, |
| | heads = num_heads, |
| | attn_dim_head = d_model // num_heads, |
| | ff_mult = dff / d_model, |
| | rotary_pos_emb = True, |
| | attn_flash = True, |
| | attn_dropout = dropout, |
| | ff_dropout = dropout, |
| | use_rmsnorm = True |
| | ) |
| |
|
| | |
| | self.decoder = Decoder( |
| | dim = d_model, |
| | depth = num_decoder_layers, |
| | heads = num_heads, |
| | attn_dim_head = d_model // num_heads, |
| | ff_mult = dff / d_model, |
| | rotary_pos_emb = True, |
| | cross_attend = True, |
| | attn_flash = True, |
| | attn_dropout = dropout, |
| | ff_dropout = dropout, |
| | use_rmsnorm = True |
| | ) |
| |
|
| | self.final_linear = nn.Linear(d_model, vocab_size) |
| | self.final_linear.weight = self.embedding.weight |
| |
|
| | def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask): |
| | |
| | src_emb = self.embedding(src) * math.sqrt(self.d_model) |
| | src_emb = self.dropout_layer(src_emb) |
| |
|
| | tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model) |
| | tgt_emb = self.dropout_layer(tgt_emb) |
| |
|
| | |
| | |
| | |
| | enc_mask = ~src_padding_mask if src_padding_mask is not None else None |
| | dec_mask = ~tgt_padding_mask if tgt_padding_mask is not None else None |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | memory = self.encoder(src_emb, mask=enc_mask) |
| |
|
| | |
| | |
| | |
| | decoder_output = self.decoder( |
| | tgt_emb, |
| | context=memory, |
| | mask=dec_mask, |
| | context_mask=enc_mask |
| | ) |
| |
|
| | return self.final_linear(decoder_output) |
| |
|
| | |
| | def create_masks(self, src, tgt): |
| | src_padding_mask = (src == tokenizer.pad_token_id) |
| | tgt_padding_mask = (tgt == tokenizer.pad_token_id) |
| | |
| | tgt_mask = nn.Transformer.generate_square_subsequent_mask( |
| | sz=tgt.size(1), device=src.device, dtype=torch.bool |
| | ) |
| | return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask |
| |
|
| | @torch.no_grad() |
| | def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor: |
| | self.eval() |
| | |
| | src_padding_mask = (src == tokenizer.pad_token_id) |
| | |
| | enc_mask = ~src_padding_mask |
| |
|
| | |
| | src_emb = self.embedding(src) * math.sqrt(self.d_model) |
| | |
| | memory = self.encoder(self.dropout_layer(src_emb), mask=enc_mask) |
| |
|
| | batch_size = src.shape[0] |
| | |
| | memory = memory.repeat_interleave(num_beams, dim=0) |
| | enc_mask = enc_mask.repeat_interleave(num_beams, dim=0) |
| |
|
| | initial_token = tokenizer.pad_token_id |
| | beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device) |
| | beam_scores = torch.zeros(batch_size * num_beams, device=src.device) |
| | finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device) |
| |
|
| | for _ in range(max_length - 1): |
| | if finished_beams.all(): break |
| |
|
| | |
| | tgt_emb = self.embedding(beams) * math.sqrt(self.d_model) |
| | |
| |
|
| | |
| | |
| | decoder_output = self.decoder( |
| | self.dropout_layer(tgt_emb), |
| | context=memory, |
| | context_mask=enc_mask |
| | ) |
| |
|
| | logits = self.final_linear(decoder_output[:, -1, :]) |
| | log_probs = F.log_softmax(logits, dim=-1) |
| |
|
| | |
| | log_probs[:, tokenizer.pad_token_id] = -torch.inf |
| | if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0 |
| |
|
| | total_scores = beam_scores.unsqueeze(1) + log_probs |
| | if _ == 0: |
| | total_scores = total_scores.view(batch_size, num_beams, -1) |
| | total_scores[:, 1:, :] = -torch.inf |
| | total_scores = total_scores.view(batch_size * num_beams, -1) |
| | else: |
| | total_scores = beam_scores.unsqueeze(1) + log_probs |
| |
|
| | total_scores = total_scores.view(batch_size, -1) |
| | top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1) |
| |
|
| | beam_indices = top_indices // log_probs.shape[-1] |
| | token_indices = top_indices % log_probs.shape[-1] |
| |
|
| | batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1) |
| | effective_indices = (batch_indices * num_beams + beam_indices).view(-1) |
| |
|
| | beams = beams[effective_indices] |
| | beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1) |
| | beam_scores = top_scores.view(-1) |
| | finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id) |
| |
|
| | final_beams = beams.view(batch_size, num_beams, -1) |
| | final_scores = beam_scores.view(batch_size, num_beams) |
| | normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1) |
| | best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :] |
| | self.train() |
| | return best_beams |
| |
|