Spaces:
Paused
Paused
| """ | |
| Test that all cache calls in async functions in router_strategy/ are async | |
| """ | |
| import os | |
| import sys | |
| from typing import Dict, List, Tuple | |
| import ast | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system path | |
| import os | |
| class AsyncCacheCallVisitor(ast.NodeVisitor): | |
| def __init__(self): | |
| self.async_functions: Dict[str, List[Tuple[str, int]]] = {} | |
| self.current_function = None | |
| def visit_AsyncFunctionDef(self, node): | |
| """Visit async function definitions and store their cache calls""" | |
| self.current_function = node.name | |
| self.async_functions[node.name] = [] | |
| self.generic_visit(node) | |
| self.current_function = None | |
| def visit_Call(self, node): | |
| """Visit function calls and check for cache operations""" | |
| if self.current_function is not None: | |
| # Check if it's a cache-related call | |
| if isinstance(node.func, ast.Attribute): | |
| method_name = node.func.attr | |
| if any(keyword in method_name.lower() for keyword in ["cache"]): | |
| # Get the full method call path | |
| if isinstance(node.func.value, ast.Name): | |
| full_call = f"{node.func.value.id}.{method_name}" | |
| elif isinstance(node.func.value, ast.Attribute): | |
| # Handle nested attributes like self.router_cache.get | |
| parts = [] | |
| current = node.func.value | |
| while isinstance(current, ast.Attribute): | |
| parts.append(current.attr) | |
| current = current.value | |
| if isinstance(current, ast.Name): | |
| parts.append(current.id) | |
| parts.reverse() | |
| parts.append(method_name) | |
| full_call = ".".join(parts) | |
| else: | |
| full_call = method_name | |
| # Store both the call and its line number | |
| self.async_functions[self.current_function].append( | |
| (full_call, node.lineno) | |
| ) | |
| self.generic_visit(node) | |
| def get_python_files(directory: str) -> List[str]: | |
| """Get all Python files in the router_strategy directory""" | |
| python_files = [] | |
| for file in os.listdir(directory): | |
| if file.endswith(".py") and not file.startswith("__"): | |
| python_files.append(os.path.join(directory, file)) | |
| return python_files | |
| def analyze_file(file_path: str) -> Dict[str, List[Tuple[str, int]]]: | |
| """Analyze a Python file for async functions and their cache calls""" | |
| with open(file_path, "r") as file: | |
| tree = ast.parse(file.read()) | |
| visitor = AsyncCacheCallVisitor() | |
| visitor.visit(tree) | |
| return visitor.async_functions | |
| def test_router_strategy_async_cache_calls(): | |
| """Test that all cache calls in async functions are properly async""" | |
| strategy_dir = os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.dirname(__file__))), | |
| "litellm", | |
| "router_strategy", | |
| ) | |
| # Get all Python files in the router_strategy directory | |
| python_files = get_python_files(strategy_dir) | |
| print("python files:", python_files) | |
| all_async_functions: Dict[str, Dict[str, List[Tuple[str, int]]]] = {} | |
| for file_path in python_files: | |
| file_name = os.path.basename(file_path) | |
| async_functions = analyze_file(file_path) | |
| if async_functions: | |
| all_async_functions[file_name] = async_functions | |
| print(f"\nAnalyzing {file_name}:") | |
| for func_name, cache_calls in async_functions.items(): | |
| print(f"\nAsync function: {func_name}") | |
| print(f"Cache calls found: {cache_calls}") | |
| # Assert that cache calls in async functions use async methods | |
| for call, line_number in cache_calls: | |
| if any(keyword in call.lower() for keyword in ["cache"]): | |
| assert ( | |
| "async" in call.lower() | |
| ), f"VIOLATION: Cache call '{call}' in async function '{func_name}' should be async. file path: {file_path}, line number: {line_number}" | |
| # Assert we found async functions to analyze | |
| assert ( | |
| len(all_async_functions) > 0 | |
| ), "No async functions found in router_strategy directory" | |
| if __name__ == "__main__": | |
| test_router_strategy_async_cache_calls() | |