File size: 4,123 Bytes
0fd441a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# tests/test_llm.py
# run with pytest tests/.
# 
# import pytest
from unittest.mock import patch, MagicMock
import huggingface_hub
from huggingface_hub import get_token

from llm.llm_login import login_huggingface
from llm.provider_validator import is_valid_provider, suggest_providers
from llm.hf_client import HFChatClient  # Assuming this exists
from llm.openai_client import OpenAIClient  # Assuming this exists

def test_login_huggingface_success():
    with patch('huggingface_hub.login') as mock_login:
        api_token = "hf_test_token"
        login_huggingface(api_token)
        mock_login.assert_called_once_with(token=api_token, add_to_git_credential=False)

def test_login_huggingface_no_token():
    with patch('huggingface_hub.login') as mock_login:
        with pytest.raises(ValueError, match="API token required"):
            login_huggingface(None)

@patch('huggingface_hub.login')
def test_login_huggingface_error(mock_login):
    mock_login.side_effect = Exception("Login failed")
    with pytest.raises(Exception, match="Login failed"):
        login_huggingface("invalid_token")

def test_is_valid_provider():
    assert is_valid_provider("huggingface") is True
    assert is_valid_provider("openai") is True
    assert is_valid_provider("invalid_provider") is False
    assert is_valid_provider("") is False
    assert is_valid_provider(None) is False

def test_suggest_providers():
    suggestions = suggest_providers("hugngface")  # Typo example
    assert isinstance(suggestions, list)
    assert "huggingface" in suggestions
    
    no_suggestions = suggest_providers("completely_unknown")
    assert isinstance(no_suggestions, list)
    assert len(no_suggestions) == 0

@patch('llm.hf_client.HFChatClient.__init__')
def test_hf_client_init(mock_init):
    mock_init.return_value = None
    client = HFChatClient(model_id="test-model", api_token="test_token")
    mock_init.assert_called_once_with(
        model_id="test-model",
        api_token="test_token",
        # Add other expected params based on actual __init__
    )

@patch('llm.hf_client.login_huggingface')
@patch('llm.hf_client.get_token')
def test_hf_client_token(mock_get_token, mock_login):
    mock_get_token.return_value = "cached_token"
    mock_login.return_value = None
    
    client = HFChatClient(model_id="test-model")
    assert client.api_token == "cached_token"

@patch('openai.OpenAI')
def test_openai_client_init(mock_openai):
    mock_client = MagicMock()
    mock_openai.return_value = mock_client
    
    client = OpenAIClient(api_key="sk_test_key", base_url="https://api.openai.com/v1")
    mock_openai.assert_called_once_with(
        api_key="sk_test_key",
        base_url="https://api.openai.com/v1"
    )
    assert client.client == mock_client

@patch('openai.OpenAI')
def test_openai_client_chat(mock_openai):
    mock_client = MagicMock()
    mock_response = MagicMock()
    mock_response.choices = [MagicMock(content="Hello!")]
    mock_client.chat.completions.create.return_value = mock_response
    mock_openai.return_value = mock_client
    
    client = OpenAIClient(api_key="sk_test_key")
    response = client.chat("Hello", model="gpt-3.5-turbo")
    
    assert response == "Hello!"
    mock_client.chat.completions.create.assert_called_once_with(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": "Hello"}]
    )

def test_provider_validator_edge_cases():
    # Test with non-string inputs
    assert is_valid_provider(123) is False
    assert suggest_providers(123) == []
    
    # Test case insensitivity
    assert is_valid_provider("HUGGINGFACE") is True
    assert is_valid_provider("OpEnAi") is True

@patch('huggingface_hub.get_token')
def test_get_token_from_env(mock_get_token):
    mock_get_token.return_value = None
    with patch.dict('os.environ', {'HUGGINGFACE_HUB_TOKEN': 'env_token'}):
        token = get_token()
        assert token == 'env_token'

@patch('huggingface_hub.get_token')
def test_get_token_from_cache(mock_get_token):
    mock_get_token.return_value = 'cached_token'
    token = get_token()
    assert token == 'cached_token'