Spaces:
Paused
Paused
| import os | |
| import sys | |
| import traceback | |
| import uuid | |
| from datetime import datetime | |
| from dotenv import load_dotenv | |
| from fastapi import Request | |
| from fastapi.routing import APIRoute | |
| import io | |
| import os | |
| import time | |
| # this file is to test litellm/proxy | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system path | |
| import asyncio | |
| import logging | |
| load_dotenv() | |
| import pytest | |
| import uuid | |
| import litellm | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.proxy.proxy_server import ( | |
| LitellmUserRoles, | |
| audio_transcriptions, | |
| chat_completion, | |
| completion, | |
| embeddings, | |
| model_list, | |
| moderations, | |
| user_api_key_auth, | |
| ) | |
| from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend | |
| verbose_proxy_logger.setLevel(level=logging.DEBUG) | |
| from starlette.datastructures import URL | |
| from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update | |
| from litellm.proxy._types import LiteLLM_AuditLogs, LitellmTableNames | |
| from litellm.caching.caching import DualCache | |
| from unittest.mock import patch, AsyncMock | |
| proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) | |
| import json | |
| async def test_create_audit_log_for_update_premium_user(): | |
| """ | |
| Basic unit test for create_audit_log_for_update | |
| Test that the audit log is created when a premium user updates a team | |
| """ | |
| with patch("litellm.proxy.proxy_server.premium_user", True), patch( | |
| "litellm.store_audit_logs", True | |
| ), patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma: | |
| mock_prisma.db.litellm_auditlog.create = AsyncMock() | |
| request_data = LiteLLM_AuditLogs( | |
| id="test_id", | |
| updated_at=datetime.now(), | |
| changed_by="test_changed_by", | |
| action="updated", | |
| table_name=LitellmTableNames.TEAM_TABLE_NAME, | |
| object_id="test_object_id", | |
| updated_values=json.dumps({"key": "value"}), | |
| before_value=json.dumps({"old_key": "old_value"}), | |
| ) | |
| await create_audit_log_for_update(request_data) | |
| mock_prisma.db.litellm_auditlog.create.assert_called_once_with( | |
| data={ | |
| "id": "test_id", | |
| "updated_at": request_data.updated_at, | |
| "changed_by": request_data.changed_by, | |
| "action": request_data.action, | |
| "table_name": request_data.table_name, | |
| "object_id": request_data.object_id, | |
| "updated_values": request_data.updated_values, | |
| "before_value": request_data.before_value, | |
| } | |
| ) | |
| def prisma_client(): | |
| from litellm.proxy.proxy_cli import append_query_params | |
| ### add connection pool + pool timeout args | |
| params = {"connection_limit": 100, "pool_timeout": 60} | |
| database_url = os.getenv("DATABASE_URL") | |
| modified_url = append_query_params(database_url, params) | |
| os.environ["DATABASE_URL"] = modified_url | |
| # Assuming PrismaClient is a class that needs to be instantiated | |
| prisma_client = PrismaClient( | |
| database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj | |
| ) | |
| return prisma_client | |
| async def test_create_audit_log_in_db(prisma_client): | |
| print("prisma client=", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr(litellm.proxy.proxy_server, "premium_user", True) | |
| setattr(litellm, "store_audit_logs", True) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| audit_log_id = f"audit_log_id_{uuid.uuid4()}" | |
| # create a audit log for /key/generate | |
| request_data = LiteLLM_AuditLogs( | |
| id=audit_log_id, | |
| updated_at=datetime.now(), | |
| changed_by="test_changed_by", | |
| action="updated", | |
| table_name=LitellmTableNames.TEAM_TABLE_NAME, | |
| object_id="test_object_id", | |
| updated_values=json.dumps({"key": "value"}), | |
| before_value=json.dumps({"old_key": "old_value"}), | |
| ) | |
| await create_audit_log_for_update(request_data) | |
| await asyncio.sleep(1) | |
| # now read the last log from the db | |
| last_log = await prisma_client.db.litellm_auditlog.find_first( | |
| where={"id": audit_log_id} | |
| ) | |
| assert last_log.id == audit_log_id | |
| setattr(litellm, "store_audit_logs", False) | |