""" Test file for the enhanced OllamaClient with validation and retry mechanisms. """ import sys import os import json from unittest.mock import Mock, patch # Add the current directory to the Python path sys.path.insert(0, os.path.dirname(__file__)) def test_ollama_client_initialization(): """Test OllamaClient initialization with new parameters""" from app.core.services.ollama_client import OllamaClient # Test with default parameters client = OllamaClient("test-model") assert client.model_name == "test-model" assert client.base_url == "http://localhost:11434" assert client.max_retries == 3 # Test with custom parameters client = OllamaClient("test-model", "http://custom:11434", 5) assert client.model_name == "test-model" assert client.base_url == "http://custom:11434" assert client.max_retries == 5 print("✓ OllamaClient initialization tests passed") def test_generate_with_validation(): """Test generate_with_validation method""" from app.core.services.ollama_client import OllamaClient # Mock the API response mock_response = Mock() mock_response.json.return_value = { "response": '{"business_name": "测试公司", "confidence": 0.9}' } mock_response.raise_for_status.return_value = None with patch('requests.post', return_value=mock_response): client = OllamaClient("test-model") # Test with business name extraction validation result = client.generate_with_validation( prompt="Extract business name from: 测试公司", response_type='business_name_extraction', return_parsed=True ) assert isinstance(result, dict) assert result.get('business_name') == '测试公司' assert result.get('confidence') == 0.9 print("✓ generate_with_validation test passed") def test_generate_with_schema(): """Test generate_with_schema method""" from app.core.services.ollama_client import OllamaClient # Define a custom schema custom_schema = { "type": "object", "properties": { "name": {"type": "string"}, "age": {"type": "number"} }, "required": ["name", "age"] } # Mock the API response mock_response = Mock() mock_response.json.return_value = { "response": '{"name": "张三", "age": 30}' } mock_response.raise_for_status.return_value = None with patch('requests.post', return_value=mock_response): client = OllamaClient("test-model") # Test with custom schema validation result = client.generate_with_schema( prompt="Generate person info", schema=custom_schema, return_parsed=True ) assert isinstance(result, dict) assert result.get('name') == '张三' assert result.get('age') == 30 print("✓ generate_with_schema test passed") def test_backward_compatibility(): """Test backward compatibility with original generate method""" from app.core.services.ollama_client import OllamaClient # Mock the API response mock_response = Mock() mock_response.json.return_value = { "response": "Simple text response" } mock_response.raise_for_status.return_value = None with patch('requests.post', return_value=mock_response): client = OllamaClient("test-model") # Test original generate method (should still work) result = client.generate("Simple prompt") assert result == "Simple text response" # Test with strip_think=False result = client.generate("Simple prompt", strip_think=False) assert result == "Simple text response" print("✓ Backward compatibility tests passed") def test_retry_mechanism(): """Test retry mechanism for failed requests""" from app.core.services.ollama_client import OllamaClient import requests # Mock failed requests followed by success mock_failed_response = Mock() mock_failed_response.raise_for_status.side_effect = requests.exceptions.RequestException("Connection failed") mock_success_response = Mock() mock_success_response.json.return_value = { "response": "Success response" } mock_success_response.raise_for_status.return_value = None with patch('requests.post', side_effect=[mock_failed_response, mock_success_response]): client = OllamaClient("test-model", max_retries=2) # Should retry and eventually succeed result = client.generate("Test prompt") assert result == "Success response" print("✓ Retry mechanism test passed") def test_validation_failure(): """Test validation failure handling""" from app.core.services.ollama_client import OllamaClient # Mock API response with invalid JSON mock_response = Mock() mock_response.json.return_value = { "response": "Invalid JSON response" } mock_response.raise_for_status.return_value = None with patch('requests.post', return_value=mock_response): client = OllamaClient("test-model", max_retries=2) try: # This should fail validation and retry result = client.generate_with_validation( prompt="Test prompt", response_type='business_name_extraction', return_parsed=True ) # If we get here, it means validation failed and retries were exhausted print("✓ Validation failure handling test passed") except ValueError as e: # Expected behavior - validation failed after retries assert "Failed to parse JSON response after all retries" in str(e) print("✓ Validation failure handling test passed") def test_enhanced_methods(): """Test the new enhanced methods""" from app.core.services.ollama_client import OllamaClient # Mock the API response mock_response = Mock() mock_response.json.return_value = { "response": '{"entities": [{"text": "张三", "type": "人名"}]}' } mock_response.raise_for_status.return_value = None with patch('requests.post', return_value=mock_response): client = OllamaClient("test-model") # Test generate_with_validation result = client.generate_with_validation( prompt="Extract entities", response_type='entity_extraction', return_parsed=True ) assert isinstance(result, dict) assert 'entities' in result assert len(result['entities']) == 1 assert result['entities'][0]['text'] == '张三' print("✓ Enhanced methods tests passed") def main(): """Run all tests""" print("Testing enhanced OllamaClient...") print("=" * 50) try: test_ollama_client_initialization() test_generate_with_validation() test_generate_with_schema() test_backward_compatibility() test_retry_mechanism() test_validation_failure() test_enhanced_methods() print("\n" + "=" * 50) print("✓ All enhanced OllamaClient tests passed!") except Exception as e: print(f"\n✗ Test failed: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()