Spaces:
Paused
Paused
| import ast | |
| import os | |
| import re | |
| def find_azure_files(base_dir): | |
| """ | |
| Find all Python files in the Azure directory. | |
| """ | |
| azure_files = [] | |
| for root, _, files in os.walk(base_dir): | |
| for file in files: | |
| if file.endswith(".py"): | |
| azure_files.append(os.path.join(root, file)) | |
| return azure_files | |
| def check_direct_instantiation(file_path): | |
| """ | |
| Check if a file directly instantiates AzureOpenAI or AsyncAzureOpenAI | |
| outside of the BaseAzureLLM class methods. | |
| """ | |
| with open(file_path, "r") as file: | |
| content = file.read() | |
| # Parse the file | |
| tree = ast.parse(content) | |
| # Track issues found | |
| issues = [] | |
| # Find all class definitions | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.ClassDef): | |
| class_name = node.name | |
| # Skip BaseAzureLLM class since it's allowed to define the client creation methods | |
| if class_name == "BaseAzureLLM": | |
| continue | |
| # Check method bodies for direct instantiation | |
| for method in node.body: | |
| if isinstance(method, ast.FunctionDef) or isinstance( | |
| method, ast.AsyncFunctionDef | |
| ): | |
| method_name = method.name | |
| # Skip methods that are specifically for client creation | |
| if method_name in [ | |
| "get_azure_openai_client", | |
| "initialize_azure_sdk_client", | |
| ]: | |
| continue | |
| # Look for direct instantiation in the method body | |
| for subnode in ast.walk(method): | |
| if isinstance(subnode, ast.Call): | |
| if hasattr(subnode, "func") and hasattr(subnode.func, "id"): | |
| if subnode.func.id in [ | |
| "AzureOpenAI", | |
| "AsyncAzureOpenAI", | |
| ]: | |
| issues.append( | |
| f"Direct instantiation of {subnode.func.id} in {class_name}.{method_name}" | |
| ) | |
| elif hasattr(subnode, "func") and hasattr( | |
| subnode.func, "attr" | |
| ): | |
| if subnode.func.attr in [ | |
| "AzureOpenAI", | |
| "AsyncAzureOpenAI", | |
| ]: | |
| issues.append( | |
| f"Direct instantiation of {subnode.func.attr} in {class_name}.{method_name}" | |
| ) | |
| return issues | |
| def main(): | |
| """ | |
| Main function to run the test. | |
| """ | |
| # local | |
| base_dir = "../../litellm/llms/azure" | |
| azure_files = find_azure_files(base_dir) | |
| print(f"Found {len(azure_files)} Azure Python files to check") | |
| all_issues = [] | |
| for file_path in azure_files: | |
| issues = check_direct_instantiation(file_path) | |
| if issues: | |
| all_issues.extend([f"{file_path}: {issue}" for issue in issues]) | |
| if all_issues: | |
| print("Found direct instantiations of AzureOpenAI or AsyncAzureOpenAI:") | |
| for issue in all_issues: | |
| print(f" - {issue}") | |
| raise Exception( | |
| f"Found {len(all_issues)} direct instantiations of AzureOpenAI or AsyncAzureOpenAI classes. Use get_azure_openai_client instead." | |
| ) | |
| else: | |
| print("All Azure modules are correctly using get_azure_openai_client!") | |
| if __name__ == "__main__": | |
| main() | |