legal-doc-masker/backend/test_enhanced_ollama_client.py

231 lines
7.4 KiB
Python

"""
Test file for the enhanced OllamaClient with validation and retry mechanisms.
"""
import sys
import os
import json
from unittest.mock import Mock, patch
# Add the current directory to the Python path
sys.path.insert(0, os.path.dirname(__file__))
def test_ollama_client_initialization():
"""Test OllamaClient initialization with new parameters"""
from app.core.services.ollama_client import OllamaClient
# Test with default parameters
client = OllamaClient("test-model")
assert client.model_name == "test-model"
assert client.base_url == "http://localhost:11434"
assert client.max_retries == 3
# Test with custom parameters
client = OllamaClient("test-model", "http://custom:11434", 5)
assert client.model_name == "test-model"
assert client.base_url == "http://custom:11434"
assert client.max_retries == 5
print("✓ OllamaClient initialization tests passed")
def test_generate_with_validation():
"""Test generate_with_validation method"""
from app.core.services.ollama_client import OllamaClient
# Mock the API response
mock_response = Mock()
mock_response.json.return_value = {
"response": '{"business_name": "测试公司", "confidence": 0.9}'
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model")
# Test with business name extraction validation
result = client.generate_with_validation(
prompt="Extract business name from: 测试公司",
response_type='business_name_extraction',
return_parsed=True
)
assert isinstance(result, dict)
assert result.get('business_name') == '测试公司'
assert result.get('confidence') == 0.9
print("✓ generate_with_validation test passed")
def test_generate_with_schema():
"""Test generate_with_schema method"""
from app.core.services.ollama_client import OllamaClient
# Define a custom schema
custom_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"}
},
"required": ["name", "age"]
}
# Mock the API response
mock_response = Mock()
mock_response.json.return_value = {
"response": '{"name": "张三", "age": 30}'
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model")
# Test with custom schema validation
result = client.generate_with_schema(
prompt="Generate person info",
schema=custom_schema,
return_parsed=True
)
assert isinstance(result, dict)
assert result.get('name') == '张三'
assert result.get('age') == 30
print("✓ generate_with_schema test passed")
def test_backward_compatibility():
"""Test backward compatibility with original generate method"""
from app.core.services.ollama_client import OllamaClient
# Mock the API response
mock_response = Mock()
mock_response.json.return_value = {
"response": "Simple text response"
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model")
# Test original generate method (should still work)
result = client.generate("Simple prompt")
assert result == "Simple text response"
# Test with strip_think=False
result = client.generate("Simple prompt", strip_think=False)
assert result == "Simple text response"
print("✓ Backward compatibility tests passed")
def test_retry_mechanism():
"""Test retry mechanism for failed requests"""
from app.core.services.ollama_client import OllamaClient
import requests
# Mock failed requests followed by success
mock_failed_response = Mock()
mock_failed_response.raise_for_status.side_effect = requests.exceptions.RequestException("Connection failed")
mock_success_response = Mock()
mock_success_response.json.return_value = {
"response": "Success response"
}
mock_success_response.raise_for_status.return_value = None
with patch('requests.post', side_effect=[mock_failed_response, mock_success_response]):
client = OllamaClient("test-model", max_retries=2)
# Should retry and eventually succeed
result = client.generate("Test prompt")
assert result == "Success response"
print("✓ Retry mechanism test passed")
def test_validation_failure():
"""Test validation failure handling"""
from app.core.services.ollama_client import OllamaClient
# Mock API response with invalid JSON
mock_response = Mock()
mock_response.json.return_value = {
"response": "Invalid JSON response"
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model", max_retries=2)
try:
# This should fail validation and retry
result = client.generate_with_validation(
prompt="Test prompt",
response_type='business_name_extraction',
return_parsed=True
)
# If we get here, it means validation failed and retries were exhausted
print("✓ Validation failure handling test passed")
except ValueError as e:
# Expected behavior - validation failed after retries
assert "Failed to parse JSON response after all retries" in str(e)
print("✓ Validation failure handling test passed")
def test_enhanced_methods():
"""Test the new enhanced methods"""
from app.core.services.ollama_client import OllamaClient
# Mock the API response
mock_response = Mock()
mock_response.json.return_value = {
"response": '{"entities": [{"text": "张三", "type": "人名"}]}'
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model")
# Test generate_with_validation
result = client.generate_with_validation(
prompt="Extract entities",
response_type='entity_extraction',
return_parsed=True
)
assert isinstance(result, dict)
assert 'entities' in result
assert len(result['entities']) == 1
assert result['entities'][0]['text'] == '张三'
print("✓ Enhanced methods tests passed")
def main():
"""Run all tests"""
print("Testing enhanced OllamaClient...")
print("=" * 50)
try:
test_ollama_client_initialization()
test_generate_with_validation()
test_generate_with_schema()
test_backward_compatibility()
test_retry_mechanism()
test_validation_failure()
test_enhanced_methods()
print("\n" + "=" * 50)
print("✓ All enhanced OllamaClient tests passed!")
except Exception as e:
print(f"\n✗ Test failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()