Spaces:
Sleeping
Sleeping
| import pytest | |
| from fastapi.testclient import TestClient | |
| import time | |
| import re | |
| from evolutiontransformer.api import app | |
| def get_final_answer(text: str) -> int | None: | |
| numbers = re.findall(r"\d+", text) | |
| return int(numbers[-1]) if numbers else None | |
| def client(): | |
| with TestClient(app) as c: | |
| yield c | |
| def await_task_completion(client, task_id, timeout=60): | |
| start_time = time.time() | |
| while time.time() - start_time < timeout: | |
| status_response = client.get(f"/tasks/{task_id}") | |
| print(status_response.json()) | |
| if status_response.status_code == 500: | |
| return {"error": status_response.json().get("detail", "Unknown error")} | |
| assert status_response.status_code == 200 | |
| status_data = status_response.json() | |
| if status_data["status"] == "SUCCESS": | |
| return status_data["result"] | |
| time.sleep(2) | |
| else: | |
| pytest.fail( | |
| f"Task {task_id} did not complete within the {timeout}-second timeout." | |
| ) | |
| return None | |
| def test_generate_endpoint_svamp(client): | |
| """ | |
| Tests inference on svamp | |
| """ | |
| response = client.post( | |
| "/generate", | |
| json={ | |
| "model_name": "svamp", | |
| "prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:", | |
| "max_new_tokens": 50, | |
| "temperature": 0.7, | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "task_id" in data | |
| task_id = data["task_id"] | |
| final_result = await_task_completion(client, task_id) | |
| assert "response" in final_result | |
| output_text = final_result["response"] | |
| answer = get_final_answer(output_text) | |
| assert answer == 14 | |
| def test_merge_then_inference_svamp_1(client): | |
| """ | |
| Tests merging then inference for svamp dataset | |
| """ | |
| merge_response = client.post( | |
| "/merge", | |
| json={ | |
| "model1_name": "svamp", | |
| "model2_name": "tinystories", | |
| "layer_recipe": [[(i, 0, 1.0)] for i in range(24)], | |
| "embedding_lambdas": [1.0, 1.0], | |
| "linear_lambdas": [1.0, 1.0], | |
| "merged_name": "svamp_merged", | |
| }, | |
| ) | |
| assert merge_response.status_code == 200 | |
| merge_data = merge_response.json() | |
| assert "task_id" in merge_data | |
| merge_task_id = merge_data["task_id"] | |
| merge_status_data = await_task_completion(client, merge_task_id) | |
| model_name = merge_status_data["response"] | |
| time.sleep(5) | |
| generate_response = client.post( | |
| "/generate", | |
| json={ | |
| "model_name": model_name, | |
| "prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:", | |
| "max_new_tokens": 50, | |
| "temperature": 0.7, | |
| }, | |
| ) | |
| assert generate_response.status_code == 200 | |
| generate_data = generate_response.json() | |
| assert "task_id" in generate_data | |
| generate_task_id = generate_data["task_id"] | |
| final_result = await_task_completion(client, generate_task_id) | |
| assert "response" in final_result | |
| output_text = final_result["response"] | |
| answer = get_final_answer(output_text) | |
| assert answer == 14 | |
| def test_merge_then_inference_svamp_2(client): | |
| """ | |
| Tests merging then inference for svamp dataset | |
| """ | |
| merge_repsonse = client.post( | |
| "/merge", | |
| json={ | |
| "model1_name": "svamp", | |
| "model2_name": "tinystories", | |
| "layer_recipe": [[(i % 24, 0, 1.0 if i < 24 else 0.5)] for i in range(48)], | |
| "embedding_lambdas": [1.0, 1.0], | |
| "linear_lambdas": [1.0, 1.0], | |
| "merged_name": "svamp_merged", | |
| }, | |
| ) | |
| assert merge_repsonse.status_code == 200 | |
| merge_data = merge_repsonse.json() | |
| assert "task_id" in merge_data | |
| merge_task_id = merge_data["task_id"] | |
| merge_status_data = await_task_completion(client, merge_task_id) | |
| model_name = merge_status_data["response"] | |
| merge_response2 = client.post( | |
| "/merge", | |
| json={ | |
| "model1_name": model_name, | |
| "model2_name": "tinystories", | |
| "layer_recipe": [[(i, 1, 0.25)] for i in range(24)], | |
| "embedding_lambdas": [0.0, 0.0], | |
| "linear_lambdas": [0.0, 0.0], | |
| "merged_name": "svamp_merged", | |
| }, | |
| ) | |
| assert merge_response2.status_code == 200 | |
| merge_data2 = merge_response2.json() | |
| assert "task_id" in merge_data2 | |
| merge_task_id2 = merge_data2["task_id"] | |
| merge_status_data2 = await_task_completion(client, merge_task_id2) | |
| model_name2 = merge_status_data2["response"] | |
| time.sleep(5) | |
| generate_response = client.post( | |
| "/generate", | |
| json={ | |
| "model_name": model_name2, | |
| "prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:", | |
| "max_new_tokens": 50, | |
| "temperature": 0.7, | |
| }, | |
| ) | |
| assert generate_response.status_code == 200 | |
| generate_data = generate_response.json() | |
| assert "task_id" in generate_data | |
| generate_task_id = generate_data["task_id"] | |
| final_result = await_task_completion(client, generate_task_id) | |
| assert "response" in final_result | |
| output_text = final_result["response"] | |
| answer = get_final_answer(output_text) | |
| assert answer == 14 | |
| def test_merge_two_children_then_merge(client): | |
| """ | |
| Tests creating two children and merging them | |
| """ | |
| merge_response1 = client.post( | |
| "/merge", | |
| json={ | |
| "model1_name": "svamp", | |
| "model2_name": "tinystories", | |
| "layer_recipe": [[(i, 0, 0.8)] for i in range(12)] | |
| + [[(i, 1, 0.6)] for i in range(12)], | |
| "embedding_lambdas": [0.7, 0.3], | |
| "linear_lambdas": [0.8, 0.2], | |
| "merged_name": "child1", | |
| }, | |
| ) | |
| assert merge_response1.status_code == 200 | |
| merge_data1 = merge_response1.json() | |
| assert "task_id" in merge_data1 | |
| merge_task_id1 = merge_data1["task_id"] | |
| merge_status_data1 = await_task_completion(client, merge_task_id1) | |
| child1_name = merge_status_data1["response"] | |
| merge_response2 = client.post( | |
| "/merge", | |
| json={ | |
| "model1_name": "svamp", | |
| "model2_name": "tinystories", | |
| "layer_recipe": [[(i, 1, 0.9)] for i in range(8)] | |
| + [[(i, 0, 0.4)] for i in range(16)], | |
| "embedding_lambdas": [0.2, 0.9], | |
| "linear_lambdas": [0.3, 0.7], | |
| "merged_name": "child2", | |
| }, | |
| ) | |
| assert merge_response2.status_code == 200 | |
| merge_data2 = merge_response2.json() | |
| assert "task_id" in merge_data2 | |
| merge_task_id2 = merge_data2["task_id"] | |
| merge_status_data2 = await_task_completion(client, merge_task_id2) | |
| child2_name = merge_status_data2["response"] | |
| merge_response3 = client.post( | |
| "/merge", | |
| json={ | |
| "model1_name": child1_name, | |
| "model2_name": child2_name, | |
| "layer_recipe": [[(i, 0, 0.6), (i, 1, 0.4)] for i in range(24)], | |
| "embedding_lambdas": [0.5, 0.5], | |
| "linear_lambdas": [0.6, 0.4], | |
| "merged_name": "final_merged", | |
| }, | |
| ) | |
| assert merge_response3.status_code == 200 | |
| merge_data3 = merge_response3.json() | |
| assert "task_id" in merge_data3 | |
| merge_task_id3 = merge_data3["task_id"] | |
| merge_status_data3 = await_task_completion(client, merge_task_id3) | |
| final_model_name = merge_status_data3["response"] | |
| time.sleep(5) | |
| number_of_models = client.post(f"/list_models") | |
| assert number_of_models.status_code == 200 | |
| number_of_models_data = number_of_models.json() | |
| assert "task_id" in number_of_models_data | |
| number_of_models_task_id = number_of_models_data["task_id"] | |
| number_of_models_result = await_task_completion(client, number_of_models_task_id) | |
| assert "response" in number_of_models_result | |
| models = number_of_models_result["response"] | |
| print(models) | |
| assert len(models) == 5 | |
| generate_response = client.post( | |
| "/generate", | |
| json={ | |
| "model_name": final_model_name, | |
| "prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:", | |
| "max_new_tokens": 50, | |
| "temperature": 0.7, | |
| }, | |
| ) | |
| assert generate_response.status_code == 200 | |
| generate_data = generate_response.json() | |
| assert "task_id" in generate_data | |
| generate_task_id = generate_data["task_id"] | |
| final_result = await_task_completion(client, generate_task_id) | |
| assert "response" in final_result | |
| output_text = final_result["response"] | |
| answer = get_final_answer(output_text) | |
| assert answer == 14 | |
| def test_merge_fail(client): | |
| """ | |
| Tests merging with too many layers | |
| """ | |
| merge_repsonse = client.post( | |
| "/merge", | |
| json={ | |
| "model1_name": "svamp", | |
| "model2_name": "tinystories", | |
| "layer_recipe": [[(i, 0, 1.0)] for i in range(50)], | |
| "embedding_lambdas": [1.0, 1.0], | |
| "linear_lambdas": [1.0, 1.0], | |
| "merged_name": "svamp_merged", | |
| }, | |
| ) | |
| assert merge_repsonse.status_code == 200 | |
| merge_data = merge_repsonse.json() | |
| assert "task_id" in merge_data | |
| merge_task_id = merge_data["task_id"] | |
| merge_status_data = await_task_completion(client, merge_task_id) | |
| assert "response" not in merge_status_data | |
| assert "error" in merge_status_data | |
| assert "Layer recipe too long" in merge_status_data["error"] | |