Spaces:
Paused
Paused
| # Test the following scenarios: | |
| # 1. Generate a Key, and use it to make a call | |
| # 2. Make a call with invalid key, expect it to fail | |
| # 3. Make a call to a key with invalid model - expect to fail | |
| # 4. Make a call to a key with valid model - expect to pass | |
| # 5. Make a call with user over budget, expect to fail | |
| # 6. Make a streaming chat/completions call with user over budget, expect to fail | |
| # 7. Make a call with an key that never expires, expect to pass | |
| # 8. Make a call with an expired key, expect to fail | |
| # 9. Delete a Key | |
| # 10. Generate a key, call key/info. Assert info returned is the same as generated key info | |
| # 11. Generate a Key, cal key/info, call key/update, call key/info | |
| # 12. Make a call with key over budget, expect to fail | |
| # 14. Make a streaming chat/completions call with key over budget, expect to fail | |
| # 15. Generate key, when `allow_user_auth`=False - check if `/key/info` returns key_name=null | |
| # 16. Generate key, when `allow_user_auth`=True - check if `/key/info` returns key_name=sk...<last-4-digits> | |
| # function to call to generate key - async def new_user(data: NewUserRequest): | |
| # function to validate a request - async def user_auth(request: Request): | |
| import os | |
| import sys | |
| import traceback | |
| import uuid | |
| from datetime import datetime, timezone | |
| from dotenv import load_dotenv | |
| from fastapi import Request | |
| from fastapi.routing import APIRoute | |
| import httpx | |
| load_dotenv() | |
| import io | |
| import os | |
| import time | |
| # this file is to test litellm/proxy | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system path | |
| import asyncio | |
| import logging | |
| import pytest | |
| import litellm | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.proxy.management_endpoints.internal_user_endpoints import ( | |
| new_user, | |
| user_info, | |
| user_update, | |
| ) | |
| from litellm.proxy.auth.auth_checks import get_key_object | |
| from litellm.proxy.management_endpoints.key_management_endpoints import ( | |
| delete_key_fn, | |
| generate_key_fn, | |
| generate_key_helper_fn, | |
| info_key_fn, | |
| list_keys, | |
| regenerate_key_fn, | |
| update_key_fn, | |
| ) | |
| from litellm.proxy.management_endpoints.team_endpoints import ( | |
| new_team, | |
| team_info, | |
| update_team, | |
| ) | |
| from litellm.proxy.proxy_server import ( | |
| LitellmUserRoles, | |
| audio_transcriptions, | |
| chat_completion, | |
| completion, | |
| embeddings, | |
| model_list, | |
| moderations, | |
| user_api_key_auth, | |
| ) | |
| from litellm.proxy.image_endpoints import image_generation | |
| from litellm.proxy.management_endpoints.customer_endpoints import ( | |
| new_end_user, | |
| ) | |
| from litellm.proxy.spend_tracking.spend_management_endpoints import ( | |
| global_spend, | |
| spend_key_fn, | |
| spend_user_fn, | |
| view_spend_logs, | |
| ) | |
| from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend | |
| verbose_proxy_logger.setLevel(level=logging.DEBUG) | |
| from starlette.datastructures import URL | |
| from litellm.caching.caching import DualCache | |
| from litellm.proxy._types import ( | |
| DynamoDBArgs, | |
| GenerateKeyRequest, | |
| KeyRequest, | |
| LiteLLM_UpperboundKeyGenerateParams, | |
| NewCustomerRequest, | |
| NewTeamRequest, | |
| NewUserRequest, | |
| ProxyErrorTypes, | |
| ProxyException, | |
| UpdateKeyRequest, | |
| UpdateTeamRequest, | |
| UpdateUserRequest, | |
| UserAPIKeyAuth, | |
| ) | |
| proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) | |
| request_data = { | |
| "model": "azure-gpt-3.5", | |
| "messages": [ | |
| {"role": "user", "content": "this is my new test. respond in 50 lines"} | |
| ], | |
| } | |
| def prisma_client(): | |
| from litellm.proxy.proxy_cli import append_query_params | |
| ### add connection pool + pool timeout args | |
| params = {"connection_limit": 100, "pool_timeout": 60} | |
| database_url = os.getenv("DATABASE_URL") | |
| modified_url = append_query_params(database_url, params) | |
| os.environ["DATABASE_URL"] = modified_url | |
| # Assuming PrismaClient is a class that needs to be instantiated | |
| prisma_client = PrismaClient( | |
| database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj | |
| ) | |
| # Reset litellm.proxy.proxy_server.prisma_client to None | |
| litellm.proxy.proxy_server.litellm_proxy_budget_name = ( | |
| f"litellm-proxy-budget-{time.time()}" | |
| ) | |
| litellm.proxy.proxy_server.user_custom_key_generate = None | |
| return prisma_client | |
| async def test_new_user_response(prisma_client): | |
| try: | |
| print("prisma client=", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| _team_id = "ishaan-special-team_{}".format(uuid.uuid4()) | |
| await new_team( | |
| NewTeamRequest( | |
| team_id=_team_id, | |
| ), | |
| http_request=Request(scope={"type": "http"}), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| _response = await new_user( | |
| data=NewUserRequest( | |
| models=["azure-gpt-3.5"], | |
| team_id=_team_id, | |
| tpm_limit=20, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| ) | |
| print(_response) | |
| assert _response.models == ["azure-gpt-3.5"] | |
| assert _response.team_id == _team_id | |
| assert _response.tpm_limit == 20 | |
| except Exception as e: | |
| print("Got Exception", e) | |
| pytest.fail(f"Got exception {e}") | |
| def test_generate_and_call_with_valid_key(prisma_client, api_route): | |
| # 1. Generate a Key, and use it to make a call | |
| from unittest.mock import MagicMock | |
| print("prisma client=", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| user_api_key_dict = UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ) | |
| request = NewUserRequest(user_role=LitellmUserRoles.INTERNAL_USER) | |
| key = await new_user(request, user_api_key_dict=user_api_key_dict) | |
| print(key) | |
| user_id = key.user_id | |
| # check /user/info to verify user_role was set correctly | |
| request_mock = MagicMock() | |
| new_user_info = await user_info( | |
| request=request_mock, user_id=user_id, user_api_key_dict=user_api_key_dict | |
| ) | |
| new_user_info = new_user_info.user_info | |
| print("new_user_info=", new_user_info) | |
| assert new_user_info["user_role"] == LitellmUserRoles.INTERNAL_USER | |
| assert new_user_info["user_id"] == user_id | |
| generated_key = key.key | |
| bearer_token = "Bearer " + generated_key | |
| assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict | |
| value_from_prisma = await prisma_client.get_data( | |
| token=generated_key, | |
| ) | |
| print("token from prisma", value_from_prisma) | |
| request = Request( | |
| { | |
| "type": "http", | |
| "route": api_route, | |
| "path": api_route.path, | |
| "headers": [("Authorization", bearer_token)], | |
| } | |
| ) | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| asyncio.run(test()) | |
| except Exception as e: | |
| pytest.fail(f"An exception occurred - {str(e)}") | |
| def test_call_with_invalid_key(prisma_client): | |
| # 2. Make a call with invalid key, expect it to fail | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| generated_key = "sk-126666" | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}, receive=None) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("got result", result) | |
| pytest.fail(f"This should have failed!. IT's an invalid key") | |
| asyncio.run(test()) | |
| except Exception as e: | |
| print("Got Exception", e) | |
| print(e.message) | |
| assert "Authentication Error, Invalid proxy server token passed" in e.message | |
| pass | |
| def test_call_with_invalid_model(prisma_client): | |
| litellm.set_verbose = True | |
| # 3. Make a call to a key with an invalid model - expect to fail | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = NewUserRequest(models=["mistral"]) | |
| key = await new_user( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| async def return_body(): | |
| return b'{"model": "gemini-pro-vision"}' | |
| request.body = return_body | |
| # use generated key to auth in | |
| print( | |
| "Bearer token being sent to user_api_key_auth() - {}".format( | |
| bearer_token | |
| ) | |
| ) | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| pytest.fail(f"This should have failed!. IT's an invalid model") | |
| asyncio.run(test()) | |
| except Exception as e: | |
| assert isinstance(e, ProxyException) | |
| assert e.type == ProxyErrorTypes.key_model_access_denied | |
| assert e.param == "model" | |
| def test_call_with_valid_model(prisma_client): | |
| # 4. Make a call to a key with a valid model - expect to pass | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = NewUserRequest(models=["mistral"]) | |
| key = await new_user( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| async def return_body(): | |
| return b'{"model": "mistral"}' | |
| request.body = return_body | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| asyncio.run(test()) | |
| except Exception as e: | |
| pytest.fail(f"An exception occurred - {str(e)}") | |
| async def test_call_with_valid_model_using_all_models(prisma_client): | |
| """ | |
| Do not delete | |
| this is the Admin UI flow | |
| 1. Create a team with model = `all-proxy-models` | |
| 2. Create a key with model = `all-team-models` | |
| 3. Call /chat/completions with the key -> expect to pass | |
| """ | |
| # Make a call to a key with model = `all-proxy-models` this is an Alias from LiteLLM Admin UI | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| team_request = NewTeamRequest( | |
| team_alias="testing-team", | |
| models=["all-proxy-models"], | |
| ) | |
| new_team_response = await new_team( | |
| data=team_request, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| print("new_team_response", new_team_response) | |
| created_team_id = new_team_response["team_id"] | |
| request = GenerateKeyRequest( | |
| models=["all-team-models"], team_id=created_team_id | |
| ) | |
| key = await generate_key_fn( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| async def return_body(): | |
| return b'{"model": "mistral"}' | |
| request.body = return_body | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| # call /key/info for key - models == "all-proxy-models" | |
| key_info = await info_key_fn( | |
| key=generated_key, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| print("key_info", key_info) | |
| models = key_info["info"]["models"] | |
| assert models == ["all-team-models"] | |
| except Exception as e: | |
| pytest.fail(f"An exception occurred - {str(e)}") | |
| def test_call_with_user_over_budget(prisma_client): | |
| # 5. Make a call with a key over budget, expect to fail | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = NewUserRequest(max_budget=0.00001) | |
| key = await new_user( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| user_id = key.user_id | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| # update spend using track_cost callback, make 2nd request, it should fail | |
| from litellm import Choices, Message, ModelResponse, Usage | |
| from litellm.proxy.proxy_server import _ProxyDBLogger | |
| proxy_db_logger = _ProxyDBLogger() | |
| resp = ModelResponse( | |
| id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac", | |
| choices=[ | |
| Choices( | |
| finish_reason=None, | |
| index=0, | |
| message=Message( | |
| content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
| role="assistant", | |
| ), | |
| ) | |
| ], | |
| model="gpt-35-turbo", # azure always has model written like this | |
| usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
| ) | |
| await proxy_db_logger._PROXY_track_cost_callback( | |
| kwargs={ | |
| "stream": False, | |
| "litellm_params": { | |
| "metadata": { | |
| "user_api_key": generated_key, | |
| "user_api_key_user_id": user_id, | |
| } | |
| }, | |
| "response_cost": 0.00002, | |
| }, | |
| completion_response=resp, | |
| start_time=datetime.now(), | |
| end_time=datetime.now(), | |
| ) | |
| await asyncio.sleep(5) | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| pytest.fail("This should have failed!. They key crossed it's budget") | |
| asyncio.run(test()) | |
| except Exception as e: | |
| print("got an errror=", e) | |
| error_detail = e.message | |
| assert "ExceededBudget:" in error_detail | |
| assert isinstance(e, ProxyException) | |
| assert e.type == ProxyErrorTypes.budget_exceeded | |
| print(vars(e)) | |
| def test_end_user_cache_write_unit_test(): | |
| """ | |
| assert end user object is being written to cache as expected | |
| """ | |
| pass | |
| def test_call_with_end_user_over_budget(prisma_client): | |
| # Test if a user passed to /chat/completions is tracked & fails when they cross their budget | |
| # we only check this when litellm.max_end_user_budget is set | |
| import random | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr(litellm, "max_end_user_budget", 0.00001) | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| user = f"ishaan {uuid.uuid4().hex}" | |
| request = NewCustomerRequest( | |
| user_id=user, max_budget=0.000001 | |
| ) # create a key with no budget | |
| await new_end_user( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| bearer_token = "Bearer sk-1234" | |
| async def return_body(): | |
| return_string = f'{{"model": "gemini-pro-vision", "user": "{user}"}}' | |
| # return string as bytes | |
| return return_string.encode() | |
| request.body = return_body | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| # update spend using track_cost callback, make 2nd request, it should fail | |
| from litellm import Choices, Message, ModelResponse, Usage | |
| from litellm.proxy.proxy_server import _ProxyDBLogger | |
| proxy_db_logger = _ProxyDBLogger() | |
| resp = ModelResponse( | |
| id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac", | |
| choices=[ | |
| Choices( | |
| finish_reason=None, | |
| index=0, | |
| message=Message( | |
| content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
| role="assistant", | |
| ), | |
| ) | |
| ], | |
| model="gpt-35-turbo", # azure always has model written like this | |
| usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
| ) | |
| await proxy_db_logger._PROXY_track_cost_callback( | |
| kwargs={ | |
| "stream": False, | |
| "litellm_params": { | |
| "metadata": { | |
| "user_api_key": "sk-1234", | |
| "user_api_key_end_user_id": user, | |
| }, | |
| "proxy_server_request": { | |
| "body": { | |
| "user": user, | |
| } | |
| }, | |
| }, | |
| "response_cost": 10, | |
| }, | |
| completion_response=resp, | |
| start_time=datetime.now(), | |
| end_time=datetime.now(), | |
| ) | |
| await asyncio.sleep(10) | |
| await update_spend( | |
| prisma_client=prisma_client, | |
| db_writer_client=None, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| pytest.fail("This should have failed!. They key crossed it's budget") | |
| asyncio.run(test()) | |
| except Exception as e: | |
| print(f"raised error: {e}, traceback: {traceback.format_exc()}") | |
| error_detail = e.message | |
| assert "Budget has been exceeded! Current" in error_detail | |
| assert isinstance(e, ProxyException) | |
| assert e.type == ProxyErrorTypes.budget_exceeded | |
| print(vars(e)) | |
| def test_call_with_proxy_over_budget(prisma_client): | |
| # 5.1 Make a call with a proxy over budget, expect to fail | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| litellm_proxy_budget_name = f"litellm-proxy-budget-{time.time()}" | |
| setattr( | |
| litellm.proxy.proxy_server, | |
| "litellm_proxy_admin_name", | |
| litellm_proxy_budget_name, | |
| ) | |
| setattr(litellm, "max_budget", 0.00001) | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| user_api_key_cache.set_cache( | |
| key="{}:spend".format(litellm_proxy_budget_name), value=0 | |
| ) | |
| setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = NewUserRequest() | |
| key = await new_user( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| user_id = key.user_id | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| # update spend using track_cost callback, make 2nd request, it should fail | |
| from litellm import Choices, Message, ModelResponse, Usage | |
| from litellm.proxy.proxy_server import _ProxyDBLogger | |
| proxy_db_logger = _ProxyDBLogger() | |
| resp = ModelResponse( | |
| id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac", | |
| choices=[ | |
| Choices( | |
| finish_reason=None, | |
| index=0, | |
| message=Message( | |
| content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
| role="assistant", | |
| ), | |
| ) | |
| ], | |
| model="gpt-35-turbo", # azure always has model written like this | |
| usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
| ) | |
| await proxy_db_logger._PROXY_track_cost_callback( | |
| kwargs={ | |
| "stream": False, | |
| "litellm_params": { | |
| "metadata": { | |
| "user_api_key": generated_key, | |
| "user_api_key_user_id": user_id, | |
| } | |
| }, | |
| "response_cost": 0.00002, | |
| }, | |
| completion_response=resp, | |
| start_time=datetime.now(), | |
| end_time=datetime.now(), | |
| ) | |
| await asyncio.sleep(5) | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| pytest.fail(f"This should have failed!. They key crossed it's budget") | |
| asyncio.run(test()) | |
| except Exception as e: | |
| if hasattr(e, "message"): | |
| error_detail = e.message | |
| else: | |
| error_detail = traceback.format_exc() | |
| assert "Budget has been exceeded" in error_detail | |
| assert isinstance(e, ProxyException) | |
| assert e.type == ProxyErrorTypes.budget_exceeded | |
| print(vars(e)) | |
| def test_call_with_user_over_budget_stream(prisma_client): | |
| # 6. Make a call with a key over budget, expect to fail | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| import logging | |
| from litellm._logging import verbose_proxy_logger | |
| litellm.set_verbose = True | |
| verbose_proxy_logger.setLevel(logging.DEBUG) | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = NewUserRequest(max_budget=0.00001) | |
| key = await new_user( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| user_id = key.user_id | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| # update spend using track_cost callback, make 2nd request, it should fail | |
| from litellm import Choices, Message, ModelResponse, Usage | |
| from litellm.proxy.proxy_server import _ProxyDBLogger | |
| proxy_db_logger = _ProxyDBLogger() | |
| resp = ModelResponse( | |
| id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac", | |
| choices=[ | |
| Choices( | |
| finish_reason=None, | |
| index=0, | |
| message=Message( | |
| content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
| role="assistant", | |
| ), | |
| ) | |
| ], | |
| model="gpt-35-turbo", # azure always has model written like this | |
| usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
| ) | |
| await proxy_db_logger._PROXY_track_cost_callback( | |
| kwargs={ | |
| "stream": True, | |
| "complete_streaming_response": resp, | |
| "litellm_params": { | |
| "metadata": { | |
| "user_api_key": generated_key, | |
| "user_api_key_user_id": user_id, | |
| } | |
| }, | |
| "response_cost": 0.00002, | |
| }, | |
| completion_response=ModelResponse(), | |
| start_time=datetime.now(), | |
| end_time=datetime.now(), | |
| ) | |
| await asyncio.sleep(5) | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| pytest.fail("This should have failed!. They key crossed it's budget") | |
| asyncio.run(test()) | |
| except Exception as e: | |
| error_detail = e.message | |
| assert "ExceededBudget:" in error_detail | |
| assert isinstance(e, ProxyException) | |
| assert e.type == ProxyErrorTypes.budget_exceeded | |
| print(vars(e)) | |
| def test_call_with_proxy_over_budget_stream(prisma_client): | |
| # 6.1 Make a call with a global proxy over budget, expect to fail | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| litellm_proxy_budget_name = f"litellm-proxy-budget-{time.time()}" | |
| setattr( | |
| litellm.proxy.proxy_server, | |
| "litellm_proxy_admin_name", | |
| litellm_proxy_budget_name, | |
| ) | |
| setattr(litellm, "max_budget", 0.00001) | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| user_api_key_cache.set_cache( | |
| key="{}:spend".format(litellm_proxy_budget_name), value=0 | |
| ) | |
| setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) | |
| import logging | |
| from litellm._logging import verbose_proxy_logger | |
| litellm.set_verbose = True | |
| verbose_proxy_logger.setLevel(logging.DEBUG) | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| ## CREATE PROXY + USER BUDGET ## | |
| # request = NewUserRequest( | |
| # max_budget=0.00001, user_id=litellm_proxy_budget_name | |
| # ) | |
| request = NewUserRequest() | |
| key = await new_user( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| user_id = key.user_id | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| # update spend using track_cost callback, make 2nd request, it should fail | |
| from litellm import Choices, Message, ModelResponse, Usage | |
| from litellm.proxy.proxy_server import _ProxyDBLogger | |
| proxy_db_logger = _ProxyDBLogger() | |
| resp = ModelResponse( | |
| id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac", | |
| choices=[ | |
| Choices( | |
| finish_reason=None, | |
| index=0, | |
| message=Message( | |
| content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
| role="assistant", | |
| ), | |
| ) | |
| ], | |
| model="gpt-35-turbo", # azure always has model written like this | |
| usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
| ) | |
| await proxy_db_logger._PROXY_track_cost_callback( | |
| kwargs={ | |
| "stream": True, | |
| "complete_streaming_response": resp, | |
| "litellm_params": { | |
| "metadata": { | |
| "user_api_key": generated_key, | |
| "user_api_key_user_id": user_id, | |
| } | |
| }, | |
| "response_cost": 0.00002, | |
| }, | |
| completion_response=ModelResponse(), | |
| start_time=datetime.now(), | |
| end_time=datetime.now(), | |
| ) | |
| await asyncio.sleep(5) | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| pytest.fail(f"This should have failed!. They key crossed it's budget") | |
| asyncio.run(test()) | |
| except Exception as e: | |
| error_detail = e.message | |
| assert "Budget has been exceeded" in error_detail | |
| print(vars(e)) | |
| def test_generate_and_call_with_valid_key_never_expires(prisma_client): | |
| # 7. Make a call with an key that never expires, expect to pass | |
| print("prisma client=", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = NewUserRequest(duration=None) | |
| key = await new_user( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| asyncio.run(test()) | |
| except Exception as e: | |
| pytest.fail(f"An exception occurred - {str(e)}") | |
| def test_generate_and_call_with_expired_key(prisma_client): | |
| # 8. Make a call with an expired key, expect to fail | |
| print("prisma client=", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = NewUserRequest(duration="0s") | |
| key = await new_user( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| pytest.fail("This should have failed!. It's an expired key") | |
| asyncio.run(test()) | |
| except Exception as e: | |
| print("Got Exception", e) | |
| print(e.message) | |
| assert "Authentication Error" in e.message | |
| assert e.type == ProxyErrorTypes.expired_key | |
| pass | |
| def test_delete_key(prisma_client): | |
| # 9. Generate a Key, delete it. Check if deletion works fine | |
| print("prisma client=", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr(litellm.proxy.proxy_server, "user_custom_auth", None) | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| request = NewUserRequest() | |
| key = await new_user( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| bearer_token = "Bearer " + generated_key | |
| delete_key_request = KeyRequest(keys=[generated_key]) | |
| bearer_token = "Bearer sk-1234" | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/key/delete") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print(f"result: {result}") | |
| result.user_role = LitellmUserRoles.PROXY_ADMIN | |
| # delete the key | |
| result_delete_key = await delete_key_fn( | |
| data=delete_key_request, user_api_key_dict=result | |
| ) | |
| print("result from delete key", result_delete_key) | |
| assert result_delete_key == {"deleted_keys": [generated_key]} | |
| assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict | |
| assert ( | |
| hash_token(generated_key) | |
| not in user_api_key_cache.in_memory_cache.cache_dict | |
| ) | |
| asyncio.run(test()) | |
| except Exception as e: | |
| pytest.fail(f"An exception occurred - {str(e)}") | |
| def test_delete_key_auth(prisma_client): | |
| # 10. Generate a Key, delete it, use it to make a call -> expect fail | |
| print("prisma client=", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| request = NewUserRequest() | |
| key = await new_user( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| bearer_token = "Bearer " + generated_key | |
| delete_key_request = KeyRequest(keys=[generated_key]) | |
| # delete the key | |
| bearer_token = "Bearer sk-1234" | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/key/delete") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print(f"result: {result}") | |
| result.user_role = LitellmUserRoles.PROXY_ADMIN | |
| result_delete_key = await delete_key_fn( | |
| data=delete_key_request, user_api_key_dict=result | |
| ) | |
| print("result from delete key", result_delete_key) | |
| assert result_delete_key == {"deleted_keys": [generated_key]} | |
| request = Request(scope={"type": "http"}, receive=None) | |
| request._url = URL(url="/chat/completions") | |
| assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict | |
| assert ( | |
| hash_token(generated_key) | |
| not in user_api_key_cache.in_memory_cache.cache_dict | |
| ) | |
| # use generated key to auth in | |
| bearer_token = "Bearer " + generated_key | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("got result", result) | |
| pytest.fail(f"This should have failed!. IT's an invalid key") | |
| asyncio.run(test()) | |
| except Exception as e: | |
| print("Got Exception", e) | |
| print(e.message) | |
| assert "Authentication Error" in e.message | |
| pass | |
| def test_generate_and_call_key_info(prisma_client): | |
| # 10. Generate a Key, cal key/info | |
| print("prisma client=", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = NewUserRequest( | |
| metadata={"team": "litellm-team3", "project": "litellm-project3"} | |
| ) | |
| key = await new_user( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| # use generated key to auth in | |
| result = await info_key_fn( | |
| key=generated_key, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| ), | |
| ) | |
| print("result from info_key_fn", result) | |
| assert result["key"] == generated_key | |
| print("\n info for key=", result["info"]) | |
| assert result["info"]["max_parallel_requests"] == None | |
| assert result["info"]["metadata"] == { | |
| "team": "litellm-team3", | |
| "project": "litellm-project3", | |
| } | |
| # cleanup - delete key | |
| delete_key_request = KeyRequest(keys=[generated_key]) | |
| bearer_token = "Bearer sk-1234" | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/key/delete") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print(f"result: {result}") | |
| result.user_role = LitellmUserRoles.PROXY_ADMIN | |
| result_delete_key = await delete_key_fn( | |
| data=delete_key_request, user_api_key_dict=result | |
| ) | |
| asyncio.run(test()) | |
| except Exception as e: | |
| pytest.fail(f"An exception occurred - {str(e)}") | |
| def test_generate_and_update_key(prisma_client): | |
| # 11. Generate a Key, cal key/info, call key/update, call key/info | |
| # Check if data gets updated | |
| # Check if untouched data does not get updated | |
| import uuid | |
| print("prisma client=", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| # create team "[email protected]"" | |
| print("creating team [email protected]") | |
| _team_1 = "[email protected]_{}".format(uuid.uuid4()) | |
| await new_team( | |
| NewTeamRequest( | |
| team_id=_team_1, | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| _team_2 = "ishaan-special-team_{}".format(uuid.uuid4()) | |
| await new_team( | |
| NewTeamRequest( | |
| team_id=_team_2, | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| request = NewUserRequest( | |
| metadata={"project": "litellm-project3"}, | |
| team_id=_team_1, | |
| ) | |
| key = await new_user( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| # use generated key to auth in | |
| result = await info_key_fn( | |
| key=generated_key, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| ), | |
| ) | |
| print("result from info_key_fn", result) | |
| assert result["key"] == generated_key | |
| print("\n info for key=", result["info"]) | |
| assert result["info"]["max_parallel_requests"] == None | |
| assert result["info"]["metadata"] == { | |
| "project": "litellm-project3", | |
| } | |
| assert result["info"]["team_id"] == _team_1 | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/update/key") | |
| # update the key | |
| response1 = await update_key_fn( | |
| request=Request, | |
| data=UpdateKeyRequest( | |
| key=generated_key, | |
| models=["ada", "babbage", "curie", "davinci"], | |
| budget_duration="1mo", | |
| max_budget=100, | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print("response1=", response1) | |
| # update the tpm limit | |
| response2 = await update_key_fn( | |
| request=Request, | |
| data=UpdateKeyRequest(key=generated_key, tpm_limit=1000), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print("response2=", response2) | |
| # get info on key after update | |
| result = await info_key_fn( | |
| key=generated_key, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| ), | |
| ) | |
| print("result from info_key_fn", result) | |
| assert result["key"] == generated_key | |
| print("\n info for key=", result["info"]) | |
| assert result["info"]["max_parallel_requests"] == None | |
| assert result["info"]["metadata"] == { | |
| "project": "litellm-project3", | |
| } | |
| assert result["info"]["models"] == ["ada", "babbage", "curie", "davinci"] | |
| assert result["info"]["tpm_limit"] == 1000 | |
| assert result["info"]["budget_duration"] == "1mo" | |
| assert result["info"]["max_budget"] == 100 | |
| # budget_reset_at should exist for "1mo" duration | |
| assert result["info"]["budget_reset_at"] is not None | |
| budget_reset_at = result["info"]["budget_reset_at"].replace(tzinfo=timezone.utc) | |
| current_time = datetime.now(timezone.utc) | |
| print(f"Budget reset time: {budget_reset_at}") | |
| print(f"Current time: {current_time}") | |
| # Instead of checking exact timing, just verify that: | |
| # 1. Both are in the same day (for tests running same day) | |
| # 2. Or budget_reset_at is in next month | |
| if budget_reset_at.day == current_time.day: | |
| # Same day of month - just check month difference | |
| month_diff = budget_reset_at.month - current_time.month | |
| if budget_reset_at.year > current_time.year: | |
| month_diff += 12 | |
| # Should be scheduled for next month (at least 0.5 month away) | |
| assert month_diff >= 1, f"Expected reset to be at least 1 month ahead, got {month_diff} months" | |
| assert month_diff <= 2, f"Expected reset to be at most 2 months ahead, got {month_diff} months" | |
| else: | |
| # Just ensure the date is reasonable (not more than 40 days away) | |
| days_diff = (budget_reset_at - current_time).days | |
| assert 0 <= days_diff <= 40, f"Expected reset date to be reasonable, got {days_diff} days from now" | |
| # cleanup - delete key | |
| delete_key_request = KeyRequest(keys=[generated_key]) | |
| # delete the key | |
| bearer_token = "Bearer sk-1234" | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/key/delete") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print(f"result: {result}") | |
| result.user_role = LitellmUserRoles.PROXY_ADMIN | |
| result_delete_key = await delete_key_fn( | |
| data=delete_key_request, user_api_key_dict=result | |
| ) | |
| asyncio.run(test()) | |
| except Exception as e: | |
| print("Got Exception", e) | |
| pytest.fail(f"An exception occurred - {str(e)}\n{traceback.format_exc()}") | |
| def test_key_generate_with_custom_auth(prisma_client): | |
| # custom - generate key function | |
| async def custom_generate_key_fn(data: GenerateKeyRequest) -> dict: | |
| """ | |
| Asynchronous function for generating a key based on the input data. | |
| Args: | |
| data (GenerateKeyRequest): The input data for key generation. | |
| Returns: | |
| dict: A dictionary containing the decision and an optional message. | |
| { | |
| "decision": False, | |
| "message": "This violates LiteLLM Proxy Rules. No team id provided.", | |
| } | |
| """ | |
| # decide if a key should be generated or not | |
| print("using custom auth function!") | |
| data_json = data.json() # type: ignore | |
| # Unpacking variables | |
| team_id = data_json.get("team_id") | |
| duration = data_json.get("duration") | |
| models = data_json.get("models") | |
| aliases = data_json.get("aliases") | |
| config = data_json.get("config") | |
| spend = data_json.get("spend") | |
| user_id = data_json.get("user_id") | |
| max_parallel_requests = data_json.get("max_parallel_requests") | |
| metadata = data_json.get("metadata") | |
| tpm_limit = data_json.get("tpm_limit") | |
| rpm_limit = data_json.get("rpm_limit") | |
| if team_id is not None and team_id == "[email protected]": | |
| # only team_id="[email protected]" can make keys | |
| return { | |
| "decision": True, | |
| } | |
| else: | |
| print("Failed custom auth") | |
| return { | |
| "decision": False, | |
| "message": "This violates LiteLLM Proxy Rules. No team id provided.", | |
| } | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr( | |
| litellm.proxy.proxy_server, "user_custom_key_generate", custom_generate_key_fn | |
| ) | |
| try: | |
| async def test(): | |
| try: | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = GenerateKeyRequest() | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| pytest.fail(f"Expected an exception. Got {key}") | |
| except Exception as e: | |
| # this should fail | |
| print("Got Exception", e) | |
| print(e.message) | |
| print("First request failed!. This is expected") | |
| assert ( | |
| "This violates LiteLLM Proxy Rules. No team id provided." | |
| in e.message | |
| ) | |
| request_2 = GenerateKeyRequest( | |
| team_id="[email protected]", | |
| ) | |
| key = await generate_key_fn( | |
| request_2, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| asyncio.run(test()) | |
| except Exception as e: | |
| print("Got Exception", e) | |
| print(e.message) | |
| pytest.fail(f"An exception occurred - {str(e)}") | |
| def test_call_with_key_over_budget(prisma_client): | |
| # 12. Make a call with a key over budget, expect to fail | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = GenerateKeyRequest(max_budget=0.00001) | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| user_id = key.user_id | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| # update spend using track_cost callback, make 2nd request, it should fail | |
| from litellm import Choices, Message, ModelResponse, Usage | |
| from litellm.caching.caching import Cache | |
| from litellm.proxy.proxy_server import _ProxyDBLogger | |
| proxy_db_logger = _ProxyDBLogger() | |
| litellm.cache = Cache() | |
| import time | |
| import uuid | |
| request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{uuid.uuid4()}" | |
| resp = ModelResponse( | |
| id=request_id, | |
| choices=[ | |
| Choices( | |
| finish_reason=None, | |
| index=0, | |
| message=Message( | |
| content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
| role="assistant", | |
| ), | |
| ) | |
| ], | |
| model="gpt-35-turbo", # azure always has model written like this | |
| usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
| ) | |
| await proxy_db_logger._PROXY_track_cost_callback( | |
| kwargs={ | |
| "model": "chatgpt-v-3", | |
| "stream": False, | |
| "litellm_params": { | |
| "metadata": { | |
| "user_api_key": hash_token(generated_key), | |
| "user_api_key_user_id": user_id, | |
| } | |
| }, | |
| "response_cost": 0.00002, | |
| }, | |
| completion_response=resp, | |
| start_time=datetime.now(), | |
| end_time=datetime.now(), | |
| ) | |
| await update_spend( | |
| prisma_client=prisma_client, | |
| db_writer_client=None, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| # test spend_log was written and we can read it | |
| spend_logs = await view_spend_logs( | |
| request_id=request_id, | |
| user_api_key_dict=UserAPIKeyAuth(api_key=generated_key), | |
| ) | |
| print("read spend logs", spend_logs) | |
| assert len(spend_logs) == 1 | |
| spend_log = spend_logs[0] | |
| assert spend_log.request_id == request_id | |
| assert spend_log.spend == float("2e-05") | |
| assert spend_log.model == "chatgpt-v-3" | |
| assert ( | |
| spend_log.cache_key | |
| == "509ba0554a7129ae4f4fd13d11c141acce5549bb6aaf1f629ed543101615658e" | |
| ) | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| pytest.fail("This should have failed!. They key crossed it's budget") | |
| asyncio.run(test()) | |
| except Exception as e: | |
| # print(f"Error - {str(e)}") | |
| traceback.print_exc() | |
| if hasattr(e, "message"): | |
| error_detail = e.message | |
| else: | |
| error_detail = str(e) | |
| assert "Budget has been exceeded" in error_detail | |
| assert isinstance(e, ProxyException) | |
| assert e.type == ProxyErrorTypes.budget_exceeded | |
| print(vars(e)) | |
| def test_call_with_key_over_budget_no_cache(prisma_client): | |
| # 12. Make a call with a key over budget, expect to fail | |
| # ✅ Tests if spend trackign works when the key does not exist in memory | |
| # Related to this: https://github.com/BerriAI/litellm/issues/3920 | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = GenerateKeyRequest(max_budget=0.00001) | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| user_id = key.user_id | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| # update spend using track_cost callback, make 2nd request, it should fail | |
| from litellm.proxy.proxy_server import _ProxyDBLogger | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| user_api_key_cache.in_memory_cache.cache_dict = {} | |
| setattr(litellm.proxy.proxy_server, "proxy_batch_write_at", 1) | |
| from litellm import Choices, Message, ModelResponse, Usage | |
| from litellm.caching.caching import Cache | |
| litellm.cache = Cache() | |
| import time | |
| import uuid | |
| request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{uuid.uuid4()}" | |
| resp = ModelResponse( | |
| id=request_id, | |
| choices=[ | |
| Choices( | |
| finish_reason=None, | |
| index=0, | |
| message=Message( | |
| content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
| role="assistant", | |
| ), | |
| ) | |
| ], | |
| model="gpt-35-turbo", # azure always has model written like this | |
| usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
| ) | |
| proxy_db_logger = _ProxyDBLogger() | |
| await proxy_db_logger._PROXY_track_cost_callback( | |
| kwargs={ | |
| "model": "chatgpt-v-3", | |
| "stream": False, | |
| "litellm_params": { | |
| "metadata": { | |
| "user_api_key": hash_token(generated_key), | |
| "user_api_key_user_id": user_id, | |
| } | |
| }, | |
| "response_cost": 0.00002, | |
| }, | |
| completion_response=resp, | |
| start_time=datetime.now(), | |
| end_time=datetime.now(), | |
| ) | |
| await asyncio.sleep(10) | |
| await update_spend( | |
| prisma_client=prisma_client, | |
| db_writer_client=None, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| # test spend_log was written and we can read it | |
| spend_logs = await view_spend_logs( | |
| request_id=request_id, | |
| user_api_key_dict=UserAPIKeyAuth(api_key=generated_key), | |
| ) | |
| print("read spend logs", spend_logs) | |
| assert len(spend_logs) == 1 | |
| spend_log = spend_logs[0] | |
| assert spend_log.request_id == request_id | |
| assert spend_log.spend == float("2e-05") | |
| assert spend_log.model == "chatgpt-v-3" | |
| assert ( | |
| spend_log.cache_key | |
| == "509ba0554a7129ae4f4fd13d11c141acce5549bb6aaf1f629ed543101615658e" | |
| ) | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| pytest.fail(f"This should have failed!. They key crossed it's budget") | |
| asyncio.run(test()) | |
| except Exception as e: | |
| # print(f"Error - {str(e)}") | |
| traceback.print_exc() | |
| if hasattr(e, "message"): | |
| error_detail = e.message | |
| else: | |
| error_detail = str(e) | |
| assert "Budget has been exceeded" in error_detail | |
| assert isinstance(e, ProxyException) | |
| assert e.type == ProxyErrorTypes.budget_exceeded | |
| print(vars(e)) | |
| async def test_call_with_key_over_model_budget( | |
| prisma_client, request_model, should_pass | |
| ): | |
| # 12. Make a call with a key over budget, expect to fail | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| verbose_proxy_logger.setLevel(logging.DEBUG) | |
| # init model max budget limiter | |
| from litellm.proxy.hooks.model_max_budget_limiter import ( | |
| _PROXY_VirtualKeyModelMaxBudgetLimiter, | |
| ) | |
| model_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter( | |
| dual_cache=DualCache() | |
| ) | |
| litellm.callbacks.append(model_budget_limiter) | |
| try: | |
| # set budget for chatgpt-v-3 to 0.000001, expect the next request to fail | |
| model_max_budget = { | |
| "gpt-4o-mini": { | |
| "budget_limit": "0.000001", | |
| "time_period": "1d", | |
| }, | |
| "gpt-4o": { | |
| "budget_limit": "200", | |
| "time_period": "30d", | |
| }, | |
| } | |
| request = GenerateKeyRequest( | |
| max_budget=100000, # the key itself has a very high budget | |
| model_max_budget=model_max_budget, | |
| ) | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| user_id = key.user_id | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| async def return_body(): | |
| request_str = f'{{"model": "{request_model}"}}' # Added extra curly braces to escape JSON | |
| return request_str.encode() | |
| request.body = return_body | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| # update spend using track_cost callback, make 2nd request, it should fail | |
| await litellm.acompletion( | |
| model=request_model, | |
| messages=[{"role": "user", "content": "Hello, how are you?"}], | |
| metadata={ | |
| "user_api_key": hash_token(generated_key), | |
| "user_api_key_model_max_budget": model_max_budget, | |
| }, | |
| ) | |
| await asyncio.sleep(2) | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| if should_pass is True: | |
| print( | |
| f"Passed request for model={request_model}, model_max_budget={model_max_budget}" | |
| ) | |
| return | |
| print("result from user auth with new key", result) | |
| pytest.fail("This should have failed!. They key crossed it's budget") | |
| except Exception as e: | |
| # print(f"Error - {str(e)}") | |
| print( | |
| f"Failed request for model={request_model}, model_max_budget={model_max_budget}" | |
| ) | |
| assert ( | |
| should_pass is False | |
| ), f"This should have failed!. They key crossed it's budget for model={request_model}. {e}" | |
| traceback.print_exc() | |
| error_detail = e.message | |
| assert f"exceeded budget for model={request_model}" in error_detail | |
| assert isinstance(e, ProxyException) | |
| assert e.type == ProxyErrorTypes.budget_exceeded | |
| print(vars(e)) | |
| finally: | |
| litellm.callbacks.remove(model_budget_limiter) | |
| async def test_call_with_key_never_over_budget(prisma_client): | |
| # Make a call with a key with budget=None, it should never fail | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = GenerateKeyRequest(max_budget=None) | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| user_id = key.user_id | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key: {result}") | |
| # update spend using track_cost callback, make 2nd request, it should fail | |
| import time | |
| import uuid | |
| from litellm import Choices, Message, ModelResponse, Usage | |
| from litellm.proxy.proxy_server import _ProxyDBLogger | |
| proxy_db_logger = _ProxyDBLogger() | |
| request_id = f"chatcmpl-{uuid.uuid4()}" | |
| resp = ModelResponse( | |
| id=request_id, | |
| choices=[ | |
| Choices( | |
| finish_reason=None, | |
| index=0, | |
| message=Message( | |
| content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
| role="assistant", | |
| ), | |
| ) | |
| ], | |
| model="gpt-35-turbo", # azure always has model written like this | |
| usage=Usage( | |
| prompt_tokens=210000, completion_tokens=200000, total_tokens=41000 | |
| ), | |
| ) | |
| await proxy_db_logger._PROXY_track_cost_callback( | |
| kwargs={ | |
| "model": "chatgpt-v-3", | |
| "stream": False, | |
| "litellm_params": { | |
| "metadata": { | |
| "user_api_key": hash_token(generated_key), | |
| "user_api_key_user_id": user_id, | |
| } | |
| }, | |
| "response_cost": 200000, | |
| }, | |
| completion_response=resp, | |
| start_time=datetime.now(), | |
| end_time=datetime.now(), | |
| ) | |
| await update_spend( | |
| prisma_client=prisma_client, | |
| db_writer_client=None, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| except Exception as e: | |
| pytest.fail(f"This should have not failed!. They key uses max_budget=None. {e}") | |
| async def test_call_with_key_over_budget_stream(prisma_client): | |
| # 14. Make a call with a key over budget, expect to fail | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| import logging | |
| from litellm._logging import verbose_proxy_logger | |
| litellm.set_verbose = True | |
| verbose_proxy_logger.setLevel(logging.DEBUG) | |
| try: | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = GenerateKeyRequest(max_budget=0.00001) | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| user_id = key.user_id | |
| bearer_token = "Bearer " + generated_key | |
| print(f"generated_key: {generated_key}") | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| # update spend using track_cost callback, make 2nd request, it should fail | |
| import time | |
| import uuid | |
| from litellm import Choices, Message, ModelResponse, Usage | |
| from litellm.proxy.proxy_server import _ProxyDBLogger | |
| proxy_db_logger = _ProxyDBLogger() | |
| request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{uuid.uuid4()}" | |
| resp = ModelResponse( | |
| id=request_id, | |
| choices=[ | |
| Choices( | |
| finish_reason=None, | |
| index=0, | |
| message=Message( | |
| content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
| role="assistant", | |
| ), | |
| ) | |
| ], | |
| model="gpt-35-turbo", # azure always has model written like this | |
| usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
| ) | |
| await proxy_db_logger._PROXY_track_cost_callback( | |
| kwargs={ | |
| "call_type": "acompletion", | |
| "model": "sagemaker-chatgpt-v-3", | |
| "stream": True, | |
| "complete_streaming_response": resp, | |
| "litellm_params": { | |
| "metadata": { | |
| "user_api_key": hash_token(generated_key), | |
| "user_api_key_user_id": user_id, | |
| } | |
| }, | |
| "response_cost": 0.00005, | |
| }, | |
| completion_response=resp, | |
| start_time=datetime.now(), | |
| end_time=datetime.now(), | |
| ) | |
| await update_spend( | |
| prisma_client=prisma_client, | |
| db_writer_client=None, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| pytest.fail(f"This should have failed!. They key crossed it's budget") | |
| except Exception as e: | |
| print("Got Exception", e) | |
| error_detail = e.message | |
| assert "Budget has been exceeded" in error_detail | |
| print(vars(e)) | |
| async def test_aview_spend_per_user(prisma_client): | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| try: | |
| user_by_spend = await spend_user_fn(user_id=None) | |
| assert type(user_by_spend) == list | |
| assert len(user_by_spend) > 0 | |
| first_user = user_by_spend[0] | |
| print("\nfirst_user=", first_user) | |
| assert first_user["spend"] >= 0 | |
| except Exception as e: | |
| print("Got Exception", e) | |
| pytest.fail(f"Got exception {e}") | |
| async def test_view_spend_per_key(prisma_client): | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| try: | |
| key_by_spend = await spend_key_fn() | |
| assert type(key_by_spend) == list | |
| assert len(key_by_spend) > 0 | |
| first_key = key_by_spend[0] | |
| print("\nfirst_key=", first_key) | |
| assert first_key.spend >= 0 | |
| except Exception as e: | |
| print("Got Exception", e) | |
| pytest.fail(f"Got exception {e}") | |
| async def test_key_name_null(prisma_client): | |
| """ | |
| - create key | |
| - get key info | |
| - assert key_name is null | |
| """ | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| os.environ["DISABLE_KEY_NAME"] = "True" | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| try: | |
| request = GenerateKeyRequest() | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print("generated key=", key) | |
| generated_key = key.key | |
| result = await info_key_fn( | |
| key=generated_key, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| print("result from info_key_fn", result) | |
| assert result["info"]["key_name"] is None | |
| except Exception as e: | |
| print("Got Exception", e) | |
| pytest.fail(f"Got exception {e}") | |
| finally: | |
| os.environ["DISABLE_KEY_NAME"] = "False" | |
| async def test_key_name_set(prisma_client): | |
| """ | |
| - create key | |
| - get key info | |
| - assert key_name is not null | |
| """ | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True}) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| try: | |
| request = GenerateKeyRequest() | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| generated_key = key.key | |
| result = await info_key_fn( | |
| key=generated_key, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| print("result from info_key_fn", result) | |
| assert isinstance(result["info"]["key_name"], str) | |
| except Exception as e: | |
| print("Got Exception", e) | |
| pytest.fail(f"Got exception {e}") | |
| async def test_default_key_params(prisma_client): | |
| """ | |
| - create key | |
| - get key info | |
| - assert key_name is not null | |
| """ | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True}) | |
| litellm.default_key_generate_params = {"max_budget": 0.000122} | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| try: | |
| request = GenerateKeyRequest() | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| generated_key = key.key | |
| result = await info_key_fn( | |
| key=generated_key, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| print("result from info_key_fn", result) | |
| assert result["info"]["max_budget"] == 0.000122 | |
| except Exception as e: | |
| print("Got Exception", e) | |
| pytest.fail(f"Got exception {e}") | |
| async def test_upperbound_key_param_larger_budget(prisma_client): | |
| """ | |
| - create key | |
| - get key info | |
| - assert key_name is not null | |
| """ | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams( | |
| max_budget=0.001, budget_duration="1m" | |
| ) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| try: | |
| request = GenerateKeyRequest( | |
| max_budget=200000, | |
| budget_duration="30d", | |
| ) | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| # print(result) | |
| except Exception as e: | |
| assert e.code == str(400) | |
| async def test_upperbound_key_param_larger_duration(prisma_client): | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams( | |
| max_budget=100, duration="14d" | |
| ) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| try: | |
| request = GenerateKeyRequest( | |
| max_budget=10, | |
| duration="30d", | |
| ) | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| pytest.fail("Expected this to fail but it passed") | |
| # print(result) | |
| except Exception as e: | |
| assert e.code == str(400) | |
| async def test_upperbound_key_param_none_duration(prisma_client): | |
| from datetime import datetime, timedelta | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams( | |
| max_budget=100, duration="14d" | |
| ) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| try: | |
| request = GenerateKeyRequest() | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| # print(result) | |
| assert key.max_budget == 100 | |
| assert key.expires is not None | |
| _date_key_expires = key.expires.date() | |
| _fourteen_days_from_now = (datetime.now() + timedelta(days=14)).date() | |
| assert _date_key_expires == _fourteen_days_from_now | |
| except Exception as e: | |
| pytest.fail(f"Got exception {e}") | |
| def test_get_bearer_token(): | |
| from litellm.proxy.auth.user_api_key_auth import _get_bearer_token | |
| # Test valid Bearer token | |
| api_key = "Bearer valid_token" | |
| result = _get_bearer_token(api_key) | |
| assert result == "valid_token", f"Expected 'valid_token', got '{result}'" | |
| # Test empty API key | |
| api_key = "" | |
| result = _get_bearer_token(api_key) | |
| assert result == "", f"Expected '', got '{result}'" | |
| # Test API key without Bearer prefix | |
| api_key = "invalid_token" | |
| result = _get_bearer_token(api_key) | |
| assert result == "", f"Expected '', got '{result}'" | |
| # Test API key with Bearer prefix and extra spaces | |
| api_key = " Bearer valid_token " | |
| result = _get_bearer_token(api_key) | |
| assert result == "", f"Expected '', got '{result}'" | |
| # Test API key with Bearer prefix and no token | |
| api_key = "Bearer sk-1234" | |
| result = _get_bearer_token(api_key) | |
| assert result == "sk-1234", f"Expected 'valid_token', got '{result}'" | |
| async def test_update_logs_with_spend_logs_url(prisma_client): | |
| """ | |
| Unit test for making sure spend logs list is still updated when url passed in | |
| """ | |
| from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter | |
| db_spend_update_writer = DBSpendUpdateWriter() | |
| payload = {"startTime": datetime.now(), "endTime": datetime.now()} | |
| await db_spend_update_writer._insert_spend_log_to_db(payload=payload, prisma_client=prisma_client) | |
| assert len(prisma_client.spend_log_transactions) > 0 | |
| prisma_client.spend_log_transactions = [] | |
| spend_logs_url = "" | |
| payload = {"startTime": datetime.now(), "endTime": datetime.now()} | |
| await db_spend_update_writer._insert_spend_log_to_db( | |
| payload=payload, spend_logs_url=spend_logs_url, prisma_client=prisma_client | |
| ) | |
| assert len(prisma_client.spend_log_transactions) > 0 | |
| async def test_user_api_key_auth(prisma_client): | |
| from litellm.proxy.proxy_server import ProxyException | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True}) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # Test case: No API Key passed in | |
| try: | |
| await user_api_key_auth(request, api_key=None) | |
| pytest.fail(f"This should have failed!. IT's an invalid key") | |
| except ProxyException as exc: | |
| print(exc.message) | |
| assert exc.message == "Authentication Error, No api key passed in." | |
| # Test case: Malformed API Key (missing 'Bearer ' prefix) | |
| try: | |
| await user_api_key_auth(request, api_key="my_token") | |
| pytest.fail(f"This should have failed!. IT's an invalid key") | |
| except ProxyException as exc: | |
| print(exc.message) | |
| assert ( | |
| exc.message | |
| == "Authentication Error, Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: my_token" | |
| ) | |
| # Test case: User passes empty string API Key | |
| try: | |
| await user_api_key_auth(request, api_key="") | |
| pytest.fail(f"This should have failed!. IT's an invalid key") | |
| except ProxyException as exc: | |
| print(exc.message) | |
| assert ( | |
| exc.message | |
| == "Authentication Error, Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: " | |
| ) | |
| async def test_user_api_key_auth_without_master_key(prisma_client): | |
| # if master key is not set, expect all calls to go through | |
| try: | |
| from litellm.proxy.proxy_server import ProxyException | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", None) | |
| setattr( | |
| litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True} | |
| ) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # Test case: No API Key passed in | |
| await user_api_key_auth(request, api_key=None) | |
| await user_api_key_auth(request, api_key="my_token") | |
| await user_api_key_auth(request, api_key="") | |
| await user_api_key_auth(request, api_key="Bearer " + "1234") | |
| except Exception as e: | |
| print("Got Exception", e) | |
| pytest.fail(f"Got exception {e}") | |
| async def test_key_with_no_permissions(prisma_client): | |
| """ | |
| - create key | |
| - get key info | |
| - assert key_name is null | |
| """ | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": False}) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| try: | |
| response = await generate_key_helper_fn( | |
| request_type="key", | |
| **{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": "ishaan", "team_id": "litellm-dashboard"}, # type: ignore | |
| ) | |
| print(response) | |
| key = response["token"] | |
| # make a /chat/completions call -> it should fail | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key="Bearer " + key) | |
| print("result from user auth with new key", result) | |
| pytest.fail(f"This should have failed!. IT's an invalid key") | |
| except Exception as e: | |
| print("Got Exception", e) | |
| print(e.message) | |
| async def track_cost_callback_helper_fn(generated_key: str, user_id: str): | |
| import uuid | |
| from litellm import Choices, Message, ModelResponse, Usage | |
| from litellm.proxy.proxy_server import _ProxyDBLogger | |
| request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{uuid.uuid4()}" | |
| resp = ModelResponse( | |
| id=request_id, | |
| choices=[ | |
| Choices( | |
| finish_reason=None, | |
| index=0, | |
| message=Message( | |
| content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
| role="assistant", | |
| ), | |
| ) | |
| ], | |
| model="gpt-35-turbo", # azure always has model written like this | |
| usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
| ) | |
| proxy_db_logger = _ProxyDBLogger() | |
| await proxy_db_logger._PROXY_track_cost_callback( | |
| kwargs={ | |
| "call_type": "acompletion", | |
| "model": "sagemaker-chatgpt-v-3", | |
| "stream": True, | |
| "complete_streaming_response": resp, | |
| "litellm_params": { | |
| "metadata": { | |
| "user_api_key": hash_token(generated_key), | |
| "user_api_key_user_id": user_id, | |
| } | |
| }, | |
| "response_cost": 0.00005, | |
| }, | |
| completion_response=resp, | |
| start_time=datetime.now(), | |
| end_time=datetime.now(), | |
| ) | |
| async def test_proxy_load_test_db(prisma_client): | |
| """ | |
| Run 1500 req./s against track_cost_callback function | |
| """ | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| import logging | |
| import time | |
| from litellm._logging import verbose_proxy_logger | |
| litellm.set_verbose = True | |
| verbose_proxy_logger.setLevel(logging.DEBUG) | |
| try: | |
| start_time = time.time() | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = GenerateKeyRequest(max_budget=0.00001) | |
| key = await generate_key_fn( | |
| request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| user_id = key.user_id | |
| bearer_token = "Bearer " + generated_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| print("result from user auth with new key", result) | |
| # update spend using track_cost callback, make 2nd request, it should fail | |
| n = 5000 | |
| tasks = [ | |
| track_cost_callback_helper_fn(generated_key=generated_key, user_id=user_id) | |
| for _ in range(n) | |
| ] | |
| completions = await asyncio.gather(*tasks) | |
| await asyncio.sleep(120) | |
| try: | |
| # call spend logs | |
| spend_logs = await view_spend_logs( | |
| api_key=generated_key, | |
| user_api_key_dict=UserAPIKeyAuth(api_key=generated_key), | |
| ) | |
| print(f"len responses: {len(spend_logs)}") | |
| assert len(spend_logs) == n | |
| print(n, time.time() - start_time, len(spend_logs)) | |
| except Exception: | |
| print(n, time.time() - start_time, 0) | |
| raise Exception(f"it worked! key={key.key}") | |
| except Exception as e: | |
| pytest.fail(f"An exception occurred - {str(e)}") | |
| async def test_master_key_hashing(prisma_client): | |
| try: | |
| import uuid | |
| print("prisma client=", prisma_client) | |
| master_key = "sk-1234" | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", master_key) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| _team_id = "ishaans-special-team_{}".format(uuid.uuid4()) | |
| user_api_key_dict = UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ) | |
| await new_team( | |
| NewTeamRequest(team_id=_team_id), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| _response = await new_user( | |
| data=NewUserRequest( | |
| models=["azure-gpt-3.5"], | |
| team_id=_team_id, | |
| tpm_limit=20, | |
| ), | |
| user_api_key_dict=user_api_key_dict, | |
| ) | |
| print(_response) | |
| assert _response.models == ["azure-gpt-3.5"] | |
| assert _response.team_id == _team_id | |
| assert _response.tpm_limit == 20 | |
| bearer_token = "Bearer " + master_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # use generated key to auth in | |
| result: UserAPIKeyAuth = await user_api_key_auth( | |
| request=request, api_key=bearer_token | |
| ) | |
| assert result.api_key == hash_token(master_key) | |
| except Exception as e: | |
| print("Got Exception", e) | |
| pytest.fail(f"Got exception {e}") | |
| async def test_reset_spend_authentication(prisma_client): | |
| """ | |
| 1. Test master key can access this route -> ONLY MASTER KEY SHOULD BE ABLE TO RESET SPEND | |
| 2. Test that non-master key gets rejected | |
| 3. Test that non-master key with role == LitellmUserRoles.PROXY_ADMIN or admin gets rejected | |
| """ | |
| print("prisma client=", prisma_client) | |
| master_key = "sk-1234" | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", master_key) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| bearer_token = "Bearer " + master_key | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/global/spend/reset") | |
| # Test 1 - Master Key | |
| result: UserAPIKeyAuth = await user_api_key_auth( | |
| request=request, api_key=bearer_token | |
| ) | |
| print("result from user auth with Master key", result) | |
| assert result.token is not None | |
| # Test 2 - Non-Master Key | |
| _response = await new_user( | |
| data=NewUserRequest( | |
| tpm_limit=20, | |
| ) | |
| ) | |
| generate_key = "Bearer " + _response.key | |
| try: | |
| await user_api_key_auth(request=request, api_key=generate_key) | |
| pytest.fail(f"This should have failed!. IT's an expired key") | |
| except Exception as e: | |
| print("Got Exception", e) | |
| assert ( | |
| "Tried to access route=/global/spend/reset, which is only for MASTER KEY" | |
| in e.message | |
| ) | |
| # Test 3 - Non-Master Key with role == LitellmUserRoles.PROXY_ADMIN or admin | |
| _response = await new_user( | |
| data=NewUserRequest( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| tpm_limit=20, | |
| ) | |
| ) | |
| generate_key = "Bearer " + _response.key | |
| try: | |
| await user_api_key_auth(request=request, api_key=generate_key) | |
| pytest.fail(f"This should have failed!. IT's an expired key") | |
| except Exception as e: | |
| print("Got Exception", e) | |
| assert ( | |
| "Tried to access route=/global/spend/reset, which is only for MASTER KEY" | |
| in e.message | |
| ) | |
| async def test_create_update_team(prisma_client): | |
| """ | |
| - Set max_budget, budget_duration, max_budget, tpm_limit, rpm_limit | |
| - Assert response has correct values | |
| - Update max_budget, budget_duration, max_budget, tpm_limit, rpm_limit | |
| - Assert response has correct values | |
| - Call team_info and assert response has correct values | |
| """ | |
| print("prisma client=", prisma_client) | |
| master_key = "sk-1234" | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", master_key) | |
| import datetime | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| _team_id = "test-team_{}".format(uuid.uuid4()) | |
| response = await new_team( | |
| NewTeamRequest( | |
| team_id=_team_id, | |
| max_budget=20, | |
| budget_duration="30d", | |
| tpm_limit=20, | |
| rpm_limit=20, | |
| ), | |
| http_request=Request(scope={"type": "http"}), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print("RESPONSE from new_team", response) | |
| assert response["team_id"] == _team_id | |
| assert response["max_budget"] == 20 | |
| assert response["tpm_limit"] == 20 | |
| assert response["rpm_limit"] == 20 | |
| assert response["budget_duration"] == "30d" | |
| assert response["budget_reset_at"] is not None and isinstance( | |
| response["budget_reset_at"], datetime.datetime | |
| ) | |
| # updating team budget duration and reset at | |
| response = await update_team( | |
| UpdateTeamRequest( | |
| team_id=_team_id, | |
| max_budget=30, | |
| budget_duration="2d", | |
| tpm_limit=30, | |
| rpm_limit=30, | |
| ), | |
| http_request=Request(scope={"type": "http"}), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print("RESPONSE from update_team", response) | |
| _updated_info = response["data"] | |
| _updated_info = dict(_updated_info) | |
| assert _updated_info["team_id"] == _team_id | |
| assert _updated_info["max_budget"] == 30 | |
| assert _updated_info["tpm_limit"] == 30 | |
| assert _updated_info["rpm_limit"] == 30 | |
| assert _updated_info["budget_duration"] == "2d" | |
| assert _updated_info["budget_reset_at"] is not None and isinstance( | |
| _updated_info["budget_reset_at"], datetime.datetime | |
| ) | |
| # budget_reset_at should be 2 days from now | |
| budget_reset_at = _updated_info["budget_reset_at"].replace(tzinfo=timezone.utc) | |
| current_time = datetime.datetime.now(timezone.utc) | |
| # Verify that budget_reset_at is at midnight (hour, minute, second are all 0) | |
| assert budget_reset_at.hour == 0 | |
| assert budget_reset_at.minute == 0 | |
| assert budget_reset_at.second == 0 | |
| # Calculate days difference - should be close to 2 days (within 1 day to account for time of test execution) | |
| days_diff = (budget_reset_at.date() - current_time.date()).days | |
| assert 1 <= days_diff <= 2 | |
| # now hit team_info | |
| try: | |
| response = await team_info( | |
| team_id=_team_id, | |
| http_request=Request(scope={"type": "http"}), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| except Exception as e: | |
| print(e) | |
| pytest.fail("Receives error - {}".format(e)) | |
| _team_info = response["team_info"] | |
| _team_info = dict(_team_info) | |
| assert _team_info["team_id"] == _team_id | |
| assert _team_info["max_budget"] == 30 | |
| assert _team_info["tpm_limit"] == 30 | |
| assert _team_info["rpm_limit"] == 30 | |
| assert _team_info["budget_duration"] == "2d" | |
| assert _team_info["budget_reset_at"] is not None and isinstance( | |
| _team_info["budget_reset_at"], datetime.datetime | |
| ) | |
| async def test_update_user_role(prisma_client): | |
| """ | |
| Tests if we update user role, incorrect values are not stored in cache | |
| -> create a user with role == INTERNAL_USER | |
| -> access an Admin only route -> expect to fail | |
| -> update user role to == PROXY_ADMIN | |
| -> access an Admin only route -> expect to succeed | |
| """ | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| key = await new_user( | |
| data=NewUserRequest( | |
| user_role=LitellmUserRoles.INTERNAL_USER, | |
| ) | |
| ) | |
| print(key) | |
| api_key = "Bearer " + key.key | |
| api_route = APIRoute(path="/global/spend", endpoint=global_spend) | |
| request = Request( | |
| { | |
| "type": "http", | |
| "route": api_route, | |
| "path": "/global/spend", | |
| "headers": [("Authorization", api_key)], | |
| } | |
| ) | |
| request._url = URL(url="/global/spend") | |
| # use generated key to auth in | |
| try: | |
| result = await user_api_key_auth(request=request, api_key=api_key) | |
| print("result from user auth with new key", result) | |
| except Exception as e: | |
| print(e) | |
| pass | |
| await user_update( | |
| data=UpdateUserRequest( | |
| user_id=key.user_id, user_role=LitellmUserRoles.PROXY_ADMIN | |
| ) | |
| ) | |
| # await asyncio.sleep(3) | |
| # use generated key to auth in | |
| print("\n\nMAKING NEW REQUEST WITH UPDATED USER ROLE\n\n") | |
| result = await user_api_key_auth(request=request, api_key=api_key) | |
| print("result from user auth with new key", result) | |
| async def test_update_user_unit_test(prisma_client): | |
| """ | |
| Unit test for /user/update | |
| Ensure that params are updated for UpdateUserRequest | |
| """ | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| key = await new_user( | |
| data=NewUserRequest( | |
| user_email=f"test-{uuid.uuid4()}@test.com", | |
| ) | |
| ) | |
| print(key) | |
| user_info = await user_update( | |
| data=UpdateUserRequest( | |
| user_id=key.user_id, | |
| team_id="1234", | |
| max_budget=100, | |
| budget_duration="10d", | |
| tpm_limit=100, | |
| rpm_limit=100, | |
| metadata={"very-new-metadata": "something"}, | |
| ) | |
| ) | |
| print("user_info", user_info) | |
| assert user_info is not None | |
| _user_info = user_info["data"].model_dump() | |
| assert _user_info["user_id"] == key.user_id | |
| assert _user_info["team_id"] == "1234" | |
| assert _user_info["max_budget"] == 100 | |
| assert _user_info["budget_duration"] == "10d" | |
| assert _user_info["tpm_limit"] == 100 | |
| assert _user_info["rpm_limit"] == 100 | |
| assert _user_info["metadata"] == {"very-new-metadata": "something"} | |
| # budget_reset_at should be at midnight 10 days from now | |
| budget_reset_at = _user_info["budget_reset_at"].replace(tzinfo=timezone.utc) | |
| current_time = datetime.now(timezone.utc) | |
| # Verify that budget_reset_at is at midnight (hour, minute, second are all 0) | |
| assert budget_reset_at.hour == 0 | |
| assert budget_reset_at.minute == 0 | |
| assert budget_reset_at.second == 0 | |
| # Calculate days difference - should be close to 10 days (within 1 day to account for time of test execution) | |
| days_diff = (budget_reset_at.date() - current_time.date()).days | |
| assert 9 <= days_diff <= 10 | |
| async def test_custom_api_key_header_name(prisma_client): | |
| """ """ | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr( | |
| litellm.proxy.proxy_server, | |
| "general_settings", | |
| {"litellm_key_header_name": "x-litellm-key"}, | |
| ) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| api_route = APIRoute(path="/chat/completions", endpoint=chat_completion) | |
| request = Request( | |
| { | |
| "type": "http", | |
| "route": api_route, | |
| "path": api_route.path, | |
| "headers": [ | |
| (b"x-litellm-key", b"Bearer sk-1234"), | |
| ], | |
| } | |
| ) | |
| # this should pass because we pass the master key as X-Litellm-Key and litellm_key_header_name="X-Litellm-Key" in general settings | |
| result = await user_api_key_auth(request=request, api_key="Bearer invalid-key") | |
| # this should fail because X-Litellm-Key is invalid | |
| request = Request( | |
| { | |
| "type": "http", | |
| "route": api_route, | |
| "path": api_route.path, | |
| "headers": [], | |
| } | |
| ) | |
| try: | |
| result = await user_api_key_auth(request=request, api_key="Bearer sk-1234") | |
| pytest.fail(f"This should have failed!. invalid Auth on this request") | |
| except Exception as e: | |
| print("failed with error", e) | |
| assert ( | |
| "Malformed API Key passed in. Ensure Key has `Bearer ` prefix" in e.message | |
| ) | |
| pass | |
| # this should pass because X-Litellm-Key is valid | |
| async def test_generate_key_with_model_tpm_limit(prisma_client): | |
| print("prisma client=", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = GenerateKeyRequest( | |
| metadata={ | |
| "team": "litellm-team3", | |
| "model_tpm_limit": {"gpt-4": 100}, | |
| "model_rpm_limit": {"gpt-4": 2}, | |
| } | |
| ) | |
| key = await generate_key_fn( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| # use generated key to auth in | |
| result = await info_key_fn( | |
| key=generated_key, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| print("result from info_key_fn", result) | |
| assert result["key"] == generated_key | |
| print("\n info for key=", result["info"]) | |
| assert result["info"]["metadata"] == { | |
| "team": "litellm-team3", | |
| "model_tpm_limit": {"gpt-4": 100}, | |
| "model_rpm_limit": {"gpt-4": 2}, | |
| } | |
| # Update model tpm_limit and rpm_limit | |
| request = UpdateKeyRequest( | |
| key=generated_key, | |
| model_tpm_limit={"gpt-4": 200}, | |
| model_rpm_limit={"gpt-4": 3}, | |
| ) | |
| _request = Request(scope={"type": "http"}) | |
| _request._url = URL(url="/update/key") | |
| await update_key_fn( | |
| data=request, | |
| request=_request, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| result = await info_key_fn( | |
| key=generated_key, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| print("result from info_key_fn", result) | |
| assert result["key"] == generated_key | |
| print("\n info for key=", result["info"]) | |
| assert result["info"]["metadata"] == { | |
| "team": "litellm-team3", | |
| "model_tpm_limit": {"gpt-4": 200}, | |
| "model_rpm_limit": {"gpt-4": 3}, | |
| } | |
| async def test_generate_key_with_guardrails(prisma_client): | |
| print("prisma client=", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| request = GenerateKeyRequest( | |
| guardrails=["aporia-pre-call"], | |
| metadata={ | |
| "team": "litellm-team3", | |
| }, | |
| ) | |
| key = await generate_key_fn( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print("generated key=", key) | |
| generated_key = key.key | |
| # use generated key to auth in | |
| result = await info_key_fn( | |
| key=generated_key, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| print("result from info_key_fn", result) | |
| assert result["key"] == generated_key | |
| print("\n info for key=", result["info"]) | |
| assert result["info"]["metadata"] == { | |
| "team": "litellm-team3", | |
| "guardrails": ["aporia-pre-call"], | |
| } | |
| # Update model tpm_limit and rpm_limit | |
| request = UpdateKeyRequest( | |
| key=generated_key, | |
| guardrails=["aporia-pre-call", "aporia-post-call"], | |
| ) | |
| _request = Request(scope={"type": "http"}) | |
| _request._url = URL(url="/update/key") | |
| await update_key_fn( | |
| data=request, | |
| request=_request, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| result = await info_key_fn( | |
| key=generated_key, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| print("result from info_key_fn", result) | |
| assert result["key"] == generated_key | |
| print("\n info for key=", result["info"]) | |
| assert result["info"]["metadata"] == { | |
| "team": "litellm-team3", | |
| "guardrails": ["aporia-pre-call", "aporia-post-call"], | |
| } | |
| async def test_team_guardrails(prisma_client): | |
| """ | |
| - Test setting guardrails on a team | |
| - Assert this is returned when calling /team/info | |
| - Team/update with guardrails should update the guardrails | |
| - Assert new guardrails are returned when calling /team/info | |
| """ | |
| litellm.set_verbose = True | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| _new_team = NewTeamRequest( | |
| team_alias="test-teamA", | |
| guardrails=["aporia-pre-call"], | |
| ) | |
| new_team_response = await new_team( | |
| data=_new_team, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| print("new_team_response", new_team_response) | |
| # call /team/info | |
| team_info_response = await team_info( | |
| team_id=new_team_response["team_id"], | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| print("team_info_response", team_info_response) | |
| assert team_info_response["team_info"].metadata["guardrails"] == ["aporia-pre-call"] | |
| # team update with guardrails | |
| team_update_response = await update_team( | |
| data=UpdateTeamRequest( | |
| team_id=new_team_response["team_id"], | |
| guardrails=["aporia-pre-call", "aporia-post-call"], | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| print("team_update_response", team_update_response) | |
| # call /team/info again | |
| team_info_response = await team_info( | |
| team_id=new_team_response["team_id"], | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| print("team_info_response", team_info_response) | |
| assert team_info_response["team_info"].metadata["guardrails"] == [ | |
| "aporia-pre-call", | |
| "aporia-post-call", | |
| ] | |
| async def test_team_access_groups(prisma_client): | |
| """ | |
| Test team based model access groups | |
| - Test calling a model in the access group -> pass | |
| - Test calling a model not in the access group -> fail | |
| """ | |
| litellm.set_verbose = True | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| # create router with access groups | |
| litellm_router = litellm.Router( | |
| model_list=[ | |
| { | |
| "model_name": "gemini-pro-vision", | |
| "litellm_params": { | |
| "model": "vertex_ai/gemini-1.0-pro-vision-001", | |
| }, | |
| "model_info": {"access_groups": ["beta-models"]}, | |
| }, | |
| { | |
| "model_name": "gpt-4o", | |
| "litellm_params": { | |
| "model": "gpt-4o", | |
| }, | |
| "model_info": {"access_groups": ["beta-models"]}, | |
| }, | |
| ] | |
| ) | |
| setattr(litellm.proxy.proxy_server, "llm_router", litellm_router) | |
| # Create team with models=["beta-models"] | |
| team_request = NewTeamRequest( | |
| team_alias="testing-team", | |
| models=["beta-models"], | |
| ) | |
| new_team_response = await new_team( | |
| data=team_request, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| print("new_team_response", new_team_response) | |
| created_team_id = new_team_response["team_id"] | |
| # create key with team_id=created_team_id | |
| request = GenerateKeyRequest( | |
| team_id=created_team_id, | |
| ) | |
| key = await generate_key_fn( | |
| data=request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| print(key) | |
| generated_key = key.key | |
| bearer_token = "Bearer " + generated_key | |
| request._url = URL(url="/chat/completions") | |
| for model in ["gpt-4o", "gemini-pro-vision"]: | |
| # Expect these to pass | |
| async def return_body(): | |
| return_string = f'{{"model": "{model}"}}' | |
| # return string as bytes | |
| return return_string.encode() | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| request.body = return_body | |
| # use generated key to auth in | |
| print( | |
| "Bearer token being sent to user_api_key_auth() - {}".format(bearer_token) | |
| ) | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| for model in ["gpt-4", "gpt-4o-mini", "gemini-experimental"]: | |
| # Expect these to fail | |
| async def return_body_2(): | |
| return_string = f'{{"model": "{model}"}}' | |
| # return string as bytes | |
| return return_string.encode() | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| request.body = return_body_2 | |
| # use generated key to auth in | |
| print( | |
| "Bearer token being sent to user_api_key_auth() - {}".format(bearer_token) | |
| ) | |
| try: | |
| result = await user_api_key_auth(request=request, api_key=bearer_token) | |
| pytest.fail(f"This should have failed!. IT's an invalid model") | |
| except Exception as e: | |
| print("got exception", e) | |
| assert isinstance(e, ProxyException) | |
| assert e.type == ProxyErrorTypes.team_model_access_denied | |
| assert e.param == "model" | |
| async def test_team_tags(prisma_client): | |
| """ | |
| - Test setting tags on a team | |
| - Assert this is returned when calling /team/info | |
| - Team/update with tags should update the tags | |
| - Assert new tags are returned when calling /team/info | |
| """ | |
| litellm.set_verbose = True | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| _new_team = NewTeamRequest( | |
| team_alias="test-teamA", | |
| tags=["teamA"], | |
| ) | |
| new_team_response = await new_team( | |
| data=_new_team, | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| print("new_team_response", new_team_response) | |
| # call /team/info | |
| team_info_response = await team_info( | |
| team_id=new_team_response["team_id"], | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| print("team_info_response", team_info_response) | |
| assert team_info_response["team_info"].metadata["tags"] == ["teamA"] | |
| # team update with tags | |
| team_update_response = await update_team( | |
| data=UpdateTeamRequest( | |
| team_id=new_team_response["team_id"], | |
| tags=["teamA", "teamB"], | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| print("team_update_response", team_update_response) | |
| # call /team/info again | |
| team_info_response = await team_info( | |
| team_id=new_team_response["team_id"], | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| print("team_info_response", team_info_response) | |
| assert team_info_response["team_info"].metadata["tags"] == ["teamA", "teamB"] | |
| async def test_aadmin_only_routes(prisma_client): | |
| """ | |
| Tests if setting admin_only_routes works | |
| only an admin should be able to access admin only routes | |
| """ | |
| litellm.set_verbose = True | |
| print(f"os.getenv('DATABASE_URL')={os.getenv('DATABASE_URL')}") | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| general_settings = { | |
| "allowed_routes": ["/embeddings", "/key/generate"], | |
| "admin_only_routes": ["/key/generate"], | |
| } | |
| from litellm.proxy import proxy_server | |
| initial_general_settings = getattr(proxy_server, "general_settings") | |
| setattr(proxy_server, "general_settings", general_settings) | |
| admin_user = await new_user( | |
| data=NewUserRequest( | |
| user_name="admin", | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| non_admin_user = await new_user( | |
| data=NewUserRequest( | |
| user_name="non-admin", | |
| user_role=LitellmUserRoles.INTERNAL_USER, | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| admin_user_key = admin_user.key | |
| non_admin_user_key = non_admin_user.key | |
| assert admin_user_key is not None | |
| assert non_admin_user_key is not None | |
| # assert non-admin can not access admin routes | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/key/generate") | |
| await user_api_key_auth( | |
| request=request, | |
| api_key="Bearer " + admin_user_key, | |
| ) | |
| # this should pass | |
| try: | |
| await user_api_key_auth( | |
| request=request, | |
| api_key="Bearer " + non_admin_user_key, | |
| ) | |
| pytest.fail("Expected this call to fail. User is over limit.") | |
| except Exception as e: | |
| print("error str=", str(e.message)) | |
| error_str = str(e.message) | |
| assert "Route" in error_str and "admin only route" in error_str | |
| pass | |
| setattr(proxy_server, "general_settings", initial_general_settings) | |
| async def test_list_keys(prisma_client): | |
| """ | |
| Test the list_keys function: | |
| - Test basic key | |
| - Test pagination | |
| - Test filtering by user_id, and key_alias | |
| """ | |
| from fastapi import Query | |
| from litellm.proxy.proxy_server import hash_token | |
| from litellm.proxy._types import LitellmUserRoles | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| # Test basic listing | |
| request = Request(scope={"type": "http", "query_string": b""}) | |
| response = await list_keys( | |
| request, | |
| UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN.value, | |
| ), | |
| page=1, | |
| size=10, | |
| ) | |
| print("response=", response) | |
| assert "keys" in response | |
| assert len(response["keys"]) > 0 | |
| assert "total_count" in response | |
| assert "current_page" in response | |
| assert "total_pages" in response | |
| # Test pagination | |
| response = await list_keys( | |
| request, | |
| UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value), | |
| page=1, | |
| size=2, | |
| ) | |
| print("pagination response=", response) | |
| assert len(response["keys"]) == 2 | |
| assert response["current_page"] == 1 | |
| # Test filtering by user_id | |
| unique_id = str(uuid.uuid4()) | |
| team_id = f"key-list-team-{unique_id}" | |
| key_alias = f"key-list-alias-{unique_id}" | |
| user_id = f"key-list-user-{unique_id}" | |
| response = await new_user( | |
| data=NewUserRequest( | |
| user_id=f"key-list-user-{unique_id}", | |
| user_role=LitellmUserRoles.INTERNAL_USER, | |
| key_alias=f"key-list-alias-{unique_id}", | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
| ) | |
| _key = hash_token(response.key) | |
| await asyncio.sleep(2) | |
| # Test filtering by user_id | |
| response = await list_keys( | |
| request, | |
| UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value), | |
| user_id=user_id, | |
| page=1, | |
| size=10, | |
| ) | |
| print("filtered user_id response=", response) | |
| assert len(response["keys"]) == 1 | |
| assert _key in response["keys"] | |
| # Test filtering by key_alias | |
| response = await list_keys( | |
| request, | |
| UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value), | |
| key_alias=key_alias, | |
| page=1, | |
| size=10, | |
| ) | |
| assert len(response["keys"]) == 1 | |
| assert _key in response["keys"] | |
| async def test_auth_vertex_ai_route(prisma_client): | |
| """ | |
| If user is premium user and vertex-ai route is used. Assert Virtual Key checks are run | |
| """ | |
| litellm.set_verbose = True | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "premium_user", True) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| route = "/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent" | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url=route) | |
| request._headers = {"Authorization": "Bearer sk-12345"} | |
| try: | |
| await user_api_key_auth(request=request, api_key="Bearer " + "sk-12345") | |
| pytest.fail("Expected this call to fail. User is over limit.") | |
| except Exception as e: | |
| print(vars(e)) | |
| print("error str=", str(e.message)) | |
| error_str = str(e.message) | |
| assert e.code == "401" | |
| assert "Invalid proxy server token passed" in error_str | |
| pass | |
| async def test_user_api_key_auth_db_unavailable(): | |
| """ | |
| Test that user_api_key_auth handles DB connection failures appropriately when: | |
| 1. DB connection fails during token validation | |
| 2. allow_requests_on_db_unavailable=True | |
| """ | |
| litellm.set_verbose = True | |
| # Mock dependencies | |
| class MockPrismaClient: | |
| async def get_data(self, *args, **kwargs): | |
| print("MockPrismaClient.get_data() called") | |
| raise httpx.ConnectError("Failed to connect to DB") | |
| async def connect(self): | |
| print("MockPrismaClient.connect() called") | |
| pass | |
| class MockDualCache: | |
| async def async_get_cache(self, *args, **kwargs): | |
| return None | |
| async def async_set_cache(self, *args, **kwargs): | |
| pass | |
| async def set_cache(self, *args, **kwargs): | |
| pass | |
| # Set up test environment | |
| setattr(litellm.proxy.proxy_server, "prisma_client", MockPrismaClient()) | |
| setattr(litellm.proxy.proxy_server, "user_api_key_cache", MockDualCache()) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr( | |
| litellm.proxy.proxy_server, | |
| "general_settings", | |
| {"allow_requests_on_db_unavailable": True}, | |
| ) | |
| # Create test request | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # Run test with a sample API key | |
| result = await user_api_key_auth( | |
| request=request, | |
| api_key="Bearer sk-123456789", | |
| ) | |
| # Verify results | |
| assert isinstance(result, UserAPIKeyAuth) | |
| assert result.key_name == "failed-to-connect-to-db" | |
| assert result.user_id == litellm.proxy.proxy_server.litellm_proxy_admin_name | |
| async def test_user_api_key_auth_db_unavailable_not_allowed(): | |
| """ | |
| Test that user_api_key_auth raises an exception when: | |
| This is default behavior | |
| 1. DB connection fails during token validation | |
| 2. allow_requests_on_db_unavailable=False (default behavior) | |
| """ | |
| # Mock dependencies | |
| class MockPrismaClient: | |
| async def get_data(self, *args, **kwargs): | |
| print("MockPrismaClient.get_data() called") | |
| raise httpx.ConnectError("Failed to connect to DB") | |
| async def connect(self): | |
| print("MockPrismaClient.connect() called") | |
| pass | |
| class MockDualCache: | |
| async def async_get_cache(self, *args, **kwargs): | |
| return None | |
| async def async_set_cache(self, *args, **kwargs): | |
| pass | |
| async def set_cache(self, *args, **kwargs): | |
| pass | |
| # Set up test environment | |
| setattr(litellm.proxy.proxy_server, "prisma_client", MockPrismaClient()) | |
| setattr(litellm.proxy.proxy_server, "user_api_key_cache", MockDualCache()) | |
| setattr(litellm.proxy.proxy_server, "general_settings", {}) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| # Create test request | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/chat/completions") | |
| # Run test with a sample API key | |
| with pytest.raises(litellm.proxy._types.ProxyException): | |
| await user_api_key_auth( | |
| request=request, | |
| api_key="Bearer sk-123456789", | |
| ) | |
| ## E2E Virtual Key + Secret Manager Tests ######################################### | |
| async def test_key_generate_with_secret_manager_call(prisma_client): | |
| """ | |
| Generate a key | |
| assert it exists in the secret manager | |
| delete the key | |
| assert it is deleted from the secret manager | |
| """ | |
| from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2 | |
| from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings | |
| from litellm.proxy.hooks.key_management_event_hooks import ( | |
| LITELLM_PREFIX_STORED_VIRTUAL_KEYS, | |
| ) | |
| litellm.set_verbose = True | |
| #### Test Setup ############################################################ | |
| aws_secret_manager_client = AWSSecretsManagerV2() | |
| litellm.secret_manager_client = aws_secret_manager_client | |
| litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER | |
| litellm._key_management_settings = KeyManagementSettings( | |
| store_virtual_keys=True, | |
| ) | |
| general_settings = { | |
| "key_management_system": "aws_secret_manager", | |
| "key_management_settings": { | |
| "store_virtual_keys": True, | |
| }, | |
| } | |
| setattr(litellm.proxy.proxy_server, "general_settings", general_settings) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| ############################################################################ | |
| # generate new key | |
| key_alias = f"test_alias_secret_manager_key-{uuid.uuid4()}" | |
| spend = 100 | |
| max_budget = 400 | |
| models = ["fake-openai-endpoint"] | |
| new_key = await generate_key_fn( | |
| data=GenerateKeyRequest( | |
| key_alias=key_alias, spend=spend, max_budget=max_budget, models=models | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| generated_key = new_key.key | |
| print(generated_key) | |
| await asyncio.sleep(2) | |
| # read from the secret manager | |
| result = await aws_secret_manager_client.async_read_secret( | |
| secret_name=f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{key_alias}" | |
| ) | |
| # Assert the correct key is stored in the secret manager | |
| print("response from AWS Secret Manager") | |
| print(result) | |
| assert result == generated_key | |
| # delete the key | |
| await delete_key_fn( | |
| data=KeyRequest(keys=[generated_key]), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234" | |
| ), | |
| ) | |
| await asyncio.sleep(2) | |
| # Assert the key is deleted from the secret manager | |
| result = await aws_secret_manager_client.async_read_secret( | |
| secret_name=f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{key_alias}" | |
| ) | |
| assert result is None | |
| # cleanup | |
| setattr(litellm.proxy.proxy_server, "general_settings", {}) | |
| ################################################################################ | |
| async def test_key_alias_uniqueness(prisma_client): | |
| """ | |
| Test that: | |
| 1. We cannot create two keys with the same alias | |
| 2. We cannot update a key to use an alias that's already taken | |
| 3. We can update a key while keeping its existing alias | |
| """ | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| try: | |
| # Create first key with an alias | |
| unique_alias = f"test-alias-{uuid.uuid4()}" | |
| key1 = await generate_key_fn( | |
| data=GenerateKeyRequest(key_alias=unique_alias), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| # Try to create second key with same alias - should fail | |
| try: | |
| key2 = await generate_key_fn( | |
| data=GenerateKeyRequest(key_alias=unique_alias), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| pytest.fail("Should not be able to create a second key with the same alias") | |
| except Exception as e: | |
| print("vars(e)=", vars(e)) | |
| assert "Unique key aliases across all keys are required" in str(e.message) | |
| # Create another key with different alias | |
| another_alias = f"test-alias-{uuid.uuid4()}" | |
| key3 = await generate_key_fn( | |
| data=GenerateKeyRequest(key_alias=another_alias), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| # Try to update key3 to use key1's alias - should fail | |
| try: | |
| await update_key_fn( | |
| data=UpdateKeyRequest(key=key3.key, key_alias=unique_alias), | |
| request=Request(scope={"type": "http"}), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| pytest.fail("Should not be able to update a key to use an existing alias") | |
| except Exception as e: | |
| assert "Unique key aliases across all keys are required" in str(e.message) | |
| # Update key1 with its own existing alias - should succeed | |
| updated_key = await update_key_fn( | |
| data=UpdateKeyRequest(key=key1.key, key_alias=unique_alias), | |
| request=Request(scope={"type": "http"}), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| assert updated_key is not None | |
| except Exception as e: | |
| print("got exceptions, e=", e) | |
| print("vars(e)=", vars(e)) | |
| pytest.fail(f"An unexpected error occurred: {str(e)}") | |
| async def test_enforce_unique_key_alias(prisma_client): | |
| """ | |
| Unit test the _enforce_unique_key_alias function: | |
| 1. Test it allows unique aliases | |
| 2. Test it blocks duplicate aliases for new keys | |
| 3. Test it allows updating a key with its own existing alias | |
| 4. Test it blocks updating a key with another key's alias | |
| """ | |
| from litellm.proxy.management_endpoints.key_management_endpoints import ( | |
| _enforce_unique_key_alias, | |
| ) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| try: | |
| # Test 1: Allow unique alias | |
| unique_alias = f"test-alias-{uuid.uuid4()}" | |
| await _enforce_unique_key_alias( | |
| key_alias=unique_alias, | |
| prisma_client=prisma_client, | |
| ) # Should pass | |
| # Create a key with this alias in the database | |
| key1 = await generate_key_fn( | |
| data=GenerateKeyRequest(key_alias=unique_alias), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| # Test 2: Block duplicate alias for new key | |
| try: | |
| await _enforce_unique_key_alias( | |
| key_alias=unique_alias, | |
| prisma_client=prisma_client, | |
| ) | |
| pytest.fail("Should not allow duplicate alias") | |
| except Exception as e: | |
| assert "Unique key aliases across all keys are required" in str(e.message) | |
| # Test 3: Allow updating key with its own alias | |
| await _enforce_unique_key_alias( | |
| key_alias=unique_alias, | |
| existing_key_token=hash_token(key1.key), | |
| prisma_client=prisma_client, | |
| ) # Should pass | |
| # Test 4: Block updating with another key's alias | |
| another_key = await generate_key_fn( | |
| data=GenerateKeyRequest(key_alias=f"test-alias-{uuid.uuid4()}"), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| try: | |
| await _enforce_unique_key_alias( | |
| key_alias=unique_alias, | |
| existing_key_token=another_key.key, | |
| prisma_client=prisma_client, | |
| ) | |
| pytest.fail("Should not allow using another key's alias") | |
| except Exception as e: | |
| assert "Unique key aliases across all keys are required" in str(e.message) | |
| except Exception as e: | |
| print("Unexpected error:", e) | |
| pytest.fail(f"An unexpected error occurred: {str(e)}") | |
| def test_should_track_cost_callback(): | |
| """ | |
| Test that the should_track_cost_callback function works as expected | |
| """ | |
| from litellm.proxy.hooks.proxy_track_cost_callback import ( | |
| _should_track_cost_callback, | |
| ) | |
| assert _should_track_cost_callback( | |
| user_api_key=None, | |
| user_id=None, | |
| team_id=None, | |
| end_user_id="1234", | |
| ) | |
| async def test_get_paginated_teams(prisma_client): | |
| """ | |
| Test the get_paginated_teams function: | |
| 1. Test pagination returns valid results | |
| 2. Test total count matches across pages | |
| 3. Test page size is respected | |
| """ | |
| from litellm.proxy.management_endpoints.team_endpoints import get_paginated_teams | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| try: | |
| # Get first page with page_size=2 | |
| teams_page_1, total_count_1 = await get_paginated_teams( | |
| prisma_client=prisma_client, page_size=2, page=1 | |
| ) | |
| print("teams_page_1=", teams_page_1) | |
| print("total_count_1=", total_count_1) | |
| # Get second page | |
| teams_page_2, total_count_2 = await get_paginated_teams( | |
| prisma_client=prisma_client, page_size=2, page=2 | |
| ) | |
| print("teams_page_2=", teams_page_2) | |
| print("total_count_2=", total_count_2) | |
| # Verify results | |
| assert isinstance(teams_page_1, list) # Should return a list | |
| assert isinstance(total_count_1, int) # Should return an integer count | |
| assert ( | |
| total_count_1 == total_count_2 | |
| ) # Total count should be consistent across pages | |
| assert len(teams_page_1) <= 2 # Should respect page_size limit | |
| except Exception as e: | |
| print(f"Error occurred: {e}") | |
| pytest.fail(f"Test failed with exception: {e}") | |
| async def test_reset_budget_job(prisma_client, entity_type): | |
| """ | |
| Test that the ResetBudgetJob correctly resets budgets for keys, users, and teams. | |
| For each entity type: | |
| 1. Create a new entity with max_budget=100, spend=99, budget_duration=5s | |
| 2. Call the reset_budget function | |
| 3. Verify the entity's spend is reset to 0 and budget_reset_at is updated | |
| """ | |
| from datetime import datetime, timedelta | |
| import time | |
| from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob | |
| from litellm.proxy.utils import ProxyLogging | |
| # Setup | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| proxy_logging_obj = ProxyLogging(user_api_key_cache=None) | |
| reset_budget_job = ResetBudgetJob( | |
| proxy_logging_obj=proxy_logging_obj, prisma_client=prisma_client | |
| ) | |
| # Create entity based on type | |
| entity_id = None | |
| if entity_type == "key": | |
| # Create a key with specific budget settings | |
| key = await generate_key_fn( | |
| data=GenerateKeyRequest( | |
| max_budget=100, | |
| budget_duration="5s", | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| entity_id = key.token_id | |
| print("generated key=", key) | |
| # Update the key to set spend and reset_at to now | |
| updated = await prisma_client.db.litellm_verificationtoken.update_many( | |
| where={"token": key.token_id}, | |
| data={ | |
| "spend": 99.0, | |
| }, | |
| ) | |
| print("Updated key=", updated) | |
| elif entity_type == "user": | |
| # Create a user with specific budget settings | |
| user = await new_user( | |
| data=NewUserRequest( | |
| max_budget=100, | |
| budget_duration="5s", | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| entity_id = user.user_id | |
| # Update the user to set spend and reset_at to now | |
| await prisma_client.db.litellm_usertable.update_many( | |
| where={"user_id": user.user_id}, | |
| data={ | |
| "spend": 99.0, | |
| }, | |
| ) | |
| elif entity_type == "team": | |
| # Create a team with specific budget settings | |
| team_id = f"test-team-{uuid.uuid4()}" | |
| team = await new_team( | |
| NewTeamRequest( | |
| team_id=team_id, | |
| max_budget=100, | |
| budget_duration="5s", | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| http_request=Request(scope={"type": "http"}), | |
| ) | |
| entity_id = team_id | |
| # Update the team to set spend and reset_at to now | |
| current_time = datetime.utcnow() | |
| await prisma_client.db.litellm_teamtable.update( | |
| where={"team_id": team_id}, | |
| data={ | |
| "spend": 99.0, | |
| }, | |
| ) | |
| # Verify entity was created and updated with spend | |
| if entity_type == "key": | |
| entity_before = await prisma_client.db.litellm_verificationtoken.find_unique( | |
| where={"token": entity_id} | |
| ) | |
| elif entity_type == "user": | |
| entity_before = await prisma_client.db.litellm_usertable.find_unique( | |
| where={"user_id": entity_id} | |
| ) | |
| elif entity_type == "team": | |
| entity_before = await prisma_client.db.litellm_teamtable.find_unique( | |
| where={"team_id": entity_id} | |
| ) | |
| assert entity_before is not None | |
| assert entity_before.spend == 99.0 | |
| # Wait for 5 seconds to pass | |
| print("sleeping for 5 seconds") | |
| time.sleep(5) | |
| # Call the reset_budget function | |
| await reset_budget_job.reset_budget() | |
| # Verify the entity's spend is reset and budget_reset_at is updated | |
| if entity_type == "key": | |
| entity_after = await prisma_client.db.litellm_verificationtoken.find_unique( | |
| where={"token": entity_id} | |
| ) | |
| elif entity_type == "user": | |
| entity_after = await prisma_client.db.litellm_usertable.find_unique( | |
| where={"user_id": entity_id} | |
| ) | |
| elif entity_type == "team": | |
| entity_after = await prisma_client.db.litellm_teamtable.find_unique( | |
| where={"team_id": entity_id} | |
| ) | |
| assert entity_after is not None | |
| assert entity_after.spend == 0.0 | |
| def test_delete_nonexistent_key_returns_404(prisma_client): | |
| # Try to delete a key that does not exist, expect a 404 error | |
| import random, string | |
| from litellm.proxy._types import KeyRequest, UserAPIKeyAuth, LitellmUserRoles, ProxyException | |
| from litellm.proxy.management_endpoints.key_management_endpoints import delete_key_fn | |
| from starlette.datastructures import URL | |
| from fastapi import Request | |
| print("prisma client=", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| try: | |
| async def test(): | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| # Generate a random key that does not exist | |
| random_key = "sk-" + ''.join(random.choices(string.ascii_letters + string.digits, k=24)) | |
| delete_key_request = KeyRequest(keys=[random_key]) | |
| bearer_token = "Bearer sk-1234" | |
| request = Request(scope={"type": "http"}) | |
| request._url = URL(url="/key/delete") | |
| # use admin to auth in | |
| result = await litellm.proxy.proxy_server.user_api_key_auth(request=request, api_key=bearer_token) | |
| result.user_role = LitellmUserRoles.PROXY_ADMIN | |
| try: | |
| await delete_key_fn(data=delete_key_request, user_api_key_dict=result) | |
| pytest.fail("Expected ProxyException 404 for non-existent key, but delete_key_fn did not raise.") | |
| except ProxyException as e: | |
| print("Caught ProxyException:", e) | |
| assert str(e.code) == "404" | |
| assert "No keys found" in str(e.message) or "No matching keys or aliases found to delete" in str(e.message) | |
| import asyncio | |
| asyncio.run(test()) | |
| except Exception as e: | |
| pytest.fail(f"An exception occurred - {str(e)}") | |