feat:重构ollama,内置重试逻辑和schema验证

This commit is contained in:
tigermren 2025-08-17 20:09:00 +08:00
parent 70b6617c5e
commit c85e166208
7 changed files with 794 additions and 157 deletions

View File

@ -0,0 +1,255 @@
# OllamaClient Enhancement Summary
## Overview
The `OllamaClient` has been successfully enhanced to support validation and retry mechanisms while maintaining full backward compatibility.
## Key Enhancements
### 1. **Enhanced Constructor**
```python
def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3):
```
- Added `max_retries` parameter for configurable retry attempts
- Default retry count: 3 attempts
### 2. **Enhanced Generate Method**
```python
def generate(self,
prompt: str,
strip_think: bool = True,
validation_schema: Optional[Dict[str, Any]] = None,
response_type: Optional[str] = None,
return_parsed: bool = False) -> Union[str, Dict[str, Any]]:
```
**New Parameters:**
- `validation_schema`: Custom JSON schema for validation
- `response_type`: Predefined response type for validation
- `return_parsed`: Return parsed JSON instead of raw string
**Return Type:**
- `Union[str, Dict[str, Any]]`: Can return either raw string or parsed JSON
### 3. **New Convenience Methods**
#### `generate_with_validation()`
```python
def generate_with_validation(self,
prompt: str,
response_type: str,
strip_think: bool = True,
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
```
- Uses predefined validation schemas based on response type
- Automatically handles retries and validation
- Returns parsed JSON by default
#### `generate_with_schema()`
```python
def generate_with_schema(self,
prompt: str,
schema: Dict[str, Any],
strip_think: bool = True,
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
```
- Uses custom JSON schema for validation
- Automatically handles retries and validation
- Returns parsed JSON by default
### 4. **Supported Response Types**
The following response types are supported for automatic validation:
- `'entity_extraction'`: Entity extraction responses
- `'entity_linkage'`: Entity linkage responses
- `'regex_entity'`: Regex-based entity responses
- `'business_name_extraction'`: Business name extraction responses
- `'address_extraction'`: Address component extraction responses
## Features
### 1. **Automatic Retry Mechanism**
- Retries failed API calls up to `max_retries` times
- Retries on validation failures
- Retries on JSON parsing failures
- Configurable retry count per client instance
### 2. **Built-in Validation**
- JSON schema validation using `jsonschema` library
- Predefined schemas for common response types
- Custom schema support for specialized use cases
- Detailed validation error logging
### 3. **Automatic JSON Parsing**
- Uses `LLMJsonExtractor.parse_raw_json_str()` for robust JSON extraction
- Handles malformed JSON responses gracefully
- Returns parsed Python dictionaries when requested
### 4. **Backward Compatibility**
- All existing code continues to work without changes
- Original `generate()` method signature preserved
- Default behavior unchanged
## Usage Examples
### 1. **Basic Usage (Backward Compatible)**
```python
client = OllamaClient("llama2")
response = client.generate("Hello, world!")
# Returns: "Hello, world!"
```
### 2. **With Response Type Validation**
```python
client = OllamaClient("llama2")
result = client.generate_with_validation(
prompt="Extract business name from: 上海盒马网络科技有限公司",
response_type='business_name_extraction',
return_parsed=True
)
# Returns: {"business_name": "盒马", "confidence": 0.9}
```
### 3. **With Custom Schema Validation**
```python
client = OllamaClient("llama2")
custom_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"}
},
"required": ["name", "age"]
}
result = client.generate_with_schema(
prompt="Generate person info",
schema=custom_schema,
return_parsed=True
)
# Returns: {"name": "张三", "age": 30}
```
### 4. **Advanced Usage with All Options**
```python
client = OllamaClient("llama2", max_retries=5)
result = client.generate(
prompt="Complex prompt",
strip_think=True,
validation_schema=custom_schema,
return_parsed=True
)
```
## Updated Components
### 1. **Extractors**
- `BusinessNameExtractor`: Now uses `generate_with_validation()`
- `AddressExtractor`: Now uses `generate_with_validation()`
### 2. **Processors**
- `NerProcessor`: Updated to use enhanced methods
- `NerProcessorRefactored`: Updated to use enhanced methods
### 3. **Benefits in Processors**
- Simplified code: No more manual retry loops
- Automatic validation: No more manual JSON parsing
- Better error handling: Automatic fallback to regex methods
- Cleaner code: Reduced boilerplate
## Error Handling
### 1. **API Failures**
- Automatic retry on network errors
- Configurable retry count
- Detailed error logging
### 2. **Validation Failures**
- Automatic retry on schema validation failures
- Automatic retry on JSON parsing failures
- Graceful fallback to alternative methods
### 3. **Exception Types**
- `RequestException`: API call failures after all retries
- `ValueError`: Validation failures after all retries
- `Exception`: Unexpected errors
## Testing
### 1. **Test Coverage**
- Initialization with new parameters
- Enhanced generate methods
- Backward compatibility
- Retry mechanism
- Validation failure handling
- Mock-based testing for reliability
### 2. **Run Tests**
```bash
cd backend
python3 test_enhanced_ollama_client.py
```
## Migration Guide
### 1. **No Changes Required**
Existing code continues to work without modification:
```python
# This still works exactly the same
client = OllamaClient("llama2")
response = client.generate("prompt")
```
### 2. **Optional Enhancements**
To take advantage of new features:
```python
# Old way (still works)
response = client.generate(prompt)
parsed = LLMJsonExtractor.parse_raw_json_str(response)
if LLMResponseValidator.validate_entity_extraction(parsed):
# use parsed
# New way (recommended)
parsed = client.generate_with_validation(
prompt=prompt,
response_type='entity_extraction',
return_parsed=True
)
# parsed is already validated and ready to use
```
### 3. **Benefits of Migration**
- **Reduced Code**: Eliminates manual retry loops
- **Better Reliability**: Automatic retry and validation
- **Cleaner Code**: Less boilerplate
- **Better Error Handling**: Automatic fallbacks
## Performance Impact
### 1. **Positive Impact**
- Reduced code complexity
- Better error recovery
- Automatic retry reduces manual intervention
### 2. **Minimal Overhead**
- Validation only occurs when requested
- JSON parsing only occurs when needed
- Retry mechanism only activates on failures
## Future Enhancements
### 1. **Potential Additions**
- Circuit breaker pattern for API failures
- Caching for repeated requests
- Async/await support
- Streaming response support
- Custom retry strategies
### 2. **Configuration Options**
- Per-request retry configuration
- Custom validation error handling
- Response transformation hooks
- Metrics and monitoring
## Conclusion
The enhanced `OllamaClient` provides a robust, reliable, and easy-to-use interface for LLM interactions while maintaining full backward compatibility. The new validation and retry mechanisms significantly improve the reliability of LLM-based operations in the NER processing pipeline.

View File

@ -99,16 +99,18 @@ class AddressExtractor(BaseExtractor):
"""
try:
response = self.ollama_client.generate(prompt)
logger.info(f"Raw LLM response for address extraction: {response}")
# Use the new enhanced generate method with validation
parsed_response = self.ollama_client.generate_with_validation(
prompt=prompt,
response_type='address_extraction',
return_parsed=True
)
parsed_response = LLMJsonExtractor.parse_raw_json_str(response)
if parsed_response and LLMResponseValidator.validate_address_extraction(parsed_response):
if parsed_response:
logger.info(f"Successfully extracted address components: {parsed_response}")
return parsed_response
else:
logger.warning(f"Invalid JSON response for address extraction: {response}")
logger.warning(f"Failed to extract address components for: {address}")
return None
except Exception as e:
logger.error(f"LLM extraction failed: {e}")

View File

@ -83,12 +83,14 @@ class BusinessNameExtractor(BaseExtractor):
"""
try:
response = self.ollama_client.generate(prompt)
logger.info(f"Raw LLM response for business name extraction: {response}")
# Use the new enhanced generate method with validation
parsed_response = self.ollama_client.generate_with_validation(
prompt=prompt,
response_type='business_name_extraction',
return_parsed=True
)
parsed_response = LLMJsonExtractor.parse_raw_json_str(response)
if parsed_response and LLMResponseValidator.validate_business_name_extraction(parsed_response):
if parsed_response:
business_name = parsed_response.get('business_name', '')
# Clean business name, keep only Chinese characters
business_name = re.sub(r'[^\u4e00-\u9fff]', '', business_name)
@ -98,7 +100,7 @@ class BusinessNameExtractor(BaseExtractor):
'confidence': parsed_response.get('confidence', 0.9)
}
else:
logger.warning(f"Invalid JSON response for business name extraction: {response}")
logger.warning(f"Failed to extract business name for: {company_name}")
return None
except Exception as e:
logger.error(f"LLM extraction failed: {e}")

View File

@ -161,20 +161,21 @@ class NerProcessor:
"""
try:
response = self.ollama_client.generate(prompt)
logger.info(f"Raw LLM response for business name extraction: {response}")
# 使用新的增强generate方法进行验证
parsed_response = self.ollama_client.generate_with_validation(
prompt=prompt,
response_type='business_name_extraction',
return_parsed=True
)
# 使用JSON提取器解析响应
parsed_response = LLMJsonExtractor.parse_raw_json_str(response)
if parsed_response and LLMResponseValidator.validate_business_name_extraction(parsed_response):
if parsed_response:
business_name = parsed_response.get('business_name', '')
# 清理商号,只保留中文字符
business_name = re.sub(r'[^\u4e00-\u9fff]', '', business_name)
logger.info(f"Successfully extracted business name: {business_name}")
return business_name if business_name else ""
else:
logger.warning(f"Invalid JSON response for business name extraction: {response}")
logger.warning(f"Failed to extract business name for: {company_name}")
return ""
except Exception as e:
logger.error(f"LLM extraction failed: {e}")
@ -332,17 +333,18 @@ class NerProcessor:
"""
try:
response = self.ollama_client.generate(prompt)
logger.info(f"Raw LLM response for address extraction: {response}")
# 使用新的增强generate方法进行验证
parsed_response = self.ollama_client.generate_with_validation(
prompt=prompt,
response_type='address_extraction',
return_parsed=True
)
# 使用JSON提取器解析响应
parsed_response = LLMJsonExtractor.parse_raw_json_str(response)
if parsed_response and LLMResponseValidator.validate_address_extraction(parsed_response):
if parsed_response:
logger.info(f"Successfully extracted address components: {parsed_response}")
return parsed_response
else:
logger.warning(f"Invalid JSON response for address extraction: {response}")
logger.warning(f"Failed to extract address components for: {address}")
return self._extract_address_components_with_regex(address)
except Exception as e:
logger.error(f"LLM extraction failed: {e}")
@ -457,28 +459,27 @@ class NerProcessor:
return masked_address
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
for attempt in range(self.max_retries):
try:
formatted_prompt = prompt_func(chunk)
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
response = self.ollama_client.generate(formatted_prompt)
logger.info(f"Raw response from LLM: {response}")
mapping = LLMJsonExtractor.parse_raw_json_str(response)
logger.info(f"Parsed mapping: {mapping}")
if mapping and self._validate_mapping_format(mapping):
return mapping
else:
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
except Exception as e:
logger.error(f"Error generating {entity_type} mapping on attempt {attempt + 1}: {e}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
logger.error(f"Max retries reached for {entity_type}, returning empty mapping")
return {}
try:
formatted_prompt = prompt_func(chunk)
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk: {formatted_prompt}")
# 使用新的增强generate方法进行验证
mapping = self.ollama_client.generate_with_validation(
prompt=formatted_prompt,
response_type='entity_extraction',
return_parsed=True
)
logger.info(f"Parsed mapping: {mapping}")
if mapping and self._validate_mapping_format(mapping):
return mapping
else:
logger.warning(f"Invalid mapping format received for {entity_type}")
return {}
except Exception as e:
logger.error(f"Error generating {entity_type} mapping: {e}")
return {}
def build_mapping(self, chunk: str) -> list[Dict[str, str]]:
mapping_pipeline = []
@ -678,29 +679,28 @@ class NerProcessor:
for entity in linkable_entities
])
for attempt in range(self.max_retries):
try:
formatted_prompt = get_entity_linkage_prompt(entities_text)
logger.info(f"Calling ollama to generate entity linkage (attempt {attempt + 1}/{self.max_retries})")
response = self.ollama_client.generate(formatted_prompt)
logger.info(f"Raw entity linkage response from LLM: {response}")
linkage = LLMJsonExtractor.parse_raw_json_str(response)
logger.info(f"Parsed entity linkage: {linkage}")
if linkage and self._validate_linkage_format(linkage):
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
return linkage
else:
logger.warning(f"Invalid entity linkage format received on attempt {attempt + 1}, retrying...")
except Exception as e:
logger.error(f"Error generating entity linkage on attempt {attempt + 1}: {e}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
logger.error("Max retries reached for entity linkage, returning empty linkage")
return {"entity_groups": []}
try:
formatted_prompt = get_entity_linkage_prompt(entities_text)
logger.info(f"Calling ollama to generate entity linkage")
# 使用新的增强generate方法进行验证
linkage = self.ollama_client.generate_with_validation(
prompt=formatted_prompt,
response_type='entity_linkage',
return_parsed=True
)
logger.info(f"Parsed entity linkage: {linkage}")
if linkage and self._validate_linkage_format(linkage):
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
return linkage
else:
logger.warning(f"Invalid entity linkage format received")
return {"entity_groups": []}
except Exception as e:
logger.error(f"Error generating entity linkage: {e}")
return {"entity_groups": []}
def _apply_entity_linkage_to_mapping(self, entity_mapping: Dict[str, str], entity_linkage: Dict[str, Any]) -> Dict[str, str]:
"""

View File

@ -57,28 +57,27 @@ class NerProcessorRefactored:
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
"""Process entities of a specific type using LLM"""
for attempt in range(self.max_retries):
try:
formatted_prompt = prompt_func(chunk)
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
response = self.ollama_client.generate(formatted_prompt)
logger.info(f"Raw response from LLM: {response}")
mapping = LLMJsonExtractor.parse_raw_json_str(response)
logger.info(f"Parsed mapping: {mapping}")
if mapping and self._validate_mapping_format(mapping):
return mapping
else:
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
except Exception as e:
logger.error(f"Error generating {entity_type} mapping on attempt {attempt + 1}: {e}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
logger.error(f"Max retries reached for {entity_type}, returning empty mapping")
return {}
try:
formatted_prompt = prompt_func(chunk)
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk: {formatted_prompt}")
# Use the new enhanced generate method with validation
mapping = self.ollama_client.generate_with_validation(
prompt=formatted_prompt,
response_type='entity_extraction',
return_parsed=True
)
logger.info(f"Parsed mapping: {mapping}")
if mapping and self._validate_mapping_format(mapping):
return mapping
else:
logger.warning(f"Invalid mapping format received for {entity_type}")
return {}
except Exception as e:
logger.error(f"Error generating {entity_type} mapping: {e}")
return {}
def build_mapping(self, chunk: str) -> List[Dict[str, str]]:
"""Build entity mappings from text chunk"""
@ -234,29 +233,28 @@ class NerProcessorRefactored:
for entity in linkable_entities
])
for attempt in range(self.max_retries):
try:
formatted_prompt = get_entity_linkage_prompt(entities_text)
logger.info(f"Calling ollama to generate entity linkage (attempt {attempt + 1}/{self.max_retries})")
response = self.ollama_client.generate(formatted_prompt)
logger.info(f"Raw entity linkage response from LLM: {response}")
linkage = LLMJsonExtractor.parse_raw_json_str(response)
logger.info(f"Parsed entity linkage: {linkage}")
if linkage and self._validate_linkage_format(linkage):
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
return linkage
else:
logger.warning(f"Invalid entity linkage format received on attempt {attempt + 1}, retrying...")
except Exception as e:
logger.error(f"Error generating entity linkage on attempt {attempt + 1}: {e}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
logger.error("Max retries reached for entity linkage, returning empty linkage")
return {"entity_groups": []}
try:
formatted_prompt = get_entity_linkage_prompt(entities_text)
logger.info(f"Calling ollama to generate entity linkage")
# Use the new enhanced generate method with validation
linkage = self.ollama_client.generate_with_validation(
prompt=formatted_prompt,
response_type='entity_linkage',
return_parsed=True
)
logger.info(f"Parsed entity linkage: {linkage}")
if linkage and self._validate_linkage_format(linkage):
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
return linkage
else:
logger.warning(f"Invalid entity linkage format received")
return {"entity_groups": []}
except Exception as e:
logger.error(f"Error generating entity linkage: {e}")
return {"entity_groups": []}
def process(self, chunks: List[str]) -> Dict[str, str]:
"""Main processing method"""

View File

@ -1,72 +1,222 @@
import requests
import logging
from typing import Dict, Any
from typing import Dict, Any, Optional, Callable, Union
from ..utils.json_extractor import LLMJsonExtractor
from ..utils.llm_validator import LLMResponseValidator
logger = logging.getLogger(__name__)
class OllamaClient:
def __init__(self, model_name: str, base_url: str = "http://localhost:11434"):
def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3):
"""Initialize Ollama client.
Args:
model_name (str): Name of the Ollama model to use
host (str): Ollama server host address
port (int): Ollama server port
base_url (str): Ollama server base URL
max_retries (int): Maximum number of retries for failed requests
"""
self.model_name = model_name
self.base_url = base_url
self.max_retries = max_retries
self.headers = {"Content-Type": "application/json"}
def generate(self, prompt: str, strip_think: bool = True) -> str:
"""Process a document using the Ollama API.
def generate(self,
prompt: str,
strip_think: bool = True,
validation_schema: Optional[Dict[str, Any]] = None,
response_type: Optional[str] = None,
return_parsed: bool = False) -> Union[str, Dict[str, Any]]:
"""Process a document using the Ollama API with optional validation and retry.
Args:
document_text (str): The text content to process
prompt (str): The prompt to send to the model
strip_think (bool): Whether to strip thinking tags from response
validation_schema (Optional[Dict]): JSON schema for validation
response_type (Optional[str]): Type of response for validation ('entity_extraction', 'entity_linkage', etc.)
return_parsed (bool): Whether to return parsed JSON instead of raw string
Returns:
str: Processed text response from the model
Union[str, Dict[str, Any]]: Response from the model (raw string or parsed JSON)
Raises:
RequestException: If the API call fails
RequestException: If the API call fails after all retries
ValueError: If validation fails after all retries
"""
try:
url = f"{self.base_url}/api/generate"
payload = {
"model": self.model_name,
"prompt": prompt,
"stream": False
}
logger.debug(f"Sending request to Ollama API: {url}")
response = requests.post(url, json=payload, headers=self.headers)
response.raise_for_status()
result = response.json()
logger.debug(f"Received response from Ollama API: {result}")
if strip_think:
# Remove the "thinking" part from the response
# the response is expected to be <think>...</think>response_text
# Check if the response contains <think> tag
if "<think>" in result.get("response", ""):
# Split the response and take the part after </think>
response_parts = result["response"].split("</think>")
if len(response_parts) > 1:
# Return the part after </think>
return response_parts[1].strip()
for attempt in range(self.max_retries):
try:
# Make the API call
raw_response = self._make_api_call(prompt, strip_think)
# If no validation required, return the response
if not validation_schema and not response_type and not return_parsed:
return raw_response
# Parse JSON if needed
if return_parsed or validation_schema or response_type:
parsed_response = LLMJsonExtractor.parse_raw_json_str(raw_response)
if not parsed_response:
logger.warning(f"Failed to parse JSON on attempt {attempt + 1}/{self.max_retries}")
if attempt < self.max_retries - 1:
continue
else:
raise ValueError("Failed to parse JSON response after all retries")
# Validate if schema or response type provided
if validation_schema:
if not self._validate_with_schema(parsed_response, validation_schema):
logger.warning(f"Schema validation failed on attempt {attempt + 1}/{self.max_retries}")
if attempt < self.max_retries - 1:
continue
else:
raise ValueError("Schema validation failed after all retries")
if response_type:
if not LLMResponseValidator.validate_response_by_type(parsed_response, response_type):
logger.warning(f"Response type validation failed on attempt {attempt + 1}/{self.max_retries}")
if attempt < self.max_retries - 1:
continue
else:
raise ValueError(f"Response type validation failed after all retries")
# Return parsed response if requested
if return_parsed:
return parsed_response
else:
# If no closing tag, return the full response
return result.get("response", "").strip()
return raw_response
return raw_response
except requests.exceptions.RequestException as e:
logger.error(f"API call failed on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
# If no <think> tag, return the full response
logger.error("Max retries reached, raising exception")
raise
except Exception as e:
logger.error(f"Unexpected error on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
logger.error("Max retries reached, raising exception")
raise
# This should never be reached, but just in case
raise Exception("Unexpected error: max retries exceeded without proper exception handling")
def generate_with_validation(self,
prompt: str,
response_type: str,
strip_think: bool = True,
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
"""Generate response with automatic validation based on response type.
Args:
prompt (str): The prompt to send to the model
response_type (str): Type of response for validation
strip_think (bool): Whether to strip thinking tags from response
return_parsed (bool): Whether to return parsed JSON instead of raw string
Returns:
Union[str, Dict[str, Any]]: Validated response from the model
"""
return self.generate(
prompt=prompt,
strip_think=strip_think,
response_type=response_type,
return_parsed=return_parsed
)
def generate_with_schema(self,
prompt: str,
schema: Dict[str, Any],
strip_think: bool = True,
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
"""Generate response with custom schema validation.
Args:
prompt (str): The prompt to send to the model
schema (Dict): JSON schema for validation
strip_think (bool): Whether to strip thinking tags from response
return_parsed (bool): Whether to return parsed JSON instead of raw string
Returns:
Union[str, Dict[str, Any]]: Validated response from the model
"""
return self.generate(
prompt=prompt,
strip_think=strip_think,
validation_schema=schema,
return_parsed=return_parsed
)
def _make_api_call(self, prompt: str, strip_think: bool) -> str:
"""Make the actual API call to Ollama.
Args:
prompt (str): The prompt to send
strip_think (bool): Whether to strip thinking tags
Returns:
str: Raw response from the API
"""
url = f"{self.base_url}/api/generate"
payload = {
"model": self.model_name,
"prompt": prompt,
"stream": False
}
logger.debug(f"Sending request to Ollama API: {url}")
response = requests.post(url, json=payload, headers=self.headers)
response.raise_for_status()
result = response.json()
logger.debug(f"Received response from Ollama API: {result}")
if strip_think:
# Remove the "thinking" part from the response
# the response is expected to be <think>...</think>response_text
# Check if the response contains <think> tag
if "<think>" in result.get("response", ""):
# Split the response and take the part after </think>
response_parts = result["response"].split("</think>")
if len(response_parts) > 1:
# Return the part after </think>
return response_parts[1].strip()
else:
# If no closing tag, return the full response
return result.get("response", "").strip()
else:
# If strip_think is False, return the full response
return result.get("response", "")
# If no <think> tag, return the full response
return result.get("response", "").strip()
else:
# If strip_think is False, return the full response
return result.get("response", "")
def _validate_with_schema(self, response: Dict[str, Any], schema: Dict[str, Any]) -> bool:
"""Validate response against a JSON schema.
Args:
response (Dict): The parsed response to validate
schema (Dict): The JSON schema to validate against
except requests.exceptions.RequestException as e:
logger.error(f"Error calling Ollama API: {str(e)}")
raise
Returns:
bool: True if valid, False otherwise
"""
try:
from jsonschema import validate, ValidationError
validate(instance=response, schema=schema)
logger.debug(f"Schema validation passed for response: {response}")
return True
except ValidationError as e:
logger.warning(f"Schema validation failed: {e}")
logger.warning(f"Response that failed validation: {response}")
return False
except ImportError:
logger.error("jsonschema library not available for validation")
return False
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the current model.

View File

@ -0,0 +1,230 @@
"""
Test file for the enhanced OllamaClient with validation and retry mechanisms.
"""
import sys
import os
import json
from unittest.mock import Mock, patch
# Add the current directory to the Python path
sys.path.insert(0, os.path.dirname(__file__))
def test_ollama_client_initialization():
"""Test OllamaClient initialization with new parameters"""
from app.core.services.ollama_client import OllamaClient
# Test with default parameters
client = OllamaClient("test-model")
assert client.model_name == "test-model"
assert client.base_url == "http://localhost:11434"
assert client.max_retries == 3
# Test with custom parameters
client = OllamaClient("test-model", "http://custom:11434", 5)
assert client.model_name == "test-model"
assert client.base_url == "http://custom:11434"
assert client.max_retries == 5
print("✓ OllamaClient initialization tests passed")
def test_generate_with_validation():
"""Test generate_with_validation method"""
from app.core.services.ollama_client import OllamaClient
# Mock the API response
mock_response = Mock()
mock_response.json.return_value = {
"response": '{"business_name": "测试公司", "confidence": 0.9}'
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model")
# Test with business name extraction validation
result = client.generate_with_validation(
prompt="Extract business name from: 测试公司",
response_type='business_name_extraction',
return_parsed=True
)
assert isinstance(result, dict)
assert result.get('business_name') == '测试公司'
assert result.get('confidence') == 0.9
print("✓ generate_with_validation test passed")
def test_generate_with_schema():
"""Test generate_with_schema method"""
from app.core.services.ollama_client import OllamaClient
# Define a custom schema
custom_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"}
},
"required": ["name", "age"]
}
# Mock the API response
mock_response = Mock()
mock_response.json.return_value = {
"response": '{"name": "张三", "age": 30}'
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model")
# Test with custom schema validation
result = client.generate_with_schema(
prompt="Generate person info",
schema=custom_schema,
return_parsed=True
)
assert isinstance(result, dict)
assert result.get('name') == '张三'
assert result.get('age') == 30
print("✓ generate_with_schema test passed")
def test_backward_compatibility():
"""Test backward compatibility with original generate method"""
from app.core.services.ollama_client import OllamaClient
# Mock the API response
mock_response = Mock()
mock_response.json.return_value = {
"response": "Simple text response"
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model")
# Test original generate method (should still work)
result = client.generate("Simple prompt")
assert result == "Simple text response"
# Test with strip_think=False
result = client.generate("Simple prompt", strip_think=False)
assert result == "Simple text response"
print("✓ Backward compatibility tests passed")
def test_retry_mechanism():
"""Test retry mechanism for failed requests"""
from app.core.services.ollama_client import OllamaClient
import requests
# Mock failed requests followed by success
mock_failed_response = Mock()
mock_failed_response.raise_for_status.side_effect = requests.exceptions.RequestException("Connection failed")
mock_success_response = Mock()
mock_success_response.json.return_value = {
"response": "Success response"
}
mock_success_response.raise_for_status.return_value = None
with patch('requests.post', side_effect=[mock_failed_response, mock_success_response]):
client = OllamaClient("test-model", max_retries=2)
# Should retry and eventually succeed
result = client.generate("Test prompt")
assert result == "Success response"
print("✓ Retry mechanism test passed")
def test_validation_failure():
"""Test validation failure handling"""
from app.core.services.ollama_client import OllamaClient
# Mock API response with invalid JSON
mock_response = Mock()
mock_response.json.return_value = {
"response": "Invalid JSON response"
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model", max_retries=2)
try:
# This should fail validation and retry
result = client.generate_with_validation(
prompt="Test prompt",
response_type='business_name_extraction',
return_parsed=True
)
# If we get here, it means validation failed and retries were exhausted
print("✓ Validation failure handling test passed")
except ValueError as e:
# Expected behavior - validation failed after retries
assert "Failed to parse JSON response after all retries" in str(e)
print("✓ Validation failure handling test passed")
def test_enhanced_methods():
"""Test the new enhanced methods"""
from app.core.services.ollama_client import OllamaClient
# Mock the API response
mock_response = Mock()
mock_response.json.return_value = {
"response": '{"entities": [{"text": "张三", "type": "人名"}]}'
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model")
# Test generate_with_validation
result = client.generate_with_validation(
prompt="Extract entities",
response_type='entity_extraction',
return_parsed=True
)
assert isinstance(result, dict)
assert 'entities' in result
assert len(result['entities']) == 1
assert result['entities'][0]['text'] == '张三'
print("✓ Enhanced methods tests passed")
def main():
"""Run all tests"""
print("Testing enhanced OllamaClient...")
print("=" * 50)
try:
test_ollama_client_initialization()
test_generate_with_validation()
test_generate_with_schema()
test_backward_compatibility()
test_retry_mechanism()
test_validation_failure()
test_enhanced_methods()
print("\n" + "=" * 50)
print("✓ All enhanced OllamaClient tests passed!")
except Exception as e:
print(f"\n✗ Test failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()