Spaces:
Paused
Paused
| import os | |
| import sys | |
| import traceback | |
| from unittest import mock | |
| import pytest | |
| from dotenv import load_dotenv | |
| import litellm.proxy | |
| import litellm.proxy.proxy_server | |
| load_dotenv() | |
| import io | |
| import os | |
| # this file is to test litellm/proxy | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system path | |
| import asyncio | |
| import logging | |
| from litellm.proxy.proxy_server import ProxyConfig | |
| INVALID_FILES = ["config_with_missing_include.yaml"] | |
| async def test_basic_reading_configs_from_files(): | |
| """ | |
| Test that the config is read correctly from the files in the example_config_yaml folder | |
| """ | |
| proxy_config_instance = ProxyConfig() | |
| current_path = os.path.dirname(os.path.abspath(__file__)) | |
| example_config_yaml_path = os.path.join(current_path, "example_config_yaml") | |
| # get all the files from example_config_yaml | |
| files = os.listdir(example_config_yaml_path) | |
| print(files) | |
| for file in files: | |
| if file in INVALID_FILES: # these are intentionally invalid files | |
| continue | |
| print("reading file=", file) | |
| config_path = os.path.join(example_config_yaml_path, file) | |
| config = await proxy_config_instance.get_config(config_file_path=config_path) | |
| print(config) | |
| async def test_read_config_from_bad_file_path(): | |
| """ | |
| Raise an exception if the file path is not valid | |
| """ | |
| proxy_config_instance = ProxyConfig() | |
| config_path = "non-existent-file.yaml" | |
| with pytest.raises(Exception): | |
| config = await proxy_config_instance.get_config(config_file_path=config_path) | |
| async def test_read_config_file_with_os_environ_vars(): | |
| """ | |
| Ensures os.environ variables are read correctly from config.yaml | |
| Following vars are set as os.environ variables in the config.yaml file | |
| - DEFAULT_USER_ROLE | |
| - AWS_ACCESS_KEY_ID | |
| - AWS_SECRET_ACCESS_KEY | |
| - AZURE_GPT_4O | |
| - FIREWORKS | |
| """ | |
| _env_vars_for_testing = { | |
| "DEFAULT_USER_ROLE": "admin", | |
| "AWS_ACCESS_KEY_ID": "1234567890", | |
| "AWS_SECRET_ACCESS_KEY": "1234567890", | |
| "AZURE_GPT_4O": "1234567890", | |
| "FIREWORKS": "1234567890", | |
| } | |
| _old_env_vars = {} | |
| for key, value in _env_vars_for_testing.items(): | |
| if key in os.environ: | |
| _old_env_vars[key] = os.environ.get(key) | |
| os.environ[key] = value | |
| # Read config | |
| proxy_config_instance = ProxyConfig() | |
| current_path = os.path.dirname(os.path.abspath(__file__)) | |
| config_path = os.path.join( | |
| current_path, "example_config_yaml", "config_with_env_vars.yaml" | |
| ) | |
| config = await proxy_config_instance.get_config(config_file_path=config_path) | |
| print(config) | |
| # Add assertions | |
| assert ( | |
| config["litellm_settings"]["default_internal_user_params"]["user_role"] | |
| == "admin" | |
| ) | |
| assert ( | |
| config["litellm_settings"]["s3_callback_params"]["s3_aws_access_key_id"] | |
| == "1234567890" | |
| ) | |
| assert ( | |
| config["litellm_settings"]["s3_callback_params"]["s3_aws_secret_access_key"] | |
| == "1234567890" | |
| ) | |
| for model in config["model_list"]: | |
| if "azure" in model["litellm_params"]["model"]: | |
| assert model["litellm_params"]["api_key"] == "1234567890" | |
| elif "fireworks" in model["litellm_params"]["model"]: | |
| assert model["litellm_params"]["api_key"] == "1234567890" | |
| # cleanup | |
| for key, value in _env_vars_for_testing.items(): | |
| if key in _old_env_vars: | |
| os.environ[key] = _old_env_vars[key] | |
| else: | |
| del os.environ[key] | |
| async def test_basic_include_directive(): | |
| """ | |
| Test that the include directive correctly loads and merges configs | |
| """ | |
| proxy_config_instance = ProxyConfig() | |
| current_path = os.path.dirname(os.path.abspath(__file__)) | |
| config_path = os.path.join( | |
| current_path, "example_config_yaml", "config_with_include.yaml" | |
| ) | |
| config = await proxy_config_instance.get_config(config_file_path=config_path) | |
| # Verify the included model list was merged | |
| assert len(config["model_list"]) > 0 | |
| assert any( | |
| model["model_name"] == "included-model" for model in config["model_list"] | |
| ) | |
| # Verify original config settings remain | |
| assert config["litellm_settings"]["callbacks"] == ["prometheus"] | |
| async def test_missing_include_file(): | |
| """ | |
| Test that a missing included file raises FileNotFoundError | |
| """ | |
| proxy_config_instance = ProxyConfig() | |
| current_path = os.path.dirname(os.path.abspath(__file__)) | |
| config_path = os.path.join( | |
| current_path, "example_config_yaml", "config_with_missing_include.yaml" | |
| ) | |
| with pytest.raises(FileNotFoundError): | |
| await proxy_config_instance.get_config(config_file_path=config_path) | |
| async def test_multiple_includes(): | |
| """ | |
| Test that multiple files in the include list are all processed correctly | |
| """ | |
| proxy_config_instance = ProxyConfig() | |
| current_path = os.path.dirname(os.path.abspath(__file__)) | |
| config_path = os.path.join( | |
| current_path, "example_config_yaml", "config_with_multiple_includes.yaml" | |
| ) | |
| config = await proxy_config_instance.get_config(config_file_path=config_path) | |
| # Verify models from both included files are present | |
| assert len(config["model_list"]) == 2 | |
| assert any( | |
| model["model_name"] == "included-model-1" for model in config["model_list"] | |
| ) | |
| assert any( | |
| model["model_name"] == "included-model-2" for model in config["model_list"] | |
| ) | |
| # Verify original config settings remain | |
| assert config["litellm_settings"]["callbacks"] == ["prometheus"] | |
| def test_add_callbacks_from_db_config(): | |
| """Test that callbacks are added correctly and duplicates are prevented""" | |
| # Setup | |
| from litellm.integrations.langfuse.langfuse_prompt_management import ( | |
| LangfusePromptManagement, | |
| ) | |
| proxy_config = ProxyConfig() | |
| # Reset litellm callbacks before test | |
| litellm.success_callback = [] | |
| litellm.failure_callback = [] | |
| # Test Case 1: Add new callbacks | |
| config_data = { | |
| "litellm_settings": { | |
| "success_callback": ["langfuse", "custom_callback_api"], | |
| "failure_callback": ["langfuse"], | |
| } | |
| } | |
| proxy_config._add_callbacks_from_db_config(config_data) | |
| # 1 instance of LangfusePromptManagement should exist in litellm.success_callback | |
| num_langfuse_instances = sum( | |
| isinstance(callback, LangfusePromptManagement) | |
| for callback in litellm.success_callback | |
| ) | |
| assert num_langfuse_instances == 1 | |
| assert len(litellm.success_callback) == 2 | |
| assert len(litellm.failure_callback) == 1 | |
| # Test Case 2: Try adding duplicate callbacks | |
| proxy_config._add_callbacks_from_db_config(config_data) | |
| # Verify no duplicates were added | |
| assert len(litellm.success_callback) == 2 | |
| assert len(litellm.failure_callback) == 1 | |
| # Cleanup | |
| litellm.success_callback = [] | |
| litellm.failure_callback = [] | |
| litellm._known_custom_logger_compatible_callbacks = [] | |
| def test_add_callbacks_invalid_input(): | |
| """Test handling of invalid input for callbacks""" | |
| proxy_config = ProxyConfig() | |
| # Reset callbacks | |
| litellm.success_callback = [] | |
| litellm.failure_callback = [] | |
| # Test Case 1: Invalid callback format | |
| config_data = { | |
| "litellm_settings": { | |
| "success_callback": "invalid_string_format", # Should be a list | |
| "failure_callback": 123, # Should be a list | |
| } | |
| } | |
| proxy_config._add_callbacks_from_db_config(config_data) | |
| # Verify no callbacks were added with invalid input | |
| assert len(litellm.success_callback) == 0 | |
| assert len(litellm.failure_callback) == 0 | |
| # Test Case 2: Missing litellm_settings | |
| config_data = {} | |
| proxy_config._add_callbacks_from_db_config(config_data) | |
| # Verify no callbacks were added | |
| assert len(litellm.success_callback) == 0 | |
| assert len(litellm.failure_callback) == 0 | |
| # Cleanup | |
| litellm.success_callback = [] | |
| litellm.failure_callback = [] | |