| import unittest | |
| from Andromeda.model import AndromedaTokenizer | |
| class TestAndromedaTokenizer(unittest.TestCase): | |
| def setUp(self): | |
| self.tokenizer = AndromedaTokenizer() | |
| def test_initialization(self): | |
| self.assertIsNotNone(self.tokenizer.tokenizer, "Tokenizer is not initialized.") | |
| self.assertEqual(self.tokenizer.tokenizer.eos_token, "<eos>", "EOS token is not correctly set.") | |
| self.assertEqual(self.tokenizer.tokenizer.pad_token, "<pad>", "PAD token is not correctly set.") | |
| self.assertEqual(self.tokenizer.tokenizer.model_max_length, 8192, "Model max length is not correctly set.") | |
| def test_tokenize_texts(self): | |
| texts = ["Hello, world!", "Andromeda is great."] | |
| tokenized_texts = self.tokenizer.tokenize_texts(texts) | |
| self.assertEqual(tokenized_texts.shape[0], len(texts), "Number of tokenized texts does not match input.") | |
| self.assertTrue(all(isinstance(t, torch.Tensor) for t in tokenized_texts), "Not all tokenized texts are PyTorch tensors.") | |
| def test_decode(self): | |
| texts = ["Hello, world!", "Andromeda is great."] | |
| tokenized_texts = self.tokenizer.tokenize_texts(texts) | |
| decoded_texts = [self.tokenizer.decode(t) for t in tokenized_texts] | |
| self.assertEqual(decoded_texts, texts, "Decoded texts do not match original texts.") | |
| def test_len(self): | |
| num_tokens = len(self.tokenizer) | |
| self.assertIsInstance(num_tokens, int, "Number of tokens is not an integer.") | |
| self.assertGreater(num_tokens, 0, "Number of tokens is not greater than 0.") | |
| if __name__ == '__main__': | |
| unittest.main() |