231 lines
7.4 KiB
Python
231 lines
7.4 KiB
Python
"""
|
|
Test file for the enhanced OllamaClient with validation and retry mechanisms.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import json
|
|
from unittest.mock import Mock, patch
|
|
|
|
# Add the current directory to the Python path
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
|
|
def test_ollama_client_initialization():
|
|
"""Test OllamaClient initialization with new parameters"""
|
|
from app.core.services.ollama_client import OllamaClient
|
|
|
|
# Test with default parameters
|
|
client = OllamaClient("test-model")
|
|
assert client.model_name == "test-model"
|
|
assert client.base_url == "http://localhost:11434"
|
|
assert client.max_retries == 3
|
|
|
|
# Test with custom parameters
|
|
client = OllamaClient("test-model", "http://custom:11434", 5)
|
|
assert client.model_name == "test-model"
|
|
assert client.base_url == "http://custom:11434"
|
|
assert client.max_retries == 5
|
|
|
|
print("✓ OllamaClient initialization tests passed")
|
|
|
|
|
|
def test_generate_with_validation():
|
|
"""Test generate_with_validation method"""
|
|
from app.core.services.ollama_client import OllamaClient
|
|
|
|
# Mock the API response
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {
|
|
"response": '{"business_name": "测试公司", "confidence": 0.9}'
|
|
}
|
|
mock_response.raise_for_status.return_value = None
|
|
|
|
with patch('requests.post', return_value=mock_response):
|
|
client = OllamaClient("test-model")
|
|
|
|
# Test with business name extraction validation
|
|
result = client.generate_with_validation(
|
|
prompt="Extract business name from: 测试公司",
|
|
response_type='business_name_extraction',
|
|
return_parsed=True
|
|
)
|
|
|
|
assert isinstance(result, dict)
|
|
assert result.get('business_name') == '测试公司'
|
|
assert result.get('confidence') == 0.9
|
|
|
|
print("✓ generate_with_validation test passed")
|
|
|
|
|
|
def test_generate_with_schema():
|
|
"""Test generate_with_schema method"""
|
|
from app.core.services.ollama_client import OllamaClient
|
|
|
|
# Define a custom schema
|
|
custom_schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {"type": "string"},
|
|
"age": {"type": "number"}
|
|
},
|
|
"required": ["name", "age"]
|
|
}
|
|
|
|
# Mock the API response
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {
|
|
"response": '{"name": "张三", "age": 30}'
|
|
}
|
|
mock_response.raise_for_status.return_value = None
|
|
|
|
with patch('requests.post', return_value=mock_response):
|
|
client = OllamaClient("test-model")
|
|
|
|
# Test with custom schema validation
|
|
result = client.generate_with_schema(
|
|
prompt="Generate person info",
|
|
schema=custom_schema,
|
|
return_parsed=True
|
|
)
|
|
|
|
assert isinstance(result, dict)
|
|
assert result.get('name') == '张三'
|
|
assert result.get('age') == 30
|
|
|
|
print("✓ generate_with_schema test passed")
|
|
|
|
|
|
def test_backward_compatibility():
|
|
"""Test backward compatibility with original generate method"""
|
|
from app.core.services.ollama_client import OllamaClient
|
|
|
|
# Mock the API response
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {
|
|
"response": "Simple text response"
|
|
}
|
|
mock_response.raise_for_status.return_value = None
|
|
|
|
with patch('requests.post', return_value=mock_response):
|
|
client = OllamaClient("test-model")
|
|
|
|
# Test original generate method (should still work)
|
|
result = client.generate("Simple prompt")
|
|
assert result == "Simple text response"
|
|
|
|
# Test with strip_think=False
|
|
result = client.generate("Simple prompt", strip_think=False)
|
|
assert result == "Simple text response"
|
|
|
|
print("✓ Backward compatibility tests passed")
|
|
|
|
|
|
def test_retry_mechanism():
|
|
"""Test retry mechanism for failed requests"""
|
|
from app.core.services.ollama_client import OllamaClient
|
|
import requests
|
|
|
|
# Mock failed requests followed by success
|
|
mock_failed_response = Mock()
|
|
mock_failed_response.raise_for_status.side_effect = requests.exceptions.RequestException("Connection failed")
|
|
|
|
mock_success_response = Mock()
|
|
mock_success_response.json.return_value = {
|
|
"response": "Success response"
|
|
}
|
|
mock_success_response.raise_for_status.return_value = None
|
|
|
|
with patch('requests.post', side_effect=[mock_failed_response, mock_success_response]):
|
|
client = OllamaClient("test-model", max_retries=2)
|
|
|
|
# Should retry and eventually succeed
|
|
result = client.generate("Test prompt")
|
|
assert result == "Success response"
|
|
|
|
print("✓ Retry mechanism test passed")
|
|
|
|
|
|
def test_validation_failure():
|
|
"""Test validation failure handling"""
|
|
from app.core.services.ollama_client import OllamaClient
|
|
|
|
# Mock API response with invalid JSON
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {
|
|
"response": "Invalid JSON response"
|
|
}
|
|
mock_response.raise_for_status.return_value = None
|
|
|
|
with patch('requests.post', return_value=mock_response):
|
|
client = OllamaClient("test-model", max_retries=2)
|
|
|
|
try:
|
|
# This should fail validation and retry
|
|
result = client.generate_with_validation(
|
|
prompt="Test prompt",
|
|
response_type='business_name_extraction',
|
|
return_parsed=True
|
|
)
|
|
# If we get here, it means validation failed and retries were exhausted
|
|
print("✓ Validation failure handling test passed")
|
|
except ValueError as e:
|
|
# Expected behavior - validation failed after retries
|
|
assert "Failed to parse JSON response after all retries" in str(e)
|
|
print("✓ Validation failure handling test passed")
|
|
|
|
|
|
def test_enhanced_methods():
|
|
"""Test the new enhanced methods"""
|
|
from app.core.services.ollama_client import OllamaClient
|
|
|
|
# Mock the API response
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {
|
|
"response": '{"entities": [{"text": "张三", "type": "人名"}]}'
|
|
}
|
|
mock_response.raise_for_status.return_value = None
|
|
|
|
with patch('requests.post', return_value=mock_response):
|
|
client = OllamaClient("test-model")
|
|
|
|
# Test generate_with_validation
|
|
result = client.generate_with_validation(
|
|
prompt="Extract entities",
|
|
response_type='entity_extraction',
|
|
return_parsed=True
|
|
)
|
|
|
|
assert isinstance(result, dict)
|
|
assert 'entities' in result
|
|
assert len(result['entities']) == 1
|
|
assert result['entities'][0]['text'] == '张三'
|
|
|
|
print("✓ Enhanced methods tests passed")
|
|
|
|
|
|
def main():
|
|
"""Run all tests"""
|
|
print("Testing enhanced OllamaClient...")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
test_ollama_client_initialization()
|
|
test_generate_with_validation()
|
|
test_generate_with_schema()
|
|
test_backward_compatibility()
|
|
test_retry_mechanism()
|
|
test_validation_failure()
|
|
test_enhanced_methods()
|
|
|
|
print("\n" + "=" * 50)
|
|
print("✓ All enhanced OllamaClient tests passed!")
|
|
|
|
except Exception as e:
|
|
print(f"\n✗ Test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|