Compare commits

..

No commits in common. "84499f52ea4c14afe474142d5216eb4ddf5c5957" and "8399bc37fca55c66524062be625a1220259d4396" have entirely different histories.

46 changed files with 313 additions and 3976 deletions

View File

@ -1,255 +0,0 @@
# OllamaClient Enhancement Summary
## Overview
The `OllamaClient` has been successfully enhanced to support validation and retry mechanisms while maintaining full backward compatibility.
## Key Enhancements
### 1. **Enhanced Constructor**
```python
def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3):
```
- Added `max_retries` parameter for configurable retry attempts
- Default retry count: 3 attempts
### 2. **Enhanced Generate Method**
```python
def generate(self,
prompt: str,
strip_think: bool = True,
validation_schema: Optional[Dict[str, Any]] = None,
response_type: Optional[str] = None,
return_parsed: bool = False) -> Union[str, Dict[str, Any]]:
```
**New Parameters:**
- `validation_schema`: Custom JSON schema for validation
- `response_type`: Predefined response type for validation
- `return_parsed`: Return parsed JSON instead of raw string
**Return Type:**
- `Union[str, Dict[str, Any]]`: Can return either raw string or parsed JSON
### 3. **New Convenience Methods**
#### `generate_with_validation()`
```python
def generate_with_validation(self,
prompt: str,
response_type: str,
strip_think: bool = True,
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
```
- Uses predefined validation schemas based on response type
- Automatically handles retries and validation
- Returns parsed JSON by default
#### `generate_with_schema()`
```python
def generate_with_schema(self,
prompt: str,
schema: Dict[str, Any],
strip_think: bool = True,
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
```
- Uses custom JSON schema for validation
- Automatically handles retries and validation
- Returns parsed JSON by default
### 4. **Supported Response Types**
The following response types are supported for automatic validation:
- `'entity_extraction'`: Entity extraction responses
- `'entity_linkage'`: Entity linkage responses
- `'regex_entity'`: Regex-based entity responses
- `'business_name_extraction'`: Business name extraction responses
- `'address_extraction'`: Address component extraction responses
## Features
### 1. **Automatic Retry Mechanism**
- Retries failed API calls up to `max_retries` times
- Retries on validation failures
- Retries on JSON parsing failures
- Configurable retry count per client instance
### 2. **Built-in Validation**
- JSON schema validation using `jsonschema` library
- Predefined schemas for common response types
- Custom schema support for specialized use cases
- Detailed validation error logging
### 3. **Automatic JSON Parsing**
- Uses `LLMJsonExtractor.parse_raw_json_str()` for robust JSON extraction
- Handles malformed JSON responses gracefully
- Returns parsed Python dictionaries when requested
### 4. **Backward Compatibility**
- All existing code continues to work without changes
- Original `generate()` method signature preserved
- Default behavior unchanged
## Usage Examples
### 1. **Basic Usage (Backward Compatible)**
```python
client = OllamaClient("llama2")
response = client.generate("Hello, world!")
# Returns: "Hello, world!"
```
### 2. **With Response Type Validation**
```python
client = OllamaClient("llama2")
result = client.generate_with_validation(
prompt="Extract business name from: 上海盒马网络科技有限公司",
response_type='business_name_extraction',
return_parsed=True
)
# Returns: {"business_name": "盒马", "confidence": 0.9}
```
### 3. **With Custom Schema Validation**
```python
client = OllamaClient("llama2")
custom_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"}
},
"required": ["name", "age"]
}
result = client.generate_with_schema(
prompt="Generate person info",
schema=custom_schema,
return_parsed=True
)
# Returns: {"name": "张三", "age": 30}
```
### 4. **Advanced Usage with All Options**
```python
client = OllamaClient("llama2", max_retries=5)
result = client.generate(
prompt="Complex prompt",
strip_think=True,
validation_schema=custom_schema,
return_parsed=True
)
```
## Updated Components
### 1. **Extractors**
- `BusinessNameExtractor`: Now uses `generate_with_validation()`
- `AddressExtractor`: Now uses `generate_with_validation()`
### 2. **Processors**
- `NerProcessor`: Updated to use enhanced methods
- `NerProcessorRefactored`: Updated to use enhanced methods
### 3. **Benefits in Processors**
- Simplified code: No more manual retry loops
- Automatic validation: No more manual JSON parsing
- Better error handling: Automatic fallback to regex methods
- Cleaner code: Reduced boilerplate
## Error Handling
### 1. **API Failures**
- Automatic retry on network errors
- Configurable retry count
- Detailed error logging
### 2. **Validation Failures**
- Automatic retry on schema validation failures
- Automatic retry on JSON parsing failures
- Graceful fallback to alternative methods
### 3. **Exception Types**
- `RequestException`: API call failures after all retries
- `ValueError`: Validation failures after all retries
- `Exception`: Unexpected errors
## Testing
### 1. **Test Coverage**
- Initialization with new parameters
- Enhanced generate methods
- Backward compatibility
- Retry mechanism
- Validation failure handling
- Mock-based testing for reliability
### 2. **Run Tests**
```bash
cd backend
python3 test_enhanced_ollama_client.py
```
## Migration Guide
### 1. **No Changes Required**
Existing code continues to work without modification:
```python
# This still works exactly the same
client = OllamaClient("llama2")
response = client.generate("prompt")
```
### 2. **Optional Enhancements**
To take advantage of new features:
```python
# Old way (still works)
response = client.generate(prompt)
parsed = LLMJsonExtractor.parse_raw_json_str(response)
if LLMResponseValidator.validate_entity_extraction(parsed):
# use parsed
# New way (recommended)
parsed = client.generate_with_validation(
prompt=prompt,
response_type='entity_extraction',
return_parsed=True
)
# parsed is already validated and ready to use
```
### 3. **Benefits of Migration**
- **Reduced Code**: Eliminates manual retry loops
- **Better Reliability**: Automatic retry and validation
- **Cleaner Code**: Less boilerplate
- **Better Error Handling**: Automatic fallbacks
## Performance Impact
### 1. **Positive Impact**
- Reduced code complexity
- Better error recovery
- Automatic retry reduces manual intervention
### 2. **Minimal Overhead**
- Validation only occurs when requested
- JSON parsing only occurs when needed
- Retry mechanism only activates on failures
## Future Enhancements
### 1. **Potential Additions**
- Circuit breaker pattern for API failures
- Caching for repeated requests
- Async/await support
- Streaming response support
- Custom retry strategies
### 2. **Configuration Options**
- Per-request retry configuration
- Custom validation error handling
- Response transformation hooks
- Metrics and monitoring
## Conclusion
The enhanced `OllamaClient` provides a robust, reliable, and easy-to-use interface for LLM interactions while maintaining full backward compatibility. The new validation and retry mechanisms significantly improve the reliability of LLM-based operations in the NER processing pipeline.

View File

@ -1,166 +0,0 @@
# NerProcessor Refactoring Summary
## Overview
The `ner_processor.py` file has been successfully refactored from a monolithic 729-line class into a modular, maintainable architecture following SOLID principles.
## New Architecture
### Directory Structure
```
backend/app/core/document_handlers/
├── ner_processor.py # Original file (unchanged)
├── ner_processor_refactored.py # New refactored version
├── masker_factory.py # Factory for creating maskers
├── maskers/
│ ├── __init__.py
│ ├── base_masker.py # Abstract base class
│ ├── name_masker.py # Chinese/English name masking
│ ├── company_masker.py # Company name masking
│ ├── address_masker.py # Address masking
│ ├── id_masker.py # ID/social credit code masking
│ └── case_masker.py # Case number masking
├── extractors/
│ ├── __init__.py
│ ├── base_extractor.py # Abstract base class
│ ├── business_name_extractor.py # Business name extraction
│ └── address_extractor.py # Address component extraction
└── validators/ # (Placeholder for future use)
```
## Key Components
### 1. Base Classes
- **`BaseMasker`**: Abstract base class for all maskers
- **`BaseExtractor`**: Abstract base class for all extractors
### 2. Maskers
- **`ChineseNameMasker`**: Handles Chinese name masking (surname + pinyin initials)
- **`EnglishNameMasker`**: Handles English name masking (first letter + ***)
- **`CompanyMasker`**: Handles company name masking (business name replacement)
- **`AddressMasker`**: Handles address masking (component replacement)
- **`IDMasker`**: Handles ID and social credit code masking
- **`CaseMasker`**: Handles case number masking
### 3. Extractors
- **`BusinessNameExtractor`**: Extracts business names from company names using LLM + regex fallback
- **`AddressExtractor`**: Extracts address components using LLM + regex fallback
### 4. Factory
- **`MaskerFactory`**: Creates maskers with proper dependencies
### 5. Refactored Processor
- **`NerProcessorRefactored`**: Main orchestrator using the new architecture
## Benefits Achieved
### 1. Single Responsibility Principle
- Each class has one clear responsibility
- Maskers only handle masking logic
- Extractors only handle extraction logic
- Processor only handles orchestration
### 2. Open/Closed Principle
- Easy to add new maskers without modifying existing code
- New entity types can be supported by creating new maskers
### 3. Dependency Injection
- Dependencies are injected rather than hardcoded
- Easier to test and mock
### 4. Better Testing
- Each component can be tested in isolation
- Mock dependencies easily
### 5. Code Reusability
- Maskers can be used independently
- Common functionality shared through base classes
### 6. Maintainability
- Changes to one masking rule don't affect others
- Clear separation of concerns
## Migration Strategy
### Phase 1: ✅ Complete
- Created base classes and interfaces
- Extracted all maskers
- Created extractors
- Created factory pattern
- Created refactored processor
### Phase 2: Testing (Next)
- Run validation script: `python3 validate_refactoring.py`
- Run existing tests to ensure compatibility
- Create comprehensive unit tests for each component
### Phase 3: Integration (Future)
- Replace original processor with refactored version
- Update imports throughout the codebase
- Remove old code
### Phase 4: Enhancement (Future)
- Add configuration management
- Add more extractors as needed
- Add validation components
## Testing
### Validation Script
Run the validation script to test the refactored code:
```bash
cd backend
python3 validate_refactoring.py
```
### Unit Tests
Run the unit tests for the refactored components:
```bash
cd backend
python3 -m pytest tests/test_refactored_ner_processor.py -v
```
## Current Status
✅ **Completed:**
- All maskers extracted and implemented
- All extractors created
- Factory pattern implemented
- Refactored processor created
- Validation script created
- Unit tests created
🔄 **Next Steps:**
- Test the refactored code
- Ensure all existing functionality works
- Replace original processor when ready
## File Comparison
| Metric | Original | Refactored |
|--------|----------|------------|
| Main Class Lines | 729 | ~200 |
| Number of Classes | 1 | 10+ |
| Responsibilities | Multiple | Single |
| Testability | Low | High |
| Maintainability | Low | High |
| Extensibility | Low | High |
## Backward Compatibility
The refactored code maintains full backward compatibility:
- All existing masking rules are preserved
- All existing functionality works the same
- The public API remains unchanged
- The original `ner_processor.py` is untouched
## Future Enhancements
1. **Configuration Management**: Centralized configuration for masking rules
2. **Validation Framework**: Dedicated validation components
3. **Performance Optimization**: Caching and optimization strategies
4. **Monitoring**: Metrics and logging for each component
5. **Plugin System**: Dynamic loading of new maskers and extractors
## Conclusion
The refactoring successfully transforms the monolithic `NerProcessor` into a modular, maintainable, and extensible architecture while preserving all existing functionality. The new architecture follows SOLID principles and provides a solid foundation for future enhancements.

View File

@ -1,118 +0,0 @@
# Test Setup Guide
This document explains how to set up and run tests for the legal-doc-masker backend.
## Test Structure
```
backend/
├── tests/
│ ├── __init__.py
│ ├── test_ner_processor.py
│ ├── test1.py
│ └── test.txt
├── conftest.py
├── pytest.ini
└── run_tests.py
```
## VS Code Configuration
The `.vscode/settings.json` file has been configured to:
1. **Set pytest as the test framework**: `"python.testing.pytestEnabled": true`
2. **Point to the correct test directory**: `"python.testing.pytestArgs": ["backend/tests"]`
3. **Set the working directory**: `"python.testing.cwd": "${workspaceFolder}/backend"`
4. **Configure Python interpreter**: Points to backend virtual environment
## Running Tests
### From VS Code Test Explorer
1. Open the Test Explorer panel (Ctrl+Shift+P → "Python: Configure Tests")
2. Select "pytest" as the test framework
3. Select "backend/tests" as the test directory
4. Tests should now appear in the Test Explorer
### From Command Line
```bash
# From the project root
cd backend
python -m pytest tests/ -v
# Or use the test runner script
python run_tests.py
```
### From VS Code Terminal
```bash
# Make sure you're in the backend directory
cd backend
pytest tests/ -v
```
## Test Configuration
### pytest.ini
- **testpaths**: Points to the `tests` directory
- **python_files**: Looks for files starting with `test_` or ending with `_test.py`
- **python_functions**: Looks for functions starting with `test_`
- **markers**: Defines test markers for categorization
### conftest.py
- **Path setup**: Adds backend directory to Python path
- **Fixtures**: Provides common test fixtures
- **Environment setup**: Handles test environment initialization
## Troubleshooting
### Tests Not Discovered
1. **Check VS Code settings**: Ensure `python.testing.pytestArgs` points to `backend/tests`
2. **Verify working directory**: Ensure `python.testing.cwd` is set to `${workspaceFolder}/backend`
3. **Check Python interpreter**: Make sure it points to the backend virtual environment
### Import Errors
1. **Check conftest.py**: Ensures backend directory is in Python path
2. **Verify __init__.py**: Tests directory should have an `__init__.py` file
3. **Check relative imports**: Use absolute imports from the backend root
### Virtual Environment Issues
1. **Create virtual environment**: `python -m venv .venv`
2. **Activate environment**:
- Windows: `.venv\Scripts\activate`
- Unix/MacOS: `source .venv/bin/activate`
3. **Install dependencies**: `pip install -r requirements.txt`
## Test Examples
### Simple Test
```python
def test_simple_assertion():
"""Simple test to verify pytest is working"""
assert 1 == 1
assert 2 + 2 == 4
```
### Test with Fixture
```python
def test_with_fixture(sample_data):
"""Test using a fixture"""
assert sample_data["name"] == "test"
assert sample_data["value"] == 42
```
### Integration Test
```python
def test_ner_processor():
"""Test NER processor functionality"""
from app.core.document_handlers.ner_processor import NerProcessor
processor = NerProcessor()
# Test implementation...
```
## Best Practices
1. **Test naming**: Use descriptive test names starting with `test_`
2. **Test isolation**: Each test should be independent
3. **Use fixtures**: For common setup and teardown
4. **Add markers**: Use `@pytest.mark.slow` for slow tests
5. **Documentation**: Add docstrings to explain test purpose

View File

@ -1 +0,0 @@
# App package

View File

@ -1 +0,0 @@
# Core package

View File

@ -1 +0,0 @@
# Document handlers package

View File

@ -3,7 +3,7 @@ from typing import Optional
from .document_processor import DocumentProcessor from .document_processor import DocumentProcessor
from .processors import ( from .processors import (
TxtDocumentProcessor, TxtDocumentProcessor,
DocxDocumentProcessor, # DocxDocumentProcessor,
PdfDocumentProcessor, PdfDocumentProcessor,
MarkdownDocumentProcessor MarkdownDocumentProcessor
) )
@ -15,8 +15,8 @@ class DocumentProcessorFactory:
processors = { processors = {
'.txt': TxtDocumentProcessor, '.txt': TxtDocumentProcessor,
'.docx': DocxDocumentProcessor, # '.docx': DocxDocumentProcessor,
'.doc': DocxDocumentProcessor, # '.doc': DocxDocumentProcessor,
'.pdf': PdfDocumentProcessor, '.pdf': PdfDocumentProcessor,
'.md': MarkdownDocumentProcessor, '.md': MarkdownDocumentProcessor,
'.markdown': MarkdownDocumentProcessor '.markdown': MarkdownDocumentProcessor

View File

@ -1,17 +0,0 @@
"""
Extractors package for entity component extraction.
"""
from .base_extractor import BaseExtractor
from .llm_extractor import LLMExtractor
from .regex_extractor import RegexExtractor
from .business_name_extractor import BusinessNameExtractor
from .address_extractor import AddressExtractor
__all__ = [
'BaseExtractor',
'LLMExtractor',
'RegexExtractor',
'BusinessNameExtractor',
'AddressExtractor'
]

View File

@ -1,168 +0,0 @@
"""
Address extractor for address components.
"""
import re
import logging
from typing import Dict, Any, Optional
from ...services.ollama_client import OllamaClient
from ...utils.json_extractor import LLMJsonExtractor
from ...utils.llm_validator import LLMResponseValidator
from .base_extractor import BaseExtractor
logger = logging.getLogger(__name__)
class AddressExtractor(BaseExtractor):
"""Extractor for address components"""
def __init__(self, ollama_client: OllamaClient):
self.ollama_client = ollama_client
self._confidence = 0.5 # Default confidence for regex fallback
def extract(self, address: str) -> Optional[Dict[str, str]]:
"""
Extract address components from address.
Args:
address: The address to extract from
Returns:
Dictionary with address components and confidence, or None if extraction fails
"""
if not address:
return None
# Try LLM extraction first
try:
result = self._extract_with_llm(address)
if result:
self._confidence = result.get('confidence', 0.9)
return result
except Exception as e:
logger.warning(f"LLM extraction failed for {address}: {e}")
# Fallback to regex extraction
result = self._extract_with_regex(address)
self._confidence = 0.5 # Lower confidence for regex
return result
def _extract_with_llm(self, address: str) -> Optional[Dict[str, str]]:
"""Extract address components using LLM"""
prompt = f"""
你是一个专业的地址分析助手请从以下地址中提取需要脱敏的组件并严格按照JSON格式返回结果
地址{address}
脱敏规则
1. 保留区级以上地址县等
2. 路名路名需要脱敏以大写首字母替代
3. 门牌号门牌数字需要脱敏****代替
4. 大厦名小区名需要脱敏以大写首字母替代
示例
- 上海市静安区恒丰路66号白云大厦1607室
- 路名恒丰路
- 门牌号66
- 大厦名白云大厦
- 小区名
- 北京市朝阳区建国路88号SOHO现代城A座1001室
- 路名建国路
- 门牌号88
- 大厦名SOHO现代城
- 小区名
- 广州市天河区珠江新城花城大道123号富力中心B座2001室
- 路名花城大道
- 门牌号123
- 大厦名富力中心
- 小区名
请严格按照以下JSON格式输出不要包含任何其他文字
{{
"road_name": "提取的路名",
"house_number": "提取的门牌号",
"building_name": "提取的大厦名",
"community_name": "提取的小区名(如果没有则为空字符串)",
"confidence": 0.9
}}
注意
- road_name字段必须包含路名恒丰路建国路等
- house_number字段必须包含门牌号6688
- building_name字段必须包含大厦名白云大厦SOHO现代城等
- community_name字段包含小区名如果没有则为空字符串
- confidence字段是0-1之间的数字表示提取的置信度
- 必须严格按照JSON格式不要添加任何解释或额外文字
"""
try:
# Use the new enhanced generate method with validation
parsed_response = self.ollama_client.generate_with_validation(
prompt=prompt,
response_type='address_extraction',
return_parsed=True
)
if parsed_response:
logger.info(f"Successfully extracted address components: {parsed_response}")
return parsed_response
else:
logger.warning(f"Failed to extract address components for: {address}")
return None
except Exception as e:
logger.error(f"LLM extraction failed: {e}")
return None
def _extract_with_regex(self, address: str) -> Optional[Dict[str, str]]:
"""Extract address components using regex patterns"""
# Road name pattern: usually ends with "路", "街", "大道", etc.
road_pattern = r'([^省市区县]+[路街大道巷弄])'
# House number pattern: digits + 号
house_number_pattern = r'(\d+)号'
# Building name pattern: usually contains "大厦", "中心", "广场", etc.
building_pattern = r'([^号室]+(?:大厦|中心|广场|城|楼|座))'
# Community name pattern: usually contains "小区", "花园", "苑", etc.
community_pattern = r'([^号室]+(?:小区|花园|苑|园|庭))'
road_name = ""
house_number = ""
building_name = ""
community_name = ""
# Extract road name
road_match = re.search(road_pattern, address)
if road_match:
road_name = road_match.group(1).strip()
# Extract house number
house_match = re.search(house_number_pattern, address)
if house_match:
house_number = house_match.group(1)
# Extract building name
building_match = re.search(building_pattern, address)
if building_match:
building_name = building_match.group(1).strip()
# Extract community name
community_match = re.search(community_pattern, address)
if community_match:
community_name = community_match.group(1).strip()
return {
"road_name": road_name,
"house_number": house_number,
"building_name": building_name,
"community_name": community_name,
"confidence": 0.5 # Lower confidence for regex fallback
}
def get_confidence(self) -> float:
"""Return confidence level of extraction"""
return self._confidence

View File

@ -1,20 +0,0 @@
"""
Abstract base class for all extractors.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
class BaseExtractor(ABC):
"""Abstract base class for all extractors"""
@abstractmethod
def extract(self, text: str) -> Optional[Dict[str, Any]]:
"""Extract components from text"""
pass
@abstractmethod
def get_confidence(self) -> float:
"""Return confidence level of extraction"""
pass

View File

@ -1,192 +0,0 @@
"""
Business name extractor for company names.
"""
import re
import logging
from typing import Dict, Any, Optional
from ...services.ollama_client import OllamaClient
from ...utils.json_extractor import LLMJsonExtractor
from ...utils.llm_validator import LLMResponseValidator
from .base_extractor import BaseExtractor
logger = logging.getLogger(__name__)
class BusinessNameExtractor(BaseExtractor):
"""Extractor for business names from company names"""
def __init__(self, ollama_client: OllamaClient):
self.ollama_client = ollama_client
self._confidence = 0.5 # Default confidence for regex fallback
def extract(self, company_name: str) -> Optional[Dict[str, str]]:
"""
Extract business name from company name.
Args:
company_name: The company name to extract from
Returns:
Dictionary with business name and confidence, or None if extraction fails
"""
if not company_name:
return None
# Try LLM extraction first
try:
result = self._extract_with_llm(company_name)
if result:
self._confidence = result.get('confidence', 0.9)
return result
except Exception as e:
logger.warning(f"LLM extraction failed for {company_name}: {e}")
# Fallback to regex extraction
result = self._extract_with_regex(company_name)
self._confidence = 0.5 # Lower confidence for regex
return result
def _extract_with_llm(self, company_name: str) -> Optional[Dict[str, str]]:
"""Extract business name using LLM"""
prompt = f"""
你是一个专业的公司名称分析助手请从以下公司名称中提取商号企业字号并严格按照JSON格式返回结果
公司名称{company_name}
商号提取规则
1. 公司名通常为地域+商号+业务/行业+组织类型
2. 也有商号+地域+业务/行业+组织类型
3. 商号是企业名称中最具识别性的部分通常是2-4个汉字
4. 不要包含地域行业组织类型等信息
5. 律师事务所的商号通常是地域后的部分
示例
- 上海盒马网络科技有限公司 -> 盒马
- 丰田通商上海有限公司 -> 丰田通商
- 雅诗兰黛上海商贸有限公司 -> 雅诗兰黛
- 北京百度网讯科技有限公司 -> 百度
- 腾讯科技深圳有限公司 -> 腾讯
- 北京大成律师事务所 -> 大成
请严格按照以下JSON格式输出不要包含任何其他文字
{{
"business_name": "提取的商号",
"confidence": 0.9
}}
注意
- business_name字段必须包含提取的商号
- confidence字段是0-1之间的数字表示提取的置信度
- 必须严格按照JSON格式不要添加任何解释或额外文字
"""
try:
# Use the new enhanced generate method with validation
parsed_response = self.ollama_client.generate_with_validation(
prompt=prompt,
response_type='business_name_extraction',
return_parsed=True
)
if parsed_response:
business_name = parsed_response.get('business_name', '')
# Clean business name, keep only Chinese characters
business_name = re.sub(r'[^\u4e00-\u9fff]', '', business_name)
logger.info(f"Successfully extracted business name: {business_name}")
return {
'business_name': business_name,
'confidence': parsed_response.get('confidence', 0.9)
}
else:
logger.warning(f"Failed to extract business name for: {company_name}")
return None
except Exception as e:
logger.error(f"LLM extraction failed: {e}")
return None
def _extract_with_regex(self, company_name: str) -> Optional[Dict[str, str]]:
"""Extract business name using regex patterns"""
# Handle law firms specially
if '律师事务所' in company_name:
return self._extract_law_firm_business_name(company_name)
# Common region prefixes
region_prefixes = [
'北京', '上海', '广州', '深圳', '杭州', '南京', '苏州', '成都', '武汉', '西安',
'天津', '重庆', '青岛', '大连', '宁波', '厦门', '无锡', '长沙', '郑州', '济南',
'哈尔滨', '沈阳', '长春', '石家庄', '太原', '呼和浩特', '合肥', '福州', '南昌',
'南宁', '海口', '贵阳', '昆明', '兰州', '西宁', '银川', '乌鲁木齐', '拉萨',
'香港', '澳门', '台湾'
]
# Common organization type suffixes
org_suffixes = [
'有限公司', '股份有限公司', '有限责任公司', '股份公司', '集团公司', '集团',
'科技公司', '网络公司', '信息技术公司', '软件公司', '互联网公司',
'贸易公司', '商贸公司', '进出口公司', '物流公司', '运输公司',
'房地产公司', '置业公司', '投资公司', '金融公司', '银行',
'保险公司', '证券公司', '基金公司', '信托公司', '租赁公司',
'咨询公司', '服务公司', '管理公司', '广告公司', '传媒公司',
'教育公司', '培训公司', '医疗公司', '医药公司', '生物公司',
'制造公司', '工业公司', '化工公司', '能源公司', '电力公司',
'建筑公司', '工程公司', '建设公司', '开发公司', '设计公司',
'销售公司', '营销公司', '代理公司', '经销商', '零售商',
'连锁公司', '超市', '商场', '百货', '专卖店', '便利店'
]
name = company_name
# Remove region prefix
for region in region_prefixes:
if name.startswith(region):
name = name[len(region):].strip()
break
# Remove region information in parentheses
name = re.sub(r'[(].*?[)]', '', name).strip()
# Remove organization type suffix
for suffix in org_suffixes:
if name.endswith(suffix):
name = name[:-len(suffix)].strip()
break
# If remaining part is too long, try to extract first 2-4 characters
if len(name) > 4:
# Try to find a good break point
for i in range(2, min(5, len(name))):
if name[i] in ['', '', '', '', '', '', '', '', '', '', '', '', '', '']:
name = name[:i]
break
return {
'business_name': name if name else company_name[:2],
'confidence': 0.5
}
def _extract_law_firm_business_name(self, law_firm_name: str) -> Optional[Dict[str, str]]:
"""Extract business name from law firm names"""
# Remove "律师事务所" suffix
name = law_firm_name.replace('律师事务所', '').replace('分所', '').strip()
# Handle region information in parentheses
name = re.sub(r'[(].*?[)]', '', name).strip()
# Common region prefixes
region_prefixes = ['北京', '上海', '广州', '深圳', '杭州', '南京', '苏州', '成都', '武汉', '西安']
for region in region_prefixes:
if name.startswith(region):
name = name[len(region):].strip()
break
return {
'business_name': name,
'confidence': 0.5
}
def get_confidence(self) -> float:
"""Return confidence level of extraction"""
return self._confidence

View File

@ -1,65 +0,0 @@
"""
Factory for creating maskers.
"""
from typing import Dict, Type, Any
from .maskers.base_masker import BaseMasker
from .maskers.name_masker import ChineseNameMasker, EnglishNameMasker
from .maskers.company_masker import CompanyMasker
from .maskers.address_masker import AddressMasker
from .maskers.id_masker import IDMasker
from .maskers.case_masker import CaseMasker
from ...services.ollama_client import OllamaClient
class MaskerFactory:
"""Factory for creating maskers"""
_maskers: Dict[str, Type[BaseMasker]] = {
'chinese_name': ChineseNameMasker,
'english_name': EnglishNameMasker,
'company': CompanyMasker,
'address': AddressMasker,
'id': IDMasker,
'case': CaseMasker,
}
@classmethod
def create_masker(cls, masker_type: str, ollama_client: OllamaClient = None, config: Dict[str, Any] = None) -> BaseMasker:
"""
Create a masker of the specified type.
Args:
masker_type: Type of masker to create
ollama_client: Ollama client for LLM-based maskers
config: Configuration for the masker
Returns:
Instance of the specified masker
Raises:
ValueError: If masker type is unknown
"""
if masker_type not in cls._maskers:
raise ValueError(f"Unknown masker type: {masker_type}")
masker_class = cls._maskers[masker_type]
# Handle maskers that need ollama_client
if masker_type in ['company', 'address']:
if not ollama_client:
raise ValueError(f"Ollama client is required for {masker_type} masker")
return masker_class(ollama_client)
# Handle maskers that don't need special parameters
return masker_class()
@classmethod
def get_available_maskers(cls) -> list[str]:
"""Get list of available masker types"""
return list(cls._maskers.keys())
@classmethod
def register_masker(cls, masker_type: str, masker_class: Type[BaseMasker]):
"""Register a new masker type"""
cls._maskers[masker_type] = masker_class

View File

@ -1,20 +0,0 @@
"""
Maskers package for entity masking functionality.
"""
from .base_masker import BaseMasker
from .name_masker import ChineseNameMasker, EnglishNameMasker
from .company_masker import CompanyMasker
from .address_masker import AddressMasker
from .id_masker import IDMasker
from .case_masker import CaseMasker
__all__ = [
'BaseMasker',
'ChineseNameMasker',
'EnglishNameMasker',
'CompanyMasker',
'AddressMasker',
'IDMasker',
'CaseMasker'
]

View File

@ -1,91 +0,0 @@
"""
Address masker for addresses.
"""
import re
import logging
from typing import Dict, Any
from pypinyin import pinyin, Style
from ...services.ollama_client import OllamaClient
from ..extractors.address_extractor import AddressExtractor
from .base_masker import BaseMasker
logger = logging.getLogger(__name__)
class AddressMasker(BaseMasker):
"""Masker for addresses"""
def __init__(self, ollama_client: OllamaClient):
self.extractor = AddressExtractor(ollama_client)
def mask(self, address: str, context: Dict[str, Any] = None) -> str:
"""
Mask address by replacing components with masked versions.
Args:
address: The address to mask
context: Additional context (not used for address masking)
Returns:
Masked address
"""
if not address:
return address
# Extract address components
components = self.extractor.extract(address)
if not components:
return address
masked_address = address
# Replace road name
if components.get("road_name"):
road_name = components["road_name"]
# Get pinyin initials for road name
try:
pinyin_list = pinyin(road_name, style=Style.NORMAL)
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
masked_address = masked_address.replace(road_name, initials + "")
except Exception as e:
logger.warning(f"Failed to get pinyin for road name {road_name}: {e}")
# Fallback to first character
masked_address = masked_address.replace(road_name, road_name[0].upper() + "")
# Replace house number
if components.get("house_number"):
house_number = components["house_number"]
masked_address = masked_address.replace(house_number + "", "**号")
# Replace building name
if components.get("building_name"):
building_name = components["building_name"]
# Get pinyin initials for building name
try:
pinyin_list = pinyin(building_name, style=Style.NORMAL)
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
masked_address = masked_address.replace(building_name, initials)
except Exception as e:
logger.warning(f"Failed to get pinyin for building name {building_name}: {e}")
# Fallback to first character
masked_address = masked_address.replace(building_name, building_name[0].upper())
# Replace community name
if components.get("community_name"):
community_name = components["community_name"]
# Get pinyin initials for community name
try:
pinyin_list = pinyin(community_name, style=Style.NORMAL)
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
masked_address = masked_address.replace(community_name, initials)
except Exception as e:
logger.warning(f"Failed to get pinyin for community name {community_name}: {e}")
# Fallback to first character
masked_address = masked_address.replace(community_name, community_name[0].upper())
return masked_address
def get_supported_types(self) -> list[str]:
"""Return list of entity types this masker supports"""
return ['地址']

View File

@ -1,24 +0,0 @@
"""
Abstract base class for all maskers.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
class BaseMasker(ABC):
"""Abstract base class for all maskers"""
@abstractmethod
def mask(self, text: str, context: Dict[str, Any] = None) -> str:
"""Mask the given text according to specific rules"""
pass
@abstractmethod
def get_supported_types(self) -> list[str]:
"""Return list of entity types this masker supports"""
pass
def can_mask(self, entity_type: str) -> bool:
"""Check if this masker can handle the given entity type"""
return entity_type in self.get_supported_types()

View File

@ -1,33 +0,0 @@
"""
Case masker for case numbers.
"""
import re
from typing import Dict, Any
from .base_masker import BaseMasker
class CaseMasker(BaseMasker):
"""Masker for case numbers"""
def mask(self, text: str, context: Dict[str, Any] = None) -> str:
"""
Mask case numbers by replacing digits with ***.
Args:
text: The text to mask
context: Additional context (not used for case masking)
Returns:
Masked text
"""
if not text:
return text
# Replace digits with *** while preserving structure
masked = re.sub(r'(\d[\d\s]*)(号)', r'***\2', text)
return masked
def get_supported_types(self) -> list[str]:
"""Return list of entity types this masker supports"""
return ['案号']

View File

@ -1,98 +0,0 @@
"""
Company masker for company names.
"""
import re
import logging
from typing import Dict, Any
from pypinyin import pinyin, Style
from ...services.ollama_client import OllamaClient
from ..extractors.business_name_extractor import BusinessNameExtractor
from .base_masker import BaseMasker
logger = logging.getLogger(__name__)
class CompanyMasker(BaseMasker):
"""Masker for company names"""
def __init__(self, ollama_client: OllamaClient):
self.extractor = BusinessNameExtractor(ollama_client)
def mask(self, company_name: str, context: Dict[str, Any] = None) -> str:
"""
Mask company name by replacing business name with letters.
Args:
company_name: The company name to mask
context: Additional context (not used for company masking)
Returns:
Masked company name
"""
if not company_name:
return company_name
# Extract business name
extraction_result = self.extractor.extract(company_name)
if not extraction_result:
return company_name
business_name = extraction_result.get('business_name', '')
if not business_name:
return company_name
# Get pinyin first letter of business name
try:
pinyin_list = pinyin(business_name, style=Style.NORMAL)
first_letter = pinyin_list[0][0][0].upper() if pinyin_list and pinyin_list[0] else 'A'
except Exception as e:
logger.warning(f"Failed to get pinyin for {business_name}: {e}")
first_letter = 'A'
# Calculate next two letters
if first_letter >= 'Y':
# If first letter is Y or Z, use X and Y
letters = 'XY'
elif first_letter >= 'X':
# If first letter is X, use Y and Z
letters = 'YZ'
else:
# Normal case: use next two letters
letters = chr(ord(first_letter) + 1) + chr(ord(first_letter) + 2)
# Replace business name
if business_name in company_name:
masked_name = company_name.replace(business_name, letters)
else:
# Try smarter replacement
masked_name = self._replace_business_name_in_company(company_name, business_name, letters)
return masked_name
def _replace_business_name_in_company(self, company_name: str, business_name: str, letters: str) -> str:
"""Smart replacement of business name in company name"""
# Try different replacement patterns
patterns = [
business_name,
business_name + '',
business_name + '(',
'' + business_name + '',
'(' + business_name + ')',
]
for pattern in patterns:
if pattern in company_name:
if pattern.endswith('') or pattern.endswith('('):
return company_name.replace(pattern, letters + pattern[-1])
elif pattern.startswith('') or pattern.startswith('('):
return company_name.replace(pattern, pattern[0] + letters + pattern[-1])
else:
return company_name.replace(pattern, letters)
# If no pattern found, return original
return company_name
def get_supported_types(self) -> list[str]:
"""Return list of entity types this masker supports"""
return ['公司名称', 'Company', '英文公司名', 'English Company']

View File

@ -1,39 +0,0 @@
"""
ID masker for ID numbers and social credit codes.
"""
from typing import Dict, Any
from .base_masker import BaseMasker
class IDMasker(BaseMasker):
"""Masker for ID numbers and social credit codes"""
def mask(self, text: str, context: Dict[str, Any] = None) -> str:
"""
Mask ID numbers and social credit codes.
Args:
text: The text to mask
context: Additional context (not used for ID masking)
Returns:
Masked text
"""
if not text:
return text
# Determine the type based on length and format
if len(text) == 18 and text.isdigit():
# ID number: keep first 6 digits
return text[:6] + 'X' * (len(text) - 6)
elif len(text) == 18 and any(c.isalpha() for c in text):
# Social credit code: keep first 7 digits
return text[:7] + 'X' * (len(text) - 7)
else:
# Fallback for invalid formats
return text
def get_supported_types(self) -> list[str]:
"""Return list of entity types this masker supports"""
return ['身份证号', '社会信用代码']

View File

@ -1,89 +0,0 @@
"""
Name maskers for Chinese and English names.
"""
from typing import Dict, Any
from pypinyin import pinyin, Style
from .base_masker import BaseMasker
class ChineseNameMasker(BaseMasker):
"""Masker for Chinese names"""
def __init__(self):
self.surname_counter = {}
def mask(self, name: str, context: Dict[str, Any] = None) -> str:
"""
Mask Chinese names: keep surname, convert given name to pinyin initials.
Args:
name: The name to mask
context: Additional context containing surname_counter
Returns:
Masked name
"""
if not name or len(name) < 2:
return name
# Use context surname_counter if provided, otherwise use instance counter
surname_counter = context.get('surname_counter', self.surname_counter) if context else self.surname_counter
surname = name[0]
given_name = name[1:]
# Get pinyin initials for given name
try:
pinyin_list = pinyin(given_name, style=Style.NORMAL)
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
except Exception:
# Fallback to original characters if pinyin fails
initials = given_name
# Initialize surname counter
if surname not in surname_counter:
surname_counter[surname] = {}
# Check for duplicate surname and initials combination
if initials in surname_counter[surname]:
surname_counter[surname][initials] += 1
masked_name = f"{surname}{initials}{surname_counter[surname][initials]}"
else:
surname_counter[surname][initials] = 1
masked_name = f"{surname}{initials}"
return masked_name
def get_supported_types(self) -> list[str]:
"""Return list of entity types this masker supports"""
return ['人名', '律师姓名', '审判人员姓名']
class EnglishNameMasker(BaseMasker):
"""Masker for English names"""
def mask(self, name: str, context: Dict[str, Any] = None) -> str:
"""
Mask English names: convert each word to first letter + ***.
Args:
name: The name to mask
context: Additional context (not used for English name masking)
Returns:
Masked name
"""
if not name:
return name
masked_parts = []
for part in name.split():
if part:
masked_parts.append(part[0] + '***')
return ' '.join(masked_parts)
def get_supported_types(self) -> list[str]:
"""Return list of entity types this masker supports"""
return ['英文人名']

View File

@ -8,7 +8,6 @@ from ..utils.json_extractor import LLMJsonExtractor
from ..utils.llm_validator import LLMResponseValidator from ..utils.llm_validator import LLMResponseValidator
import re import re
from .regs.entity_regex import extract_id_number_entities, extract_social_credit_code_entities from .regs.entity_regex import extract_id_number_entities, extract_social_credit_code_entities
from pypinyin import pinyin, Style
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,466 +19,29 @@ class NerProcessor:
def _validate_mapping_format(self, mapping: Dict[str, Any]) -> bool: def _validate_mapping_format(self, mapping: Dict[str, Any]) -> bool:
return LLMResponseValidator.validate_entity_extraction(mapping) return LLMResponseValidator.validate_entity_extraction(mapping)
def _mask_chinese_name(self, name: str, surname_counter: Dict[str, Dict[str, int]]) -> str:
"""
处理中文姓名脱敏
保留姓名变为大写首字母
同姓名同首字母者按12依次编号
"""
if not name or len(name) < 2:
return name
surname = name[0]
given_name = name[1:]
# 获取名的拼音首字母
try:
pinyin_list = pinyin(given_name, style=Style.NORMAL)
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
except Exception as e:
logger.warning(f"Failed to get pinyin for {given_name}: {e}")
# 如果拼音转换失败,使用原字符
initials = given_name
# 初始化姓氏计数器
if surname not in surname_counter:
surname_counter[surname] = {}
# 检查是否有相同姓氏和首字母的组合
if initials in surname_counter[surname]:
surname_counter[surname][initials] += 1
masked_name = f"{surname}{initials}{surname_counter[surname][initials]}"
else:
surname_counter[surname][initials] = 1
masked_name = f"{surname}{initials}"
return masked_name
def _extract_business_name(self, company_name: str) -> str:
"""
从公司名称中提取商号企业字号
公司名通常为地域+商号+业务/行业+组织类型
也有商号+地域+业务/行业+组织类型
"""
if not company_name:
return ""
# 律师事务所特殊处理
if '律师事务所' in company_name:
return self._extract_law_firm_business_name(company_name)
# 常见的地域前缀
region_prefixes = [
'北京', '上海', '广州', '深圳', '杭州', '南京', '苏州', '成都', '武汉', '西安',
'天津', '重庆', '青岛', '大连', '宁波', '厦门', '无锡', '长沙', '郑州', '济南',
'哈尔滨', '沈阳', '长春', '石家庄', '太原', '呼和浩特', '合肥', '福州', '南昌',
'南宁', '海口', '贵阳', '昆明', '兰州', '西宁', '银川', '乌鲁木齐', '拉萨',
'香港', '澳门', '台湾'
]
# 常见的组织类型后缀
org_suffixes = [
'有限公司', '股份有限公司', '有限责任公司', '股份公司', '集团公司', '集团',
'科技公司', '网络公司', '信息技术公司', '软件公司', '互联网公司',
'贸易公司', '商贸公司', '进出口公司', '物流公司', '运输公司',
'房地产公司', '置业公司', '投资公司', '金融公司', '银行',
'保险公司', '证券公司', '基金公司', '信托公司', '租赁公司',
'咨询公司', '服务公司', '管理公司', '广告公司', '传媒公司',
'教育公司', '培训公司', '医疗公司', '医药公司', '生物公司',
'制造公司', '工业公司', '化工公司', '能源公司', '电力公司',
'建筑公司', '工程公司', '建设公司', '开发公司', '设计公司',
'销售公司', '营销公司', '代理公司', '经销商', '零售商',
'连锁公司', '超市', '商场', '百货', '专卖店', '便利店'
]
# 尝试使用LLM提取商号
try:
business_name = self._extract_business_name_with_llm(company_name)
if business_name:
return business_name
except Exception as e:
logger.warning(f"LLM extraction failed for {company_name}: {e}")
# 回退到正则表达式方法
return self._extract_business_name_with_regex(company_name, region_prefixes, org_suffixes)
def _extract_law_firm_business_name(self, law_firm_name: str) -> str:
"""
从律师事务所名称中提取商号
律师事务所通常为地域+商号+律师事务所或者地域+商号+律师事务所+地域+分所或者商号+地域+律师事务所
"""
# 移除"律师事务所"后缀
name = law_firm_name.replace('律师事务所', '').replace('分所', '').strip()
# 处理括号中的地域信息
name = re.sub(r'[(].*?[)]', '', name).strip()
# 常见地域前缀
region_prefixes = ['北京', '上海', '广州', '深圳', '杭州', '南京', '苏州', '成都', '武汉', '西安']
for region in region_prefixes:
if name.startswith(region):
return name[len(region):].strip()
return name
def _extract_business_name_with_llm(self, company_name: str) -> str:
"""
使用LLM提取商号
"""
prompt = f"""
你是一个专业的公司名称分析助手请从以下公司名称中提取商号企业字号并严格按照JSON格式返回结果
公司名称{company_name}
商号提取规则
1. 公司名通常为地域+商号+业务/行业+组织类型
2. 也有商号+地域+业务/行业+组织类型
3. 商号是企业名称中最具识别性的部分通常是2-4个汉字
4. 不要包含地域行业组织类型等信息
5. 律师事务所的商号通常是地域后的部分
示例
- 上海盒马网络科技有限公司 -> 盒马
- 丰田通商上海有限公司 -> 丰田通商
- 雅诗兰黛上海商贸有限公司 -> 雅诗兰黛
- 北京百度网讯科技有限公司 -> 百度
- 腾讯科技深圳有限公司 -> 腾讯
- 北京大成律师事务所 -> 大成
请严格按照以下JSON格式输出不要包含任何其他文字
{{
"business_name": "提取的商号",
"confidence": 0.9
}}
注意
- business_name字段必须包含提取的商号
- confidence字段是0-1之间的数字表示提取的置信度
- 必须严格按照JSON格式不要添加任何解释或额外文字
"""
try:
# 使用新的增强generate方法进行验证
parsed_response = self.ollama_client.generate_with_validation(
prompt=prompt,
response_type='business_name_extraction',
return_parsed=True
)
if parsed_response:
business_name = parsed_response.get('business_name', '')
# 清理商号,只保留中文字符
business_name = re.sub(r'[^\u4e00-\u9fff]', '', business_name)
logger.info(f"Successfully extracted business name: {business_name}")
return business_name if business_name else ""
else:
logger.warning(f"Failed to extract business name for: {company_name}")
return ""
except Exception as e:
logger.error(f"LLM extraction failed: {e}")
return ""
def _extract_business_name_with_regex(self, company_name: str, region_prefixes: list, org_suffixes: list) -> str:
"""
使用正则表达式提取商号回退方法
"""
name = company_name
# 移除地域前缀
for region in region_prefixes:
if name.startswith(region):
name = name[len(region):].strip()
break
# 移除括号中的地域信息
name = re.sub(r'[(].*?[)]', '', name).strip()
# 移除组织类型后缀
for suffix in org_suffixes:
if name.endswith(suffix):
name = name[:-len(suffix)].strip()
break
# 如果剩余部分太长尝试提取前2-4个字符作为商号
if len(name) > 4:
# 尝试找到合适的断点
for i in range(2, min(5, len(name))):
if name[i] in ['', '', '', '', '', '', '', '', '', '', '', '', '', '']:
name = name[:i]
break
return name if name else company_name[:2] # 回退到前两个字符
def _mask_company_name(self, company_name: str) -> str:
"""
对公司名称进行脱敏处理
将商号替换为大写字母规则是商号首字母在字母表上的后两位字母
"""
if not company_name:
return company_name
# 提取商号
business_name = self._extract_business_name(company_name)
if not business_name:
return company_name
# 获取商号的拼音首字母
try:
pinyin_list = pinyin(business_name, style=Style.NORMAL)
first_letter = pinyin_list[0][0][0].upper() if pinyin_list and pinyin_list[0] else 'A'
except Exception as e:
logger.warning(f"Failed to get pinyin for {business_name}: {e}")
first_letter = 'A'
# 计算后两位字母
if first_letter >= 'Y':
# 如果首字母是Y或Z回退到X和Y
letters = 'XY'
elif first_letter >= 'X':
# 如果首字母是X使用Y和Z
letters = 'YZ'
else:
# 正常情况:使用首字母后的两个字母
letters = chr(ord(first_letter) + 1) + chr(ord(first_letter) + 2)
# 替换商号
if business_name in company_name:
masked_name = company_name.replace(business_name, letters)
else:
# 如果无法直接替换,尝试更智能的替换
masked_name = self._replace_business_name_in_company(company_name, business_name, letters)
return masked_name
def _replace_business_name_in_company(self, company_name: str, business_name: str, letters: str) -> str:
"""
在公司名称中智能替换商号
"""
# 尝试不同的替换策略
patterns = [
business_name,
business_name + '',
business_name + '(',
'' + business_name + '',
'(' + business_name + ')',
]
for pattern in patterns:
if pattern in company_name:
if pattern.endswith('') or pattern.endswith('('):
return company_name.replace(pattern, letters + pattern[-1])
elif pattern.startswith('') or pattern.startswith('('):
return company_name.replace(pattern, pattern[0] + letters + pattern[-1])
else:
return company_name.replace(pattern, letters)
# 如果都找不到,尝试在合适的位置插入
# 这里可以根据具体的公司名称模式进行更复杂的处理
return company_name
def _extract_address_components(self, address: str) -> Dict[str, str]:
"""
使用LLM提取地址中的路名门牌号大厦名小区名
"""
prompt = f"""
你是一个专业的地址分析助手请从以下地址中提取需要脱敏的组件并严格按照JSON格式返回结果
地址{address}
脱敏规则
1. 保留区级以上地址县等
2. 路名路名需要脱敏以大写首字母替代
3. 门牌号门牌数字需要脱敏****代替
4. 大厦名小区名需要脱敏以大写首字母替代
示例
- 上海市静安区恒丰路66号白云大厦1607室
- 路名恒丰路
- 门牌号66
- 大厦名白云大厦
- 小区名
- 北京市朝阳区建国路88号SOHO现代城A座1001室
- 路名建国路
- 门牌号88
- 大厦名SOHO现代城
- 小区名
- 广州市天河区珠江新城花城大道123号富力中心B座2001室
- 路名花城大道
- 门牌号123
- 大厦名富力中心
- 小区名
请严格按照以下JSON格式输出不要包含任何其他文字
{{
"road_name": "提取的路名",
"house_number": "提取的门牌号",
"building_name": "提取的大厦名",
"community_name": "提取的小区名(如果没有则为空字符串)",
"confidence": 0.9
}}
注意
- road_name字段必须包含路名恒丰路建国路等
- house_number字段必须包含门牌号6688
- building_name字段必须包含大厦名白云大厦SOHO现代城等
- community_name字段包含小区名如果没有则为空字符串
- confidence字段是0-1之间的数字表示提取的置信度
- 必须严格按照JSON格式不要添加任何解释或额外文字
"""
try:
# 使用新的增强generate方法进行验证
parsed_response = self.ollama_client.generate_with_validation(
prompt=prompt,
response_type='address_extraction',
return_parsed=True
)
if parsed_response:
logger.info(f"Successfully extracted address components: {parsed_response}")
return parsed_response
else:
logger.warning(f"Failed to extract address components for: {address}")
return self._extract_address_components_with_regex(address)
except Exception as e:
logger.error(f"LLM extraction failed: {e}")
return self._extract_address_components_with_regex(address)
def _extract_address_components_with_regex(self, address: str) -> Dict[str, str]:
"""
使用正则表达式提取地址组件回退方法
"""
# 路名模式:通常以"路"、"街"、"大道"等结尾
road_pattern = r'([^省市区县]+[路街大道巷弄])'
# 门牌号模式:数字+号
house_number_pattern = r'(\d+)号'
# 大厦名模式:通常包含"大厦"、"中心"、"广场"等
building_pattern = r'([^号室]+(?:大厦|中心|广场|城|楼|座))'
# 小区名模式:通常包含"小区"、"花园"、"苑"等
community_pattern = r'([^号室]+(?:小区|花园|苑|园|庭))'
road_name = ""
house_number = ""
building_name = ""
community_name = ""
# 提取路名
road_match = re.search(road_pattern, address)
if road_match:
road_name = road_match.group(1).strip()
# 提取门牌号
house_match = re.search(house_number_pattern, address)
if house_match:
house_number = house_match.group(1)
# 提取大厦名
building_match = re.search(building_pattern, address)
if building_match:
building_name = building_match.group(1).strip()
# 提取小区名
community_match = re.search(community_pattern, address)
if community_match:
community_name = community_match.group(1).strip()
return {
"road_name": road_name,
"house_number": house_number,
"building_name": building_name,
"community_name": community_name,
"confidence": 0.5 # 较低置信度,因为是回退方法
}
def _mask_address(self, address: str) -> str:
"""
对地址进行脱敏处理
保留区级以上地址路名以大写首字母替代门牌数字以****代替大厦名小区名以大写首字母替代
"""
if not address:
return address
# 提取地址组件
components = self._extract_address_components(address)
masked_address = address
# 替换路名
if components.get("road_name"):
road_name = components["road_name"]
# 获取路名的拼音首字母
try:
pinyin_list = pinyin(road_name, style=Style.NORMAL)
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
masked_address = masked_address.replace(road_name, initials + "")
except Exception as e:
logger.warning(f"Failed to get pinyin for road name {road_name}: {e}")
# 如果拼音转换失败,使用原字符的首字母
masked_address = masked_address.replace(road_name, road_name[0].upper() + "")
# 替换门牌号
if components.get("house_number"):
house_number = components["house_number"]
masked_address = masked_address.replace(house_number + "", "**号")
# 替换大厦名
if components.get("building_name"):
building_name = components["building_name"]
# 获取大厦名的拼音首字母
try:
pinyin_list = pinyin(building_name, style=Style.NORMAL)
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
masked_address = masked_address.replace(building_name, initials)
except Exception as e:
logger.warning(f"Failed to get pinyin for building name {building_name}: {e}")
# 如果拼音转换失败,使用原字符的首字母
masked_address = masked_address.replace(building_name, building_name[0].upper())
# 替换小区名
if components.get("community_name"):
community_name = components["community_name"]
# 获取小区名的拼音首字母
try:
pinyin_list = pinyin(community_name, style=Style.NORMAL)
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
masked_address = masked_address.replace(community_name, initials)
except Exception as e:
logger.warning(f"Failed to get pinyin for community name {community_name}: {e}")
# 如果拼音转换失败,使用原字符的首字母
masked_address = masked_address.replace(community_name, community_name[0].upper())
return masked_address
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]: def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
try: for attempt in range(self.max_retries):
formatted_prompt = prompt_func(chunk) try:
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk: {formatted_prompt}") formatted_prompt = prompt_func(chunk)
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
# 使用新的增强generate方法进行验证 response = self.ollama_client.generate(formatted_prompt)
mapping = self.ollama_client.generate_with_validation( logger.info(f"Raw response from LLM: {response}")
prompt=formatted_prompt,
response_type='entity_extraction', mapping = LLMJsonExtractor.parse_raw_json_str(response)
return_parsed=True logger.info(f"Parsed mapping: {mapping}")
)
if mapping and self._validate_mapping_format(mapping):
logger.info(f"Parsed mapping: {mapping}") return mapping
else:
if mapping and self._validate_mapping_format(mapping): logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
return mapping except Exception as e:
else: logger.error(f"Error generating {entity_type} mapping on attempt {attempt + 1}: {e}")
logger.warning(f"Invalid mapping format received for {entity_type}") if attempt < self.max_retries - 1:
return {} logger.info("Retrying...")
except Exception as e: else:
logger.error(f"Error generating {entity_type} mapping: {e}") logger.error(f"Max retries reached for {entity_type}, returning empty mapping")
return {}
return {}
def build_mapping(self, chunk: str) -> list[Dict[str, str]]: def build_mapping(self, chunk: str) -> list[Dict[str, str]]:
mapping_pipeline = [] mapping_pipeline = []
@ -537,23 +99,22 @@ class NerProcessor:
def _generate_masked_mapping(self, unique_entities: list[Dict[str, str]], linkage: Dict[str, Any]) -> Dict[str, str]: def _generate_masked_mapping(self, unique_entities: list[Dict[str, str]], linkage: Dict[str, Any]) -> Dict[str, str]:
""" """
结合 linkage 信息按实体分组映射同一脱敏名并实现如下规则 结合 linkage 信息按实体分组映射同一脱敏名并实现如下规则
1. 中文人名保留姓名变为大写首字母同姓名同首字母者按12依次编号李强->李Q张韶涵->张SH张若宇->张RY白锦程->白JC 1. 人名/简称保留姓名变为某同姓编号
2. 律师姓名审判人员姓名同上中文人名规则 2. 公司名同组公司名映射为大写字母公司A公司B公司...
3. 公司名将商号替换为大写字母规则是商号首字母在字母表上的后两位字母上海盒马网络科技有限公司->上海JO网络科技有限公司丰田通商上海有限公司->HVVU上海有限公司 3. 英文人名每个单词首字母+***
4. 英文人名每个单词首字母+*** 4. 英文公司名替换为所属行业名称英文大写如无行业信息默认 COMPANY
5. 英文公司名替换为所属行业名称英文大写如无行业信息默认 COMPANY 5. 项目名项目名称变为小写英文字母 a项目b项目...
6. 项目名项目名称变为小写英文字母 a项目b项目... 6. 案号只替换案号中的数字部分为***保留前后结构和支持中间有空格
7. 案号只替换案号中的数字部分为***保留前后结构和""支持中间有空格 7. 身份证号6位X
8. 身份证号保留首6位其他位数变为"X"310103198802080000310103XXXXXXXXXXXX 8. 社会信用代码8位X
9. 社会信用代码保留首7位其他位数变为"X"9133021276453538XT913302XXXXXXXXXXXX 9. 地址保留区级及以上行政区划去除详细位置
10. 地址保留区级以上地址路名以大写首字母替代门牌数字以****代替大厦名小区名以大写首字母替代上海市静安区恒丰路66号白云大厦1607室上海市静安区HF路**号BY大厦**** 10. 其他类型按原有逻辑
11. 其他类型按原有逻辑
""" """
import re import re
entity_mapping = {} entity_mapping = {}
used_masked_names = set() used_masked_names = set()
group_mask_map = {} group_mask_map = {}
surname_counter = {} # 用于中文姓名脱敏的计数器 surname_counter = {}
company_letter = ord('A') company_letter = ord('A')
project_letter = ord('a') project_letter = ord('a')
# 优先区县级单位,后市、省等 # 优先区县级单位,后市、省等
@ -566,17 +127,23 @@ class NerProcessor:
group_type = group.get('group_type', '') group_type = group.get('group_type', '')
entities = group.get('entities', []) entities = group.get('entities', [])
if '公司' in group_type or 'Company' in group_type: if '公司' in group_type or 'Company' in group_type:
masked = chr(company_letter) + '公司'
company_letter += 1
for entity in entities: for entity in entities:
# 使用新的公司名称脱敏方法
masked = self._mask_company_name(entity['text'])
group_mask_map[entity['text']] = masked group_mask_map[entity['text']] = masked
elif '人名' in group_type: elif '人名' in group_type:
surname_local_counter = {}
for entity in entities: for entity in entities:
name = entity['text'] name = entity['text']
if not name: if not name:
continue continue
# 使用新的中文姓名脱敏方法 surname = name[0]
masked = self._mask_chinese_name(name, surname_counter) surname_local_counter.setdefault(surname, 0)
surname_local_counter[surname] += 1
if surname_local_counter[surname] == 1:
masked = f"{surname}"
else:
masked = f"{surname}{surname_local_counter[surname]}"
group_mask_map[name] = masked group_mask_map[name] = masked
elif '英文人名' in group_type: elif '英文人名' in group_type:
for entity in entities: for entity in entities:
@ -606,24 +173,20 @@ class NerProcessor:
entity_mapping[text] = masked entity_mapping[text] = masked
used_masked_names.add(masked) used_masked_names.add(masked)
elif '身份证号' in entity_type: elif '身份证号' in entity_type:
# 保留首6位其他位数变为"X" masked = 'X' * 6
if len(text) >= 6:
masked = text[:6] + 'X' * (len(text) - 6)
else:
masked = text # fallback for invalid length
entity_mapping[text] = masked entity_mapping[text] = masked
used_masked_names.add(masked) used_masked_names.add(masked)
elif '社会信用代码' in entity_type: elif '社会信用代码' in entity_type:
# 保留首7位其他位数变为"X" masked = 'X' * 8
if len(text) >= 7:
masked = text[:7] + 'X' * (len(text) - 7)
else:
masked = text # fallback for invalid length
entity_mapping[text] = masked entity_mapping[text] = masked
used_masked_names.add(masked) used_masked_names.add(masked)
elif '地址' in entity_type: elif '地址' in entity_type:
# 使用新的地址脱敏方法 # 保留区级及以上行政区划,去除详细位置
masked = self._mask_address(text) match = re.match(admin_pattern, text)
if match:
masked = match.group(1)
else:
masked = text # fallback
entity_mapping[text] = masked entity_mapping[text] = masked
used_masked_names.add(masked) used_masked_names.add(masked)
elif '人名' in entity_type: elif '人名' in entity_type:
@ -631,13 +194,18 @@ class NerProcessor:
if not name: if not name:
masked = '' masked = ''
else: else:
# 使用新的中文姓名脱敏方法 surname = name[0]
masked = self._mask_chinese_name(name, surname_counter) surname_counter.setdefault(surname, 0)
surname_counter[surname] += 1
if surname_counter[surname] == 1:
masked = f"{surname}"
else:
masked = f"{surname}{surname_counter[surname]}"
entity_mapping[text] = masked entity_mapping[text] = masked
used_masked_names.add(masked) used_masked_names.add(masked)
elif '公司' in entity_type or 'Company' in entity_type: elif '公司' in entity_type or 'Company' in entity_type:
# 使用新的公司名称脱敏方法 masked = chr(company_letter) + '公司'
masked = self._mask_company_name(text) company_letter += 1
entity_mapping[text] = masked entity_mapping[text] = masked
used_masked_names.add(masked) used_masked_names.add(masked)
elif '英文人名' in entity_type: elif '英文人名' in entity_type:
@ -679,28 +247,29 @@ class NerProcessor:
for entity in linkable_entities for entity in linkable_entities
]) ])
try: for attempt in range(self.max_retries):
formatted_prompt = get_entity_linkage_prompt(entities_text) try:
logger.info(f"Calling ollama to generate entity linkage") formatted_prompt = get_entity_linkage_prompt(entities_text)
logger.info(f"Calling ollama to generate entity linkage (attempt {attempt + 1}/{self.max_retries})")
# 使用新的增强generate方法进行验证 response = self.ollama_client.generate(formatted_prompt)
linkage = self.ollama_client.generate_with_validation( logger.info(f"Raw entity linkage response from LLM: {response}")
prompt=formatted_prompt,
response_type='entity_linkage', linkage = LLMJsonExtractor.parse_raw_json_str(response)
return_parsed=True logger.info(f"Parsed entity linkage: {linkage}")
)
if linkage and self._validate_linkage_format(linkage):
logger.info(f"Parsed entity linkage: {linkage}") logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
return linkage
if linkage and self._validate_linkage_format(linkage): else:
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups") logger.warning(f"Invalid entity linkage format received on attempt {attempt + 1}, retrying...")
return linkage except Exception as e:
else: logger.error(f"Error generating entity linkage on attempt {attempt + 1}: {e}")
logger.warning(f"Invalid entity linkage format received") if attempt < self.max_retries - 1:
return {"entity_groups": []} logger.info("Retrying...")
except Exception as e: else:
logger.error(f"Error generating entity linkage: {e}") logger.error("Max retries reached for entity linkage, returning empty linkage")
return {"entity_groups": []}
return {"entity_groups": []}
def _apply_entity_linkage_to_mapping(self, entity_mapping: Dict[str, str], entity_linkage: Dict[str, Any]) -> Dict[str, str]: def _apply_entity_linkage_to_mapping(self, entity_mapping: Dict[str, str], entity_linkage: Dict[str, Any]) -> Dict[str, str]:
""" """

View File

@ -1,279 +0,0 @@
"""
Refactored NerProcessor using the new masker architecture.
"""
import logging
from typing import Any, Dict, List, Optional
from ..prompts.masking_prompts import (
get_ner_name_prompt, get_ner_company_prompt, get_ner_address_prompt,
get_ner_project_prompt, get_ner_case_number_prompt, get_entity_linkage_prompt
)
from ..services.ollama_client import OllamaClient
from ...core.config import settings
from ..utils.json_extractor import LLMJsonExtractor
from ..utils.llm_validator import LLMResponseValidator
from .regs.entity_regex import extract_id_number_entities, extract_social_credit_code_entities
from .masker_factory import MaskerFactory
from .maskers.base_masker import BaseMasker
logger = logging.getLogger(__name__)
class NerProcessorRefactored:
"""Refactored NerProcessor using the new masker architecture"""
def __init__(self):
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
self.max_retries = 3
self.maskers = self._initialize_maskers()
self.surname_counter = {} # Shared counter for Chinese names
def _initialize_maskers(self) -> Dict[str, BaseMasker]:
"""Initialize all maskers"""
maskers = {}
# Create maskers that don't need ollama_client
maskers['chinese_name'] = MaskerFactory.create_masker('chinese_name')
maskers['english_name'] = MaskerFactory.create_masker('english_name')
maskers['id'] = MaskerFactory.create_masker('id')
maskers['case'] = MaskerFactory.create_masker('case')
# Create maskers that need ollama_client
maskers['company'] = MaskerFactory.create_masker('company', self.ollama_client)
maskers['address'] = MaskerFactory.create_masker('address', self.ollama_client)
return maskers
def _get_masker_for_type(self, entity_type: str) -> Optional[BaseMasker]:
"""Get the appropriate masker for the given entity type"""
for masker in self.maskers.values():
if masker.can_mask(entity_type):
return masker
return None
def _validate_mapping_format(self, mapping: Dict[str, Any]) -> bool:
"""Validate entity extraction mapping format"""
return LLMResponseValidator.validate_entity_extraction(mapping)
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
"""Process entities of a specific type using LLM"""
try:
formatted_prompt = prompt_func(chunk)
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk: {formatted_prompt}")
# Use the new enhanced generate method with validation
mapping = self.ollama_client.generate_with_validation(
prompt=formatted_prompt,
response_type='entity_extraction',
return_parsed=True
)
logger.info(f"Parsed mapping: {mapping}")
if mapping and self._validate_mapping_format(mapping):
return mapping
else:
logger.warning(f"Invalid mapping format received for {entity_type}")
return {}
except Exception as e:
logger.error(f"Error generating {entity_type} mapping: {e}")
return {}
def build_mapping(self, chunk: str) -> List[Dict[str, str]]:
"""Build entity mappings from text chunk"""
mapping_pipeline = []
# Process different entity types
entity_configs = [
(get_ner_name_prompt, "people names"),
(get_ner_company_prompt, "company names"),
(get_ner_address_prompt, "addresses"),
(get_ner_project_prompt, "project names"),
(get_ner_case_number_prompt, "case numbers")
]
for prompt_func, entity_type in entity_configs:
mapping = self._process_entity_type(chunk, prompt_func, entity_type)
if mapping:
mapping_pipeline.append(mapping)
# Process regex-based entities
regex_entity_extractors = [
extract_id_number_entities,
extract_social_credit_code_entities
]
for extractor in regex_entity_extractors:
mapping = extractor(chunk)
if mapping and LLMResponseValidator.validate_regex_entity(mapping):
mapping_pipeline.append(mapping)
elif mapping:
logger.warning(f"Invalid regex entity mapping format: {mapping}")
return mapping_pipeline
def _merge_entity_mappings(self, chunk_mappings: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Merge entity mappings from multiple chunks"""
all_entities = []
for mapping in chunk_mappings:
if isinstance(mapping, dict) and 'entities' in mapping:
entities = mapping['entities']
if isinstance(entities, list):
all_entities.extend(entities)
unique_entities = []
seen_texts = set()
for entity in all_entities:
if isinstance(entity, dict) and 'text' in entity:
text = entity['text'].strip()
if text and text not in seen_texts:
seen_texts.add(text)
unique_entities.append(entity)
elif text and text in seen_texts:
logger.info(f"Duplicate entity found: {entity}")
continue
logger.info(f"Merged {len(unique_entities)} unique entities")
return unique_entities
def _generate_masked_mapping(self, unique_entities: List[Dict[str, str]], linkage: Dict[str, Any]) -> Dict[str, str]:
"""Generate masked mappings for entities"""
entity_mapping = {}
used_masked_names = set()
group_mask_map = {}
# Process entity groups from linkage
for group in linkage.get('entity_groups', []):
group_type = group.get('group_type', '')
entities = group.get('entities', [])
# Handle company groups
if any(keyword in group_type for keyword in ['公司', 'Company']):
for entity in entities:
masker = self._get_masker_for_type('公司名称')
if masker:
masked = masker.mask(entity['text'])
group_mask_map[entity['text']] = masked
# Handle name groups
elif '人名' in group_type:
for entity in entities:
masker = self._get_masker_for_type('人名')
if masker:
context = {'surname_counter': self.surname_counter}
masked = masker.mask(entity['text'], context)
group_mask_map[entity['text']] = masked
# Handle English name groups
elif '英文人名' in group_type:
for entity in entities:
masker = self._get_masker_for_type('英文人名')
if masker:
masked = masker.mask(entity['text'])
group_mask_map[entity['text']] = masked
# Process individual entities
for entity in unique_entities:
text = entity['text']
entity_type = entity.get('type', '')
# Check if entity is in group mapping
if text in group_mask_map:
entity_mapping[text] = group_mask_map[text]
used_masked_names.add(group_mask_map[text])
continue
# Get appropriate masker for entity type
masker = self._get_masker_for_type(entity_type)
if masker:
# Prepare context for maskers that need it
context = {}
if entity_type in ['人名', '律师姓名', '审判人员姓名']:
context['surname_counter'] = self.surname_counter
masked = masker.mask(text, context)
entity_mapping[text] = masked
used_masked_names.add(masked)
else:
# Fallback for unknown entity types
base_name = ''
masked = base_name
counter = 1
while masked in used_masked_names:
if counter <= 10:
suffixes = ['', '', '', '', '', '', '', '', '', '']
masked = base_name + suffixes[counter - 1]
else:
masked = f"{base_name}{counter}"
counter += 1
entity_mapping[text] = masked
used_masked_names.add(masked)
return entity_mapping
def _validate_linkage_format(self, linkage: Dict[str, Any]) -> bool:
"""Validate entity linkage format"""
return LLMResponseValidator.validate_entity_linkage(linkage)
def _create_entity_linkage(self, unique_entities: List[Dict[str, str]]) -> Dict[str, Any]:
"""Create entity linkage information"""
linkable_entities = []
for entity in unique_entities:
entity_type = entity.get('type', '')
if any(keyword in entity_type for keyword in ['公司', 'Company', '人名', '英文人名']):
linkable_entities.append(entity)
if not linkable_entities:
logger.info("No linkable entities found")
return {"entity_groups": []}
entities_text = "\n".join([
f"- {entity['text']} (类型: {entity['type']})"
for entity in linkable_entities
])
try:
formatted_prompt = get_entity_linkage_prompt(entities_text)
logger.info(f"Calling ollama to generate entity linkage")
# Use the new enhanced generate method with validation
linkage = self.ollama_client.generate_with_validation(
prompt=formatted_prompt,
response_type='entity_linkage',
return_parsed=True
)
logger.info(f"Parsed entity linkage: {linkage}")
if linkage and self._validate_linkage_format(linkage):
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
return linkage
else:
logger.warning(f"Invalid entity linkage format received")
return {"entity_groups": []}
except Exception as e:
logger.error(f"Error generating entity linkage: {e}")
return {"entity_groups": []}
def process(self, chunks: List[str]) -> Dict[str, str]:
"""Main processing method"""
chunk_mappings = []
for i, chunk in enumerate(chunks):
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
chunk_mapping = self.build_mapping(chunk)
logger.info(f"Chunk mapping: {chunk_mapping}")
chunk_mappings.extend(chunk_mapping)
logger.info(f"Final chunk mappings: {chunk_mappings}")
unique_entities = self._merge_entity_mappings(chunk_mappings)
logger.info(f"Unique entities: {unique_entities}")
entity_linkage = self._create_entity_linkage(unique_entities)
logger.info(f"Entity linkage: {entity_linkage}")
combined_mapping = self._generate_masked_mapping(unique_entities, entity_linkage)
logger.info(f"Combined mapping: {combined_mapping}")
return combined_mapping

View File

@ -1,6 +1,7 @@
from .txt_processor import TxtDocumentProcessor from .txt_processor import TxtDocumentProcessor
from .docx_processor import DocxDocumentProcessor # from .docx_processor import DocxDocumentProcessor
from .pdf_processor import PdfDocumentProcessor from .pdf_processor import PdfDocumentProcessor
from .md_processor import MarkdownDocumentProcessor from .md_processor import MarkdownDocumentProcessor
__all__ = ['TxtDocumentProcessor', 'DocxDocumentProcessor', 'PdfDocumentProcessor', 'MarkdownDocumentProcessor'] # __all__ = ['TxtDocumentProcessor', 'DocxDocumentProcessor', 'PdfDocumentProcessor', 'MarkdownDocumentProcessor']
__all__ = ['TxtDocumentProcessor', 'PdfDocumentProcessor', 'MarkdownDocumentProcessor']

View File

@ -1,219 +0,0 @@
import os
import requests
import logging
from typing import Dict, Any, Optional
from ...document_handlers.document_processor import DocumentProcessor
from ...services.ollama_client import OllamaClient
from ...config import settings
logger = logging.getLogger(__name__)
class DocxDocumentProcessor(DocumentProcessor):
def __init__(self, input_path: str, output_path: str):
super().__init__() # Call parent class's __init__
self.input_path = input_path
self.output_path = output_path
self.output_dir = os.path.dirname(output_path)
self.name_without_suff = os.path.splitext(os.path.basename(input_path))[0]
# Setup work directory for temporary files
self.work_dir = os.path.join(
os.path.dirname(output_path),
".work",
os.path.splitext(os.path.basename(input_path))[0]
)
os.makedirs(self.work_dir, exist_ok=True)
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
# Mineru API configuration
self.mineru_base_url = getattr(settings, 'MINERU_API_URL', 'http://mineru-api:8000')
self.mineru_timeout = getattr(settings, 'MINERU_TIMEOUT', 300) # 5 minutes timeout
self.mineru_lang_list = getattr(settings, 'MINERU_LANG_LIST', ['ch'])
self.mineru_backend = getattr(settings, 'MINERU_BACKEND', 'pipeline')
self.mineru_parse_method = getattr(settings, 'MINERU_PARSE_METHOD', 'auto')
self.mineru_formula_enable = getattr(settings, 'MINERU_FORMULA_ENABLE', True)
self.mineru_table_enable = getattr(settings, 'MINERU_TABLE_ENABLE', True)
def _call_mineru_api(self, file_path: str) -> Optional[Dict[str, Any]]:
"""
Call Mineru API to convert DOCX to markdown
Args:
file_path: Path to the DOCX file
Returns:
API response as dictionary or None if failed
"""
try:
url = f"{self.mineru_base_url}/file_parse"
with open(file_path, 'rb') as file:
files = {'files': (os.path.basename(file_path), file, 'application/vnd.openxmlformats-officedocument.wordprocessingml.document')}
# Prepare form data according to Mineru API specification
data = {
'output_dir': './output',
'lang_list': self.mineru_lang_list,
'backend': self.mineru_backend,
'parse_method': self.mineru_parse_method,
'formula_enable': self.mineru_formula_enable,
'table_enable': self.mineru_table_enable,
'return_md': True,
'return_middle_json': False,
'return_model_output': False,
'return_content_list': False,
'return_images': False,
'start_page_id': 0,
'end_page_id': 99999
}
logger.info(f"Calling Mineru API for DOCX processing at {url}")
response = requests.post(
url,
files=files,
data=data,
timeout=self.mineru_timeout
)
if response.status_code == 200:
result = response.json()
logger.info("Successfully received response from Mineru API for DOCX")
return result
else:
error_msg = f"Mineru API returned status code {response.status_code}: {response.text}"
logger.error(error_msg)
# For 400 errors, include more specific information
if response.status_code == 400:
try:
error_data = response.json()
if 'error' in error_data:
error_msg = f"Mineru API error: {error_data['error']}"
except:
pass
raise Exception(error_msg)
except requests.exceptions.Timeout:
error_msg = f"Mineru API request timed out after {self.mineru_timeout} seconds"
logger.error(error_msg)
raise Exception(error_msg)
except requests.exceptions.RequestException as e:
error_msg = f"Error calling Mineru API for DOCX: {str(e)}"
logger.error(error_msg)
raise Exception(error_msg)
except Exception as e:
error_msg = f"Unexpected error calling Mineru API for DOCX: {str(e)}"
logger.error(error_msg)
raise Exception(error_msg)
def _extract_markdown_from_response(self, response: Dict[str, Any]) -> str:
"""
Extract markdown content from Mineru API response
Args:
response: Mineru API response dictionary
Returns:
Extracted markdown content as string
"""
try:
logger.debug(f"Mineru API response structure for DOCX: {response}")
# Try different possible response formats based on Mineru API
if 'markdown' in response:
return response['markdown']
elif 'md' in response:
return response['md']
elif 'content' in response:
return response['content']
elif 'text' in response:
return response['text']
elif 'result' in response and isinstance(response['result'], dict):
result = response['result']
if 'markdown' in result:
return result['markdown']
elif 'md' in result:
return result['md']
elif 'content' in result:
return result['content']
elif 'text' in result:
return result['text']
elif 'data' in response and isinstance(response['data'], dict):
data = response['data']
if 'markdown' in data:
return data['markdown']
elif 'md' in data:
return data['md']
elif 'content' in data:
return data['content']
elif 'text' in data:
return data['text']
elif isinstance(response, list) and len(response) > 0:
# If response is a list, try to extract from first item
first_item = response[0]
if isinstance(first_item, dict):
return self._extract_markdown_from_response(first_item)
elif isinstance(first_item, str):
return first_item
else:
# If no standard format found, try to extract from the response structure
logger.warning("Could not find standard markdown field in Mineru response for DOCX")
# Return the response as string if it's simple, or empty string
if isinstance(response, str):
return response
elif isinstance(response, dict):
# Try to find any text-like content
for key, value in response.items():
if isinstance(value, str) and len(value) > 100: # Likely content
return value
elif isinstance(value, dict):
# Recursively search in nested dictionaries
nested_content = self._extract_markdown_from_response(value)
if nested_content:
return nested_content
return ""
except Exception as e:
logger.error(f"Error extracting markdown from Mineru response for DOCX: {str(e)}")
return ""
def read_content(self) -> str:
logger.info("Starting DOCX content processing with Mineru API")
# Call Mineru API to convert DOCX to markdown
# This will raise an exception if the API call fails
mineru_response = self._call_mineru_api(self.input_path)
# Extract markdown content from the response
markdown_content = self._extract_markdown_from_response(mineru_response)
if not markdown_content:
raise Exception("No markdown content found in Mineru API response for DOCX")
logger.info(f"Successfully extracted {len(markdown_content)} characters of markdown content from DOCX")
# Save the raw markdown content to work directory for reference
md_output_path = os.path.join(self.work_dir, f"{self.name_without_suff}.md")
with open(md_output_path, 'w', encoding='utf-8') as file:
file.write(markdown_content)
logger.info(f"Saved raw markdown content from DOCX to {md_output_path}")
return markdown_content
def save_content(self, content: str) -> None:
# Ensure output path has .md extension
output_dir = os.path.dirname(self.output_path)
base_name = os.path.splitext(os.path.basename(self.output_path))[0]
md_output_path = os.path.join(output_dir, f"{base_name}.md")
logger.info(f"Saving masked DOCX content to: {md_output_path}")
try:
with open(md_output_path, 'w', encoding='utf-8') as file:
file.write(content)
logger.info(f"Successfully saved masked DOCX content to {md_output_path}")
except Exception as e:
logger.error(f"Error saving masked DOCX content: {e}")
raise

View File

@ -0,0 +1,77 @@
import os
import docx
from ...document_handlers.document_processor import DocumentProcessor
from magic_pdf.data.data_reader_writer import FileBasedDataWriter
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.read_api import read_local_office
import logging
from ...services.ollama_client import OllamaClient
from ...config import settings
from ...prompts.masking_prompts import get_masking_mapping_prompt
logger = logging.getLogger(__name__)
class DocxDocumentProcessor(DocumentProcessor):
def __init__(self, input_path: str, output_path: str):
super().__init__() # Call parent class's __init__
self.input_path = input_path
self.output_path = output_path
self.output_dir = os.path.dirname(output_path)
self.name_without_suff = os.path.splitext(os.path.basename(input_path))[0]
# Setup output directories
self.local_image_dir = os.path.join(self.output_dir, "images")
self.image_dir = os.path.basename(self.local_image_dir)
os.makedirs(self.local_image_dir, exist_ok=True)
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
def read_content(self) -> str:
try:
# Initialize writers
image_writer = FileBasedDataWriter(self.local_image_dir)
md_writer = FileBasedDataWriter(self.output_dir)
# Create Dataset Instance and process
ds = read_local_office(self.input_path)[0]
pipe_result = ds.apply(doc_analyze, ocr=True).pipe_txt_mode(image_writer)
# Generate markdown
md_content = pipe_result.get_markdown(self.image_dir)
pipe_result.dump_md(md_writer, f"{self.name_without_suff}.md", self.image_dir)
return md_content
except Exception as e:
logger.error(f"Error converting DOCX to MD: {e}")
raise
# def process_content(self, content: str) -> str:
# logger.info("Processing DOCX content")
# # Split content into sentences and apply masking
# sentences = content.split("。")
# final_md = ""
# for sentence in sentences:
# if sentence.strip(): # Only process non-empty sentences
# formatted_prompt = get_masking_mapping_prompt(sentence)
# logger.info("Calling ollama to generate response, prompt: %s", formatted_prompt)
# response = self.ollama_client.generate(formatted_prompt)
# logger.info(f"Response generated: {response}")
# final_md += response + "。"
# return final_md
def save_content(self, content: str) -> None:
# Ensure output path has .md extension
output_dir = os.path.dirname(self.output_path)
base_name = os.path.splitext(os.path.basename(self.output_path))[0]
md_output_path = os.path.join(output_dir, f"{base_name}.md")
logger.info(f"Saving masked content to: {md_output_path}")
try:
with open(md_output_path, 'w', encoding='utf-8') as file:
file.write(content)
logger.info(f"Successfully saved content to {md_output_path}")
except Exception as e:
logger.error(f"Error saving content: {e}")
raise

View File

@ -81,30 +81,18 @@ class PdfDocumentProcessor(DocumentProcessor):
logger.info("Successfully received response from Mineru API") logger.info("Successfully received response from Mineru API")
return result return result
else: else:
error_msg = f"Mineru API returned status code {response.status_code}: {response.text}" logger.error(f"Mineru API returned status code {response.status_code}: {response.text}")
logger.error(error_msg) return None
# For 400 errors, include more specific information
if response.status_code == 400:
try:
error_data = response.json()
if 'error' in error_data:
error_msg = f"Mineru API error: {error_data['error']}"
except:
pass
raise Exception(error_msg)
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
error_msg = f"Mineru API request timed out after {self.mineru_timeout} seconds" logger.error(f"Mineru API request timed out after {self.mineru_timeout} seconds")
logger.error(error_msg) return None
raise Exception(error_msg)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
error_msg = f"Error calling Mineru API: {str(e)}" logger.error(f"Error calling Mineru API: {str(e)}")
logger.error(error_msg) return None
raise Exception(error_msg)
except Exception as e: except Exception as e:
error_msg = f"Unexpected error calling Mineru API: {str(e)}" logger.error(f"Unexpected error calling Mineru API: {str(e)}")
logger.error(error_msg) return None
raise Exception(error_msg)
def _extract_markdown_from_response(self, response: Dict[str, Any]) -> str: def _extract_markdown_from_response(self, response: Dict[str, Any]) -> str:
""" """
@ -183,9 +171,11 @@ class PdfDocumentProcessor(DocumentProcessor):
logger.info("Starting PDF content processing with Mineru API") logger.info("Starting PDF content processing with Mineru API")
# Call Mineru API to convert PDF to markdown # Call Mineru API to convert PDF to markdown
# This will raise an exception if the API call fails
mineru_response = self._call_mineru_api(self.input_path) mineru_response = self._call_mineru_api(self.input_path)
if not mineru_response:
raise Exception("Failed to get response from Mineru API")
# Extract markdown content from the response # Extract markdown content from the response
markdown_content = self._extract_markdown_from_response(mineru_response) markdown_content = self._extract_markdown_from_response(mineru_response)

View File

@ -13,7 +13,7 @@ class DocumentService:
processor = DocumentProcessorFactory.create_processor(input_path, output_path) processor = DocumentProcessorFactory.create_processor(input_path, output_path)
if not processor: if not processor:
logger.error(f"Unsupported file format: {input_path}") logger.error(f"Unsupported file format: {input_path}")
raise Exception(f"Unsupported file format: {input_path}") return False
# Read content # Read content
content = processor.read_content() content = processor.read_content()
@ -27,5 +27,4 @@ class DocumentService:
except Exception as e: except Exception as e:
logger.error(f"Error processing document {input_path}: {str(e)}") logger.error(f"Error processing document {input_path}: {str(e)}")
# Re-raise the exception so the Celery task can handle it properly return False
raise

View File

@ -1,222 +1,72 @@
import requests import requests
import logging import logging
from typing import Dict, Any, Optional, Callable, Union from typing import Dict, Any
from ..utils.json_extractor import LLMJsonExtractor
from ..utils.llm_validator import LLMResponseValidator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OllamaClient: class OllamaClient:
def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3): def __init__(self, model_name: str, base_url: str = "http://localhost:11434"):
"""Initialize Ollama client. """Initialize Ollama client.
Args: Args:
model_name (str): Name of the Ollama model to use model_name (str): Name of the Ollama model to use
base_url (str): Ollama server base URL host (str): Ollama server host address
max_retries (int): Maximum number of retries for failed requests port (int): Ollama server port
""" """
self.model_name = model_name self.model_name = model_name
self.base_url = base_url self.base_url = base_url
self.max_retries = max_retries
self.headers = {"Content-Type": "application/json"} self.headers = {"Content-Type": "application/json"}
def generate(self, def generate(self, prompt: str, strip_think: bool = True) -> str:
prompt: str, """Process a document using the Ollama API.
strip_think: bool = True,
validation_schema: Optional[Dict[str, Any]] = None,
response_type: Optional[str] = None,
return_parsed: bool = False) -> Union[str, Dict[str, Any]]:
"""Process a document using the Ollama API with optional validation and retry.
Args: Args:
prompt (str): The prompt to send to the model document_text (str): The text content to process
strip_think (bool): Whether to strip thinking tags from response
validation_schema (Optional[Dict]): JSON schema for validation
response_type (Optional[str]): Type of response for validation ('entity_extraction', 'entity_linkage', etc.)
return_parsed (bool): Whether to return parsed JSON instead of raw string
Returns: Returns:
Union[str, Dict[str, Any]]: Response from the model (raw string or parsed JSON) str: Processed text response from the model
Raises: Raises:
RequestException: If the API call fails after all retries RequestException: If the API call fails
ValueError: If validation fails after all retries
"""
for attempt in range(self.max_retries):
try:
# Make the API call
raw_response = self._make_api_call(prompt, strip_think)
# If no validation required, return the response
if not validation_schema and not response_type and not return_parsed:
return raw_response
# Parse JSON if needed
if return_parsed or validation_schema or response_type:
parsed_response = LLMJsonExtractor.parse_raw_json_str(raw_response)
if not parsed_response:
logger.warning(f"Failed to parse JSON on attempt {attempt + 1}/{self.max_retries}")
if attempt < self.max_retries - 1:
continue
else:
raise ValueError("Failed to parse JSON response after all retries")
# Validate if schema or response type provided
if validation_schema:
if not self._validate_with_schema(parsed_response, validation_schema):
logger.warning(f"Schema validation failed on attempt {attempt + 1}/{self.max_retries}")
if attempt < self.max_retries - 1:
continue
else:
raise ValueError("Schema validation failed after all retries")
if response_type:
if not LLMResponseValidator.validate_response_by_type(parsed_response, response_type):
logger.warning(f"Response type validation failed on attempt {attempt + 1}/{self.max_retries}")
if attempt < self.max_retries - 1:
continue
else:
raise ValueError(f"Response type validation failed after all retries")
# Return parsed response if requested
if return_parsed:
return parsed_response
else:
return raw_response
return raw_response
except requests.exceptions.RequestException as e:
logger.error(f"API call failed on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
logger.error("Max retries reached, raising exception")
raise
except Exception as e:
logger.error(f"Unexpected error on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
logger.error("Max retries reached, raising exception")
raise
# This should never be reached, but just in case
raise Exception("Unexpected error: max retries exceeded without proper exception handling")
def generate_with_validation(self,
prompt: str,
response_type: str,
strip_think: bool = True,
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
"""Generate response with automatic validation based on response type.
Args:
prompt (str): The prompt to send to the model
response_type (str): Type of response for validation
strip_think (bool): Whether to strip thinking tags from response
return_parsed (bool): Whether to return parsed JSON instead of raw string
Returns:
Union[str, Dict[str, Any]]: Validated response from the model
"""
return self.generate(
prompt=prompt,
strip_think=strip_think,
response_type=response_type,
return_parsed=return_parsed
)
def generate_with_schema(self,
prompt: str,
schema: Dict[str, Any],
strip_think: bool = True,
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
"""Generate response with custom schema validation.
Args:
prompt (str): The prompt to send to the model
schema (Dict): JSON schema for validation
strip_think (bool): Whether to strip thinking tags from response
return_parsed (bool): Whether to return parsed JSON instead of raw string
Returns:
Union[str, Dict[str, Any]]: Validated response from the model
"""
return self.generate(
prompt=prompt,
strip_think=strip_think,
validation_schema=schema,
return_parsed=return_parsed
)
def _make_api_call(self, prompt: str, strip_think: bool) -> str:
"""Make the actual API call to Ollama.
Args:
prompt (str): The prompt to send
strip_think (bool): Whether to strip thinking tags
Returns:
str: Raw response from the API
"""
url = f"{self.base_url}/api/generate"
payload = {
"model": self.model_name,
"prompt": prompt,
"stream": False
}
logger.debug(f"Sending request to Ollama API: {url}")
response = requests.post(url, json=payload, headers=self.headers)
response.raise_for_status()
result = response.json()
logger.debug(f"Received response from Ollama API: {result}")
if strip_think:
# Remove the "thinking" part from the response
# the response is expected to be <think>...</think>response_text
# Check if the response contains <think> tag
if "<think>" in result.get("response", ""):
# Split the response and take the part after </think>
response_parts = result["response"].split("</think>")
if len(response_parts) > 1:
# Return the part after </think>
return response_parts[1].strip()
else:
# If no closing tag, return the full response
return result.get("response", "").strip()
else:
# If no <think> tag, return the full response
return result.get("response", "").strip()
else:
# If strip_think is False, return the full response
return result.get("response", "")
def _validate_with_schema(self, response: Dict[str, Any], schema: Dict[str, Any]) -> bool:
"""Validate response against a JSON schema.
Args:
response (Dict): The parsed response to validate
schema (Dict): The JSON schema to validate against
Returns:
bool: True if valid, False otherwise
""" """
try: try:
from jsonschema import validate, ValidationError url = f"{self.base_url}/api/generate"
validate(instance=response, schema=schema) payload = {
logger.debug(f"Schema validation passed for response: {response}") "model": self.model_name,
return True "prompt": prompt,
except ValidationError as e: "stream": False
logger.warning(f"Schema validation failed: {e}") }
logger.warning(f"Response that failed validation: {response}")
return False logger.debug(f"Sending request to Ollama API: {url}")
except ImportError: response = requests.post(url, json=payload, headers=self.headers)
logger.error("jsonschema library not available for validation") response.raise_for_status()
return False
result = response.json()
logger.debug(f"Received response from Ollama API: {result}")
if strip_think:
# Remove the "thinking" part from the response
# the response is expected to be <think>...</think>response_text
# Check if the response contains <think> tag
if "<think>" in result.get("response", ""):
# Split the response and take the part after </think>
response_parts = result["response"].split("</think>")
if len(response_parts) > 1:
# Return the part after </think>
return response_parts[1].strip()
else:
# If no closing tag, return the full response
return result.get("response", "").strip()
else:
# If no <think> tag, return the full response
return result.get("response", "").strip()
else:
# If strip_think is False, return the full response
return result.get("response", "")
except requests.exceptions.RequestException as e:
logger.error(f"Error calling Ollama API: {str(e)}")
raise
def get_model_info(self) -> Dict[str, Any]: def get_model_info(self) -> Dict[str, Any]:
"""Get information about the current model. """Get information about the current model.

View File

@ -77,54 +77,6 @@ class LLMResponseValidator:
"required": ["entities"] "required": ["entities"]
} }
# Schema for business name extraction responses
BUSINESS_NAME_EXTRACTION_SCHEMA = {
"type": "object",
"properties": {
"business_name": {
"type": "string",
"description": "The extracted business name (商号) from the company name"
},
"confidence": {
"type": "number",
"minimum": 0,
"maximum": 1,
"description": "Confidence level of the extraction (0-1)"
}
},
"required": ["business_name"]
}
# Schema for address extraction responses
ADDRESS_EXTRACTION_SCHEMA = {
"type": "object",
"properties": {
"road_name": {
"type": "string",
"description": "The road name (路名) to be masked"
},
"house_number": {
"type": "string",
"description": "The house number (门牌号) to be masked"
},
"building_name": {
"type": "string",
"description": "The building name (大厦名) to be masked"
},
"community_name": {
"type": "string",
"description": "The community name (小区名) to be masked"
},
"confidence": {
"type": "number",
"minimum": 0,
"maximum": 1,
"description": "Confidence level of the extraction (0-1)"
}
},
"required": ["road_name", "house_number", "building_name", "community_name"]
}
@classmethod @classmethod
def validate_entity_extraction(cls, response: Dict[str, Any]) -> bool: def validate_entity_extraction(cls, response: Dict[str, Any]) -> bool:
""" """
@ -190,46 +142,6 @@ class LLMResponseValidator:
logger.warning(f"Response that failed validation: {response}") logger.warning(f"Response that failed validation: {response}")
return False return False
@classmethod
def validate_business_name_extraction(cls, response: Dict[str, Any]) -> bool:
"""
Validate business name extraction response from LLM.
Args:
response: The parsed JSON response from LLM
Returns:
bool: True if valid, False otherwise
"""
try:
validate(instance=response, schema=cls.BUSINESS_NAME_EXTRACTION_SCHEMA)
logger.debug(f"Business name extraction validation passed for response: {response}")
return True
except ValidationError as e:
logger.warning(f"Business name extraction validation failed: {e}")
logger.warning(f"Response that failed validation: {response}")
return False
@classmethod
def validate_address_extraction(cls, response: Dict[str, Any]) -> bool:
"""
Validate address extraction response from LLM.
Args:
response: The parsed JSON response from LLM
Returns:
bool: True if valid, False otherwise
"""
try:
validate(instance=response, schema=cls.ADDRESS_EXTRACTION_SCHEMA)
logger.debug(f"Address extraction validation passed for response: {response}")
return True
except ValidationError as e:
logger.warning(f"Address extraction validation failed: {e}")
logger.warning(f"Response that failed validation: {response}")
return False
@classmethod @classmethod
def _validate_linkage_content(cls, response: Dict[str, Any]) -> bool: def _validate_linkage_content(cls, response: Dict[str, Any]) -> bool:
""" """
@ -289,9 +201,7 @@ class LLMResponseValidator:
validators = { validators = {
'entity_extraction': cls.validate_entity_extraction, 'entity_extraction': cls.validate_entity_extraction,
'entity_linkage': cls.validate_entity_linkage, 'entity_linkage': cls.validate_entity_linkage,
'regex_entity': cls.validate_regex_entity, 'regex_entity': cls.validate_regex_entity
'business_name_extraction': cls.validate_business_name_extraction,
'address_extraction': cls.validate_address_extraction
} }
validator = validators.get(response_type) validator = validators.get(response_type)
@ -322,10 +232,6 @@ class LLMResponseValidator:
return "Content validation failed for entity linkage" return "Content validation failed for entity linkage"
elif response_type == 'regex_entity': elif response_type == 'regex_entity':
validate(instance=response, schema=cls.REGEX_ENTITY_SCHEMA) validate(instance=response, schema=cls.REGEX_ENTITY_SCHEMA)
elif response_type == 'business_name_extraction':
validate(instance=response, schema=cls.BUSINESS_NAME_EXTRACTION_SCHEMA)
elif response_type == 'address_extraction':
validate(instance=response, schema=cls.ADDRESS_EXTRACTION_SCHEMA)
else: else:
return f"Unknown response type: {response_type}" return f"Unknown response type: {response_type}"

View File

@ -70,7 +70,6 @@ def process_file(file_id: str):
output_path = str(settings.PROCESSED_FOLDER / output_filename) output_path = str(settings.PROCESSED_FOLDER / output_filename)
# Process document with both input and output paths # Process document with both input and output paths
# This will raise an exception if processing fails
process_service.process_document(file.original_path, output_path) process_service.process_document(file.original_path, output_path)
# Update file record with processed path # Update file record with processed path
@ -82,7 +81,6 @@ def process_file(file_id: str):
file.status = FileStatus.FAILED file.status = FileStatus.FAILED
file.error_message = str(e) file.error_message = str(e)
db.commit() db.commit()
# Re-raise the exception to ensure Celery marks the task as failed
raise raise
finally: finally:

View File

@ -1,33 +0,0 @@
import pytest
import sys
import os
from pathlib import Path
# Add the backend directory to Python path for imports
backend_dir = Path(__file__).parent
sys.path.insert(0, str(backend_dir))
# Also add the current directory to ensure imports work
current_dir = Path(__file__).parent
sys.path.insert(0, str(current_dir))
@pytest.fixture
def sample_data():
"""Sample data fixture for testing"""
return {
"name": "test",
"value": 42,
"items": [1, 2, 3]
}
@pytest.fixture
def test_files_dir():
"""Fixture to get the test files directory"""
return Path(__file__).parent / "tests"
@pytest.fixture(autouse=True)
def setup_test_environment():
"""Setup test environment before each test"""
# Add any test environment setup here
yield
# Add any cleanup here

View File

@ -7,6 +7,7 @@ services:
- "8000:8000" - "8000:8000"
volumes: volumes:
- ./storage:/app/storage - ./storage:/app/storage
- ./legal_doc_masker.db:/app/legal_doc_masker.db
env_file: env_file:
- .env - .env
environment: environment:
@ -20,6 +21,7 @@ services:
command: celery -A app.services.file_service worker --loglevel=info command: celery -A app.services.file_service worker --loglevel=info
volumes: volumes:
- ./storage:/app/storage - ./storage:/app/storage
- ./legal_doc_masker.db:/app/legal_doc_masker.db
env_file: env_file:
- .env - .env
environment: environment:

View File

@ -1,15 +0,0 @@
[tool:pytest]
testpaths = tests
pythonpath = .
python_files = test_*.py *_test.py
python_classes = Test*
python_functions = test_*
addopts =
-v
--tb=short
--strict-markers
--disable-warnings
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
integration: marks tests as integration tests
unit: marks tests as unit tests

View File

@ -29,7 +29,4 @@ python-docx>=0.8.11
PyPDF2>=3.0.0 PyPDF2>=3.0.0
pandas>=2.0.0 pandas>=2.0.0
# magic-pdf[full] # magic-pdf[full]
jsonschema>=4.20.0 jsonschema>=4.20.0
# Chinese text processing
pypinyin>=0.50.0

View File

@ -1,32 +0,0 @@
#!/usr/bin/env python3
"""
Simple test runner script to verify test discovery and execution
"""
import subprocess
import sys
import os
from pathlib import Path
def run_tests():
"""Run pytest with proper configuration"""
# Change to backend directory
backend_dir = Path(__file__).parent
os.chdir(backend_dir)
# Run pytest
cmd = [sys.executable, "-m", "pytest", "tests/", "-v", "--tb=short"]
print(f"Running tests from: {backend_dir}")
print(f"Command: {' '.join(cmd)}")
print("-" * 50)
try:
result = subprocess.run(cmd, capture_output=False, text=True)
return result.returncode
except Exception as e:
print(f"Error running tests: {e}")
return 1
if __name__ == "__main__":
exit_code = run_tests()
sys.exit(exit_code)

View File

@ -1,230 +0,0 @@
"""
Test file for the enhanced OllamaClient with validation and retry mechanisms.
"""
import sys
import os
import json
from unittest.mock import Mock, patch
# Add the current directory to the Python path
sys.path.insert(0, os.path.dirname(__file__))
def test_ollama_client_initialization():
"""Test OllamaClient initialization with new parameters"""
from app.core.services.ollama_client import OllamaClient
# Test with default parameters
client = OllamaClient("test-model")
assert client.model_name == "test-model"
assert client.base_url == "http://localhost:11434"
assert client.max_retries == 3
# Test with custom parameters
client = OllamaClient("test-model", "http://custom:11434", 5)
assert client.model_name == "test-model"
assert client.base_url == "http://custom:11434"
assert client.max_retries == 5
print("✓ OllamaClient initialization tests passed")
def test_generate_with_validation():
"""Test generate_with_validation method"""
from app.core.services.ollama_client import OllamaClient
# Mock the API response
mock_response = Mock()
mock_response.json.return_value = {
"response": '{"business_name": "测试公司", "confidence": 0.9}'
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model")
# Test with business name extraction validation
result = client.generate_with_validation(
prompt="Extract business name from: 测试公司",
response_type='business_name_extraction',
return_parsed=True
)
assert isinstance(result, dict)
assert result.get('business_name') == '测试公司'
assert result.get('confidence') == 0.9
print("✓ generate_with_validation test passed")
def test_generate_with_schema():
"""Test generate_with_schema method"""
from app.core.services.ollama_client import OllamaClient
# Define a custom schema
custom_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"}
},
"required": ["name", "age"]
}
# Mock the API response
mock_response = Mock()
mock_response.json.return_value = {
"response": '{"name": "张三", "age": 30}'
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model")
# Test with custom schema validation
result = client.generate_with_schema(
prompt="Generate person info",
schema=custom_schema,
return_parsed=True
)
assert isinstance(result, dict)
assert result.get('name') == '张三'
assert result.get('age') == 30
print("✓ generate_with_schema test passed")
def test_backward_compatibility():
"""Test backward compatibility with original generate method"""
from app.core.services.ollama_client import OllamaClient
# Mock the API response
mock_response = Mock()
mock_response.json.return_value = {
"response": "Simple text response"
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model")
# Test original generate method (should still work)
result = client.generate("Simple prompt")
assert result == "Simple text response"
# Test with strip_think=False
result = client.generate("Simple prompt", strip_think=False)
assert result == "Simple text response"
print("✓ Backward compatibility tests passed")
def test_retry_mechanism():
"""Test retry mechanism for failed requests"""
from app.core.services.ollama_client import OllamaClient
import requests
# Mock failed requests followed by success
mock_failed_response = Mock()
mock_failed_response.raise_for_status.side_effect = requests.exceptions.RequestException("Connection failed")
mock_success_response = Mock()
mock_success_response.json.return_value = {
"response": "Success response"
}
mock_success_response.raise_for_status.return_value = None
with patch('requests.post', side_effect=[mock_failed_response, mock_success_response]):
client = OllamaClient("test-model", max_retries=2)
# Should retry and eventually succeed
result = client.generate("Test prompt")
assert result == "Success response"
print("✓ Retry mechanism test passed")
def test_validation_failure():
"""Test validation failure handling"""
from app.core.services.ollama_client import OllamaClient
# Mock API response with invalid JSON
mock_response = Mock()
mock_response.json.return_value = {
"response": "Invalid JSON response"
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model", max_retries=2)
try:
# This should fail validation and retry
result = client.generate_with_validation(
prompt="Test prompt",
response_type='business_name_extraction',
return_parsed=True
)
# If we get here, it means validation failed and retries were exhausted
print("✓ Validation failure handling test passed")
except ValueError as e:
# Expected behavior - validation failed after retries
assert "Failed to parse JSON response after all retries" in str(e)
print("✓ Validation failure handling test passed")
def test_enhanced_methods():
"""Test the new enhanced methods"""
from app.core.services.ollama_client import OllamaClient
# Mock the API response
mock_response = Mock()
mock_response.json.return_value = {
"response": '{"entities": [{"text": "张三", "type": "人名"}]}'
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
client = OllamaClient("test-model")
# Test generate_with_validation
result = client.generate_with_validation(
prompt="Extract entities",
response_type='entity_extraction',
return_parsed=True
)
assert isinstance(result, dict)
assert 'entities' in result
assert len(result['entities']) == 1
assert result['entities'][0]['text'] == '张三'
print("✓ Enhanced methods tests passed")
def main():
"""Run all tests"""
print("Testing enhanced OllamaClient...")
print("=" * 50)
try:
test_ollama_client_initialization()
test_generate_with_validation()
test_generate_with_schema()
test_backward_compatibility()
test_retry_mechanism()
test_validation_failure()
test_enhanced_methods()
print("\n" + "=" * 50)
print("✓ All enhanced OllamaClient tests passed!")
except Exception as e:
print(f"\n✗ Test failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@ -1 +0,0 @@
# Tests package

1
backend/tests/test.txt Normal file
View File

@ -0,0 +1 @@
关于张三天和北京易见天树有限公司的劳动纠纷

View File

@ -1,129 +0,0 @@
#!/usr/bin/env python3
"""
Test file for address masking functionality
"""
import pytest
import sys
import os
# Add the backend directory to the Python path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.core.document_handlers.ner_processor import NerProcessor
def test_address_masking():
"""Test address masking with the new rules"""
processor = NerProcessor()
# Test cases based on the requirements
test_cases = [
("上海市静安区恒丰路66号白云大厦1607室", "上海市静安区HF路**号BY大厦****室"),
("北京市朝阳区建国路88号SOHO现代城A座1001室", "北京市朝阳区JG路**号SOHO现代城A座****室"),
("广州市天河区珠江新城花城大道123号富力中心B座2001室", "广州市天河区珠江新城HC大道**号FL中心B座****室"),
("深圳市南山区科技园南区深南大道9988号腾讯大厦T1栋15楼", "深圳市南山区科技园南区SN大道**号TX大厦T1栋**楼"),
]
for original_address, expected_masked in test_cases:
masked = processor._mask_address(original_address)
print(f"Original: {original_address}")
print(f"Masked: {masked}")
print(f"Expected: {expected_masked}")
print("-" * 50)
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
def test_address_component_extraction():
"""Test address component extraction"""
processor = NerProcessor()
# Test address component extraction
test_cases = [
("上海市静安区恒丰路66号白云大厦1607室", {
"road_name": "恒丰路",
"house_number": "66",
"building_name": "白云大厦",
"community_name": ""
}),
("北京市朝阳区建国路88号SOHO现代城A座1001室", {
"road_name": "建国路",
"house_number": "88",
"building_name": "SOHO现代城",
"community_name": ""
}),
]
for address, expected_components in test_cases:
components = processor._extract_address_components(address)
print(f"Address: {address}")
print(f"Extracted components: {components}")
print(f"Expected: {expected_components}")
print("-" * 50)
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
def test_regex_fallback():
"""Test regex fallback for address extraction"""
processor = NerProcessor()
# Test regex extraction (fallback method)
test_address = "上海市静安区恒丰路66号白云大厦1607室"
components = processor._extract_address_components_with_regex(test_address)
print(f"Address: {test_address}")
print(f"Regex extracted components: {components}")
# Basic validation
assert "road_name" in components
assert "house_number" in components
assert "building_name" in components
assert "community_name" in components
assert "confidence" in components
def test_json_validation_for_address():
"""Test JSON validation for address extraction responses"""
from app.core.utils.llm_validator import LLMResponseValidator
# Test valid JSON response
valid_response = {
"road_name": "恒丰路",
"house_number": "66",
"building_name": "白云大厦",
"community_name": "",
"confidence": 0.9
}
assert LLMResponseValidator.validate_address_extraction(valid_response) == True
# Test invalid JSON response (missing required field)
invalid_response = {
"road_name": "恒丰路",
"house_number": "66",
"building_name": "白云大厦",
"confidence": 0.9
}
assert LLMResponseValidator.validate_address_extraction(invalid_response) == False
# Test invalid JSON response (wrong type)
invalid_response2 = {
"road_name": 123,
"house_number": "66",
"building_name": "白云大厦",
"community_name": "",
"confidence": 0.9
}
assert LLMResponseValidator.validate_address_extraction(invalid_response2) == False
if __name__ == "__main__":
print("Testing Address Masking Functionality")
print("=" * 50)
test_regex_fallback()
print()
test_json_validation_for_address()
print()
test_address_component_extraction()
print()
test_address_masking()

View File

@ -1,18 +0,0 @@
import pytest
def test_basic_discovery():
"""Basic test to verify pytest discovery is working"""
assert True
def test_import_works():
"""Test that we can import from the app module"""
try:
from app.core.document_handlers.ner_processor import NerProcessor
assert NerProcessor is not None
except ImportError as e:
pytest.fail(f"Failed to import NerProcessor: {e}")
def test_simple_math():
"""Simple math test"""
assert 1 + 1 == 2
assert 2 * 3 == 6

View File

@ -1,169 +0,0 @@
#!/usr/bin/env python3
"""
Test file for ID and social credit code masking functionality
"""
import pytest
import sys
import os
# Add the backend directory to the Python path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.core.document_handlers.ner_processor import NerProcessor
def test_id_number_masking():
"""Test ID number masking with the new rules"""
processor = NerProcessor()
# Test cases based on the requirements
test_cases = [
("310103198802080000", "310103XXXXXXXXXXXX"),
("110101199001011234", "110101XXXXXXXXXXXX"),
("440301199505151234", "440301XXXXXXXXXXXX"),
("320102198712345678", "320102XXXXXXXXXXXX"),
("12345", "12345"), # Edge case: too short
]
for original_id, expected_masked in test_cases:
# Create a mock entity for testing
entity = {'text': original_id, 'type': '身份证号'}
unique_entities = [entity]
linkage = {'entity_groups': []}
# Test the masking through the full pipeline
mapping = processor._generate_masked_mapping(unique_entities, linkage)
masked = mapping.get(original_id, original_id)
print(f"Original ID: {original_id}")
print(f"Masked ID: {masked}")
print(f"Expected: {expected_masked}")
print(f"Match: {masked == expected_masked}")
print("-" * 50)
def test_social_credit_code_masking():
"""Test social credit code masking with the new rules"""
processor = NerProcessor()
# Test cases based on the requirements
test_cases = [
("9133021276453538XT", "913302XXXXXXXXXXXX"),
("91110000100000000X", "9111000XXXXXXXXXXX"),
("914403001922038216", "9144030XXXXXXXXXXX"),
("91310000132209458G", "9131000XXXXXXXXXXX"),
("123456", "123456"), # Edge case: too short
]
for original_code, expected_masked in test_cases:
# Create a mock entity for testing
entity = {'text': original_code, 'type': '社会信用代码'}
unique_entities = [entity]
linkage = {'entity_groups': []}
# Test the masking through the full pipeline
mapping = processor._generate_masked_mapping(unique_entities, linkage)
masked = mapping.get(original_code, original_code)
print(f"Original Code: {original_code}")
print(f"Masked Code: {masked}")
print(f"Expected: {expected_masked}")
print(f"Match: {masked == expected_masked}")
print("-" * 50)
def test_edge_cases():
"""Test edge cases for ID and social credit code masking"""
processor = NerProcessor()
# Test edge cases
edge_cases = [
("", ""), # Empty string
("123", "123"), # Too short for ID
("123456", "123456"), # Too short for social credit code
("123456789012345678901234567890", "123456XXXXXXXXXXXXXXXXXX"), # Very long ID
]
for original, expected in edge_cases:
# Test ID number
entity_id = {'text': original, 'type': '身份证号'}
mapping_id = processor._generate_masked_mapping([entity_id], {'entity_groups': []})
masked_id = mapping_id.get(original, original)
# Test social credit code
entity_code = {'text': original, 'type': '社会信用代码'}
mapping_code = processor._generate_masked_mapping([entity_code], {'entity_groups': []})
masked_code = mapping_code.get(original, original)
print(f"Original: {original}")
print(f"ID Masked: {masked_id}")
print(f"Code Masked: {masked_code}")
print("-" * 30)
def test_mixed_entities():
"""Test masking with mixed entity types"""
processor = NerProcessor()
# Create mixed entities
entities = [
{'text': '310103198802080000', 'type': '身份证号'},
{'text': '9133021276453538XT', 'type': '社会信用代码'},
{'text': '李强', 'type': '人名'},
{'text': '上海盒马网络科技有限公司', 'type': '公司名称'},
]
linkage = {'entity_groups': []}
# Test the masking through the full pipeline
mapping = processor._generate_masked_mapping(entities, linkage)
print("Mixed Entities Test:")
print("=" * 30)
for entity in entities:
original = entity['text']
entity_type = entity['type']
masked = mapping.get(original, original)
print(f"{entity_type}: {original} -> {masked}")
def test_id_masking():
"""Test ID number and social credit code masking"""
from app.core.document_handlers.ner_processor import NerProcessor
processor = NerProcessor()
# Test ID number masking
id_entity = {'text': '310103198802080000', 'type': '身份证号'}
id_mapping = processor._generate_masked_mapping([id_entity], {'entity_groups': []})
masked_id = id_mapping.get('310103198802080000', '')
# Test social credit code masking
code_entity = {'text': '9133021276453538XT', 'type': '社会信用代码'}
code_mapping = processor._generate_masked_mapping([code_entity], {'entity_groups': []})
masked_code = code_mapping.get('9133021276453538XT', '')
# Verify the masking rules
assert masked_id.startswith('310103') # First 6 digits preserved
assert masked_id.endswith('XXXXXXXXXXXX') # Rest masked with X
assert len(masked_id) == 18 # Total length preserved
assert masked_code.startswith('913302') # First 7 digits preserved
assert masked_code.endswith('XXXXXXXXXXXX') # Rest masked with X
assert len(masked_code) == 18 # Total length preserved
print(f"ID masking: 310103198802080000 -> {masked_id}")
print(f"Code masking: 9133021276453538XT -> {masked_code}")
if __name__ == "__main__":
print("Testing ID and Social Credit Code Masking")
print("=" * 50)
test_id_number_masking()
print()
test_social_credit_code_masking()
print()
test_edge_cases()
print()
test_mixed_entities()

View File

@ -4,9 +4,9 @@ from app.core.document_handlers.ner_processor import NerProcessor
def test_generate_masked_mapping(): def test_generate_masked_mapping():
processor = NerProcessor() processor = NerProcessor()
unique_entities = [ unique_entities = [
{'text': '', 'type': '人名'}, {'text': '', 'type': '人名'},
{'text': '', 'type': '人名'}, # Duplicate to test numbering {'text': '', 'type': '人名'},
{'text': '小明', 'type': '人名'}, {'text': '', 'type': '人名'},
{'text': 'Acme Manufacturing Inc.', 'type': '英文公司名', 'industry': 'manufacturing'}, {'text': 'Acme Manufacturing Inc.', 'type': '英文公司名', 'industry': 'manufacturing'},
{'text': 'Google LLC', 'type': '英文公司名'}, {'text': 'Google LLC', 'type': '英文公司名'},
{'text': 'A公司', 'type': '公司名称'}, {'text': 'A公司', 'type': '公司名称'},
@ -32,23 +32,23 @@ def test_generate_masked_mapping():
'group_id': 'g2', 'group_id': 'g2',
'group_type': '人名', 'group_type': '人名',
'entities': [ 'entities': [
{'text': '', 'type': '人名', 'is_primary': True}, {'text': '', 'type': '人名', 'is_primary': True},
{'text': '', 'type': '人名', 'is_primary': False}, {'text': '', 'type': '人名', 'is_primary': False},
] ]
} }
] ]
} }
mapping = processor._generate_masked_mapping(unique_entities, linkage) mapping = processor._generate_masked_mapping(unique_entities, linkage)
# 人名 - Updated for new Chinese name masking rules # 人名
assert mapping['李强'] == '李Q' assert mapping['李雷'].startswith('李某')
assert mapping['王小明'] == '王XM' assert mapping['李明'].startswith('李某')
assert mapping['王强'].startswith('王某')
# 英文公司名 # 英文公司名
assert mapping['Acme Manufacturing Inc.'] == 'MANUFACTURING' assert mapping['Acme Manufacturing Inc.'] == 'MANUFACTURING'
assert mapping['Google LLC'] == 'COMPANY' assert mapping['Google LLC'] == 'COMPANY'
# 公司名同组 - Updated for new company masking rules # 公司名同组
# Note: The exact results may vary due to LLM extraction assert mapping['A公司'] == mapping['B公司']
assert '公司' in mapping['A公司'] or mapping['A公司'] != 'A公司' assert mapping['A公司'].endswith('公司')
assert '公司' in mapping['B公司'] or mapping['B公司'] != 'B公司'
# 英文人名 # 英文人名
assert mapping['John Smith'] == 'J*** S***' assert mapping['John Smith'] == 'J*** S***'
assert mapping['Elizabeth Windsor'] == 'E*** W***' assert mapping['Elizabeth Windsor'] == 'E*** W***'
@ -59,217 +59,4 @@ def test_generate_masked_mapping():
# 身份证号 # 身份证号
assert mapping['310101198802080000'] == 'XXXXXX' assert mapping['310101198802080000'] == 'XXXXXX'
# 社会信用代码 # 社会信用代码
assert mapping['9133021276453538XT'] == 'XXXXXXXX' assert mapping['9133021276453538XT'] == 'XXXXXXXX'
def test_chinese_name_pinyin_masking():
"""Test Chinese name masking with pinyin functionality"""
processor = NerProcessor()
# Test basic Chinese name masking
test_cases = [
("李强", "李Q"),
("张韶涵", "张SH"),
("张若宇", "张RY"),
("白锦程", "白JC"),
("王小明", "王XM"),
("陈志强", "陈ZQ"),
]
surname_counter = {}
for original_name, expected_masked in test_cases:
masked = processor._mask_chinese_name(original_name, surname_counter)
assert masked == expected_masked, f"Expected {expected_masked}, got {masked} for {original_name}"
# Test duplicate handling
duplicate_test_cases = [
("李强", "李Q"),
("李强", "李Q2"), # Should be numbered
("李倩", "李Q3"), # Should be numbered
("张韶涵", "张SH"),
("张韶涵", "张SH2"), # Should be numbered
("张若宇", "张RY"), # Different initials, should not be numbered
]
surname_counter = {} # Reset counter
for original_name, expected_masked in duplicate_test_cases:
masked = processor._mask_chinese_name(original_name, surname_counter)
assert masked == expected_masked, f"Expected {expected_masked}, got {masked} for {original_name}"
# Test edge cases
edge_cases = [
("", ""), # Empty string
("", ""), # Single character
("李强强", "李QQ"), # Multiple characters with same pinyin
]
surname_counter = {} # Reset counter
for original_name, expected_masked in edge_cases:
masked = processor._mask_chinese_name(original_name, surname_counter)
assert masked == expected_masked, f"Expected {expected_masked}, got {masked} for {original_name}"
def test_chinese_name_integration():
"""Test Chinese name masking integrated with the full mapping process"""
processor = NerProcessor()
# Test Chinese names in the full mapping context
unique_entities = [
{'text': '李强', 'type': '人名'},
{'text': '张韶涵', 'type': '人名'},
{'text': '张若宇', 'type': '人名'},
{'text': '白锦程', 'type': '人名'},
{'text': '李强', 'type': '人名'}, # Duplicate
{'text': '张韶涵', 'type': '人名'}, # Duplicate
]
linkage = {
'entity_groups': [
{
'group_id': 'g1',
'group_type': '人名',
'entities': [
{'text': '李强', 'type': '人名', 'is_primary': True},
{'text': '张韶涵', 'type': '人名', 'is_primary': True},
{'text': '张若宇', 'type': '人名', 'is_primary': True},
{'text': '白锦程', 'type': '人名', 'is_primary': True},
]
}
]
}
mapping = processor._generate_masked_mapping(unique_entities, linkage)
# Verify the mapping results
assert mapping['李强'] == '李Q'
assert mapping['张韶涵'] == '张SH'
assert mapping['张若宇'] == '张RY'
assert mapping['白锦程'] == '白JC'
# Check that duplicates are handled correctly
# The second occurrence should be numbered
assert '李Q2' in mapping.values() or '张SH2' in mapping.values()
def test_lawyer_and_judge_names():
"""Test that lawyer and judge names follow the same Chinese name rules"""
processor = NerProcessor()
# Test lawyer and judge names
test_entities = [
{'text': '王律师', 'type': '律师姓名'},
{'text': '李法官', 'type': '审判人员姓名'},
{'text': '张检察官', 'type': '检察官姓名'},
]
linkage = {
'entity_groups': [
{
'group_id': 'g1',
'group_type': '律师姓名',
'entities': [{'text': '王律师', 'type': '律师姓名', 'is_primary': True}]
},
{
'group_id': 'g2',
'group_type': '审判人员姓名',
'entities': [{'text': '李法官', 'type': '审判人员姓名', 'is_primary': True}]
},
{
'group_id': 'g3',
'group_type': '检察官姓名',
'entities': [{'text': '张检察官', 'type': '检察官姓名', 'is_primary': True}]
}
]
}
mapping = processor._generate_masked_mapping(test_entities, linkage)
# These should follow the same Chinese name masking rules
assert mapping['王律师'] == '王L'
assert mapping['李法官'] == '李F'
assert mapping['张检察官'] == '张JC'
def test_company_name_masking():
"""Test company name masking with business name extraction"""
processor = NerProcessor()
# Test basic company name masking
test_cases = [
("上海盒马网络科技有限公司", "上海JO网络科技有限公司"),
("丰田通商(上海)有限公司", "HVVU上海有限公司"),
("雅诗兰黛(上海)商贸有限公司", "AUNF上海商贸有限公司"),
("北京百度网讯科技有限公司", "北京BC网讯科技有限公司"),
("腾讯科技(深圳)有限公司", "TU科技深圳有限公司"),
("阿里巴巴集团控股有限公司", "阿里巴巴集团控股有限公司"), # 商号可能无法正确提取
]
for original_name, expected_masked in test_cases:
masked = processor._mask_company_name(original_name)
print(f"{original_name} -> {masked} (expected: {expected_masked})")
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
def test_business_name_extraction():
"""Test business name extraction from company names"""
processor = NerProcessor()
# Test business name extraction
test_cases = [
("上海盒马网络科技有限公司", "盒马"),
("丰田通商(上海)有限公司", "丰田通商"),
("雅诗兰黛(上海)商贸有限公司", "雅诗兰黛"),
("北京百度网讯科技有限公司", "百度"),
("腾讯科技(深圳)有限公司", "腾讯"),
("律师事务所", "律师事务所"), # Edge case
]
for company_name, expected_business_name in test_cases:
business_name = processor._extract_business_name(company_name)
print(f"Company: {company_name} -> Business Name: {business_name} (expected: {expected_business_name})")
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
def test_json_validation_for_business_name():
"""Test JSON validation for business name extraction responses"""
from app.core.utils.llm_validator import LLMResponseValidator
# Test valid JSON response
valid_response = {
"business_name": "盒马",
"confidence": 0.9
}
assert LLMResponseValidator.validate_business_name_extraction(valid_response) == True
# Test invalid JSON response (missing required field)
invalid_response = {
"confidence": 0.9
}
assert LLMResponseValidator.validate_business_name_extraction(invalid_response) == False
# Test invalid JSON response (wrong type)
invalid_response2 = {
"business_name": 123,
"confidence": 0.9
}
assert LLMResponseValidator.validate_business_name_extraction(invalid_response2) == False
def test_law_firm_masking():
"""Test law firm name masking"""
processor = NerProcessor()
# Test law firm name masking
test_cases = [
("北京大成律师事务所", "北京D律师事务所"),
("上海锦天城律师事务所", "上海JTC律师事务所"),
("广东广信君达律师事务所", "广东GXJD律师事务所"),
]
for original_name, expected_masked in test_cases:
masked = processor._mask_company_name(original_name)
print(f"{original_name} -> {masked} (expected: {expected_masked})")
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification

View File

@ -1,128 +0,0 @@
"""
Tests for the refactored NerProcessor.
"""
import pytest
import sys
import os
# Add the backend directory to the Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from app.core.document_handlers.ner_processor_refactored import NerProcessorRefactored
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker, EnglishNameMasker
from app.core.document_handlers.maskers.id_masker import IDMasker
from app.core.document_handlers.maskers.case_masker import CaseMasker
def test_chinese_name_masker():
"""Test Chinese name masker"""
masker = ChineseNameMasker()
# Test basic masking
result1 = masker.mask("李强")
assert result1 == "李Q"
result2 = masker.mask("张韶涵")
assert result2 == "张SH"
result3 = masker.mask("张若宇")
assert result3 == "张RY"
result4 = masker.mask("白锦程")
assert result4 == "白JC"
# Test duplicate handling
result5 = masker.mask("李强") # Should get a number
assert result5 == "李Q2"
print(f"Chinese name masking tests passed")
def test_english_name_masker():
"""Test English name masker"""
masker = EnglishNameMasker()
result = masker.mask("John Smith")
assert result == "J*** S***"
result2 = masker.mask("Mary Jane Watson")
assert result2 == "M*** J*** W***"
print(f"English name masking tests passed")
def test_id_masker():
"""Test ID masker"""
masker = IDMasker()
# Test ID number
result1 = masker.mask("310103198802080000")
assert result1 == "310103XXXXXXXXXXXX"
assert len(result1) == 18
# Test social credit code
result2 = masker.mask("9133021276453538XT")
assert result2 == "913302XXXXXXXXXXXX"
assert len(result2) == 18
print(f"ID masking tests passed")
def test_case_masker():
"""Test case masker"""
masker = CaseMasker()
result1 = masker.mask("(2022)京 03 民终 3852 号")
assert "***号" in result1
result2 = masker.mask("2020京0105 民初69754 号")
assert "***号" in result2
print(f"Case masking tests passed")
def test_masker_factory():
"""Test masker factory"""
from app.core.document_handlers.masker_factory import MaskerFactory
# Test creating maskers
chinese_masker = MaskerFactory.create_masker('chinese_name')
assert isinstance(chinese_masker, ChineseNameMasker)
english_masker = MaskerFactory.create_masker('english_name')
assert isinstance(english_masker, EnglishNameMasker)
id_masker = MaskerFactory.create_masker('id')
assert isinstance(id_masker, IDMasker)
case_masker = MaskerFactory.create_masker('case')
assert isinstance(case_masker, CaseMasker)
print(f"Masker factory tests passed")
def test_refactored_processor_initialization():
"""Test that the refactored processor can be initialized"""
try:
processor = NerProcessorRefactored()
assert processor is not None
assert hasattr(processor, 'maskers')
assert len(processor.maskers) > 0
print(f"Refactored processor initialization test passed")
except Exception as e:
print(f"Refactored processor initialization failed: {e}")
# This might fail if Ollama is not running, which is expected in test environment
if __name__ == "__main__":
print("Running refactored NerProcessor tests...")
test_chinese_name_masker()
test_english_name_masker()
test_id_masker()
test_case_masker()
test_masker_factory()
test_refactored_processor_initialization()
print("All refactored NerProcessor tests completed!")

View File

@ -1,213 +0,0 @@
"""
Validation script for the refactored NerProcessor.
"""
import sys
import os
# Add the current directory to the Python path
sys.path.insert(0, os.path.dirname(__file__))
def test_imports():
"""Test that all modules can be imported"""
print("Testing imports...")
try:
from app.core.document_handlers.maskers.base_masker import BaseMasker
print("✓ BaseMasker imported successfully")
except Exception as e:
print(f"✗ Failed to import BaseMasker: {e}")
return False
try:
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker, EnglishNameMasker
print("✓ Name maskers imported successfully")
except Exception as e:
print(f"✗ Failed to import name maskers: {e}")
return False
try:
from app.core.document_handlers.maskers.id_masker import IDMasker
print("✓ IDMasker imported successfully")
except Exception as e:
print(f"✗ Failed to import IDMasker: {e}")
return False
try:
from app.core.document_handlers.maskers.case_masker import CaseMasker
print("✓ CaseMasker imported successfully")
except Exception as e:
print(f"✗ Failed to import CaseMasker: {e}")
return False
try:
from app.core.document_handlers.maskers.company_masker import CompanyMasker
print("✓ CompanyMasker imported successfully")
except Exception as e:
print(f"✗ Failed to import CompanyMasker: {e}")
return False
try:
from app.core.document_handlers.maskers.address_masker import AddressMasker
print("✓ AddressMasker imported successfully")
except Exception as e:
print(f"✗ Failed to import AddressMasker: {e}")
return False
try:
from app.core.document_handlers.masker_factory import MaskerFactory
print("✓ MaskerFactory imported successfully")
except Exception as e:
print(f"✗ Failed to import MaskerFactory: {e}")
return False
try:
from app.core.document_handlers.extractors.business_name_extractor import BusinessNameExtractor
print("✓ BusinessNameExtractor imported successfully")
except Exception as e:
print(f"✗ Failed to import BusinessNameExtractor: {e}")
return False
try:
from app.core.document_handlers.extractors.address_extractor import AddressExtractor
print("✓ AddressExtractor imported successfully")
except Exception as e:
print(f"✗ Failed to import AddressExtractor: {e}")
return False
try:
from app.core.document_handlers.ner_processor_refactored import NerProcessorRefactored
print("✓ NerProcessorRefactored imported successfully")
except Exception as e:
print(f"✗ Failed to import NerProcessorRefactored: {e}")
return False
return True
def test_masker_functionality():
"""Test basic masker functionality"""
print("\nTesting masker functionality...")
try:
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker
masker = ChineseNameMasker()
result = masker.mask("李强")
assert result == "李Q", f"Expected '李Q', got '{result}'"
print("✓ ChineseNameMasker works correctly")
except Exception as e:
print(f"✗ ChineseNameMasker test failed: {e}")
return False
try:
from app.core.document_handlers.maskers.name_masker import EnglishNameMasker
masker = EnglishNameMasker()
result = masker.mask("John Smith")
assert result == "J*** S***", f"Expected 'J*** S***', got '{result}'"
print("✓ EnglishNameMasker works correctly")
except Exception as e:
print(f"✗ EnglishNameMasker test failed: {e}")
return False
try:
from app.core.document_handlers.maskers.id_masker import IDMasker
masker = IDMasker()
result = masker.mask("310103198802080000")
assert result == "310103XXXXXXXXXXXX", f"Expected '310103XXXXXXXXXXXX', got '{result}'"
print("✓ IDMasker works correctly")
except Exception as e:
print(f"✗ IDMasker test failed: {e}")
return False
try:
from app.core.document_handlers.maskers.case_masker import CaseMasker
masker = CaseMasker()
result = masker.mask("(2022)京 03 民终 3852 号")
assert "***号" in result, f"Expected '***号' in result, got '{result}'"
print("✓ CaseMasker works correctly")
except Exception as e:
print(f"✗ CaseMasker test failed: {e}")
return False
return True
def test_factory():
"""Test masker factory"""
print("\nTesting masker factory...")
try:
from app.core.document_handlers.masker_factory import MaskerFactory
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker
masker = MaskerFactory.create_masker('chinese_name')
assert isinstance(masker, ChineseNameMasker), f"Expected ChineseNameMasker, got {type(masker)}"
print("✓ MaskerFactory works correctly")
except Exception as e:
print(f"✗ MaskerFactory test failed: {e}")
return False
return True
def test_processor_initialization():
"""Test processor initialization"""
print("\nTesting processor initialization...")
try:
from app.core.document_handlers.ner_processor_refactored import NerProcessorRefactored
processor = NerProcessorRefactored()
assert processor is not None, "Processor should not be None"
assert hasattr(processor, 'maskers'), "Processor should have maskers attribute"
assert len(processor.maskers) > 0, "Processor should have at least one masker"
print("✓ NerProcessorRefactored initializes correctly")
except Exception as e:
print(f"✗ NerProcessorRefactored initialization failed: {e}")
# This might fail if Ollama is not running, which is expected
print(" (This is expected if Ollama is not running)")
return True # Don't fail the validation for this
return True
def main():
"""Main validation function"""
print("Validating refactored NerProcessor...")
print("=" * 50)
success = True
# Test imports
if not test_imports():
success = False
# Test functionality
if not test_masker_functionality():
success = False
# Test factory
if not test_factory():
success = False
# Test processor initialization
if not test_processor_initialization():
success = False
print("\n" + "=" * 50)
if success:
print("✓ All validation tests passed!")
print("The refactored code is working correctly.")
else:
print("✗ Some validation tests failed.")
print("Please check the errors above.")
return success
if __name__ == "__main__":
main()

View File

@ -34,6 +34,7 @@ services:
- "8000:8000" - "8000:8000"
volumes: volumes:
- ./backend/storage:/app/storage - ./backend/storage:/app/storage
- ./backend/legal_doc_masker.db:/app/legal_doc_masker.db
env_file: env_file:
- ./backend/.env - ./backend/.env
environment: environment:
@ -54,6 +55,7 @@ services:
command: celery -A app.services.file_service worker --loglevel=info command: celery -A app.services.file_service worker --loglevel=info
volumes: volumes:
- ./backend/storage:/app/storage - ./backend/storage:/app/storage
- ./backend/legal_doc_masker.db:/app/legal_doc_masker.db
env_file: env_file:
- ./backend/.env - ./backend/.env
environment: environment:

67
download_models.py Normal file
View File

@ -0,0 +1,67 @@
import json
import shutil
import os
import requests
from modelscope import snapshot_download
def download_json(url):
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
return response.json()
def download_and_modify_json(url, local_filename, modifications):
if os.path.exists(local_filename):
data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.2.0':
data = download_json(url)
else:
data = download_json(url)
# 修改内容
for key, value in modifications.items():
data[key] = value
# 保存修改后的内容
with open(local_filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
if __name__ == '__main__':
mineru_patterns = [
# "models/Layout/LayoutLMv3/*",
"models/Layout/YOLO/*",
"models/MFD/YOLO/*",
"models/MFR/unimernet_hf_small_2503/*",
"models/OCR/paddleocr_torch/*",
# "models/TabRec/TableMaster/*",
# "models/TabRec/StructEqTable/*",
]
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
model_dir = model_dir + '/models'
print(f'model_dir is: {model_dir}')
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
# paddleocr_model_dir = model_dir + '/OCR/paddleocr'
# user_paddleocr_dir = os.path.expanduser('~/.paddleocr')
# if os.path.exists(user_paddleocr_dir):
# shutil.rmtree(user_paddleocr_dir)
# shutil.copytree(paddleocr_model_dir, user_paddleocr_dir)
json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json'
config_file_name = 'magic-pdf.json'
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, config_file_name)
json_mods = {
'models-dir': model_dir,
'layoutreader-model-dir': layoutreader_model_dir,
}
download_and_modify_json(json_url, config_file, json_mods)
print(f'The configuration file has been configured successfully, the path is: {config_file}')

View File

@ -16,9 +16,8 @@ import {
DialogContent, DialogContent,
DialogActions, DialogActions,
Typography, Typography,
Tooltip,
} from '@mui/material'; } from '@mui/material';
import { Download as DownloadIcon, Delete as DeleteIcon, Error as ErrorIcon } from '@mui/icons-material'; import { Download as DownloadIcon, Delete as DeleteIcon } from '@mui/icons-material';
import { File, FileStatus } from '../types/file'; import { File, FileStatus } from '../types/file';
import { api } from '../services/api'; import { api } from '../services/api';
@ -173,50 +172,6 @@ const FileList: React.FC<FileListProps> = ({ files, onFileStatusChange }) => {
color={getStatusColor(file.status) as any} color={getStatusColor(file.status) as any}
size="small" size="small"
/> />
{file.status === FileStatus.FAILED && file.error_message && (
<div style={{ marginTop: '4px' }}>
<Tooltip
title={file.error_message}
placement="top-start"
arrow
sx={{ maxWidth: '400px' }}
>
<div
style={{
display: 'flex',
alignItems: 'flex-start',
gap: '4px',
padding: '4px 8px',
backgroundColor: '#ffebee',
borderRadius: '4px',
border: '1px solid #ffcdd2'
}}
>
<ErrorIcon
color="error"
sx={{ fontSize: '16px', marginTop: '1px', flexShrink: 0 }}
/>
<Typography
variant="caption"
color="error"
sx={{
display: 'block',
wordBreak: 'break-word',
maxWidth: '300px',
lineHeight: '1.2',
cursor: 'help',
fontWeight: 500
}}
>
{file.error_message.length > 50
? `${file.error_message.substring(0, 50)}...`
: file.error_message
}
</Typography>
</div>
</Tooltip>
</div>
)}
</TableCell> </TableCell>
<TableCell> <TableCell>
{new Date(file.created_at).toLocaleString()} {new Date(file.created_at).toLocaleString()}