Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForCausalLM | |
| import re | |
| from evolutiontransformer.worker import ( | |
| load_base_models_if_needed, | |
| BASE_MODELS, | |
| inference, | |
| inference_task, | |
| merge_models, | |
| ) | |
| def get_final_answer(text: str) -> int | None: | |
| numbers = re.findall(r"\d+", text) | |
| return int(numbers[-1]) if numbers else None | |
| def test_inference(): | |
| session_id = "test_session" | |
| print("### Testing inference on SVAMP model...") | |
| prompt = "If there are 3 cars and 2 bikes, how many vehicles are there in total?\nAnswer:" | |
| output = inference_task(session_id, "svamp", prompt) | |
| assert get_final_answer(output["response"]) == 5 | |
| def test_merge_models(): | |
| load_base_models_if_needed() | |
| model_recipe = { | |
| "layer_recipe": [[(i, "svamp", 1.0)] for i in range(24)], | |
| "embedding_lambdas": [1.0, 1.0], | |
| "linear_lambdas": [1.0, 1.0], | |
| } | |
| merged_model = merge_models(model_recipe) | |
| for (name1, param1), (name2, param2) in zip( | |
| BASE_MODELS["svamp"].named_parameters(), merged_model.named_parameters() | |
| ): | |
| assert torch.allclose(param1, param2) | |
| def test_merge_models_with_inference1(): | |
| load_base_models_if_needed() | |
| model_recipe = { | |
| "layer_recipe": [ | |
| [(i % 24, "svamp", 1.0 if i < 24 else 0.5)] for i in range(48) | |
| ], | |
| "embedding_lambdas": [1.0, 1.0], | |
| "linear_lambdas": [1.0, 1.0], | |
| } | |
| merged_model = merge_models(model_recipe) | |
| print( | |
| inference( | |
| merged_model, | |
| "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:", | |
| ) | |
| ) | |
| def test_merge_models_with_inference2(): | |
| load_base_models_if_needed() | |
| model_recipe = { | |
| "layer_recipe": [[(i, "tinystories", 1.0)] for i in range(24)], | |
| "embedding_lambdas": [0.0, 0.0], | |
| "linear_lambdas": [0.0, 0.0], | |
| } | |
| merged_model = merge_models(model_recipe) | |
| print( | |
| inference( | |
| merged_model, | |
| "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:", | |
| ) | |
| ) | |