| import torch | |
| import unittest | |
| from Andromeda.model import Andromeda | |
| class TestAndromeda(unittest.TestCase): | |
| def setUp(self): | |
| self.model = Andromeda() | |
| def test_initialization(self): | |
| self.assertIsNotNone(self.model.andromeda, "Transformer is not initialized.") | |
| self.assertIsNotNone(self.model.decoder, "AutoregressiveWrapper is not initialized.") | |
| def test_forward_pass(self): | |
| input_tokens = torch.randint(0, 50432, (1, 8192)) | |
| output = self.model(input_tokens) | |
| self.assertIsInstance(output, torch.Tensor, "Output is not a PyTorch tensor.") | |
| self.assertEqual(output.shape[0], input_tokens.shape[0], "Output batch size does not match input.") | |
| def test_error_handling(self): | |
| with self.assertRaises(Exception): | |
| self.model.forward(None) | |
| def test_model_parameters(self): | |
| self.assertEqual(self.model.Andromeda.num_tokens, 50432, "Number of tokens is not correctly set.") | |
| self.assertEqual(self.model.Andromeda.max_seq_len, 8192, "Max sequence length is not correctly set.") | |
| def test_model_output(self): | |
| input_tokens = torch.randint(0, 50432, (1, 8192)) | |
| output1 = self.model(input_tokens) | |
| output2 = self.model(input_tokens) | |
| self.assertTrue(torch.allclose(output1, output2), "Model does not produce consistent output.") | |
| class TestAndromedaExtended(unittest.TestCase): | |
| def setUp(self): | |
| self.model = Andromeda() | |
| def test_input_size(self): | |
| for seq_len in [512, 1024, 2048, 4096]: | |
| input_tokens = torch.randint(0, 50432, (1, seq_len)) | |
| output = self.model(input_tokens) | |
| self.assertEqual(output.shape[1], seq_len, f"Output sequence length does not match input for seq_len={seq_len}.") | |
| def test_batch_size(self): | |
| for batch_size in [2, 4, 8, 16]: | |
| input_tokens = torch.randint(0, 50432, (batch_size, 8192)) | |
| output = self.model(input_tokens) | |
| self.assertEqual(output.shape[0], batch_size, f"Output batch size does not match input for batch_size={batch_size}.") | |
| def test_token_range(self): | |
| for token in [0, 50431]: | |
| input_tokens = torch.full((1, 8192), fill_value=token) | |
| output = self.model(input_tokens) | |
| self.assertIsInstance(output, torch.Tensor, f"Output is not a PyTorch tensor for token={token}.") | |
| def test_model_depth(self): | |
| for depth in [16, 32, 64]: | |
| model = Andromeda(depth=depth) | |
| self.assertEqual(model.Andromeda.attn_layers.depth, depth, f"Model depth is not correctly set for depth={depth}.") | |
| def test_model_dim(self): | |
| for dim in [1280, 2560, 5120]: | |
| model = Andromeda(dim=dim) | |
| self.assertEqual(model.Andromeda.attn_layers.dim, dim, f"Model dimension is not correctly set for dim={dim}.") | |
| def test_model_heads(self): | |
| for heads in [12, 24, 48]: | |
| model = Andromeda(heads=heads) | |
| self.assertEqual(model.Andromeda.attn_layers.heads, heads, f"Number of heads is not correctly set for heads={heads}.") | |
| def test_model_dim_head(self): | |
| for dim_head in [64, 128, 256]: | |
| model = Andromeda(dim_head=dim_head) | |
| self.assertEqual(model.Andromeda.attn_layers.dim_head, dim_head, f"Head dimension is not correctly set for dim_head={dim_head}.") | |
| def test_model_alibi_num_heads(self): | |
| for alibi_num_heads in [6, 12, 24]: | |
| model = Andromeda(alibi_num_heads=alibi_num_heads) | |
| self.assertEqual(model.Andromeda.attn_layers.alibi_num_heads, alibi_num_heads, f"Number of alibi heads is not correctly set for alibi_num_heads={alibi_num_heads}.") | |
| def test_model_shift_tokens(self): | |
| for shift_tokens in [0, 1, 2]: | |
| model = Andromeda(shift_tokens=shift_tokens) | |
| self.assertEqual(model.Andromeda.attn_layers.shift_tokens, shift_tokens, f"Number of shift tokens is not correctly set for shift_tokens={shift_tokens}.") | |
| def test_model_use_abs_pos_emb(self): | |
| for use_abs_pos_emb in [True, False]: | |
| model = Andromeda(use_abs_pos_emb=use_abs_pos_emb) | |
| self.assertEqual(model.Andromeda.use_abs_pos_emb, use_abs_pos_emb, f"Use absolute position embedding flag is not correctly set for use_abs_pos_emb={use_abs_pos_emb}.") | |
| def test_model_alibi_pos_bias(self): | |
| for alibi_pos_bias in [True, False]: | |
| model = Andromeda(alibi_pos_bias=alibi_pos_bias) | |
| self.assertEqual(model.Andromeda.attn_layers.alibi_pos_bias, alibi_pos_bias, f"Alibi position bias flag is not correctly set for alibi_pos_bias={alibi_pos_bias}.") | |
| def test_model_rotary_xpos(self): | |
| for rotary_xpos in [True, False]: | |
| model = Andromeda(rotary_xpos=rotary_xpos) | |
| self.assertEqual(model.Andromeda.attn_layers.rotary_xpos, rotary_xpos, f"Rotary position flag is not correctly set for rotary_xpos={rotary_xpos}.") | |
| def test_model_attn_flash(self): | |
| for attn_flash in [True, False]: | |
| model = Andromeda(attn_flash=attn_flash) | |
| self.assertEqual(model.Andromeda.attn_layers.attn_flash, attn_flash, f"Attention flash flag is not correctly set for attn_flash={attn_flash}") | |
| if __name__ == '__main__': | |
| unittest.main() |