Spaces:
Runtime error
Runtime error
| Tutorial: Simple LSTM | |
| ===================== | |
| In this tutorial we will extend fairseq by adding a new | |
| that encodes a source | |
| sentence with an LSTM and then passes the final hidden state to a second LSTM | |
| that decodes the target sentence (without attention). | |
| This tutorial covers: | |
| 1. **Writing an Encoder and Decoder** to encode/decode the source/target | |
| sentence, respectively. | |
| 2. **Registering a new Model** so that it can be used with the existing | |
| :ref:`Command-line tools`. | |
| 3. **Training the Model** using the existing command-line tools. | |
| 4. **Making generation faster** by modifying the Decoder to use | |
| :ref:`Incremental decoding`. | |
| 1. Building an Encoder and Decoder | |
| ---------------------------------- | |
| In this section we'll define a simple LSTM Encoder and Decoder. All Encoders | |
| should implement the :class:`~fairseq.models.FairseqEncoder` interface and | |
| Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface. | |
| These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders | |
| and FairseqDecoders can be written and used in the same ways as ordinary PyTorch | |
| Modules. | |
| Encoder | |
| ~~~~~~~ | |
| Our Encoder will embed the tokens in the source sentence, feed them to a | |
| and return the final hidden state. To create our encoder | |
| save the following in a new file named :file:`fairseq/models/simple_lstm.py`:: | |
| import torch.nn as nn | |
| from fairseq import utils | |
| from fairseq.models import FairseqEncoder | |
| class SimpleLSTMEncoder(FairseqEncoder): | |
| def __init__( | |
| self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1, | |
| ): | |
| super().__init__(dictionary) | |
| self.args = args | |
| # Our encoder will embed the inputs before feeding them to the LSTM. | |
| self.embed_tokens = nn.Embedding( | |
| num_embeddings=len(dictionary), | |
| embedding_dim=embed_dim, | |
| padding_idx=dictionary.pad(), | |
| ) | |
| self.dropout = nn.Dropout(p=dropout) | |
| # We'll use a single-layer, unidirectional LSTM for simplicity. | |
| self.lstm = nn.LSTM( | |
| input_size=embed_dim, | |
| hidden_size=hidden_dim, | |
| num_layers=1, | |
| bidirectional=False, | |
| batch_first=True, | |
| ) | |
| def forward(self, src_tokens, src_lengths): | |
| # The inputs to the ``forward()`` function are determined by the | |
| # Task, and in particular the ``'net_input'`` key in each | |
| # mini-batch. We discuss Tasks in the next tutorial, but for now just | |
| # know that *src_tokens* has shape `(batch, src_len)` and *src_lengths* | |
| # has shape `(batch)`. | |
| # Note that the source is typically padded on the left. This can be | |
| # configured by adding the `--left-pad-source "False"` command-line | |
| # argument, but here we'll make the Encoder handle either kind of | |
| # padding by converting everything to be right-padded. | |
| if self.args.left_pad_source: | |
| # Convert left-padding to right-padding. | |
| src_tokens = utils.convert_padding_direction( | |
| src_tokens, | |
| padding_idx=self.dictionary.pad(), | |
| left_to_right=True | |
| ) | |
| # Embed the source. | |
| x = self.embed_tokens(src_tokens) | |
| # Apply dropout. | |
| x = self.dropout(x) | |
| # Pack the sequence into a PackedSequence object to feed to the LSTM. | |
| x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True) | |
| # Get the output from the LSTM. | |
| _outputs, (final_hidden, _final_cell) = self.lstm(x) | |
| # Return the Encoder's output. This can be any object and will be | |
| # passed directly to the Decoder. | |
| return { | |
| # this will have shape `(bsz, hidden_dim)` | |
| 'final_hidden': final_hidden.squeeze(0), | |
| } | |
| # Encoders are required to implement this method so that we can rearrange | |
| # the order of the batch elements during inference (e.g., beam search). | |
| def reorder_encoder_out(self, encoder_out, new_order): | |
| """ | |
| Reorder encoder output according to `new_order`. | |
| Args: | |
| encoder_out: output from the ``forward()`` method | |
| new_order (LongTensor): desired order | |
| Returns: | |
| `encoder_out` rearranged according to `new_order` | |
| """ | |
| final_hidden = encoder_out['final_hidden'] | |
| return { | |
| 'final_hidden': final_hidden.index_select(0, new_order), | |
| } | |
| Decoder | |
| ~~~~~~~ | |
| Our Decoder will predict the next word, conditioned on the Encoder's final | |
| hidden state and an embedded representation of the previous target word -- which | |
| is sometimes called *teacher forcing*. More specifically, we'll use a | |
| to produce a sequence of hidden states that we'll project | |
| to the size of the output vocabulary to predict each target word. | |
| :: | |
| import torch | |
| from fairseq.models import FairseqDecoder | |
| class SimpleLSTMDecoder(FairseqDecoder): | |
| def __init__( | |
| self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128, | |
| dropout=0.1, | |
| ): | |
| super().__init__(dictionary) | |
| # Our decoder will embed the inputs before feeding them to the LSTM. | |
| self.embed_tokens = nn.Embedding( | |
| num_embeddings=len(dictionary), | |
| embedding_dim=embed_dim, | |
| padding_idx=dictionary.pad(), | |
| ) | |
| self.dropout = nn.Dropout(p=dropout) | |
| # We'll use a single-layer, unidirectional LSTM for simplicity. | |
| self.lstm = nn.LSTM( | |
| # For the first layer we'll concatenate the Encoder's final hidden | |
| # state with the embedded target tokens. | |
| input_size=encoder_hidden_dim + embed_dim, | |
| hidden_size=hidden_dim, | |
| num_layers=1, | |
| bidirectional=False, | |
| ) | |
| # Define the output projection. | |
| self.output_projection = nn.Linear(hidden_dim, len(dictionary)) | |
| # During training Decoders are expected to take the entire target sequence | |
| # (shifted right by one position) and produce logits over the vocabulary. | |
| # The *prev_output_tokens* tensor begins with the end-of-sentence symbol, | |
| # ``dictionary.eos()``, followed by the target sequence. | |
| def forward(self, prev_output_tokens, encoder_out): | |
| """ | |
| Args: | |
| prev_output_tokens (LongTensor): previous decoder outputs of shape | |
| `(batch, tgt_len)`, for teacher forcing | |
| encoder_out (Tensor, optional): output from the encoder, used for | |
| encoder-side attention | |
| Returns: | |
| tuple: | |
| - the last decoder layer's output of shape | |
| `(batch, tgt_len, vocab)` | |
| - the last decoder layer's attention weights of shape | |
| `(batch, tgt_len, src_len)` | |
| """ | |
| bsz, tgt_len = prev_output_tokens.size() | |
| # Extract the final hidden state from the Encoder. | |
| final_encoder_hidden = encoder_out['final_hidden'] | |
| # Embed the target sequence, which has been shifted right by one | |
| # position and now starts with the end-of-sentence symbol. | |
| x = self.embed_tokens(prev_output_tokens) | |
| # Apply dropout. | |
| x = self.dropout(x) | |
| # Concatenate the Encoder's final hidden state to *every* embedded | |
| # target token. | |
| x = torch.cat( | |
| [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)], | |
| dim=2, | |
| ) | |
| # Using PackedSequence objects in the Decoder is harder than in the | |
| # Encoder, since the targets are not sorted in descending length order, | |
| # which is a requirement of ``pack_padded_sequence()``. Instead we'll | |
| # feed nn.LSTM directly. | |
| initial_state = ( | |
| final_encoder_hidden.unsqueeze(0), # hidden | |
| torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell | |
| ) | |
| output, _ = self.lstm( | |
| x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)` | |
| initial_state, | |
| ) | |
| x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)` | |
| # Project the outputs to the size of the vocabulary. | |
| x = self.output_projection(x) | |
| # Return the logits and ``None`` for the attention weights | |
| return x, None | |
| 2. Registering the Model | |
| ------------------------ | |
| Now that we've defined our Encoder and Decoder we must *register* our model with | |
| fairseq using the :func:`~fairseq.models.register_model` function decorator. | |
| Once the model is registered we'll be able to use it with the existing | |
| Tools`. | |
| All registered models must implement the | |
| interface. For sequence-to-sequence | |
| models (i.e., any model with a single Encoder and Decoder), we can instead | |
| implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface. | |
| Create a small wrapper class in the same file and register it in fairseq with | |
| the name ``'simple_lstm'``:: | |
| from fairseq.models import FairseqEncoderDecoderModel, register_model | |
| # Note: the register_model "decorator" should immediately precede the | |
| # definition of the Model class. | |
| @register_model('simple_lstm') | |
| class SimpleLSTMModel(FairseqEncoderDecoderModel): | |
| @staticmethod | |
| def add_args(parser): | |
| # Models can override this method to add new command-line arguments. | |
| # Here we'll add some new command-line arguments to configure dropout | |
| # and the dimensionality of the embeddings and hidden states. | |
| parser.add_argument( | |
| '--encoder-embed-dim', type=int, metavar='N', | |
| help='dimensionality of the encoder embeddings', | |
| ) | |
| parser.add_argument( | |
| '--encoder-hidden-dim', type=int, metavar='N', | |
| help='dimensionality of the encoder hidden state', | |
| ) | |
| parser.add_argument( | |
| '--encoder-dropout', type=float, default=0.1, | |
| help='encoder dropout probability', | |
| ) | |
| parser.add_argument( | |
| '--decoder-embed-dim', type=int, metavar='N', | |
| help='dimensionality of the decoder embeddings', | |
| ) | |
| parser.add_argument( | |
| '--decoder-hidden-dim', type=int, metavar='N', | |
| help='dimensionality of the decoder hidden state', | |
| ) | |
| parser.add_argument( | |
| '--decoder-dropout', type=float, default=0.1, | |
| help='decoder dropout probability', | |
| ) | |
| @classmethod | |
| def build_model(cls, args, task): | |
| # Fairseq initializes models by calling the ``build_model()`` | |
| # function. This provides more flexibility, since the returned model | |
| # instance can be of a different type than the one that was called. | |
| # In this case we'll just return a SimpleLSTMModel instance. | |
| # Initialize our Encoder and Decoder. | |
| encoder = SimpleLSTMEncoder( | |
| args=args, | |
| dictionary=task.source_dictionary, | |
| embed_dim=args.encoder_embed_dim, | |
| hidden_dim=args.encoder_hidden_dim, | |
| dropout=args.encoder_dropout, | |
| ) | |
| decoder = SimpleLSTMDecoder( | |
| dictionary=task.target_dictionary, | |
| encoder_hidden_dim=args.encoder_hidden_dim, | |
| embed_dim=args.decoder_embed_dim, | |
| hidden_dim=args.decoder_hidden_dim, | |
| dropout=args.decoder_dropout, | |
| ) | |
| model = SimpleLSTMModel(encoder, decoder) | |
| # Print the model architecture. | |
| print(model) | |
| return model | |
| # We could override the ``forward()`` if we wanted more control over how | |
| # the encoder and decoder interact, but it's not necessary for this | |
| # tutorial since we can inherit the default implementation provided by | |
| # the FairseqEncoderDecoderModel base class, which looks like: | |
| # | |
| # def forward(self, src_tokens, src_lengths, prev_output_tokens): | |
| # encoder_out = self.encoder(src_tokens, src_lengths) | |
| # decoder_out = self.decoder(prev_output_tokens, encoder_out) | |
| # return decoder_out | |
| Finally let's define a *named architecture* with the configuration for our | |
| model. This is done with the :func:`~fairseq.models.register_model_architecture` | |
| function decorator. Thereafter this named architecture can be used with the | |
| ``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``:: | |
| from fairseq.models import register_model_architecture | |
| # The first argument to ``register_model_architecture()`` should be the name | |
| # of the model we registered above (i.e., 'simple_lstm'). The function we | |
| # register here should take a single argument *args* and modify it in-place | |
| # to match the desired architecture. | |
| @register_model_architecture('simple_lstm', 'tutorial_simple_lstm') | |
| def tutorial_simple_lstm(args): | |
| # We use ``getattr()`` to prioritize arguments that are explicitly given | |
| # on the command-line, so that the defaults defined below are only used | |
| # when no other value has been specified. | |
| args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) | |
| args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256) | |
| args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256) | |
| args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256) | |
| 3. Training the Model | |
| --------------------- | |
| Now we're ready to train the model. We can use the existing :ref:`fairseq-train` | |
| command-line tool for this, making sure to specify our new Model architecture | |
| (``--arch tutorial_simple_lstm``). | |
| .. note:: | |
| Make sure you've already preprocessed the data from the IWSLT example in the | |
| :file:`examples/translation/` directory. | |
| .. code-block:: console | |
| > fairseq-train data-bin/iwslt14.tokenized.de-en \ | |
| --arch tutorial_simple_lstm \ | |
| --encoder-dropout 0.2 --decoder-dropout 0.2 \ | |
| --optimizer adam --lr 0.005 --lr-shrink 0.5 \ | |
| --max-tokens 12000 | |
| (...) | |
| | epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396 | |
| | epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954 | |
| The model files should appear in the :file:`checkpoints/` directory. While this | |
| model architecture is not very good, we can use the :ref:`fairseq-generate` script to | |
| generate translations and compute our BLEU score over the test set: | |
| .. code-block:: console | |
| > fairseq-generate data-bin/iwslt14.tokenized.de-en \ | |
| --path checkpoints/checkpoint_best.pt \ | |
| --beam 5 \ | |
| --remove-bpe | |
| (...) | |
| | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s) | |
| | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146) | |
| 4. Making generation faster | |
| --------------------------- | |
| While autoregressive generation from sequence-to-sequence models is inherently | |
| slow, our implementation above is especially slow because it recomputes the | |
| entire sequence of Decoder hidden states for every output token (i.e., it is | |
| ``O(n^2)``). We can make this significantly faster by instead caching the | |
| previous hidden states. | |
| In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a | |
| special mode at inference time where the Model only receives a single timestep | |
| of input corresponding to the immediately previous output token (for teacher | |
| forcing) and must produce the next output incrementally. Thus the model must | |
| cache any long-term state that is needed about the sequence, e.g., hidden | |
| states, convolutional states, etc. | |
| To implement incremental decoding we will modify our model to implement the | |
| interface. Compared to the | |
| standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental | |
| decoder interface allows ``forward()`` methods to take an extra keyword argument | |
| (*incremental_state*) that can be used to cache state across time-steps. | |
| Let's replace our ``SimpleLSTMDecoder`` with an incremental one:: | |
| import torch | |
| from fairseq.models import FairseqIncrementalDecoder | |
| class SimpleLSTMDecoder(FairseqIncrementalDecoder): | |
| def __init__( | |
| self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128, | |
| dropout=0.1, | |
| ): | |
| # This remains the same as before. | |
| super().__init__(dictionary) | |
| self.embed_tokens = nn.Embedding( | |
| num_embeddings=len(dictionary), | |
| embedding_dim=embed_dim, | |
| padding_idx=dictionary.pad(), | |
| ) | |
| self.dropout = nn.Dropout(p=dropout) | |
| self.lstm = nn.LSTM( | |
| input_size=encoder_hidden_dim + embed_dim, | |
| hidden_size=hidden_dim, | |
| num_layers=1, | |
| bidirectional=False, | |
| ) | |
| self.output_projection = nn.Linear(hidden_dim, len(dictionary)) | |
| # We now take an additional kwarg (*incremental_state*) for caching the | |
| # previous hidden and cell states. | |
| def forward(self, prev_output_tokens, encoder_out, incremental_state=None): | |
| if incremental_state is not None: | |
| # If the *incremental_state* argument is not ``None`` then we are | |
| # in incremental inference mode. While *prev_output_tokens* will | |
| # still contain the entire decoded prefix, we will only use the | |
| # last step and assume that the rest of the state is cached. | |
| prev_output_tokens = prev_output_tokens[:, -1:] | |
| # This remains the same as before. | |
| bsz, tgt_len = prev_output_tokens.size() | |
| final_encoder_hidden = encoder_out['final_hidden'] | |
| x = self.embed_tokens(prev_output_tokens) | |
| x = self.dropout(x) | |
| x = torch.cat( | |
| [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)], | |
| dim=2, | |
| ) | |
| # We will now check the cache and load the cached previous hidden and | |
| # cell states, if they exist, otherwise we will initialize them to | |
| # zeros (as before). We will use the ``utils.get_incremental_state()`` | |
| # and ``utils.set_incremental_state()`` helpers. | |
| initial_state = utils.get_incremental_state( | |
| self, incremental_state, 'prev_state', | |
| ) | |
| if initial_state is None: | |
| # first time initialization, same as the original version | |
| initial_state = ( | |
| final_encoder_hidden.unsqueeze(0), # hidden | |
| torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell | |
| ) | |
| # Run one step of our LSTM. | |
| output, latest_state = self.lstm(x.transpose(0, 1), initial_state) | |
| # Update the cache with the latest hidden and cell states. | |
| utils.set_incremental_state( | |
| self, incremental_state, 'prev_state', latest_state, | |
| ) | |
| # This remains the same as before | |
| x = output.transpose(0, 1) | |
| x = self.output_projection(x) | |
| return x, None | |
| # The ``FairseqIncrementalDecoder`` interface also requires implementing a | |
| # ``reorder_incremental_state()`` method, which is used during beam search | |
| # to select and reorder the incremental state. | |
| def reorder_incremental_state(self, incremental_state, new_order): | |
| # Load the cached state. | |
| prev_state = utils.get_incremental_state( | |
| self, incremental_state, 'prev_state', | |
| ) | |
| # Reorder batches according to *new_order*. | |
| reordered_state = ( | |
| prev_state[0].index_select(1, new_order), # hidden | |
| prev_state[1].index_select(1, new_order), # cell | |
| ) | |
| # Update the cached state. | |
| utils.set_incremental_state( | |
| self, incremental_state, 'prev_state', reordered_state, | |
| ) | |
| Finally, we can rerun generation and observe the speedup: | |
| .. code-block:: console | |
| # Before | |
| > fairseq-generate data-bin/iwslt14.tokenized.de-en \ | |
| --path checkpoints/checkpoint_best.pt \ | |
| --beam 5 \ | |
| --remove-bpe | |
| (...) | |
| | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s) | |
| | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146) | |
| # After | |
| > fairseq-generate data-bin/iwslt14.tokenized.de-en \ | |
| --path checkpoints/checkpoint_best.pt \ | |
| --beam 5 \ | |
| --remove-bpe | |
| (...) | |
| | Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s) | |
| | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146) | |