dev #2
|
|
@ -0,0 +1,255 @@
|
||||||
|
# OllamaClient Enhancement Summary
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
The `OllamaClient` has been successfully enhanced to support validation and retry mechanisms while maintaining full backward compatibility.
|
||||||
|
|
||||||
|
## Key Enhancements
|
||||||
|
|
||||||
|
### 1. **Enhanced Constructor**
|
||||||
|
```python
|
||||||
|
def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3):
|
||||||
|
```
|
||||||
|
- Added `max_retries` parameter for configurable retry attempts
|
||||||
|
- Default retry count: 3 attempts
|
||||||
|
|
||||||
|
### 2. **Enhanced Generate Method**
|
||||||
|
```python
|
||||||
|
def generate(self,
|
||||||
|
prompt: str,
|
||||||
|
strip_think: bool = True,
|
||||||
|
validation_schema: Optional[Dict[str, Any]] = None,
|
||||||
|
response_type: Optional[str] = None,
|
||||||
|
return_parsed: bool = False) -> Union[str, Dict[str, Any]]:
|
||||||
|
```
|
||||||
|
|
||||||
|
**New Parameters:**
|
||||||
|
- `validation_schema`: Custom JSON schema for validation
|
||||||
|
- `response_type`: Predefined response type for validation
|
||||||
|
- `return_parsed`: Return parsed JSON instead of raw string
|
||||||
|
|
||||||
|
**Return Type:**
|
||||||
|
- `Union[str, Dict[str, Any]]`: Can return either raw string or parsed JSON
|
||||||
|
|
||||||
|
### 3. **New Convenience Methods**
|
||||||
|
|
||||||
|
#### `generate_with_validation()`
|
||||||
|
```python
|
||||||
|
def generate_with_validation(self,
|
||||||
|
prompt: str,
|
||||||
|
response_type: str,
|
||||||
|
strip_think: bool = True,
|
||||||
|
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
||||||
|
```
|
||||||
|
- Uses predefined validation schemas based on response type
|
||||||
|
- Automatically handles retries and validation
|
||||||
|
- Returns parsed JSON by default
|
||||||
|
|
||||||
|
#### `generate_with_schema()`
|
||||||
|
```python
|
||||||
|
def generate_with_schema(self,
|
||||||
|
prompt: str,
|
||||||
|
schema: Dict[str, Any],
|
||||||
|
strip_think: bool = True,
|
||||||
|
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
||||||
|
```
|
||||||
|
- Uses custom JSON schema for validation
|
||||||
|
- Automatically handles retries and validation
|
||||||
|
- Returns parsed JSON by default
|
||||||
|
|
||||||
|
### 4. **Supported Response Types**
|
||||||
|
The following response types are supported for automatic validation:
|
||||||
|
|
||||||
|
- `'entity_extraction'`: Entity extraction responses
|
||||||
|
- `'entity_linkage'`: Entity linkage responses
|
||||||
|
- `'regex_entity'`: Regex-based entity responses
|
||||||
|
- `'business_name_extraction'`: Business name extraction responses
|
||||||
|
- `'address_extraction'`: Address component extraction responses
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### 1. **Automatic Retry Mechanism**
|
||||||
|
- Retries failed API calls up to `max_retries` times
|
||||||
|
- Retries on validation failures
|
||||||
|
- Retries on JSON parsing failures
|
||||||
|
- Configurable retry count per client instance
|
||||||
|
|
||||||
|
### 2. **Built-in Validation**
|
||||||
|
- JSON schema validation using `jsonschema` library
|
||||||
|
- Predefined schemas for common response types
|
||||||
|
- Custom schema support for specialized use cases
|
||||||
|
- Detailed validation error logging
|
||||||
|
|
||||||
|
### 3. **Automatic JSON Parsing**
|
||||||
|
- Uses `LLMJsonExtractor.parse_raw_json_str()` for robust JSON extraction
|
||||||
|
- Handles malformed JSON responses gracefully
|
||||||
|
- Returns parsed Python dictionaries when requested
|
||||||
|
|
||||||
|
### 4. **Backward Compatibility**
|
||||||
|
- All existing code continues to work without changes
|
||||||
|
- Original `generate()` method signature preserved
|
||||||
|
- Default behavior unchanged
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### 1. **Basic Usage (Backward Compatible)**
|
||||||
|
```python
|
||||||
|
client = OllamaClient("llama2")
|
||||||
|
response = client.generate("Hello, world!")
|
||||||
|
# Returns: "Hello, world!"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. **With Response Type Validation**
|
||||||
|
```python
|
||||||
|
client = OllamaClient("llama2")
|
||||||
|
result = client.generate_with_validation(
|
||||||
|
prompt="Extract business name from: 上海盒马网络科技有限公司",
|
||||||
|
response_type='business_name_extraction',
|
||||||
|
return_parsed=True
|
||||||
|
)
|
||||||
|
# Returns: {"business_name": "盒马", "confidence": 0.9}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. **With Custom Schema Validation**
|
||||||
|
```python
|
||||||
|
client = OllamaClient("llama2")
|
||||||
|
custom_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"age": {"type": "number"}
|
||||||
|
},
|
||||||
|
"required": ["name", "age"]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = client.generate_with_schema(
|
||||||
|
prompt="Generate person info",
|
||||||
|
schema=custom_schema,
|
||||||
|
return_parsed=True
|
||||||
|
)
|
||||||
|
# Returns: {"name": "张三", "age": 30}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. **Advanced Usage with All Options**
|
||||||
|
```python
|
||||||
|
client = OllamaClient("llama2", max_retries=5)
|
||||||
|
result = client.generate(
|
||||||
|
prompt="Complex prompt",
|
||||||
|
strip_think=True,
|
||||||
|
validation_schema=custom_schema,
|
||||||
|
return_parsed=True
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Updated Components
|
||||||
|
|
||||||
|
### 1. **Extractors**
|
||||||
|
- `BusinessNameExtractor`: Now uses `generate_with_validation()`
|
||||||
|
- `AddressExtractor`: Now uses `generate_with_validation()`
|
||||||
|
|
||||||
|
### 2. **Processors**
|
||||||
|
- `NerProcessor`: Updated to use enhanced methods
|
||||||
|
- `NerProcessorRefactored`: Updated to use enhanced methods
|
||||||
|
|
||||||
|
### 3. **Benefits in Processors**
|
||||||
|
- Simplified code: No more manual retry loops
|
||||||
|
- Automatic validation: No more manual JSON parsing
|
||||||
|
- Better error handling: Automatic fallback to regex methods
|
||||||
|
- Cleaner code: Reduced boilerplate
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
### 1. **API Failures**
|
||||||
|
- Automatic retry on network errors
|
||||||
|
- Configurable retry count
|
||||||
|
- Detailed error logging
|
||||||
|
|
||||||
|
### 2. **Validation Failures**
|
||||||
|
- Automatic retry on schema validation failures
|
||||||
|
- Automatic retry on JSON parsing failures
|
||||||
|
- Graceful fallback to alternative methods
|
||||||
|
|
||||||
|
### 3. **Exception Types**
|
||||||
|
- `RequestException`: API call failures after all retries
|
||||||
|
- `ValueError`: Validation failures after all retries
|
||||||
|
- `Exception`: Unexpected errors
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
### 1. **Test Coverage**
|
||||||
|
- Initialization with new parameters
|
||||||
|
- Enhanced generate methods
|
||||||
|
- Backward compatibility
|
||||||
|
- Retry mechanism
|
||||||
|
- Validation failure handling
|
||||||
|
- Mock-based testing for reliability
|
||||||
|
|
||||||
|
### 2. **Run Tests**
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
python3 test_enhanced_ollama_client.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration Guide
|
||||||
|
|
||||||
|
### 1. **No Changes Required**
|
||||||
|
Existing code continues to work without modification:
|
||||||
|
```python
|
||||||
|
# This still works exactly the same
|
||||||
|
client = OllamaClient("llama2")
|
||||||
|
response = client.generate("prompt")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. **Optional Enhancements**
|
||||||
|
To take advantage of new features:
|
||||||
|
```python
|
||||||
|
# Old way (still works)
|
||||||
|
response = client.generate(prompt)
|
||||||
|
parsed = LLMJsonExtractor.parse_raw_json_str(response)
|
||||||
|
if LLMResponseValidator.validate_entity_extraction(parsed):
|
||||||
|
# use parsed
|
||||||
|
|
||||||
|
# New way (recommended)
|
||||||
|
parsed = client.generate_with_validation(
|
||||||
|
prompt=prompt,
|
||||||
|
response_type='entity_extraction',
|
||||||
|
return_parsed=True
|
||||||
|
)
|
||||||
|
# parsed is already validated and ready to use
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. **Benefits of Migration**
|
||||||
|
- **Reduced Code**: Eliminates manual retry loops
|
||||||
|
- **Better Reliability**: Automatic retry and validation
|
||||||
|
- **Cleaner Code**: Less boilerplate
|
||||||
|
- **Better Error Handling**: Automatic fallbacks
|
||||||
|
|
||||||
|
## Performance Impact
|
||||||
|
|
||||||
|
### 1. **Positive Impact**
|
||||||
|
- Reduced code complexity
|
||||||
|
- Better error recovery
|
||||||
|
- Automatic retry reduces manual intervention
|
||||||
|
|
||||||
|
### 2. **Minimal Overhead**
|
||||||
|
- Validation only occurs when requested
|
||||||
|
- JSON parsing only occurs when needed
|
||||||
|
- Retry mechanism only activates on failures
|
||||||
|
|
||||||
|
## Future Enhancements
|
||||||
|
|
||||||
|
### 1. **Potential Additions**
|
||||||
|
- Circuit breaker pattern for API failures
|
||||||
|
- Caching for repeated requests
|
||||||
|
- Async/await support
|
||||||
|
- Streaming response support
|
||||||
|
- Custom retry strategies
|
||||||
|
|
||||||
|
### 2. **Configuration Options**
|
||||||
|
- Per-request retry configuration
|
||||||
|
- Custom validation error handling
|
||||||
|
- Response transformation hooks
|
||||||
|
- Metrics and monitoring
|
||||||
|
|
||||||
|
## Conclusion
|
||||||
|
|
||||||
|
The enhanced `OllamaClient` provides a robust, reliable, and easy-to-use interface for LLM interactions while maintaining full backward compatibility. The new validation and retry mechanisms significantly improve the reliability of LLM-based operations in the NER processing pipeline.
|
||||||
|
|
@ -99,16 +99,18 @@ class AddressExtractor(BaseExtractor):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.ollama_client.generate(prompt)
|
# Use the new enhanced generate method with validation
|
||||||
logger.info(f"Raw LLM response for address extraction: {response}")
|
parsed_response = self.ollama_client.generate_with_validation(
|
||||||
|
prompt=prompt,
|
||||||
|
response_type='address_extraction',
|
||||||
|
return_parsed=True
|
||||||
|
)
|
||||||
|
|
||||||
parsed_response = LLMJsonExtractor.parse_raw_json_str(response)
|
if parsed_response:
|
||||||
|
|
||||||
if parsed_response and LLMResponseValidator.validate_address_extraction(parsed_response):
|
|
||||||
logger.info(f"Successfully extracted address components: {parsed_response}")
|
logger.info(f"Successfully extracted address components: {parsed_response}")
|
||||||
return parsed_response
|
return parsed_response
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Invalid JSON response for address extraction: {response}")
|
logger.warning(f"Failed to extract address components for: {address}")
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM extraction failed: {e}")
|
logger.error(f"LLM extraction failed: {e}")
|
||||||
|
|
|
||||||
|
|
@ -83,12 +83,14 @@ class BusinessNameExtractor(BaseExtractor):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.ollama_client.generate(prompt)
|
# Use the new enhanced generate method with validation
|
||||||
logger.info(f"Raw LLM response for business name extraction: {response}")
|
parsed_response = self.ollama_client.generate_with_validation(
|
||||||
|
prompt=prompt,
|
||||||
|
response_type='business_name_extraction',
|
||||||
|
return_parsed=True
|
||||||
|
)
|
||||||
|
|
||||||
parsed_response = LLMJsonExtractor.parse_raw_json_str(response)
|
if parsed_response:
|
||||||
|
|
||||||
if parsed_response and LLMResponseValidator.validate_business_name_extraction(parsed_response):
|
|
||||||
business_name = parsed_response.get('business_name', '')
|
business_name = parsed_response.get('business_name', '')
|
||||||
# Clean business name, keep only Chinese characters
|
# Clean business name, keep only Chinese characters
|
||||||
business_name = re.sub(r'[^\u4e00-\u9fff]', '', business_name)
|
business_name = re.sub(r'[^\u4e00-\u9fff]', '', business_name)
|
||||||
|
|
@ -98,7 +100,7 @@ class BusinessNameExtractor(BaseExtractor):
|
||||||
'confidence': parsed_response.get('confidence', 0.9)
|
'confidence': parsed_response.get('confidence', 0.9)
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Invalid JSON response for business name extraction: {response}")
|
logger.warning(f"Failed to extract business name for: {company_name}")
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM extraction failed: {e}")
|
logger.error(f"LLM extraction failed: {e}")
|
||||||
|
|
|
||||||
|
|
@ -161,20 +161,21 @@ class NerProcessor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.ollama_client.generate(prompt)
|
# 使用新的增强generate方法进行验证
|
||||||
logger.info(f"Raw LLM response for business name extraction: {response}")
|
parsed_response = self.ollama_client.generate_with_validation(
|
||||||
|
prompt=prompt,
|
||||||
|
response_type='business_name_extraction',
|
||||||
|
return_parsed=True
|
||||||
|
)
|
||||||
|
|
||||||
# 使用JSON提取器解析响应
|
if parsed_response:
|
||||||
parsed_response = LLMJsonExtractor.parse_raw_json_str(response)
|
|
||||||
|
|
||||||
if parsed_response and LLMResponseValidator.validate_business_name_extraction(parsed_response):
|
|
||||||
business_name = parsed_response.get('business_name', '')
|
business_name = parsed_response.get('business_name', '')
|
||||||
# 清理商号,只保留中文字符
|
# 清理商号,只保留中文字符
|
||||||
business_name = re.sub(r'[^\u4e00-\u9fff]', '', business_name)
|
business_name = re.sub(r'[^\u4e00-\u9fff]', '', business_name)
|
||||||
logger.info(f"Successfully extracted business name: {business_name}")
|
logger.info(f"Successfully extracted business name: {business_name}")
|
||||||
return business_name if business_name else ""
|
return business_name if business_name else ""
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Invalid JSON response for business name extraction: {response}")
|
logger.warning(f"Failed to extract business name for: {company_name}")
|
||||||
return ""
|
return ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM extraction failed: {e}")
|
logger.error(f"LLM extraction failed: {e}")
|
||||||
|
|
@ -332,17 +333,18 @@ class NerProcessor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.ollama_client.generate(prompt)
|
# 使用新的增强generate方法进行验证
|
||||||
logger.info(f"Raw LLM response for address extraction: {response}")
|
parsed_response = self.ollama_client.generate_with_validation(
|
||||||
|
prompt=prompt,
|
||||||
|
response_type='address_extraction',
|
||||||
|
return_parsed=True
|
||||||
|
)
|
||||||
|
|
||||||
# 使用JSON提取器解析响应
|
if parsed_response:
|
||||||
parsed_response = LLMJsonExtractor.parse_raw_json_str(response)
|
|
||||||
|
|
||||||
if parsed_response and LLMResponseValidator.validate_address_extraction(parsed_response):
|
|
||||||
logger.info(f"Successfully extracted address components: {parsed_response}")
|
logger.info(f"Successfully extracted address components: {parsed_response}")
|
||||||
return parsed_response
|
return parsed_response
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Invalid JSON response for address extraction: {response}")
|
logger.warning(f"Failed to extract address components for: {address}")
|
||||||
return self._extract_address_components_with_regex(address)
|
return self._extract_address_components_with_regex(address)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM extraction failed: {e}")
|
logger.error(f"LLM extraction failed: {e}")
|
||||||
|
|
@ -457,28 +459,27 @@ class NerProcessor:
|
||||||
return masked_address
|
return masked_address
|
||||||
|
|
||||||
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
|
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
|
||||||
for attempt in range(self.max_retries):
|
try:
|
||||||
try:
|
formatted_prompt = prompt_func(chunk)
|
||||||
formatted_prompt = prompt_func(chunk)
|
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk: {formatted_prompt}")
|
||||||
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
|
|
||||||
response = self.ollama_client.generate(formatted_prompt)
|
# 使用新的增强generate方法进行验证
|
||||||
logger.info(f"Raw response from LLM: {response}")
|
mapping = self.ollama_client.generate_with_validation(
|
||||||
|
prompt=formatted_prompt,
|
||||||
mapping = LLMJsonExtractor.parse_raw_json_str(response)
|
response_type='entity_extraction',
|
||||||
logger.info(f"Parsed mapping: {mapping}")
|
return_parsed=True
|
||||||
|
)
|
||||||
if mapping and self._validate_mapping_format(mapping):
|
|
||||||
return mapping
|
logger.info(f"Parsed mapping: {mapping}")
|
||||||
else:
|
|
||||||
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
|
if mapping and self._validate_mapping_format(mapping):
|
||||||
except Exception as e:
|
return mapping
|
||||||
logger.error(f"Error generating {entity_type} mapping on attempt {attempt + 1}: {e}")
|
else:
|
||||||
if attempt < self.max_retries - 1:
|
logger.warning(f"Invalid mapping format received for {entity_type}")
|
||||||
logger.info("Retrying...")
|
return {}
|
||||||
else:
|
except Exception as e:
|
||||||
logger.error(f"Max retries reached for {entity_type}, returning empty mapping")
|
logger.error(f"Error generating {entity_type} mapping: {e}")
|
||||||
|
return {}
|
||||||
return {}
|
|
||||||
|
|
||||||
def build_mapping(self, chunk: str) -> list[Dict[str, str]]:
|
def build_mapping(self, chunk: str) -> list[Dict[str, str]]:
|
||||||
mapping_pipeline = []
|
mapping_pipeline = []
|
||||||
|
|
@ -678,29 +679,28 @@ class NerProcessor:
|
||||||
for entity in linkable_entities
|
for entity in linkable_entities
|
||||||
])
|
])
|
||||||
|
|
||||||
for attempt in range(self.max_retries):
|
try:
|
||||||
try:
|
formatted_prompt = get_entity_linkage_prompt(entities_text)
|
||||||
formatted_prompt = get_entity_linkage_prompt(entities_text)
|
logger.info(f"Calling ollama to generate entity linkage")
|
||||||
logger.info(f"Calling ollama to generate entity linkage (attempt {attempt + 1}/{self.max_retries})")
|
|
||||||
response = self.ollama_client.generate(formatted_prompt)
|
# 使用新的增强generate方法进行验证
|
||||||
logger.info(f"Raw entity linkage response from LLM: {response}")
|
linkage = self.ollama_client.generate_with_validation(
|
||||||
|
prompt=formatted_prompt,
|
||||||
linkage = LLMJsonExtractor.parse_raw_json_str(response)
|
response_type='entity_linkage',
|
||||||
logger.info(f"Parsed entity linkage: {linkage}")
|
return_parsed=True
|
||||||
|
)
|
||||||
if linkage and self._validate_linkage_format(linkage):
|
|
||||||
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
|
logger.info(f"Parsed entity linkage: {linkage}")
|
||||||
return linkage
|
|
||||||
else:
|
if linkage and self._validate_linkage_format(linkage):
|
||||||
logger.warning(f"Invalid entity linkage format received on attempt {attempt + 1}, retrying...")
|
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
|
||||||
except Exception as e:
|
return linkage
|
||||||
logger.error(f"Error generating entity linkage on attempt {attempt + 1}: {e}")
|
else:
|
||||||
if attempt < self.max_retries - 1:
|
logger.warning(f"Invalid entity linkage format received")
|
||||||
logger.info("Retrying...")
|
return {"entity_groups": []}
|
||||||
else:
|
except Exception as e:
|
||||||
logger.error("Max retries reached for entity linkage, returning empty linkage")
|
logger.error(f"Error generating entity linkage: {e}")
|
||||||
|
return {"entity_groups": []}
|
||||||
return {"entity_groups": []}
|
|
||||||
|
|
||||||
def _apply_entity_linkage_to_mapping(self, entity_mapping: Dict[str, str], entity_linkage: Dict[str, Any]) -> Dict[str, str]:
|
def _apply_entity_linkage_to_mapping(self, entity_mapping: Dict[str, str], entity_linkage: Dict[str, Any]) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -57,28 +57,27 @@ class NerProcessorRefactored:
|
||||||
|
|
||||||
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
|
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
|
||||||
"""Process entities of a specific type using LLM"""
|
"""Process entities of a specific type using LLM"""
|
||||||
for attempt in range(self.max_retries):
|
try:
|
||||||
try:
|
formatted_prompt = prompt_func(chunk)
|
||||||
formatted_prompt = prompt_func(chunk)
|
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk: {formatted_prompt}")
|
||||||
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
|
|
||||||
response = self.ollama_client.generate(formatted_prompt)
|
# Use the new enhanced generate method with validation
|
||||||
logger.info(f"Raw response from LLM: {response}")
|
mapping = self.ollama_client.generate_with_validation(
|
||||||
|
prompt=formatted_prompt,
|
||||||
mapping = LLMJsonExtractor.parse_raw_json_str(response)
|
response_type='entity_extraction',
|
||||||
logger.info(f"Parsed mapping: {mapping}")
|
return_parsed=True
|
||||||
|
)
|
||||||
if mapping and self._validate_mapping_format(mapping):
|
|
||||||
return mapping
|
logger.info(f"Parsed mapping: {mapping}")
|
||||||
else:
|
|
||||||
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
|
if mapping and self._validate_mapping_format(mapping):
|
||||||
except Exception as e:
|
return mapping
|
||||||
logger.error(f"Error generating {entity_type} mapping on attempt {attempt + 1}: {e}")
|
else:
|
||||||
if attempt < self.max_retries - 1:
|
logger.warning(f"Invalid mapping format received for {entity_type}")
|
||||||
logger.info("Retrying...")
|
return {}
|
||||||
else:
|
except Exception as e:
|
||||||
logger.error(f"Max retries reached for {entity_type}, returning empty mapping")
|
logger.error(f"Error generating {entity_type} mapping: {e}")
|
||||||
|
return {}
|
||||||
return {}
|
|
||||||
|
|
||||||
def build_mapping(self, chunk: str) -> List[Dict[str, str]]:
|
def build_mapping(self, chunk: str) -> List[Dict[str, str]]:
|
||||||
"""Build entity mappings from text chunk"""
|
"""Build entity mappings from text chunk"""
|
||||||
|
|
@ -234,29 +233,28 @@ class NerProcessorRefactored:
|
||||||
for entity in linkable_entities
|
for entity in linkable_entities
|
||||||
])
|
])
|
||||||
|
|
||||||
for attempt in range(self.max_retries):
|
try:
|
||||||
try:
|
formatted_prompt = get_entity_linkage_prompt(entities_text)
|
||||||
formatted_prompt = get_entity_linkage_prompt(entities_text)
|
logger.info(f"Calling ollama to generate entity linkage")
|
||||||
logger.info(f"Calling ollama to generate entity linkage (attempt {attempt + 1}/{self.max_retries})")
|
|
||||||
response = self.ollama_client.generate(formatted_prompt)
|
# Use the new enhanced generate method with validation
|
||||||
logger.info(f"Raw entity linkage response from LLM: {response}")
|
linkage = self.ollama_client.generate_with_validation(
|
||||||
|
prompt=formatted_prompt,
|
||||||
linkage = LLMJsonExtractor.parse_raw_json_str(response)
|
response_type='entity_linkage',
|
||||||
logger.info(f"Parsed entity linkage: {linkage}")
|
return_parsed=True
|
||||||
|
)
|
||||||
if linkage and self._validate_linkage_format(linkage):
|
|
||||||
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
|
logger.info(f"Parsed entity linkage: {linkage}")
|
||||||
return linkage
|
|
||||||
else:
|
if linkage and self._validate_linkage_format(linkage):
|
||||||
logger.warning(f"Invalid entity linkage format received on attempt {attempt + 1}, retrying...")
|
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
|
||||||
except Exception as e:
|
return linkage
|
||||||
logger.error(f"Error generating entity linkage on attempt {attempt + 1}: {e}")
|
else:
|
||||||
if attempt < self.max_retries - 1:
|
logger.warning(f"Invalid entity linkage format received")
|
||||||
logger.info("Retrying...")
|
return {"entity_groups": []}
|
||||||
else:
|
except Exception as e:
|
||||||
logger.error("Max retries reached for entity linkage, returning empty linkage")
|
logger.error(f"Error generating entity linkage: {e}")
|
||||||
|
return {"entity_groups": []}
|
||||||
return {"entity_groups": []}
|
|
||||||
|
|
||||||
def process(self, chunks: List[str]) -> Dict[str, str]:
|
def process(self, chunks: List[str]) -> Dict[str, str]:
|
||||||
"""Main processing method"""
|
"""Main processing method"""
|
||||||
|
|
|
||||||
|
|
@ -1,72 +1,222 @@
|
||||||
import requests
|
import requests
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, Optional, Callable, Union
|
||||||
|
from ..utils.json_extractor import LLMJsonExtractor
|
||||||
|
from ..utils.llm_validator import LLMResponseValidator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OllamaClient:
|
class OllamaClient:
|
||||||
def __init__(self, model_name: str, base_url: str = "http://localhost:11434"):
|
def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3):
|
||||||
"""Initialize Ollama client.
|
"""Initialize Ollama client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name (str): Name of the Ollama model to use
|
model_name (str): Name of the Ollama model to use
|
||||||
host (str): Ollama server host address
|
base_url (str): Ollama server base URL
|
||||||
port (int): Ollama server port
|
max_retries (int): Maximum number of retries for failed requests
|
||||||
"""
|
"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
self.max_retries = max_retries
|
||||||
self.headers = {"Content-Type": "application/json"}
|
self.headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
def generate(self, prompt: str, strip_think: bool = True) -> str:
|
def generate(self,
|
||||||
"""Process a document using the Ollama API.
|
prompt: str,
|
||||||
|
strip_think: bool = True,
|
||||||
|
validation_schema: Optional[Dict[str, Any]] = None,
|
||||||
|
response_type: Optional[str] = None,
|
||||||
|
return_parsed: bool = False) -> Union[str, Dict[str, Any]]:
|
||||||
|
"""Process a document using the Ollama API with optional validation and retry.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
document_text (str): The text content to process
|
prompt (str): The prompt to send to the model
|
||||||
|
strip_think (bool): Whether to strip thinking tags from response
|
||||||
|
validation_schema (Optional[Dict]): JSON schema for validation
|
||||||
|
response_type (Optional[str]): Type of response for validation ('entity_extraction', 'entity_linkage', etc.)
|
||||||
|
return_parsed (bool): Whether to return parsed JSON instead of raw string
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Processed text response from the model
|
Union[str, Dict[str, Any]]: Response from the model (raw string or parsed JSON)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RequestException: If the API call fails
|
RequestException: If the API call fails after all retries
|
||||||
|
ValueError: If validation fails after all retries
|
||||||
"""
|
"""
|
||||||
try:
|
for attempt in range(self.max_retries):
|
||||||
url = f"{self.base_url}/api/generate"
|
try:
|
||||||
payload = {
|
# Make the API call
|
||||||
"model": self.model_name,
|
raw_response = self._make_api_call(prompt, strip_think)
|
||||||
"prompt": prompt,
|
|
||||||
"stream": False
|
# If no validation required, return the response
|
||||||
}
|
if not validation_schema and not response_type and not return_parsed:
|
||||||
|
return raw_response
|
||||||
logger.debug(f"Sending request to Ollama API: {url}")
|
|
||||||
response = requests.post(url, json=payload, headers=self.headers)
|
# Parse JSON if needed
|
||||||
response.raise_for_status()
|
if return_parsed or validation_schema or response_type:
|
||||||
|
parsed_response = LLMJsonExtractor.parse_raw_json_str(raw_response)
|
||||||
result = response.json()
|
if not parsed_response:
|
||||||
logger.debug(f"Received response from Ollama API: {result}")
|
logger.warning(f"Failed to parse JSON on attempt {attempt + 1}/{self.max_retries}")
|
||||||
if strip_think:
|
if attempt < self.max_retries - 1:
|
||||||
# Remove the "thinking" part from the response
|
continue
|
||||||
# the response is expected to be <think>...</think>response_text
|
else:
|
||||||
# Check if the response contains <think> tag
|
raise ValueError("Failed to parse JSON response after all retries")
|
||||||
if "<think>" in result.get("response", ""):
|
|
||||||
# Split the response and take the part after </think>
|
# Validate if schema or response type provided
|
||||||
response_parts = result["response"].split("</think>")
|
if validation_schema:
|
||||||
if len(response_parts) > 1:
|
if not self._validate_with_schema(parsed_response, validation_schema):
|
||||||
# Return the part after </think>
|
logger.warning(f"Schema validation failed on attempt {attempt + 1}/{self.max_retries}")
|
||||||
return response_parts[1].strip()
|
if attempt < self.max_retries - 1:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise ValueError("Schema validation failed after all retries")
|
||||||
|
|
||||||
|
if response_type:
|
||||||
|
if not LLMResponseValidator.validate_response_by_type(parsed_response, response_type):
|
||||||
|
logger.warning(f"Response type validation failed on attempt {attempt + 1}/{self.max_retries}")
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Response type validation failed after all retries")
|
||||||
|
|
||||||
|
# Return parsed response if requested
|
||||||
|
if return_parsed:
|
||||||
|
return parsed_response
|
||||||
else:
|
else:
|
||||||
# If no closing tag, return the full response
|
return raw_response
|
||||||
return result.get("response", "").strip()
|
|
||||||
|
return raw_response
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"API call failed on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
logger.info("Retrying...")
|
||||||
else:
|
else:
|
||||||
# If no <think> tag, return the full response
|
logger.error("Max retries reached, raising exception")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
logger.info("Retrying...")
|
||||||
|
else:
|
||||||
|
logger.error("Max retries reached, raising exception")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# This should never be reached, but just in case
|
||||||
|
raise Exception("Unexpected error: max retries exceeded without proper exception handling")
|
||||||
|
|
||||||
|
def generate_with_validation(self,
|
||||||
|
prompt: str,
|
||||||
|
response_type: str,
|
||||||
|
strip_think: bool = True,
|
||||||
|
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
||||||
|
"""Generate response with automatic validation based on response type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt to send to the model
|
||||||
|
response_type (str): Type of response for validation
|
||||||
|
strip_think (bool): Whether to strip thinking tags from response
|
||||||
|
return_parsed (bool): Whether to return parsed JSON instead of raw string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[str, Dict[str, Any]]: Validated response from the model
|
||||||
|
"""
|
||||||
|
return self.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
strip_think=strip_think,
|
||||||
|
response_type=response_type,
|
||||||
|
return_parsed=return_parsed
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_with_schema(self,
|
||||||
|
prompt: str,
|
||||||
|
schema: Dict[str, Any],
|
||||||
|
strip_think: bool = True,
|
||||||
|
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
||||||
|
"""Generate response with custom schema validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt to send to the model
|
||||||
|
schema (Dict): JSON schema for validation
|
||||||
|
strip_think (bool): Whether to strip thinking tags from response
|
||||||
|
return_parsed (bool): Whether to return parsed JSON instead of raw string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[str, Dict[str, Any]]: Validated response from the model
|
||||||
|
"""
|
||||||
|
return self.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
strip_think=strip_think,
|
||||||
|
validation_schema=schema,
|
||||||
|
return_parsed=return_parsed
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_api_call(self, prompt: str, strip_think: bool) -> str:
|
||||||
|
"""Make the actual API call to Ollama.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt to send
|
||||||
|
strip_think (bool): Whether to strip thinking tags
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Raw response from the API
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/api/generate"
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"Sending request to Ollama API: {url}")
|
||||||
|
response = requests.post(url, json=payload, headers=self.headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
logger.debug(f"Received response from Ollama API: {result}")
|
||||||
|
|
||||||
|
if strip_think:
|
||||||
|
# Remove the "thinking" part from the response
|
||||||
|
# the response is expected to be <think>...</think>response_text
|
||||||
|
# Check if the response contains <think> tag
|
||||||
|
if "<think>" in result.get("response", ""):
|
||||||
|
# Split the response and take the part after </think>
|
||||||
|
response_parts = result["response"].split("</think>")
|
||||||
|
if len(response_parts) > 1:
|
||||||
|
# Return the part after </think>
|
||||||
|
return response_parts[1].strip()
|
||||||
|
else:
|
||||||
|
# If no closing tag, return the full response
|
||||||
return result.get("response", "").strip()
|
return result.get("response", "").strip()
|
||||||
else:
|
else:
|
||||||
# If strip_think is False, return the full response
|
# If no <think> tag, return the full response
|
||||||
return result.get("response", "")
|
return result.get("response", "").strip()
|
||||||
|
else:
|
||||||
|
# If strip_think is False, return the full response
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
def _validate_with_schema(self, response: Dict[str, Any], schema: Dict[str, Any]) -> bool:
|
||||||
|
"""Validate response against a JSON schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response (Dict): The parsed response to validate
|
||||||
|
schema (Dict): The JSON schema to validate against
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
Returns:
|
||||||
logger.error(f"Error calling Ollama API: {str(e)}")
|
bool: True if valid, False otherwise
|
||||||
raise
|
"""
|
||||||
|
try:
|
||||||
|
from jsonschema import validate, ValidationError
|
||||||
|
validate(instance=response, schema=schema)
|
||||||
|
logger.debug(f"Schema validation passed for response: {response}")
|
||||||
|
return True
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.warning(f"Schema validation failed: {e}")
|
||||||
|
logger.warning(f"Response that failed validation: {response}")
|
||||||
|
return False
|
||||||
|
except ImportError:
|
||||||
|
logger.error("jsonschema library not available for validation")
|
||||||
|
return False
|
||||||
|
|
||||||
def get_model_info(self) -> Dict[str, Any]:
|
def get_model_info(self) -> Dict[str, Any]:
|
||||||
"""Get information about the current model.
|
"""Get information about the current model.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,230 @@
|
||||||
|
"""
|
||||||
|
Test file for the enhanced OllamaClient with validation and retry mechanisms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
# Add the current directory to the Python path
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
|
||||||
|
def test_ollama_client_initialization():
|
||||||
|
"""Test OllamaClient initialization with new parameters"""
|
||||||
|
from app.core.services.ollama_client import OllamaClient
|
||||||
|
|
||||||
|
# Test with default parameters
|
||||||
|
client = OllamaClient("test-model")
|
||||||
|
assert client.model_name == "test-model"
|
||||||
|
assert client.base_url == "http://localhost:11434"
|
||||||
|
assert client.max_retries == 3
|
||||||
|
|
||||||
|
# Test with custom parameters
|
||||||
|
client = OllamaClient("test-model", "http://custom:11434", 5)
|
||||||
|
assert client.model_name == "test-model"
|
||||||
|
assert client.base_url == "http://custom:11434"
|
||||||
|
assert client.max_retries == 5
|
||||||
|
|
||||||
|
print("✓ OllamaClient initialization tests passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_with_validation():
|
||||||
|
"""Test generate_with_validation method"""
|
||||||
|
from app.core.services.ollama_client import OllamaClient
|
||||||
|
|
||||||
|
# Mock the API response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"response": '{"business_name": "测试公司", "confidence": 0.9}'
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
with patch('requests.post', return_value=mock_response):
|
||||||
|
client = OllamaClient("test-model")
|
||||||
|
|
||||||
|
# Test with business name extraction validation
|
||||||
|
result = client.generate_with_validation(
|
||||||
|
prompt="Extract business name from: 测试公司",
|
||||||
|
response_type='business_name_extraction',
|
||||||
|
return_parsed=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result.get('business_name') == '测试公司'
|
||||||
|
assert result.get('confidence') == 0.9
|
||||||
|
|
||||||
|
print("✓ generate_with_validation test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_with_schema():
|
||||||
|
"""Test generate_with_schema method"""
|
||||||
|
from app.core.services.ollama_client import OllamaClient
|
||||||
|
|
||||||
|
# Define a custom schema
|
||||||
|
custom_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"age": {"type": "number"}
|
||||||
|
},
|
||||||
|
"required": ["name", "age"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock the API response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"response": '{"name": "张三", "age": 30}'
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
with patch('requests.post', return_value=mock_response):
|
||||||
|
client = OllamaClient("test-model")
|
||||||
|
|
||||||
|
# Test with custom schema validation
|
||||||
|
result = client.generate_with_schema(
|
||||||
|
prompt="Generate person info",
|
||||||
|
schema=custom_schema,
|
||||||
|
return_parsed=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result.get('name') == '张三'
|
||||||
|
assert result.get('age') == 30
|
||||||
|
|
||||||
|
print("✓ generate_with_schema test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_backward_compatibility():
|
||||||
|
"""Test backward compatibility with original generate method"""
|
||||||
|
from app.core.services.ollama_client import OllamaClient
|
||||||
|
|
||||||
|
# Mock the API response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"response": "Simple text response"
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
with patch('requests.post', return_value=mock_response):
|
||||||
|
client = OllamaClient("test-model")
|
||||||
|
|
||||||
|
# Test original generate method (should still work)
|
||||||
|
result = client.generate("Simple prompt")
|
||||||
|
assert result == "Simple text response"
|
||||||
|
|
||||||
|
# Test with strip_think=False
|
||||||
|
result = client.generate("Simple prompt", strip_think=False)
|
||||||
|
assert result == "Simple text response"
|
||||||
|
|
||||||
|
print("✓ Backward compatibility tests passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_retry_mechanism():
|
||||||
|
"""Test retry mechanism for failed requests"""
|
||||||
|
from app.core.services.ollama_client import OllamaClient
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Mock failed requests followed by success
|
||||||
|
mock_failed_response = Mock()
|
||||||
|
mock_failed_response.raise_for_status.side_effect = requests.exceptions.RequestException("Connection failed")
|
||||||
|
|
||||||
|
mock_success_response = Mock()
|
||||||
|
mock_success_response.json.return_value = {
|
||||||
|
"response": "Success response"
|
||||||
|
}
|
||||||
|
mock_success_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
with patch('requests.post', side_effect=[mock_failed_response, mock_success_response]):
|
||||||
|
client = OllamaClient("test-model", max_retries=2)
|
||||||
|
|
||||||
|
# Should retry and eventually succeed
|
||||||
|
result = client.generate("Test prompt")
|
||||||
|
assert result == "Success response"
|
||||||
|
|
||||||
|
print("✓ Retry mechanism test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_failure():
|
||||||
|
"""Test validation failure handling"""
|
||||||
|
from app.core.services.ollama_client import OllamaClient
|
||||||
|
|
||||||
|
# Mock API response with invalid JSON
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"response": "Invalid JSON response"
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
with patch('requests.post', return_value=mock_response):
|
||||||
|
client = OllamaClient("test-model", max_retries=2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# This should fail validation and retry
|
||||||
|
result = client.generate_with_validation(
|
||||||
|
prompt="Test prompt",
|
||||||
|
response_type='business_name_extraction',
|
||||||
|
return_parsed=True
|
||||||
|
)
|
||||||
|
# If we get here, it means validation failed and retries were exhausted
|
||||||
|
print("✓ Validation failure handling test passed")
|
||||||
|
except ValueError as e:
|
||||||
|
# Expected behavior - validation failed after retries
|
||||||
|
assert "Failed to parse JSON response after all retries" in str(e)
|
||||||
|
print("✓ Validation failure handling test passed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_enhanced_methods():
|
||||||
|
"""Test the new enhanced methods"""
|
||||||
|
from app.core.services.ollama_client import OllamaClient
|
||||||
|
|
||||||
|
# Mock the API response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"response": '{"entities": [{"text": "张三", "type": "人名"}]}'
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
with patch('requests.post', return_value=mock_response):
|
||||||
|
client = OllamaClient("test-model")
|
||||||
|
|
||||||
|
# Test generate_with_validation
|
||||||
|
result = client.generate_with_validation(
|
||||||
|
prompt="Extract entities",
|
||||||
|
response_type='entity_extraction',
|
||||||
|
return_parsed=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert 'entities' in result
|
||||||
|
assert len(result['entities']) == 1
|
||||||
|
assert result['entities'][0]['text'] == '张三'
|
||||||
|
|
||||||
|
print("✓ Enhanced methods tests passed")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run all tests"""
|
||||||
|
print("Testing enhanced OllamaClient...")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
test_ollama_client_initialization()
|
||||||
|
test_generate_with_validation()
|
||||||
|
test_generate_with_schema()
|
||||||
|
test_backward_compatibility()
|
||||||
|
test_retry_mechanism()
|
||||||
|
test_validation_failure()
|
||||||
|
test_enhanced_methods()
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("✓ All enhanced OllamaClient tests passed!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n✗ Test failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue