evolutiontransformer / tests /test_model_actions.py
tcmmichaelb139's picture
init
66f1733
import torch
from transformers import AutoModelForCausalLM
import re
from evolutiontransformer.worker import (
load_base_models_if_needed,
BASE_MODELS,
inference,
inference_task,
merge_models,
)
def get_final_answer(text: str) -> int | None:
numbers = re.findall(r"\d+", text)
return int(numbers[-1]) if numbers else None
def test_inference():
session_id = "test_session"
print("### Testing inference on SVAMP model...")
prompt = "If there are 3 cars and 2 bikes, how many vehicles are there in total?\nAnswer:"
output = inference_task(session_id, "svamp", prompt)
assert get_final_answer(output["response"]) == 5
def test_merge_models():
load_base_models_if_needed()
model_recipe = {
"layer_recipe": [[(i, "svamp", 1.0)] for i in range(24)],
"embedding_lambdas": [1.0, 1.0],
"linear_lambdas": [1.0, 1.0],
}
merged_model = merge_models(model_recipe)
for (name1, param1), (name2, param2) in zip(
BASE_MODELS["svamp"].named_parameters(), merged_model.named_parameters()
):
assert torch.allclose(param1, param2)
def test_merge_models_with_inference1():
load_base_models_if_needed()
model_recipe = {
"layer_recipe": [
[(i % 24, "svamp", 1.0 if i < 24 else 0.5)] for i in range(48)
],
"embedding_lambdas": [1.0, 1.0],
"linear_lambdas": [1.0, 1.0],
}
merged_model = merge_models(model_recipe)
print(
inference(
merged_model,
"A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
)
)
def test_merge_models_with_inference2():
load_base_models_if_needed()
model_recipe = {
"layer_recipe": [[(i, "tinystories", 1.0)] for i in range(24)],
"embedding_lambdas": [0.0, 0.0],
"linear_lambdas": [0.0, 0.0],
}
merged_model = merge_models(model_recipe)
print(
inference(
merged_model,
"A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
)
)