Compare commits
No commits in common. "main" and "feature-ner-keyword-detect" have entirely different histories.
main
...
feature-ne
|
|
@ -86,7 +86,7 @@ docker-compose build frontend
|
||||||
docker-compose build mineru-api
|
docker-compose build mineru-api
|
||||||
|
|
||||||
# Build multiple specific services
|
# Build multiple specific services
|
||||||
docker-compose build backend-api frontend celery-worker
|
docker-compose build backend-api frontend
|
||||||
```
|
```
|
||||||
|
|
||||||
### Building and restarting specific services
|
### Building and restarting specific services
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,9 @@ TARGET_DIRECTORY_PATH=/Users/tigeren/Dev/digisky/legal-doc-masker/data/doc_dest
|
||||||
INTERMEDIATE_DIR_PATH=/Users/tigeren/Dev/digisky/legal-doc-masker/data/doc_intermediate
|
INTERMEDIATE_DIR_PATH=/Users/tigeren/Dev/digisky/legal-doc-masker/data/doc_intermediate
|
||||||
|
|
||||||
# Ollama API Configuration
|
# Ollama API Configuration
|
||||||
# 3060 GPU
|
OLLAMA_API_URL=http://192.168.2.245:11434
|
||||||
# OLLAMA_API_URL=http://192.168.2.245:11434
|
|
||||||
# Mac Mini M4
|
|
||||||
OLLAMA_API_URL=http://192.168.2.224:11434
|
|
||||||
|
|
||||||
# OLLAMA_API_KEY=your_api_key_here
|
# OLLAMA_API_KEY=your_api_key_here
|
||||||
# OLLAMA_MODEL=qwen3:8b
|
OLLAMA_MODEL=qwen3:8b
|
||||||
OLLAMA_MODEL=phi4:14b
|
|
||||||
|
|
||||||
# Application Settings
|
# Application Settings
|
||||||
MONITOR_INTERVAL=5
|
MONITOR_INTERVAL=5
|
||||||
|
|
|
||||||
|
|
@ -7,31 +7,20 @@ RUN apt-get update && apt-get install -y \
|
||||||
build-essential \
|
build-essential \
|
||||||
libreoffice \
|
libreoffice \
|
||||||
wget \
|
wget \
|
||||||
git \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
|
||||||
# Copy requirements first to leverage Docker cache
|
# Copy requirements first to leverage Docker cache
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
|
# RUN pip install huggingface_hub
|
||||||
|
# RUN wget https://github.com/opendatalab/MinerU/raw/master/scripts/download_models_hf.py -O download_models_hf.py
|
||||||
|
# RUN wget https://raw.githubusercontent.com/opendatalab/MinerU/refs/heads/release-1.3.1/scripts/download_models_hf.py -O download_models_hf.py
|
||||||
|
|
||||||
# Upgrade pip and install core dependencies
|
# RUN python download_models_hf.py
|
||||||
RUN pip install --upgrade pip setuptools wheel
|
|
||||||
|
|
||||||
# Install PyTorch CPU version first (for better caching and smaller size)
|
|
||||||
RUN pip install --no-cache-dir torch==2.7.0 -f https://download.pytorch.org/whl/torch_stable.html
|
|
||||||
|
|
||||||
# Install the rest of the requirements
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
# RUN pip install -U magic-pdf[full]
|
||||||
# Pre-download NER model during build (larger image but faster startup)
|
|
||||||
# RUN python -c "
|
|
||||||
# from transformers import AutoTokenizer, AutoModelForTokenClassification
|
|
||||||
# model_name = 'uer/roberta-base-finetuned-cluener2020-chinese'
|
|
||||||
# print('Downloading NER model...')
|
|
||||||
# AutoTokenizer.from_pretrained(model_name)
|
|
||||||
# AutoModelForTokenClassification.from_pretrained(model_name)
|
|
||||||
# print('NER model downloaded successfully')
|
|
||||||
# "
|
|
||||||
|
|
||||||
|
|
||||||
# Copy the rest of the application
|
# Copy the rest of the application
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
# App package
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
# Core package
|
|
||||||
|
|
@ -42,10 +42,6 @@ class Settings(BaseSettings):
|
||||||
MINERU_FORMULA_ENABLE: bool = True # Enable formula parsing
|
MINERU_FORMULA_ENABLE: bool = True # Enable formula parsing
|
||||||
MINERU_TABLE_ENABLE: bool = True # Enable table parsing
|
MINERU_TABLE_ENABLE: bool = True # Enable table parsing
|
||||||
|
|
||||||
# MagicDoc API settings
|
|
||||||
# MAGICDOC_API_URL: str = "http://magicdoc-api:8000"
|
|
||||||
# MAGICDOC_TIMEOUT: int = 300 # 5 minutes timeout
|
|
||||||
|
|
||||||
# Logging settings
|
# Logging settings
|
||||||
LOG_LEVEL: str = "INFO"
|
LOG_LEVEL: str = "INFO"
|
||||||
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
# Document handlers package
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
||||||
from .document_processor import DocumentProcessor
|
from .document_processor import DocumentProcessor
|
||||||
from .processors import (
|
from .processors import (
|
||||||
TxtDocumentProcessor,
|
TxtDocumentProcessor,
|
||||||
DocxDocumentProcessor,
|
# DocxDocumentProcessor,
|
||||||
PdfDocumentProcessor,
|
PdfDocumentProcessor,
|
||||||
MarkdownDocumentProcessor
|
MarkdownDocumentProcessor
|
||||||
)
|
)
|
||||||
|
|
@ -15,8 +15,8 @@ class DocumentProcessorFactory:
|
||||||
|
|
||||||
processors = {
|
processors = {
|
||||||
'.txt': TxtDocumentProcessor,
|
'.txt': TxtDocumentProcessor,
|
||||||
'.docx': DocxDocumentProcessor,
|
# '.docx': DocxDocumentProcessor,
|
||||||
'.doc': DocxDocumentProcessor,
|
# '.doc': DocxDocumentProcessor,
|
||||||
'.pdf': PdfDocumentProcessor,
|
'.pdf': PdfDocumentProcessor,
|
||||||
'.md': MarkdownDocumentProcessor,
|
'.md': MarkdownDocumentProcessor,
|
||||||
'.markdown': MarkdownDocumentProcessor
|
'.markdown': MarkdownDocumentProcessor
|
||||||
|
|
|
||||||
|
|
@ -40,35 +40,16 @@ class DocumentProcessor(ABC):
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def _apply_mapping_with_alignment(self, text: str, mapping: Dict[str, str]) -> str:
|
|
||||||
"""
|
|
||||||
Apply the mapping to replace sensitive information using character-by-character alignment.
|
|
||||||
|
|
||||||
This method uses the new alignment-based masking to handle spacing issues
|
|
||||||
between NER results and original document text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Original document text
|
|
||||||
mapping: Dictionary mapping original entity text to masked text
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Masked document text
|
|
||||||
"""
|
|
||||||
logger.info(f"Applying entity mapping with alignment to text of length {len(text)}")
|
|
||||||
logger.debug(f"Entity mapping: {mapping}")
|
|
||||||
|
|
||||||
# Use the new alignment-based masking method
|
|
||||||
masked_text = self.ner_processor.apply_entity_masking_with_alignment(text, mapping)
|
|
||||||
|
|
||||||
logger.info("Successfully applied entity masking with alignment")
|
|
||||||
return masked_text
|
|
||||||
|
|
||||||
def _apply_mapping(self, text: str, mapping: Dict[str, str]) -> str:
|
def _apply_mapping(self, text: str, mapping: Dict[str, str]) -> str:
|
||||||
"""
|
"""Apply the mapping to replace sensitive information"""
|
||||||
Legacy method for simple string replacement.
|
masked_text = text
|
||||||
Now delegates to the new alignment-based method.
|
for original, masked in mapping.items():
|
||||||
"""
|
if isinstance(masked, dict):
|
||||||
return self._apply_mapping_with_alignment(text, mapping)
|
masked = next(iter(masked.values()), "某")
|
||||||
|
elif not isinstance(masked, str):
|
||||||
|
masked = str(masked) if masked is not None else "某"
|
||||||
|
masked_text = masked_text.replace(original, masked)
|
||||||
|
return masked_text
|
||||||
|
|
||||||
def process_content(self, content: str) -> str:
|
def process_content(self, content: str) -> str:
|
||||||
"""Process document content by masking sensitive information"""
|
"""Process document content by masking sensitive information"""
|
||||||
|
|
@ -78,11 +59,9 @@ class DocumentProcessor(ABC):
|
||||||
logger.info(f"Split content into {len(chunks)} chunks")
|
logger.info(f"Split content into {len(chunks)} chunks")
|
||||||
|
|
||||||
final_mapping = self.ner_processor.process(chunks)
|
final_mapping = self.ner_processor.process(chunks)
|
||||||
logger.info(f"Generated entity mapping with {len(final_mapping)} entities")
|
|
||||||
|
|
||||||
# Use the new alignment-based masking
|
masked_content = self._apply_mapping(content, final_mapping)
|
||||||
masked_content = self._apply_mapping_with_alignment(content, final_mapping)
|
logger.info("Successfully masked content")
|
||||||
logger.info("Successfully masked content using character alignment")
|
|
||||||
|
|
||||||
return masked_content
|
return masked_content
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
"""
|
|
||||||
Extractors package for entity component extraction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .base_extractor import BaseExtractor
|
|
||||||
from .business_name_extractor import BusinessNameExtractor
|
|
||||||
from .address_extractor import AddressExtractor
|
|
||||||
from .ner_extractor import NERExtractor
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'BaseExtractor',
|
|
||||||
'BusinessNameExtractor',
|
|
||||||
'AddressExtractor',
|
|
||||||
'NERExtractor'
|
|
||||||
]
|
|
||||||
|
|
@ -1,168 +0,0 @@
|
||||||
"""
|
|
||||||
Address extractor for address components.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
from ...services.ollama_client import OllamaClient
|
|
||||||
from ...utils.json_extractor import LLMJsonExtractor
|
|
||||||
from ...utils.llm_validator import LLMResponseValidator
|
|
||||||
from .base_extractor import BaseExtractor
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AddressExtractor(BaseExtractor):
|
|
||||||
"""Extractor for address components"""
|
|
||||||
|
|
||||||
def __init__(self, ollama_client: OllamaClient):
|
|
||||||
self.ollama_client = ollama_client
|
|
||||||
self._confidence = 0.5 # Default confidence for regex fallback
|
|
||||||
|
|
||||||
def extract(self, address: str) -> Optional[Dict[str, str]]:
|
|
||||||
"""
|
|
||||||
Extract address components from address.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
address: The address to extract from
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with address components and confidence, or None if extraction fails
|
|
||||||
"""
|
|
||||||
if not address:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Try LLM extraction first
|
|
||||||
try:
|
|
||||||
result = self._extract_with_llm(address)
|
|
||||||
if result:
|
|
||||||
self._confidence = result.get('confidence', 0.9)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"LLM extraction failed for {address}: {e}")
|
|
||||||
|
|
||||||
# Fallback to regex extraction
|
|
||||||
result = self._extract_with_regex(address)
|
|
||||||
self._confidence = 0.5 # Lower confidence for regex
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _extract_with_llm(self, address: str) -> Optional[Dict[str, str]]:
|
|
||||||
"""Extract address components using LLM"""
|
|
||||||
prompt = f"""
|
|
||||||
你是一个专业的地址分析助手。请从以下地址中提取需要脱敏的组件,并严格按照JSON格式返回结果。
|
|
||||||
|
|
||||||
地址:{address}
|
|
||||||
|
|
||||||
脱敏规则:
|
|
||||||
1. 保留区级以上地址(省、市、区、县等)
|
|
||||||
2. 路名(路名)需要脱敏:以大写首字母替代
|
|
||||||
3. 门牌号(门牌数字)需要脱敏:以****代替
|
|
||||||
4. 大厦名、小区名需要脱敏:以大写首字母替代
|
|
||||||
|
|
||||||
示例:
|
|
||||||
- 上海市静安区恒丰路66号白云大厦1607室
|
|
||||||
- 路名:恒丰路
|
|
||||||
- 门牌号:66
|
|
||||||
- 大厦名:白云大厦
|
|
||||||
- 小区名:(空)
|
|
||||||
|
|
||||||
- 北京市朝阳区建国路88号SOHO现代城A座1001室
|
|
||||||
- 路名:建国路
|
|
||||||
- 门牌号:88
|
|
||||||
- 大厦名:SOHO现代城
|
|
||||||
- 小区名:(空)
|
|
||||||
|
|
||||||
- 广州市天河区珠江新城花城大道123号富力中心B座2001室
|
|
||||||
- 路名:花城大道
|
|
||||||
- 门牌号:123
|
|
||||||
- 大厦名:富力中心
|
|
||||||
- 小区名:(空)
|
|
||||||
|
|
||||||
请严格按照以下JSON格式输出,不要包含任何其他文字:
|
|
||||||
|
|
||||||
{{
|
|
||||||
"road_name": "提取的路名",
|
|
||||||
"house_number": "提取的门牌号",
|
|
||||||
"building_name": "提取的大厦名",
|
|
||||||
"community_name": "提取的小区名(如果没有则为空字符串)",
|
|
||||||
"confidence": 0.9
|
|
||||||
}}
|
|
||||||
|
|
||||||
注意:
|
|
||||||
- road_name字段必须包含路名(如:恒丰路、建国路等)
|
|
||||||
- house_number字段必须包含门牌号(如:66、88等)
|
|
||||||
- building_name字段必须包含大厦名(如:白云大厦、SOHO现代城等)
|
|
||||||
- community_name字段包含小区名,如果没有则为空字符串
|
|
||||||
- confidence字段是0-1之间的数字,表示提取的置信度
|
|
||||||
- 必须严格按照JSON格式,不要添加任何解释或额外文字
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use the new enhanced generate method with validation
|
|
||||||
parsed_response = self.ollama_client.generate_with_validation(
|
|
||||||
prompt=prompt,
|
|
||||||
response_type='address_extraction',
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if parsed_response:
|
|
||||||
logger.info(f"Successfully extracted address components: {parsed_response}")
|
|
||||||
return parsed_response
|
|
||||||
else:
|
|
||||||
logger.warning(f"Failed to extract address components for: {address}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"LLM extraction failed: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_with_regex(self, address: str) -> Optional[Dict[str, str]]:
|
|
||||||
"""Extract address components using regex patterns"""
|
|
||||||
# Road name pattern: usually ends with "路", "街", "大道", etc.
|
|
||||||
road_pattern = r'([^省市区县]+[路街大道巷弄])'
|
|
||||||
|
|
||||||
# House number pattern: digits + 号
|
|
||||||
house_number_pattern = r'(\d+)号'
|
|
||||||
|
|
||||||
# Building name pattern: usually contains "大厦", "中心", "广场", etc.
|
|
||||||
building_pattern = r'([^号室]+(?:大厦|中心|广场|城|楼|座))'
|
|
||||||
|
|
||||||
# Community name pattern: usually contains "小区", "花园", "苑", etc.
|
|
||||||
community_pattern = r'([^号室]+(?:小区|花园|苑|园|庭))'
|
|
||||||
|
|
||||||
road_name = ""
|
|
||||||
house_number = ""
|
|
||||||
building_name = ""
|
|
||||||
community_name = ""
|
|
||||||
|
|
||||||
# Extract road name
|
|
||||||
road_match = re.search(road_pattern, address)
|
|
||||||
if road_match:
|
|
||||||
road_name = road_match.group(1).strip()
|
|
||||||
|
|
||||||
# Extract house number
|
|
||||||
house_match = re.search(house_number_pattern, address)
|
|
||||||
if house_match:
|
|
||||||
house_number = house_match.group(1)
|
|
||||||
|
|
||||||
# Extract building name
|
|
||||||
building_match = re.search(building_pattern, address)
|
|
||||||
if building_match:
|
|
||||||
building_name = building_match.group(1).strip()
|
|
||||||
|
|
||||||
# Extract community name
|
|
||||||
community_match = re.search(community_pattern, address)
|
|
||||||
if community_match:
|
|
||||||
community_name = community_match.group(1).strip()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"road_name": road_name,
|
|
||||||
"house_number": house_number,
|
|
||||||
"building_name": building_name,
|
|
||||||
"community_name": community_name,
|
|
||||||
"confidence": 0.5 # Lower confidence for regex fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_confidence(self) -> float:
|
|
||||||
"""Return confidence level of extraction"""
|
|
||||||
return self._confidence
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
"""
|
|
||||||
Abstract base class for all extractors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class BaseExtractor(ABC):
|
|
||||||
"""Abstract base class for all extractors"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def extract(self, text: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""Extract components from text"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_confidence(self) -> float:
|
|
||||||
"""Return confidence level of extraction"""
|
|
||||||
pass
|
|
||||||
|
|
@ -1,192 +0,0 @@
|
||||||
"""
|
|
||||||
Business name extractor for company names.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
from ...services.ollama_client import OllamaClient
|
|
||||||
from ...utils.json_extractor import LLMJsonExtractor
|
|
||||||
from ...utils.llm_validator import LLMResponseValidator
|
|
||||||
from .base_extractor import BaseExtractor
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BusinessNameExtractor(BaseExtractor):
|
|
||||||
"""Extractor for business names from company names"""
|
|
||||||
|
|
||||||
def __init__(self, ollama_client: OllamaClient):
|
|
||||||
self.ollama_client = ollama_client
|
|
||||||
self._confidence = 0.5 # Default confidence for regex fallback
|
|
||||||
|
|
||||||
def extract(self, company_name: str) -> Optional[Dict[str, str]]:
|
|
||||||
"""
|
|
||||||
Extract business name from company name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
company_name: The company name to extract from
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with business name and confidence, or None if extraction fails
|
|
||||||
"""
|
|
||||||
if not company_name:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Try LLM extraction first
|
|
||||||
try:
|
|
||||||
result = self._extract_with_llm(company_name)
|
|
||||||
if result:
|
|
||||||
self._confidence = result.get('confidence', 0.9)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"LLM extraction failed for {company_name}: {e}")
|
|
||||||
|
|
||||||
# Fallback to regex extraction
|
|
||||||
result = self._extract_with_regex(company_name)
|
|
||||||
self._confidence = 0.5 # Lower confidence for regex
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _extract_with_llm(self, company_name: str) -> Optional[Dict[str, str]]:
|
|
||||||
"""Extract business name using LLM"""
|
|
||||||
prompt = f"""
|
|
||||||
你是一个专业的公司名称分析助手。请从以下公司名称中提取商号(企业字号),并严格按照JSON格式返回结果。
|
|
||||||
|
|
||||||
公司名称:{company_name}
|
|
||||||
|
|
||||||
商号提取规则:
|
|
||||||
1. 公司名通常为:地域+商号+业务/行业+组织类型
|
|
||||||
2. 也有:商号+(地域)+业务/行业+组织类型
|
|
||||||
3. 商号是企业名称中最具识别性的部分,通常是2-4个汉字
|
|
||||||
4. 不要包含地域、行业、组织类型等信息
|
|
||||||
5. 律师事务所的商号通常是地域后的部分
|
|
||||||
|
|
||||||
示例:
|
|
||||||
- 上海盒马网络科技有限公司 -> 盒马
|
|
||||||
- 丰田通商(上海)有限公司 -> 丰田通商
|
|
||||||
- 雅诗兰黛(上海)商贸有限公司 -> 雅诗兰黛
|
|
||||||
- 北京百度网讯科技有限公司 -> 百度
|
|
||||||
- 腾讯科技(深圳)有限公司 -> 腾讯
|
|
||||||
- 北京大成律师事务所 -> 大成
|
|
||||||
|
|
||||||
请严格按照以下JSON格式输出,不要包含任何其他文字:
|
|
||||||
|
|
||||||
{{
|
|
||||||
"business_name": "提取的商号",
|
|
||||||
"confidence": 0.9
|
|
||||||
}}
|
|
||||||
|
|
||||||
注意:
|
|
||||||
- business_name字段必须包含提取的商号
|
|
||||||
- confidence字段是0-1之间的数字,表示提取的置信度
|
|
||||||
- 必须严格按照JSON格式,不要添加任何解释或额外文字
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use the new enhanced generate method with validation
|
|
||||||
parsed_response = self.ollama_client.generate_with_validation(
|
|
||||||
prompt=prompt,
|
|
||||||
response_type='business_name_extraction',
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if parsed_response:
|
|
||||||
business_name = parsed_response.get('business_name', '')
|
|
||||||
# Clean business name, keep only Chinese characters
|
|
||||||
business_name = re.sub(r'[^\u4e00-\u9fff]', '', business_name)
|
|
||||||
logger.info(f"Successfully extracted business name: {business_name}")
|
|
||||||
return {
|
|
||||||
'business_name': business_name,
|
|
||||||
'confidence': parsed_response.get('confidence', 0.9)
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
logger.warning(f"Failed to extract business name for: {company_name}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"LLM extraction failed: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_with_regex(self, company_name: str) -> Optional[Dict[str, str]]:
|
|
||||||
"""Extract business name using regex patterns"""
|
|
||||||
# Handle law firms specially
|
|
||||||
if '律师事务所' in company_name:
|
|
||||||
return self._extract_law_firm_business_name(company_name)
|
|
||||||
|
|
||||||
# Common region prefixes
|
|
||||||
region_prefixes = [
|
|
||||||
'北京', '上海', '广州', '深圳', '杭州', '南京', '苏州', '成都', '武汉', '西安',
|
|
||||||
'天津', '重庆', '青岛', '大连', '宁波', '厦门', '无锡', '长沙', '郑州', '济南',
|
|
||||||
'哈尔滨', '沈阳', '长春', '石家庄', '太原', '呼和浩特', '合肥', '福州', '南昌',
|
|
||||||
'南宁', '海口', '贵阳', '昆明', '兰州', '西宁', '银川', '乌鲁木齐', '拉萨',
|
|
||||||
'香港', '澳门', '台湾'
|
|
||||||
]
|
|
||||||
|
|
||||||
# Common organization type suffixes
|
|
||||||
org_suffixes = [
|
|
||||||
'有限公司', '股份有限公司', '有限责任公司', '股份公司', '集团公司', '集团',
|
|
||||||
'科技公司', '网络公司', '信息技术公司', '软件公司', '互联网公司',
|
|
||||||
'贸易公司', '商贸公司', '进出口公司', '物流公司', '运输公司',
|
|
||||||
'房地产公司', '置业公司', '投资公司', '金融公司', '银行',
|
|
||||||
'保险公司', '证券公司', '基金公司', '信托公司', '租赁公司',
|
|
||||||
'咨询公司', '服务公司', '管理公司', '广告公司', '传媒公司',
|
|
||||||
'教育公司', '培训公司', '医疗公司', '医药公司', '生物公司',
|
|
||||||
'制造公司', '工业公司', '化工公司', '能源公司', '电力公司',
|
|
||||||
'建筑公司', '工程公司', '建设公司', '开发公司', '设计公司',
|
|
||||||
'销售公司', '营销公司', '代理公司', '经销商', '零售商',
|
|
||||||
'连锁公司', '超市', '商场', '百货', '专卖店', '便利店'
|
|
||||||
]
|
|
||||||
|
|
||||||
name = company_name
|
|
||||||
|
|
||||||
# Remove region prefix
|
|
||||||
for region in region_prefixes:
|
|
||||||
if name.startswith(region):
|
|
||||||
name = name[len(region):].strip()
|
|
||||||
break
|
|
||||||
|
|
||||||
# Remove region information in parentheses
|
|
||||||
name = re.sub(r'[((].*?[))]', '', name).strip()
|
|
||||||
|
|
||||||
# Remove organization type suffix
|
|
||||||
for suffix in org_suffixes:
|
|
||||||
if name.endswith(suffix):
|
|
||||||
name = name[:-len(suffix)].strip()
|
|
||||||
break
|
|
||||||
|
|
||||||
# If remaining part is too long, try to extract first 2-4 characters
|
|
||||||
if len(name) > 4:
|
|
||||||
# Try to find a good break point
|
|
||||||
for i in range(2, min(5, len(name))):
|
|
||||||
if name[i] in ['网', '科', '技', '信', '息', '软', '件', '互', '联', '网', '电', '子', '商', '务']:
|
|
||||||
name = name[:i]
|
|
||||||
break
|
|
||||||
|
|
||||||
return {
|
|
||||||
'business_name': name if name else company_name[:2],
|
|
||||||
'confidence': 0.5
|
|
||||||
}
|
|
||||||
|
|
||||||
def _extract_law_firm_business_name(self, law_firm_name: str) -> Optional[Dict[str, str]]:
|
|
||||||
"""Extract business name from law firm names"""
|
|
||||||
# Remove "律师事务所" suffix
|
|
||||||
name = law_firm_name.replace('律师事务所', '').replace('分所', '').strip()
|
|
||||||
|
|
||||||
# Handle region information in parentheses
|
|
||||||
name = re.sub(r'[((].*?[))]', '', name).strip()
|
|
||||||
|
|
||||||
# Common region prefixes
|
|
||||||
region_prefixes = ['北京', '上海', '广州', '深圳', '杭州', '南京', '苏州', '成都', '武汉', '西安']
|
|
||||||
|
|
||||||
for region in region_prefixes:
|
|
||||||
if name.startswith(region):
|
|
||||||
name = name[len(region):].strip()
|
|
||||||
break
|
|
||||||
|
|
||||||
return {
|
|
||||||
'business_name': name,
|
|
||||||
'confidence': 0.5
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_confidence(self) -> float:
|
|
||||||
"""Return confidence level of extraction"""
|
|
||||||
return self._confidence
|
|
||||||
|
|
@ -1,469 +0,0 @@
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from typing import Dict, List, Any, Optional
|
|
||||||
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
|
|
||||||
from .base_extractor import BaseExtractor
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class NERExtractor(BaseExtractor):
|
|
||||||
"""
|
|
||||||
Named Entity Recognition extractor using Chinese NER model.
|
|
||||||
Uses the uer/roberta-base-finetuned-cluener2020-chinese model for Chinese NER.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.model_checkpoint = "uer/roberta-base-finetuned-cluener2020-chinese"
|
|
||||||
self.tokenizer = None
|
|
||||||
self.model = None
|
|
||||||
self.ner_pipeline = None
|
|
||||||
self._model_initialized = False
|
|
||||||
self.confidence_threshold = 0.95
|
|
||||||
|
|
||||||
# Map CLUENER model labels to our desired categories
|
|
||||||
self.label_map = {
|
|
||||||
'company': '公司名称',
|
|
||||||
'organization': '组织机构名',
|
|
||||||
'name': '人名',
|
|
||||||
'address': '地址'
|
|
||||||
}
|
|
||||||
|
|
||||||
# Don't initialize the model here - use lazy loading
|
|
||||||
|
|
||||||
def _initialize_model(self):
|
|
||||||
"""Initialize the NER model and pipeline"""
|
|
||||||
try:
|
|
||||||
logger.info(f"Loading NER model: {self.model_checkpoint}")
|
|
||||||
|
|
||||||
# Load the tokenizer and model
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_checkpoint)
|
|
||||||
self.model = AutoModelForTokenClassification.from_pretrained(self.model_checkpoint)
|
|
||||||
|
|
||||||
# Create the NER pipeline with proper configuration
|
|
||||||
self.ner_pipeline = pipeline(
|
|
||||||
"ner",
|
|
||||||
model=self.model,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
aggregation_strategy="simple"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure the tokenizer to handle max length
|
|
||||||
if hasattr(self.tokenizer, 'model_max_length'):
|
|
||||||
self.tokenizer.model_max_length = 512
|
|
||||||
|
|
||||||
self._model_initialized = True
|
|
||||||
logger.info("NER model loaded successfully")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load NER model: {str(e)}")
|
|
||||||
raise Exception(f"NER model initialization failed: {str(e)}")
|
|
||||||
|
|
||||||
def _split_text_by_sentences(self, text: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
Split text into sentences using Chinese sentence boundaries
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to split
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of sentences
|
|
||||||
"""
|
|
||||||
# Chinese sentence endings: 。!?;\n
|
|
||||||
# Also consider English sentence endings for mixed text
|
|
||||||
sentence_pattern = r'[。!?;\n]+|[.!?;]+'
|
|
||||||
sentences = re.split(sentence_pattern, text)
|
|
||||||
|
|
||||||
# Clean up sentences and filter out empty ones
|
|
||||||
cleaned_sentences = []
|
|
||||||
for sentence in sentences:
|
|
||||||
sentence = sentence.strip()
|
|
||||||
if sentence:
|
|
||||||
cleaned_sentences.append(sentence)
|
|
||||||
|
|
||||||
return cleaned_sentences
|
|
||||||
|
|
||||||
def _is_entity_boundary_safe(self, text: str, position: int) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a position is safe for splitting (won't break entities)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to check
|
|
||||||
position: Position to check for safety
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if safe to split at this position
|
|
||||||
"""
|
|
||||||
if position <= 0 or position >= len(text):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Common entity suffixes that indicate incomplete entities
|
|
||||||
entity_suffixes = ['公', '司', '所', '院', '厅', '局', '部', '会', '团', '社', '处', '室', '楼', '号']
|
|
||||||
|
|
||||||
# Check if we're in the middle of a potential entity
|
|
||||||
for suffix in entity_suffixes:
|
|
||||||
# Look for incomplete entity patterns
|
|
||||||
if text[position-1:position+1] in [f'公{suffix}', f'司{suffix}', f'所{suffix}']:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check for incomplete company names
|
|
||||||
if text[position-2:position+1] in ['公司', '事务所', '协会', '研究院']:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check for incomplete address patterns
|
|
||||||
address_patterns = ['省', '市', '区', '县', '路', '街', '巷', '号', '室']
|
|
||||||
for pattern in address_patterns:
|
|
||||||
if text[position-1:position+1] in [f'省{pattern}', f'市{pattern}', f'区{pattern}', f'县{pattern}']:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _create_sentence_chunks(self, sentences: List[str], max_tokens: int = 400) -> List[str]:
|
|
||||||
"""
|
|
||||||
Create chunks from sentences while respecting token limits and entity boundaries
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sentences: List of sentences
|
|
||||||
max_tokens: Maximum tokens per chunk
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text chunks
|
|
||||||
"""
|
|
||||||
chunks = []
|
|
||||||
current_chunk = []
|
|
||||||
current_token_count = 0
|
|
||||||
|
|
||||||
for sentence in sentences:
|
|
||||||
# Estimate token count for this sentence
|
|
||||||
sentence_tokens = len(self.tokenizer.tokenize(sentence))
|
|
||||||
|
|
||||||
# If adding this sentence would exceed the limit
|
|
||||||
if current_token_count + sentence_tokens > max_tokens and current_chunk:
|
|
||||||
# Check if we can split the sentence to fit better
|
|
||||||
if sentence_tokens > max_tokens // 2: # If sentence is too long
|
|
||||||
# Try to split the sentence at a safe boundary
|
|
||||||
split_sentence = self._split_long_sentence(sentence, max_tokens - current_token_count)
|
|
||||||
if split_sentence:
|
|
||||||
# Add the first part to current chunk
|
|
||||||
current_chunk.append(split_sentence[0])
|
|
||||||
chunks.append(''.join(current_chunk))
|
|
||||||
|
|
||||||
# Start new chunk with remaining parts
|
|
||||||
current_chunk = split_sentence[1:]
|
|
||||||
current_token_count = sum(len(self.tokenizer.tokenize(s)) for s in current_chunk)
|
|
||||||
else:
|
|
||||||
# Finalize current chunk and start new one
|
|
||||||
chunks.append(''.join(current_chunk))
|
|
||||||
current_chunk = [sentence]
|
|
||||||
current_token_count = sentence_tokens
|
|
||||||
else:
|
|
||||||
# Finalize current chunk and start new one
|
|
||||||
chunks.append(''.join(current_chunk))
|
|
||||||
current_chunk = [sentence]
|
|
||||||
current_token_count = sentence_tokens
|
|
||||||
else:
|
|
||||||
# Add sentence to current chunk
|
|
||||||
current_chunk.append(sentence)
|
|
||||||
current_token_count += sentence_tokens
|
|
||||||
|
|
||||||
# Add the last chunk if it has content
|
|
||||||
if current_chunk:
|
|
||||||
chunks.append(''.join(current_chunk))
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
def _split_long_sentence(self, sentence: str, max_tokens: int) -> Optional[List[str]]:
|
|
||||||
"""
|
|
||||||
Split a long sentence at safe boundaries
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sentence: The sentence to split
|
|
||||||
max_tokens: Maximum tokens for the first part
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of sentence parts, or None if splitting is not possible
|
|
||||||
"""
|
|
||||||
if len(self.tokenizer.tokenize(sentence)) <= max_tokens:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Try to find safe splitting points
|
|
||||||
# Look for punctuation marks that are safe to split at
|
|
||||||
safe_splitters = [',', ',', ';', ';', '、', ':', ':']
|
|
||||||
|
|
||||||
for splitter in safe_splitters:
|
|
||||||
if splitter in sentence:
|
|
||||||
parts = sentence.split(splitter)
|
|
||||||
current_part = ""
|
|
||||||
|
|
||||||
for i, part in enumerate(parts):
|
|
||||||
test_part = current_part + part + (splitter if i < len(parts) - 1 else "")
|
|
||||||
if len(self.tokenizer.tokenize(test_part)) > max_tokens:
|
|
||||||
if current_part:
|
|
||||||
# Found a safe split point
|
|
||||||
remaining = splitter.join(parts[i:])
|
|
||||||
return [current_part, remaining]
|
|
||||||
break
|
|
||||||
current_part = test_part
|
|
||||||
|
|
||||||
# If no safe split point found, try character-based splitting with entity boundary check
|
|
||||||
target_chars = int(max_tokens / 1.5) # Rough character estimate
|
|
||||||
|
|
||||||
for i in range(target_chars, len(sentence)):
|
|
||||||
if self._is_entity_boundary_safe(sentence, i):
|
|
||||||
part1 = sentence[:i]
|
|
||||||
part2 = sentence[i:]
|
|
||||||
if len(self.tokenizer.tokenize(part1)) <= max_tokens:
|
|
||||||
return [part1, part2]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def extract(self, text: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Extract named entities from the given text
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to analyze
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing extracted entities in the format expected by the system
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not text or not text.strip():
|
|
||||||
logger.warning("Empty text provided for NER processing")
|
|
||||||
return {"entities": []}
|
|
||||||
|
|
||||||
# Initialize model if not already done
|
|
||||||
if not self._model_initialized:
|
|
||||||
self._initialize_model()
|
|
||||||
|
|
||||||
logger.info(f"Processing text with NER (length: {len(text)} characters)")
|
|
||||||
|
|
||||||
# Check if text needs chunking
|
|
||||||
if len(text) > 400: # Character-based threshold for chunking
|
|
||||||
logger.info("Text is long, using chunking approach")
|
|
||||||
return self._extract_with_chunking(text)
|
|
||||||
else:
|
|
||||||
logger.info("Text is short, processing directly")
|
|
||||||
return self._extract_single(text)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during NER processing: {str(e)}")
|
|
||||||
raise Exception(f"NER processing failed: {str(e)}")
|
|
||||||
|
|
||||||
def _extract_single(self, text: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Extract entities from a single text chunk
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to analyze
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing extracted entities
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Run the NER pipeline - it handles truncation automatically
|
|
||||||
logger.info(f"Running NER pipeline with text: {text}")
|
|
||||||
results = self.ner_pipeline(text)
|
|
||||||
logger.info(f"NER results: {results}")
|
|
||||||
|
|
||||||
# Filter and process entities
|
|
||||||
filtered_entities = []
|
|
||||||
for entity in results:
|
|
||||||
entity_group = entity['entity_group']
|
|
||||||
|
|
||||||
# Only process entities that we care about
|
|
||||||
if entity_group in self.label_map:
|
|
||||||
entity_type = self.label_map[entity_group]
|
|
||||||
entity_text = entity['word']
|
|
||||||
confidence_score = entity['score']
|
|
||||||
|
|
||||||
# Clean up the tokenized text (remove spaces between Chinese characters)
|
|
||||||
cleaned_text = self._clean_tokenized_text(entity_text)
|
|
||||||
|
|
||||||
# Add to our list with both original and cleaned text, only add if confidence score is above threshold
|
|
||||||
# if entity_group is 'address' or 'company', and only has characters less then 3, then filter it out
|
|
||||||
if confidence_score > self.confidence_threshold:
|
|
||||||
filtered_entities.append({
|
|
||||||
"text": cleaned_text, # Clean text for display/processing
|
|
||||||
"tokenized_text": entity_text, # Original tokenized text from model
|
|
||||||
"type": entity_type,
|
|
||||||
"entity_group": entity_group,
|
|
||||||
"confidence": confidence_score
|
|
||||||
})
|
|
||||||
logger.info(f"Filtered entities: {filtered_entities}")
|
|
||||||
# filter out entities that are less then 3 characters with entity_group is 'address' or 'company'
|
|
||||||
filtered_entities = [entity for entity in filtered_entities if entity['entity_group'] not in ['address', 'company'] or len(entity['text']) > 3]
|
|
||||||
logger.info(f"Final Filtered entities: {filtered_entities}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"entities": filtered_entities,
|
|
||||||
"total_count": len(filtered_entities)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during single NER processing: {str(e)}")
|
|
||||||
raise Exception(f"Single NER processing failed: {str(e)}")
|
|
||||||
|
|
||||||
def _extract_with_chunking(self, text: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Extract entities from long text using sentence-based chunking approach
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to analyze
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing extracted entities
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"Using sentence-based chunking for text of length: {len(text)}")
|
|
||||||
|
|
||||||
# Split text into sentences
|
|
||||||
sentences = self._split_text_by_sentences(text)
|
|
||||||
logger.info(f"Split text into {len(sentences)} sentences")
|
|
||||||
|
|
||||||
# Create chunks from sentences
|
|
||||||
chunks = self._create_sentence_chunks(sentences, max_tokens=400)
|
|
||||||
logger.info(f"Created {len(chunks)} chunks from sentences")
|
|
||||||
|
|
||||||
all_entities = []
|
|
||||||
|
|
||||||
# Process each chunk
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
# Verify chunk won't exceed token limit
|
|
||||||
chunk_tokens = len(self.tokenizer.tokenize(chunk))
|
|
||||||
logger.info(f"Processing chunk {i+1}: {len(chunk)} chars, {chunk_tokens} tokens")
|
|
||||||
|
|
||||||
if chunk_tokens > 512:
|
|
||||||
logger.warning(f"Chunk {i+1} has {chunk_tokens} tokens, truncating")
|
|
||||||
# Truncate the chunk to fit within token limit
|
|
||||||
chunk = self.tokenizer.convert_tokens_to_string(
|
|
||||||
self.tokenizer.tokenize(chunk)[:512]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract entities from this chunk
|
|
||||||
chunk_result = self._extract_single(chunk)
|
|
||||||
chunk_entities = chunk_result.get("entities", [])
|
|
||||||
|
|
||||||
all_entities.extend(chunk_entities)
|
|
||||||
logger.info(f"Chunk {i+1} extracted {len(chunk_entities)} entities")
|
|
||||||
|
|
||||||
# Remove duplicates while preserving order
|
|
||||||
unique_entities = []
|
|
||||||
seen_texts = set()
|
|
||||||
|
|
||||||
for entity in all_entities:
|
|
||||||
text = entity['text'].strip()
|
|
||||||
if text and text not in seen_texts:
|
|
||||||
seen_texts.add(text)
|
|
||||||
unique_entities.append(entity)
|
|
||||||
|
|
||||||
logger.info(f"Sentence-based chunking completed: {len(all_entities)} total entities, {len(unique_entities)} unique entities")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"entities": unique_entities,
|
|
||||||
"total_count": len(unique_entities)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during sentence-based chunked NER processing: {str(e)}")
|
|
||||||
raise Exception(f"Sentence-based chunked NER processing failed: {str(e)}")
|
|
||||||
|
|
||||||
def _clean_tokenized_text(self, tokenized_text: str) -> str:
|
|
||||||
"""
|
|
||||||
Clean up tokenized text by removing spaces between Chinese characters
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokenized_text: Text with spaces between characters (e.g., "北 京 市")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Cleaned text without spaces (e.g., "北京市")
|
|
||||||
"""
|
|
||||||
if not tokenized_text:
|
|
||||||
return tokenized_text
|
|
||||||
|
|
||||||
# Remove spaces between Chinese characters
|
|
||||||
# This handles cases like "北 京 市" -> "北京市"
|
|
||||||
cleaned = tokenized_text.replace(" ", "")
|
|
||||||
|
|
||||||
# Also handle cases where there might be multiple spaces
|
|
||||||
cleaned = " ".join(cleaned.split())
|
|
||||||
|
|
||||||
return cleaned
|
|
||||||
|
|
||||||
def get_entity_summary(self, entities: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Generate a summary of extracted entities by type
|
|
||||||
|
|
||||||
Args:
|
|
||||||
entities: List of extracted entities
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Summary dictionary with counts by entity type
|
|
||||||
"""
|
|
||||||
summary = {}
|
|
||||||
for entity in entities:
|
|
||||||
entity_type = entity['type']
|
|
||||||
if entity_type not in summary:
|
|
||||||
summary[entity_type] = []
|
|
||||||
summary[entity_type].append(entity['text'])
|
|
||||||
|
|
||||||
# Convert to count format
|
|
||||||
summary_counts = {entity_type: len(texts) for entity_type, texts in summary.items()}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"summary": summary,
|
|
||||||
"counts": summary_counts,
|
|
||||||
"total_entities": len(entities)
|
|
||||||
}
|
|
||||||
|
|
||||||
def extract_and_summarize(self, text: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Extract entities and provide a summary in one call
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to analyze
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing entities and summary
|
|
||||||
"""
|
|
||||||
entities_result = self.extract(text)
|
|
||||||
entities = entities_result.get("entities", [])
|
|
||||||
|
|
||||||
summary_result = self.get_entity_summary(entities)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"entities": entities,
|
|
||||||
"summary": summary_result,
|
|
||||||
"total_count": len(entities)
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_confidence(self) -> float:
|
|
||||||
"""
|
|
||||||
Return confidence level of extraction
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Confidence level as a float between 0.0 and 1.0
|
|
||||||
"""
|
|
||||||
# NER models typically have high confidence for well-trained entities
|
|
||||||
# This is a reasonable default confidence level for NER extraction
|
|
||||||
return 0.85
|
|
||||||
|
|
||||||
def get_model_info(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get information about the NER model
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing model information
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"model_name": self.model_checkpoint,
|
|
||||||
"model_type": "Chinese NER",
|
|
||||||
"supported_entities": [
|
|
||||||
"人名 (Person Names)",
|
|
||||||
"公司名称 (Company Names)",
|
|
||||||
"组织机构名 (Organization Names)",
|
|
||||||
"地址 (Addresses)"
|
|
||||||
],
|
|
||||||
"description": "Fine-tuned RoBERTa model for Chinese Named Entity Recognition on CLUENER2020 dataset"
|
|
||||||
}
|
|
||||||
|
|
@ -1,65 +0,0 @@
|
||||||
"""
|
|
||||||
Factory for creating maskers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Type, Any
|
|
||||||
from .maskers.base_masker import BaseMasker
|
|
||||||
from .maskers.name_masker import ChineseNameMasker, EnglishNameMasker
|
|
||||||
from .maskers.company_masker import CompanyMasker
|
|
||||||
from .maskers.address_masker import AddressMasker
|
|
||||||
from .maskers.id_masker import IDMasker
|
|
||||||
from .maskers.case_masker import CaseMasker
|
|
||||||
from ..services.ollama_client import OllamaClient
|
|
||||||
|
|
||||||
|
|
||||||
class MaskerFactory:
|
|
||||||
"""Factory for creating maskers"""
|
|
||||||
|
|
||||||
_maskers: Dict[str, Type[BaseMasker]] = {
|
|
||||||
'chinese_name': ChineseNameMasker,
|
|
||||||
'english_name': EnglishNameMasker,
|
|
||||||
'company': CompanyMasker,
|
|
||||||
'address': AddressMasker,
|
|
||||||
'id': IDMasker,
|
|
||||||
'case': CaseMasker,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_masker(cls, masker_type: str, ollama_client: OllamaClient = None, config: Dict[str, Any] = None) -> BaseMasker:
|
|
||||||
"""
|
|
||||||
Create a masker of the specified type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
masker_type: Type of masker to create
|
|
||||||
ollama_client: Ollama client for LLM-based maskers
|
|
||||||
config: Configuration for the masker
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Instance of the specified masker
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If masker type is unknown
|
|
||||||
"""
|
|
||||||
if masker_type not in cls._maskers:
|
|
||||||
raise ValueError(f"Unknown masker type: {masker_type}")
|
|
||||||
|
|
||||||
masker_class = cls._maskers[masker_type]
|
|
||||||
|
|
||||||
# Handle maskers that need ollama_client
|
|
||||||
if masker_type in ['company', 'address']:
|
|
||||||
if not ollama_client:
|
|
||||||
raise ValueError(f"Ollama client is required for {masker_type} masker")
|
|
||||||
return masker_class(ollama_client)
|
|
||||||
|
|
||||||
# Handle maskers that don't need special parameters
|
|
||||||
return masker_class()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_available_maskers(cls) -> list[str]:
|
|
||||||
"""Get list of available masker types"""
|
|
||||||
return list(cls._maskers.keys())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_masker(cls, masker_type: str, masker_class: Type[BaseMasker]):
|
|
||||||
"""Register a new masker type"""
|
|
||||||
cls._maskers[masker_type] = masker_class
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
"""
|
|
||||||
Maskers package for entity masking functionality.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .base_masker import BaseMasker
|
|
||||||
from .name_masker import ChineseNameMasker, EnglishNameMasker
|
|
||||||
from .company_masker import CompanyMasker
|
|
||||||
from .address_masker import AddressMasker
|
|
||||||
from .id_masker import IDMasker
|
|
||||||
from .case_masker import CaseMasker
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'BaseMasker',
|
|
||||||
'ChineseNameMasker',
|
|
||||||
'EnglishNameMasker',
|
|
||||||
'CompanyMasker',
|
|
||||||
'AddressMasker',
|
|
||||||
'IDMasker',
|
|
||||||
'CaseMasker'
|
|
||||||
]
|
|
||||||
|
|
@ -1,91 +0,0 @@
|
||||||
"""
|
|
||||||
Address masker for addresses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Any
|
|
||||||
from pypinyin import pinyin, Style
|
|
||||||
from ...services.ollama_client import OllamaClient
|
|
||||||
from ..extractors.address_extractor import AddressExtractor
|
|
||||||
from .base_masker import BaseMasker
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AddressMasker(BaseMasker):
|
|
||||||
"""Masker for addresses"""
|
|
||||||
|
|
||||||
def __init__(self, ollama_client: OllamaClient):
|
|
||||||
self.extractor = AddressExtractor(ollama_client)
|
|
||||||
|
|
||||||
def mask(self, address: str, context: Dict[str, Any] = None) -> str:
|
|
||||||
"""
|
|
||||||
Mask address by replacing components with masked versions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
address: The address to mask
|
|
||||||
context: Additional context (not used for address masking)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Masked address
|
|
||||||
"""
|
|
||||||
if not address:
|
|
||||||
return address
|
|
||||||
|
|
||||||
# Extract address components
|
|
||||||
components = self.extractor.extract(address)
|
|
||||||
if not components:
|
|
||||||
return address
|
|
||||||
|
|
||||||
masked_address = address
|
|
||||||
|
|
||||||
# Replace road name
|
|
||||||
if components.get("road_name"):
|
|
||||||
road_name = components["road_name"]
|
|
||||||
# Get pinyin initials for road name
|
|
||||||
try:
|
|
||||||
pinyin_list = pinyin(road_name, style=Style.NORMAL)
|
|
||||||
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
|
|
||||||
masked_address = masked_address.replace(road_name, initials + "路")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get pinyin for road name {road_name}: {e}")
|
|
||||||
# Fallback to first character
|
|
||||||
masked_address = masked_address.replace(road_name, road_name[0].upper() + "路")
|
|
||||||
|
|
||||||
# Replace house number
|
|
||||||
if components.get("house_number"):
|
|
||||||
house_number = components["house_number"]
|
|
||||||
masked_address = masked_address.replace(house_number + "号", "**号")
|
|
||||||
|
|
||||||
# Replace building name
|
|
||||||
if components.get("building_name"):
|
|
||||||
building_name = components["building_name"]
|
|
||||||
# Get pinyin initials for building name
|
|
||||||
try:
|
|
||||||
pinyin_list = pinyin(building_name, style=Style.NORMAL)
|
|
||||||
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
|
|
||||||
masked_address = masked_address.replace(building_name, initials)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get pinyin for building name {building_name}: {e}")
|
|
||||||
# Fallback to first character
|
|
||||||
masked_address = masked_address.replace(building_name, building_name[0].upper())
|
|
||||||
|
|
||||||
# Replace community name
|
|
||||||
if components.get("community_name"):
|
|
||||||
community_name = components["community_name"]
|
|
||||||
# Get pinyin initials for community name
|
|
||||||
try:
|
|
||||||
pinyin_list = pinyin(community_name, style=Style.NORMAL)
|
|
||||||
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
|
|
||||||
masked_address = masked_address.replace(community_name, initials)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get pinyin for community name {community_name}: {e}")
|
|
||||||
# Fallback to first character
|
|
||||||
masked_address = masked_address.replace(community_name, community_name[0].upper())
|
|
||||||
|
|
||||||
return masked_address
|
|
||||||
|
|
||||||
def get_supported_types(self) -> list[str]:
|
|
||||||
"""Return list of entity types this masker supports"""
|
|
||||||
return ['地址']
|
|
||||||
|
|
@ -1,24 +0,0 @@
|
||||||
"""
|
|
||||||
Abstract base class for all maskers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMasker(ABC):
|
|
||||||
"""Abstract base class for all maskers"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def mask(self, text: str, context: Dict[str, Any] = None) -> str:
|
|
||||||
"""Mask the given text according to specific rules"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_supported_types(self) -> list[str]:
|
|
||||||
"""Return list of entity types this masker supports"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def can_mask(self, entity_type: str) -> bool:
|
|
||||||
"""Check if this masker can handle the given entity type"""
|
|
||||||
return entity_type in self.get_supported_types()
|
|
||||||
|
|
@ -1,33 +0,0 @@
|
||||||
"""
|
|
||||||
Case masker for case numbers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Dict, Any
|
|
||||||
from .base_masker import BaseMasker
|
|
||||||
|
|
||||||
|
|
||||||
class CaseMasker(BaseMasker):
|
|
||||||
"""Masker for case numbers"""
|
|
||||||
|
|
||||||
def mask(self, text: str, context: Dict[str, Any] = None) -> str:
|
|
||||||
"""
|
|
||||||
Mask case numbers by replacing digits with ***.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to mask
|
|
||||||
context: Additional context (not used for case masking)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Masked text
|
|
||||||
"""
|
|
||||||
if not text:
|
|
||||||
return text
|
|
||||||
|
|
||||||
# Replace digits with *** while preserving structure
|
|
||||||
masked = re.sub(r'(\d[\d\s]*)(号)', r'***\2', text)
|
|
||||||
return masked
|
|
||||||
|
|
||||||
def get_supported_types(self) -> list[str]:
|
|
||||||
"""Return list of entity types this masker supports"""
|
|
||||||
return ['案号']
|
|
||||||
|
|
@ -1,98 +0,0 @@
|
||||||
"""
|
|
||||||
Company masker for company names.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Any
|
|
||||||
from pypinyin import pinyin, Style
|
|
||||||
from ...services.ollama_client import OllamaClient
|
|
||||||
from ..extractors.business_name_extractor import BusinessNameExtractor
|
|
||||||
from .base_masker import BaseMasker
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CompanyMasker(BaseMasker):
|
|
||||||
"""Masker for company names"""
|
|
||||||
|
|
||||||
def __init__(self, ollama_client: OllamaClient):
|
|
||||||
self.extractor = BusinessNameExtractor(ollama_client)
|
|
||||||
|
|
||||||
def mask(self, company_name: str, context: Dict[str, Any] = None) -> str:
|
|
||||||
"""
|
|
||||||
Mask company name by replacing business name with letters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
company_name: The company name to mask
|
|
||||||
context: Additional context (not used for company masking)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Masked company name
|
|
||||||
"""
|
|
||||||
if not company_name:
|
|
||||||
return company_name
|
|
||||||
|
|
||||||
# Extract business name
|
|
||||||
extraction_result = self.extractor.extract(company_name)
|
|
||||||
if not extraction_result:
|
|
||||||
return company_name
|
|
||||||
|
|
||||||
business_name = extraction_result.get('business_name', '')
|
|
||||||
if not business_name:
|
|
||||||
return company_name
|
|
||||||
|
|
||||||
# Get pinyin first letter of business name
|
|
||||||
try:
|
|
||||||
pinyin_list = pinyin(business_name, style=Style.NORMAL)
|
|
||||||
first_letter = pinyin_list[0][0][0].upper() if pinyin_list and pinyin_list[0] else 'A'
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get pinyin for {business_name}: {e}")
|
|
||||||
first_letter = 'A'
|
|
||||||
|
|
||||||
# Calculate next two letters
|
|
||||||
if first_letter >= 'Y':
|
|
||||||
# If first letter is Y or Z, use X and Y
|
|
||||||
letters = 'XY'
|
|
||||||
elif first_letter >= 'X':
|
|
||||||
# If first letter is X, use Y and Z
|
|
||||||
letters = 'YZ'
|
|
||||||
else:
|
|
||||||
# Normal case: use next two letters
|
|
||||||
letters = chr(ord(first_letter) + 1) + chr(ord(first_letter) + 2)
|
|
||||||
|
|
||||||
# Replace business name
|
|
||||||
if business_name in company_name:
|
|
||||||
masked_name = company_name.replace(business_name, letters)
|
|
||||||
else:
|
|
||||||
# Try smarter replacement
|
|
||||||
masked_name = self._replace_business_name_in_company(company_name, business_name, letters)
|
|
||||||
|
|
||||||
return masked_name
|
|
||||||
|
|
||||||
def _replace_business_name_in_company(self, company_name: str, business_name: str, letters: str) -> str:
|
|
||||||
"""Smart replacement of business name in company name"""
|
|
||||||
# Try different replacement patterns
|
|
||||||
patterns = [
|
|
||||||
business_name,
|
|
||||||
business_name + '(',
|
|
||||||
business_name + '(',
|
|
||||||
'(' + business_name + ')',
|
|
||||||
'(' + business_name + ')',
|
|
||||||
]
|
|
||||||
|
|
||||||
for pattern in patterns:
|
|
||||||
if pattern in company_name:
|
|
||||||
if pattern.endswith('(') or pattern.endswith('('):
|
|
||||||
return company_name.replace(pattern, letters + pattern[-1])
|
|
||||||
elif pattern.startswith('(') or pattern.startswith('('):
|
|
||||||
return company_name.replace(pattern, pattern[0] + letters + pattern[-1])
|
|
||||||
else:
|
|
||||||
return company_name.replace(pattern, letters)
|
|
||||||
|
|
||||||
# If no pattern found, return original
|
|
||||||
return company_name
|
|
||||||
|
|
||||||
def get_supported_types(self) -> list[str]:
|
|
||||||
"""Return list of entity types this masker supports"""
|
|
||||||
return ['公司名称', 'Company', '英文公司名', 'English Company']
|
|
||||||
|
|
@ -1,39 +0,0 @@
|
||||||
"""
|
|
||||||
ID masker for ID numbers and social credit codes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any
|
|
||||||
from .base_masker import BaseMasker
|
|
||||||
|
|
||||||
|
|
||||||
class IDMasker(BaseMasker):
|
|
||||||
"""Masker for ID numbers and social credit codes"""
|
|
||||||
|
|
||||||
def mask(self, text: str, context: Dict[str, Any] = None) -> str:
|
|
||||||
"""
|
|
||||||
Mask ID numbers and social credit codes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to mask
|
|
||||||
context: Additional context (not used for ID masking)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Masked text
|
|
||||||
"""
|
|
||||||
if not text:
|
|
||||||
return text
|
|
||||||
|
|
||||||
# Determine the type based on length and format
|
|
||||||
if len(text) == 18 and text.isdigit():
|
|
||||||
# ID number: keep first 6 digits
|
|
||||||
return text[:6] + 'X' * (len(text) - 6)
|
|
||||||
elif len(text) == 18 and any(c.isalpha() for c in text):
|
|
||||||
# Social credit code: keep first 7 digits
|
|
||||||
return text[:7] + 'X' * (len(text) - 7)
|
|
||||||
else:
|
|
||||||
# Fallback for invalid formats
|
|
||||||
return text
|
|
||||||
|
|
||||||
def get_supported_types(self) -> list[str]:
|
|
||||||
"""Return list of entity types this masker supports"""
|
|
||||||
return ['身份证号', '社会信用代码']
|
|
||||||
|
|
@ -1,89 +0,0 @@
|
||||||
"""
|
|
||||||
Name maskers for Chinese and English names.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any
|
|
||||||
from pypinyin import pinyin, Style
|
|
||||||
from .base_masker import BaseMasker
|
|
||||||
|
|
||||||
|
|
||||||
class ChineseNameMasker(BaseMasker):
|
|
||||||
"""Masker for Chinese names"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.surname_counter = {}
|
|
||||||
|
|
||||||
def mask(self, name: str, context: Dict[str, Any] = None) -> str:
|
|
||||||
"""
|
|
||||||
Mask Chinese names: keep surname, convert given name to pinyin initials.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The name to mask
|
|
||||||
context: Additional context containing surname_counter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Masked name
|
|
||||||
"""
|
|
||||||
if not name or len(name) < 2:
|
|
||||||
return name
|
|
||||||
|
|
||||||
# Use context surname_counter if provided, otherwise use instance counter
|
|
||||||
surname_counter = context.get('surname_counter', self.surname_counter) if context else self.surname_counter
|
|
||||||
|
|
||||||
surname = name[0]
|
|
||||||
given_name = name[1:]
|
|
||||||
|
|
||||||
# Get pinyin initials for given name
|
|
||||||
try:
|
|
||||||
pinyin_list = pinyin(given_name, style=Style.NORMAL)
|
|
||||||
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
|
|
||||||
except Exception:
|
|
||||||
# Fallback to original characters if pinyin fails
|
|
||||||
initials = given_name
|
|
||||||
|
|
||||||
# Initialize surname counter
|
|
||||||
if surname not in surname_counter:
|
|
||||||
surname_counter[surname] = {}
|
|
||||||
|
|
||||||
# Check for duplicate surname and initials combination
|
|
||||||
if initials in surname_counter[surname]:
|
|
||||||
surname_counter[surname][initials] += 1
|
|
||||||
masked_name = f"{surname}{initials}{surname_counter[surname][initials]}"
|
|
||||||
else:
|
|
||||||
surname_counter[surname][initials] = 1
|
|
||||||
masked_name = f"{surname}{initials}"
|
|
||||||
|
|
||||||
return masked_name
|
|
||||||
|
|
||||||
def get_supported_types(self) -> list[str]:
|
|
||||||
"""Return list of entity types this masker supports"""
|
|
||||||
return ['人名', '律师姓名', '审判人员姓名']
|
|
||||||
|
|
||||||
|
|
||||||
class EnglishNameMasker(BaseMasker):
|
|
||||||
"""Masker for English names"""
|
|
||||||
|
|
||||||
def mask(self, name: str, context: Dict[str, Any] = None) -> str:
|
|
||||||
"""
|
|
||||||
Mask English names: convert each word to first letter + ***.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The name to mask
|
|
||||||
context: Additional context (not used for English name masking)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Masked name
|
|
||||||
"""
|
|
||||||
if not name:
|
|
||||||
return name
|
|
||||||
|
|
||||||
masked_parts = []
|
|
||||||
for part in name.split():
|
|
||||||
if part:
|
|
||||||
masked_parts.append(part[0] + '***')
|
|
||||||
|
|
||||||
return ' '.join(masked_parts)
|
|
||||||
|
|
||||||
def get_supported_types(self) -> list[str]:
|
|
||||||
"""Return list of entity types this masker supports"""
|
|
||||||
return ['英文人名']
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,410 +0,0 @@
|
||||||
"""
|
|
||||||
Refactored NerProcessor using the new masker architecture.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
from ..prompts.masking_prompts import (
|
|
||||||
get_ner_name_prompt, get_ner_company_prompt, get_ner_address_prompt,
|
|
||||||
get_ner_project_prompt, get_ner_case_number_prompt, get_entity_linkage_prompt
|
|
||||||
)
|
|
||||||
from ..services.ollama_client import OllamaClient
|
|
||||||
from ...core.config import settings
|
|
||||||
from ..utils.json_extractor import LLMJsonExtractor
|
|
||||||
from ..utils.llm_validator import LLMResponseValidator
|
|
||||||
from .regs.entity_regex import extract_id_number_entities, extract_social_credit_code_entities
|
|
||||||
from .masker_factory import MaskerFactory
|
|
||||||
from .maskers.base_masker import BaseMasker
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class NerProcessorRefactored:
|
|
||||||
"""Refactored NerProcessor using the new masker architecture"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
|
||||||
self.max_retries = 3
|
|
||||||
self.maskers = self._initialize_maskers()
|
|
||||||
self.surname_counter = {} # Shared counter for Chinese names
|
|
||||||
|
|
||||||
def _find_entity_alignment(self, entity_text: str, original_document_text: str) -> Optional[Tuple[int, int, str]]:
|
|
||||||
"""
|
|
||||||
Find entity in original document using character-by-character alignment.
|
|
||||||
|
|
||||||
This method handles the case where the original document may have spaces
|
|
||||||
that are not from tokenization, and the entity text may have different
|
|
||||||
spacing patterns.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
entity_text: The entity text to find (may have spaces from tokenization)
|
|
||||||
original_document_text: The original document text (may have spaces)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (start_pos, end_pos, found_text) or None if not found
|
|
||||||
"""
|
|
||||||
# Remove all spaces from entity text to get clean characters
|
|
||||||
clean_entity = entity_text.replace(" ", "")
|
|
||||||
|
|
||||||
# Create character lists ignoring spaces from both entity and document
|
|
||||||
entity_chars = [c for c in clean_entity]
|
|
||||||
doc_chars = [c for c in original_document_text if c != ' ']
|
|
||||||
|
|
||||||
# Find the sequence in document characters
|
|
||||||
for i in range(len(doc_chars) - len(entity_chars) + 1):
|
|
||||||
if doc_chars[i:i+len(entity_chars)] == entity_chars:
|
|
||||||
# Found match, now map back to original positions
|
|
||||||
return self._map_char_positions_to_original(i, len(entity_chars), original_document_text)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _map_char_positions_to_original(self, clean_start: int, entity_length: int, original_text: str) -> Tuple[int, int, str]:
|
|
||||||
"""
|
|
||||||
Map positions from clean text (without spaces) back to original text positions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
clean_start: Start position in clean text (without spaces)
|
|
||||||
entity_length: Length of entity in characters
|
|
||||||
original_text: Original document text with spaces
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (start_pos, end_pos, found_text) in original text
|
|
||||||
"""
|
|
||||||
original_pos = 0
|
|
||||||
clean_pos = 0
|
|
||||||
|
|
||||||
# Find the start position in original text
|
|
||||||
while clean_pos < clean_start and original_pos < len(original_text):
|
|
||||||
if original_text[original_pos] != ' ':
|
|
||||||
clean_pos += 1
|
|
||||||
original_pos += 1
|
|
||||||
|
|
||||||
start_pos = original_pos
|
|
||||||
|
|
||||||
# Find the end position by counting non-space characters
|
|
||||||
chars_found = 0
|
|
||||||
while chars_found < entity_length and original_pos < len(original_text):
|
|
||||||
if original_text[original_pos] != ' ':
|
|
||||||
chars_found += 1
|
|
||||||
original_pos += 1
|
|
||||||
|
|
||||||
end_pos = original_pos
|
|
||||||
|
|
||||||
# Extract the actual text from the original document
|
|
||||||
found_text = original_text[start_pos:end_pos]
|
|
||||||
|
|
||||||
return start_pos, end_pos, found_text
|
|
||||||
|
|
||||||
def apply_entity_masking_with_alignment(self, original_document_text: str, entity_mapping: Dict[str, str], mask_char: str = "*") -> str:
|
|
||||||
"""
|
|
||||||
Apply entity masking to original document text using character-by-character alignment.
|
|
||||||
|
|
||||||
This method finds each entity in the original document using alignment and
|
|
||||||
replaces it with the corresponding masked version. It handles multiple
|
|
||||||
occurrences of the same entity by finding all instances before moving
|
|
||||||
to the next entity.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
original_document_text: The original document text to mask
|
|
||||||
entity_mapping: Dictionary mapping original entity text to masked text
|
|
||||||
mask_char: Character to use for masking (default: "*")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Masked document text
|
|
||||||
"""
|
|
||||||
masked_document = original_document_text
|
|
||||||
|
|
||||||
# Sort entities by length (longest first) to avoid partial matches
|
|
||||||
sorted_entities = sorted(entity_mapping.keys(), key=len, reverse=True)
|
|
||||||
|
|
||||||
for entity_text in sorted_entities:
|
|
||||||
masked_text = entity_mapping[entity_text]
|
|
||||||
|
|
||||||
# Skip if masked text is the same as original text (prevents infinite loop)
|
|
||||||
if entity_text == masked_text:
|
|
||||||
logger.debug(f"Skipping entity '{entity_text}' as masked text is identical")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Find ALL occurrences of this entity in the document
|
|
||||||
# We need to loop until no more matches are found
|
|
||||||
# Add safety counter to prevent infinite loops
|
|
||||||
max_iterations = 100 # Safety limit
|
|
||||||
iteration_count = 0
|
|
||||||
|
|
||||||
while iteration_count < max_iterations:
|
|
||||||
iteration_count += 1
|
|
||||||
|
|
||||||
# Find the entity in the current masked document using alignment
|
|
||||||
alignment_result = self._find_entity_alignment(entity_text, masked_document)
|
|
||||||
|
|
||||||
if alignment_result:
|
|
||||||
start_pos, end_pos, found_text = alignment_result
|
|
||||||
|
|
||||||
# Replace the found text with the masked version
|
|
||||||
masked_document = (
|
|
||||||
masked_document[:start_pos] +
|
|
||||||
masked_text +
|
|
||||||
masked_document[end_pos:]
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Masked entity '{entity_text}' -> '{masked_text}' at positions {start_pos}-{end_pos} (iteration {iteration_count})")
|
|
||||||
else:
|
|
||||||
# No more occurrences found for this entity, move to next entity
|
|
||||||
logger.debug(f"No more occurrences of '{entity_text}' found in document after {iteration_count} iterations")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Log warning if we hit the safety limit
|
|
||||||
if iteration_count >= max_iterations:
|
|
||||||
logger.warning(f"Reached maximum iterations ({max_iterations}) for entity '{entity_text}', stopping to prevent infinite loop")
|
|
||||||
|
|
||||||
return masked_document
|
|
||||||
|
|
||||||
def _initialize_maskers(self) -> Dict[str, BaseMasker]:
|
|
||||||
"""Initialize all maskers"""
|
|
||||||
maskers = {}
|
|
||||||
|
|
||||||
# Create maskers that don't need ollama_client
|
|
||||||
maskers['chinese_name'] = MaskerFactory.create_masker('chinese_name')
|
|
||||||
maskers['english_name'] = MaskerFactory.create_masker('english_name')
|
|
||||||
maskers['id'] = MaskerFactory.create_masker('id')
|
|
||||||
maskers['case'] = MaskerFactory.create_masker('case')
|
|
||||||
|
|
||||||
# Create maskers that need ollama_client
|
|
||||||
maskers['company'] = MaskerFactory.create_masker('company', self.ollama_client)
|
|
||||||
maskers['address'] = MaskerFactory.create_masker('address', self.ollama_client)
|
|
||||||
|
|
||||||
return maskers
|
|
||||||
|
|
||||||
def _get_masker_for_type(self, entity_type: str) -> Optional[BaseMasker]:
|
|
||||||
"""Get the appropriate masker for the given entity type"""
|
|
||||||
for masker in self.maskers.values():
|
|
||||||
if masker.can_mask(entity_type):
|
|
||||||
return masker
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _validate_mapping_format(self, mapping: Dict[str, Any]) -> bool:
|
|
||||||
"""Validate entity extraction mapping format"""
|
|
||||||
return LLMResponseValidator.validate_entity_extraction(mapping)
|
|
||||||
|
|
||||||
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
|
|
||||||
"""Process entities of a specific type using LLM"""
|
|
||||||
try:
|
|
||||||
formatted_prompt = prompt_func(chunk)
|
|
||||||
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk: {formatted_prompt}")
|
|
||||||
|
|
||||||
# Use the new enhanced generate method with validation
|
|
||||||
mapping = self.ollama_client.generate_with_validation(
|
|
||||||
prompt=formatted_prompt,
|
|
||||||
response_type='entity_extraction',
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Parsed mapping: {mapping}")
|
|
||||||
|
|
||||||
if mapping and self._validate_mapping_format(mapping):
|
|
||||||
return mapping
|
|
||||||
else:
|
|
||||||
logger.warning(f"Invalid mapping format received for {entity_type}")
|
|
||||||
return {}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating {entity_type} mapping: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def build_mapping(self, chunk: str) -> List[Dict[str, str]]:
|
|
||||||
"""Build entity mappings from text chunk"""
|
|
||||||
mapping_pipeline = []
|
|
||||||
|
|
||||||
# Process different entity types
|
|
||||||
entity_configs = [
|
|
||||||
(get_ner_name_prompt, "people names"),
|
|
||||||
(get_ner_company_prompt, "company names"),
|
|
||||||
(get_ner_address_prompt, "addresses"),
|
|
||||||
(get_ner_project_prompt, "project names"),
|
|
||||||
(get_ner_case_number_prompt, "case numbers")
|
|
||||||
]
|
|
||||||
|
|
||||||
for prompt_func, entity_type in entity_configs:
|
|
||||||
mapping = self._process_entity_type(chunk, prompt_func, entity_type)
|
|
||||||
if mapping:
|
|
||||||
mapping_pipeline.append(mapping)
|
|
||||||
|
|
||||||
# Process regex-based entities
|
|
||||||
regex_entity_extractors = [
|
|
||||||
extract_id_number_entities,
|
|
||||||
extract_social_credit_code_entities
|
|
||||||
]
|
|
||||||
|
|
||||||
for extractor in regex_entity_extractors:
|
|
||||||
mapping = extractor(chunk)
|
|
||||||
if mapping and LLMResponseValidator.validate_regex_entity(mapping):
|
|
||||||
mapping_pipeline.append(mapping)
|
|
||||||
elif mapping:
|
|
||||||
logger.warning(f"Invalid regex entity mapping format: {mapping}")
|
|
||||||
|
|
||||||
return mapping_pipeline
|
|
||||||
|
|
||||||
def _merge_entity_mappings(self, chunk_mappings: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
|
||||||
"""Merge entity mappings from multiple chunks"""
|
|
||||||
all_entities = []
|
|
||||||
for mapping in chunk_mappings:
|
|
||||||
if isinstance(mapping, dict) and 'entities' in mapping:
|
|
||||||
entities = mapping['entities']
|
|
||||||
if isinstance(entities, list):
|
|
||||||
all_entities.extend(entities)
|
|
||||||
|
|
||||||
unique_entities = []
|
|
||||||
seen_texts = set()
|
|
||||||
|
|
||||||
for entity in all_entities:
|
|
||||||
if isinstance(entity, dict) and 'text' in entity:
|
|
||||||
text = entity['text'].strip()
|
|
||||||
if text and text not in seen_texts:
|
|
||||||
seen_texts.add(text)
|
|
||||||
unique_entities.append(entity)
|
|
||||||
elif text and text in seen_texts:
|
|
||||||
logger.info(f"Duplicate entity found: {entity}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info(f"Merged {len(unique_entities)} unique entities")
|
|
||||||
return unique_entities
|
|
||||||
|
|
||||||
def _generate_masked_mapping(self, unique_entities: List[Dict[str, str]], linkage: Dict[str, Any]) -> Dict[str, str]:
|
|
||||||
"""Generate masked mappings for entities"""
|
|
||||||
entity_mapping = {}
|
|
||||||
used_masked_names = set()
|
|
||||||
group_mask_map = {}
|
|
||||||
|
|
||||||
# Process entity groups from linkage
|
|
||||||
for group in linkage.get('entity_groups', []):
|
|
||||||
group_type = group.get('group_type', '')
|
|
||||||
entities = group.get('entities', [])
|
|
||||||
|
|
||||||
# Handle company groups
|
|
||||||
if any(keyword in group_type for keyword in ['公司', 'Company']):
|
|
||||||
for entity in entities:
|
|
||||||
masker = self._get_masker_for_type('公司名称')
|
|
||||||
if masker:
|
|
||||||
masked = masker.mask(entity['text'])
|
|
||||||
group_mask_map[entity['text']] = masked
|
|
||||||
|
|
||||||
# Handle name groups
|
|
||||||
elif '人名' in group_type:
|
|
||||||
for entity in entities:
|
|
||||||
masker = self._get_masker_for_type('人名')
|
|
||||||
if masker:
|
|
||||||
context = {'surname_counter': self.surname_counter}
|
|
||||||
masked = masker.mask(entity['text'], context)
|
|
||||||
group_mask_map[entity['text']] = masked
|
|
||||||
|
|
||||||
# Handle English name groups
|
|
||||||
elif '英文人名' in group_type:
|
|
||||||
for entity in entities:
|
|
||||||
masker = self._get_masker_for_type('英文人名')
|
|
||||||
if masker:
|
|
||||||
masked = masker.mask(entity['text'])
|
|
||||||
group_mask_map[entity['text']] = masked
|
|
||||||
|
|
||||||
# Process individual entities
|
|
||||||
for entity in unique_entities:
|
|
||||||
text = entity['text']
|
|
||||||
entity_type = entity.get('type', '')
|
|
||||||
|
|
||||||
# Check if entity is in group mapping
|
|
||||||
if text in group_mask_map:
|
|
||||||
entity_mapping[text] = group_mask_map[text]
|
|
||||||
used_masked_names.add(group_mask_map[text])
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get appropriate masker for entity type
|
|
||||||
masker = self._get_masker_for_type(entity_type)
|
|
||||||
if masker:
|
|
||||||
# Prepare context for maskers that need it
|
|
||||||
context = {}
|
|
||||||
if entity_type in ['人名', '律师姓名', '审判人员姓名']:
|
|
||||||
context['surname_counter'] = self.surname_counter
|
|
||||||
|
|
||||||
masked = masker.mask(text, context)
|
|
||||||
entity_mapping[text] = masked
|
|
||||||
used_masked_names.add(masked)
|
|
||||||
else:
|
|
||||||
# Fallback for unknown entity types
|
|
||||||
base_name = '某'
|
|
||||||
masked = base_name
|
|
||||||
counter = 1
|
|
||||||
while masked in used_masked_names:
|
|
||||||
if counter <= 10:
|
|
||||||
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
|
||||||
masked = base_name + suffixes[counter - 1]
|
|
||||||
else:
|
|
||||||
masked = f"{base_name}{counter}"
|
|
||||||
counter += 1
|
|
||||||
entity_mapping[text] = masked
|
|
||||||
used_masked_names.add(masked)
|
|
||||||
|
|
||||||
return entity_mapping
|
|
||||||
|
|
||||||
def _validate_linkage_format(self, linkage: Dict[str, Any]) -> bool:
|
|
||||||
"""Validate entity linkage format"""
|
|
||||||
return LLMResponseValidator.validate_entity_linkage(linkage)
|
|
||||||
|
|
||||||
def _create_entity_linkage(self, unique_entities: List[Dict[str, str]]) -> Dict[str, Any]:
|
|
||||||
"""Create entity linkage information"""
|
|
||||||
linkable_entities = []
|
|
||||||
for entity in unique_entities:
|
|
||||||
entity_type = entity.get('type', '')
|
|
||||||
if any(keyword in entity_type for keyword in ['公司', 'Company', '人名', '英文人名']):
|
|
||||||
linkable_entities.append(entity)
|
|
||||||
|
|
||||||
if not linkable_entities:
|
|
||||||
logger.info("No linkable entities found")
|
|
||||||
return {"entity_groups": []}
|
|
||||||
|
|
||||||
entities_text = "\n".join([
|
|
||||||
f"- {entity['text']} (类型: {entity['type']})"
|
|
||||||
for entity in linkable_entities
|
|
||||||
])
|
|
||||||
|
|
||||||
try:
|
|
||||||
formatted_prompt = get_entity_linkage_prompt(entities_text)
|
|
||||||
logger.info(f"Calling ollama to generate entity linkage")
|
|
||||||
|
|
||||||
# Use the new enhanced generate method with validation
|
|
||||||
linkage = self.ollama_client.generate_with_validation(
|
|
||||||
prompt=formatted_prompt,
|
|
||||||
response_type='entity_linkage',
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Parsed entity linkage: {linkage}")
|
|
||||||
|
|
||||||
if linkage and self._validate_linkage_format(linkage):
|
|
||||||
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
|
|
||||||
return linkage
|
|
||||||
else:
|
|
||||||
logger.warning(f"Invalid entity linkage format received")
|
|
||||||
return {"entity_groups": []}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating entity linkage: {e}")
|
|
||||||
return {"entity_groups": []}
|
|
||||||
|
|
||||||
def process(self, chunks: List[str]) -> Dict[str, str]:
|
|
||||||
"""Main processing method"""
|
|
||||||
chunk_mappings = []
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
|
|
||||||
chunk_mapping = self.build_mapping(chunk)
|
|
||||||
logger.info(f"Chunk mapping: {chunk_mapping}")
|
|
||||||
chunk_mappings.extend(chunk_mapping)
|
|
||||||
|
|
||||||
logger.info(f"Final chunk mappings: {chunk_mappings}")
|
|
||||||
|
|
||||||
unique_entities = self._merge_entity_mappings(chunk_mappings)
|
|
||||||
logger.info(f"Unique entities: {unique_entities}")
|
|
||||||
|
|
||||||
entity_linkage = self._create_entity_linkage(unique_entities)
|
|
||||||
logger.info(f"Entity linkage: {entity_linkage}")
|
|
||||||
|
|
||||||
combined_mapping = self._generate_masked_mapping(unique_entities, entity_linkage)
|
|
||||||
logger.info(f"Combined mapping: {combined_mapping}")
|
|
||||||
|
|
||||||
return combined_mapping
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from .txt_processor import TxtDocumentProcessor
|
from .txt_processor import TxtDocumentProcessor
|
||||||
from .docx_processor import DocxDocumentProcessor
|
# from .docx_processor import DocxDocumentProcessor
|
||||||
from .pdf_processor import PdfDocumentProcessor
|
from .pdf_processor import PdfDocumentProcessor
|
||||||
from .md_processor import MarkdownDocumentProcessor
|
from .md_processor import MarkdownDocumentProcessor
|
||||||
|
|
||||||
__all__ = ['TxtDocumentProcessor', 'DocxDocumentProcessor', 'PdfDocumentProcessor', 'MarkdownDocumentProcessor']
|
# __all__ = ['TxtDocumentProcessor', 'DocxDocumentProcessor', 'PdfDocumentProcessor', 'MarkdownDocumentProcessor']
|
||||||
|
__all__ = ['TxtDocumentProcessor', 'PdfDocumentProcessor', 'MarkdownDocumentProcessor']
|
||||||
|
|
@ -1,222 +0,0 @@
|
||||||
import os
|
|
||||||
import requests
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
from ...document_handlers.document_processor import DocumentProcessor
|
|
||||||
from ...services.ollama_client import OllamaClient
|
|
||||||
from ...config import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class DocxDocumentProcessor(DocumentProcessor):
|
|
||||||
def __init__(self, input_path: str, output_path: str):
|
|
||||||
super().__init__() # Call parent class's __init__
|
|
||||||
self.input_path = input_path
|
|
||||||
self.output_path = output_path
|
|
||||||
self.output_dir = os.path.dirname(output_path)
|
|
||||||
self.name_without_suff = os.path.splitext(os.path.basename(input_path))[0]
|
|
||||||
|
|
||||||
# Setup work directory for temporary files
|
|
||||||
self.work_dir = os.path.join(
|
|
||||||
os.path.dirname(output_path),
|
|
||||||
".work",
|
|
||||||
os.path.splitext(os.path.basename(input_path))[0]
|
|
||||||
)
|
|
||||||
os.makedirs(self.work_dir, exist_ok=True)
|
|
||||||
|
|
||||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
|
||||||
|
|
||||||
# MagicDoc API configuration (replacing Mineru)
|
|
||||||
self.magicdoc_base_url = getattr(settings, 'MAGICDOC_API_URL', 'http://magicdoc-api:8000')
|
|
||||||
self.magicdoc_timeout = getattr(settings, 'MAGICDOC_TIMEOUT', 300) # 5 minutes timeout
|
|
||||||
# MagicDoc uses simpler parameters, but we keep compatibility with existing interface
|
|
||||||
self.magicdoc_lang_list = getattr(settings, 'MAGICDOC_LANG_LIST', 'ch')
|
|
||||||
self.magicdoc_backend = getattr(settings, 'MAGICDOC_BACKEND', 'pipeline')
|
|
||||||
self.magicdoc_parse_method = getattr(settings, 'MAGICDOC_PARSE_METHOD', 'auto')
|
|
||||||
self.magicdoc_formula_enable = getattr(settings, 'MAGICDOC_FORMULA_ENABLE', True)
|
|
||||||
self.magicdoc_table_enable = getattr(settings, 'MAGICDOC_TABLE_ENABLE', True)
|
|
||||||
|
|
||||||
def _call_magicdoc_api(self, file_path: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Call MagicDoc API to convert DOCX to markdown
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the DOCX file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
API response as dictionary or None if failed
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
url = f"{self.magicdoc_base_url}/file_parse"
|
|
||||||
|
|
||||||
with open(file_path, 'rb') as file:
|
|
||||||
files = {'files': (os.path.basename(file_path), file, 'application/vnd.openxmlformats-officedocument.wordprocessingml.document')}
|
|
||||||
|
|
||||||
# Prepare form data according to MagicDoc API specification (compatible with Mineru)
|
|
||||||
data = {
|
|
||||||
'output_dir': './output',
|
|
||||||
'lang_list': self.magicdoc_lang_list,
|
|
||||||
'backend': self.magicdoc_backend,
|
|
||||||
'parse_method': self.magicdoc_parse_method,
|
|
||||||
'formula_enable': self.magicdoc_formula_enable,
|
|
||||||
'table_enable': self.magicdoc_table_enable,
|
|
||||||
'return_md': True,
|
|
||||||
'return_middle_json': False,
|
|
||||||
'return_model_output': False,
|
|
||||||
'return_content_list': False,
|
|
||||||
'return_images': False,
|
|
||||||
'start_page_id': 0,
|
|
||||||
'end_page_id': 99999
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Calling MagicDoc API for DOCX processing at {url}")
|
|
||||||
response = requests.post(
|
|
||||||
url,
|
|
||||||
files=files,
|
|
||||||
data=data,
|
|
||||||
timeout=self.magicdoc_timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
|
||||||
logger.info("Successfully received response from MagicDoc API for DOCX")
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
error_msg = f"MagicDoc API returned status code {response.status_code}: {response.text}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
# For 400 errors, include more specific information
|
|
||||||
if response.status_code == 400:
|
|
||||||
try:
|
|
||||||
error_data = response.json()
|
|
||||||
if 'error' in error_data:
|
|
||||||
error_msg = f"MagicDoc API error: {error_data['error']}"
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
raise Exception(error_msg)
|
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
|
||||||
error_msg = f"MagicDoc API request timed out after {self.magicdoc_timeout} seconds"
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise Exception(error_msg)
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
error_msg = f"Error calling MagicDoc API for DOCX: {str(e)}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise Exception(error_msg)
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Unexpected error calling MagicDoc API for DOCX: {str(e)}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise Exception(error_msg)
|
|
||||||
|
|
||||||
def _extract_markdown_from_response(self, response: Dict[str, Any]) -> str:
|
|
||||||
"""
|
|
||||||
Extract markdown content from MagicDoc API response
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: MagicDoc API response dictionary
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Extracted markdown content as string
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(f"MagicDoc API response structure for DOCX: {response}")
|
|
||||||
|
|
||||||
# Try different possible response formats based on MagicDoc API
|
|
||||||
if 'markdown' in response:
|
|
||||||
return response['markdown']
|
|
||||||
elif 'md' in response:
|
|
||||||
return response['md']
|
|
||||||
elif 'content' in response:
|
|
||||||
return response['content']
|
|
||||||
elif 'text' in response:
|
|
||||||
return response['text']
|
|
||||||
elif 'result' in response and isinstance(response['result'], dict):
|
|
||||||
result = response['result']
|
|
||||||
if 'markdown' in result:
|
|
||||||
return result['markdown']
|
|
||||||
elif 'md' in result:
|
|
||||||
return result['md']
|
|
||||||
elif 'content' in result:
|
|
||||||
return result['content']
|
|
||||||
elif 'text' in result:
|
|
||||||
return result['text']
|
|
||||||
elif 'data' in response and isinstance(response['data'], dict):
|
|
||||||
data = response['data']
|
|
||||||
if 'markdown' in data:
|
|
||||||
return data['markdown']
|
|
||||||
elif 'md' in data:
|
|
||||||
return data['md']
|
|
||||||
elif 'content' in data:
|
|
||||||
return data['content']
|
|
||||||
elif 'text' in data:
|
|
||||||
return data['text']
|
|
||||||
elif isinstance(response, list) and len(response) > 0:
|
|
||||||
# If response is a list, try to extract from first item
|
|
||||||
first_item = response[0]
|
|
||||||
if isinstance(first_item, dict):
|
|
||||||
return self._extract_markdown_from_response(first_item)
|
|
||||||
elif isinstance(first_item, str):
|
|
||||||
return first_item
|
|
||||||
else:
|
|
||||||
# If no standard format found, try to extract from the response structure
|
|
||||||
logger.warning("Could not find standard markdown field in MagicDoc response for DOCX")
|
|
||||||
|
|
||||||
# Return the response as string if it's simple, or empty string
|
|
||||||
if isinstance(response, str):
|
|
||||||
return response
|
|
||||||
elif isinstance(response, dict):
|
|
||||||
# Try to find any text-like content
|
|
||||||
for key, value in response.items():
|
|
||||||
if isinstance(value, str) and len(value) > 100: # Likely content
|
|
||||||
return value
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
# Recursively search in nested dictionaries
|
|
||||||
nested_content = self._extract_markdown_from_response(value)
|
|
||||||
if nested_content:
|
|
||||||
return nested_content
|
|
||||||
|
|
||||||
return ""
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error extracting markdown from MagicDoc response for DOCX: {str(e)}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def read_content(self) -> str:
|
|
||||||
logger.info("Starting DOCX content processing with MagicDoc API")
|
|
||||||
|
|
||||||
# Call MagicDoc API to convert DOCX to markdown
|
|
||||||
# This will raise an exception if the API call fails
|
|
||||||
magicdoc_response = self._call_magicdoc_api(self.input_path)
|
|
||||||
|
|
||||||
# Extract markdown content from the response
|
|
||||||
markdown_content = self._extract_markdown_from_response(magicdoc_response)
|
|
||||||
|
|
||||||
logger.info(f"MagicDoc API response: {markdown_content}")
|
|
||||||
|
|
||||||
if not markdown_content:
|
|
||||||
raise Exception("No markdown content found in MagicDoc API response for DOCX")
|
|
||||||
|
|
||||||
logger.info(f"Successfully extracted {len(markdown_content)} characters of markdown content from DOCX")
|
|
||||||
|
|
||||||
# Save the raw markdown content to work directory for reference
|
|
||||||
md_output_path = os.path.join(self.work_dir, f"{self.name_without_suff}.md")
|
|
||||||
with open(md_output_path, 'w', encoding='utf-8') as file:
|
|
||||||
file.write(markdown_content)
|
|
||||||
|
|
||||||
logger.info(f"Saved raw markdown content from DOCX to {md_output_path}")
|
|
||||||
|
|
||||||
return markdown_content
|
|
||||||
|
|
||||||
def save_content(self, content: str) -> None:
|
|
||||||
# Ensure output path has .md extension
|
|
||||||
output_dir = os.path.dirname(self.output_path)
|
|
||||||
base_name = os.path.splitext(os.path.basename(self.output_path))[0]
|
|
||||||
md_output_path = os.path.join(output_dir, f"{base_name}.md")
|
|
||||||
|
|
||||||
logger.info(f"Saving masked DOCX content to: {md_output_path}")
|
|
||||||
try:
|
|
||||||
with open(md_output_path, 'w', encoding='utf-8') as file:
|
|
||||||
file.write(content)
|
|
||||||
logger.info(f"Successfully saved masked DOCX content to {md_output_path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error saving masked DOCX content: {e}")
|
|
||||||
raise
|
|
||||||
|
|
@ -0,0 +1,77 @@
|
||||||
|
import os
|
||||||
|
import docx
|
||||||
|
from ...document_handlers.document_processor import DocumentProcessor
|
||||||
|
from magic_pdf.data.data_reader_writer import FileBasedDataWriter
|
||||||
|
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
||||||
|
from magic_pdf.data.read_api import read_local_office
|
||||||
|
import logging
|
||||||
|
from ...services.ollama_client import OllamaClient
|
||||||
|
from ...config import settings
|
||||||
|
from ...prompts.masking_prompts import get_masking_mapping_prompt
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DocxDocumentProcessor(DocumentProcessor):
|
||||||
|
def __init__(self, input_path: str, output_path: str):
|
||||||
|
super().__init__() # Call parent class's __init__
|
||||||
|
self.input_path = input_path
|
||||||
|
self.output_path = output_path
|
||||||
|
self.output_dir = os.path.dirname(output_path)
|
||||||
|
self.name_without_suff = os.path.splitext(os.path.basename(input_path))[0]
|
||||||
|
|
||||||
|
# Setup output directories
|
||||||
|
self.local_image_dir = os.path.join(self.output_dir, "images")
|
||||||
|
self.image_dir = os.path.basename(self.local_image_dir)
|
||||||
|
os.makedirs(self.local_image_dir, exist_ok=True)
|
||||||
|
|
||||||
|
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||||
|
|
||||||
|
def read_content(self) -> str:
|
||||||
|
try:
|
||||||
|
# Initialize writers
|
||||||
|
image_writer = FileBasedDataWriter(self.local_image_dir)
|
||||||
|
md_writer = FileBasedDataWriter(self.output_dir)
|
||||||
|
|
||||||
|
# Create Dataset Instance and process
|
||||||
|
ds = read_local_office(self.input_path)[0]
|
||||||
|
pipe_result = ds.apply(doc_analyze, ocr=True).pipe_txt_mode(image_writer)
|
||||||
|
|
||||||
|
# Generate markdown
|
||||||
|
md_content = pipe_result.get_markdown(self.image_dir)
|
||||||
|
pipe_result.dump_md(md_writer, f"{self.name_without_suff}.md", self.image_dir)
|
||||||
|
|
||||||
|
return md_content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error converting DOCX to MD: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# def process_content(self, content: str) -> str:
|
||||||
|
# logger.info("Processing DOCX content")
|
||||||
|
|
||||||
|
# # Split content into sentences and apply masking
|
||||||
|
# sentences = content.split("。")
|
||||||
|
# final_md = ""
|
||||||
|
# for sentence in sentences:
|
||||||
|
# if sentence.strip(): # Only process non-empty sentences
|
||||||
|
# formatted_prompt = get_masking_mapping_prompt(sentence)
|
||||||
|
# logger.info("Calling ollama to generate response, prompt: %s", formatted_prompt)
|
||||||
|
# response = self.ollama_client.generate(formatted_prompt)
|
||||||
|
# logger.info(f"Response generated: {response}")
|
||||||
|
# final_md += response + "。"
|
||||||
|
|
||||||
|
# return final_md
|
||||||
|
|
||||||
|
def save_content(self, content: str) -> None:
|
||||||
|
# Ensure output path has .md extension
|
||||||
|
output_dir = os.path.dirname(self.output_path)
|
||||||
|
base_name = os.path.splitext(os.path.basename(self.output_path))[0]
|
||||||
|
md_output_path = os.path.join(output_dir, f"{base_name}.md")
|
||||||
|
|
||||||
|
logger.info(f"Saving masked content to: {md_output_path}")
|
||||||
|
try:
|
||||||
|
with open(md_output_path, 'w', encoding='utf-8') as file:
|
||||||
|
file.write(content)
|
||||||
|
logger.info(f"Successfully saved content to {md_output_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving content: {e}")
|
||||||
|
raise
|
||||||
|
|
@ -81,30 +81,18 @@ class PdfDocumentProcessor(DocumentProcessor):
|
||||||
logger.info("Successfully received response from Mineru API")
|
logger.info("Successfully received response from Mineru API")
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
error_msg = f"Mineru API returned status code {response.status_code}: {response.text}"
|
logger.error(f"Mineru API returned status code {response.status_code}: {response.text}")
|
||||||
logger.error(error_msg)
|
return None
|
||||||
# For 400 errors, include more specific information
|
|
||||||
if response.status_code == 400:
|
|
||||||
try:
|
|
||||||
error_data = response.json()
|
|
||||||
if 'error' in error_data:
|
|
||||||
error_msg = f"Mineru API error: {error_data['error']}"
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
raise Exception(error_msg)
|
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
except requests.exceptions.Timeout:
|
||||||
error_msg = f"Mineru API request timed out after {self.mineru_timeout} seconds"
|
logger.error(f"Mineru API request timed out after {self.mineru_timeout} seconds")
|
||||||
logger.error(error_msg)
|
return None
|
||||||
raise Exception(error_msg)
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
error_msg = f"Error calling Mineru API: {str(e)}"
|
logger.error(f"Error calling Mineru API: {str(e)}")
|
||||||
logger.error(error_msg)
|
return None
|
||||||
raise Exception(error_msg)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Unexpected error calling Mineru API: {str(e)}"
|
logger.error(f"Unexpected error calling Mineru API: {str(e)}")
|
||||||
logger.error(error_msg)
|
return None
|
||||||
raise Exception(error_msg)
|
|
||||||
|
|
||||||
def _extract_markdown_from_response(self, response: Dict[str, Any]) -> str:
|
def _extract_markdown_from_response(self, response: Dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -183,9 +171,11 @@ class PdfDocumentProcessor(DocumentProcessor):
|
||||||
logger.info("Starting PDF content processing with Mineru API")
|
logger.info("Starting PDF content processing with Mineru API")
|
||||||
|
|
||||||
# Call Mineru API to convert PDF to markdown
|
# Call Mineru API to convert PDF to markdown
|
||||||
# This will raise an exception if the API call fails
|
|
||||||
mineru_response = self._call_mineru_api(self.input_path)
|
mineru_response = self._call_mineru_api(self.input_path)
|
||||||
|
|
||||||
|
if not mineru_response:
|
||||||
|
raise Exception("Failed to get response from Mineru API")
|
||||||
|
|
||||||
# Extract markdown content from the response
|
# Extract markdown content from the response
|
||||||
markdown_content = self._extract_markdown_from_response(mineru_response)
|
markdown_content = self._extract_markdown_from_response(mineru_response)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,12 +16,3 @@ def extract_social_credit_code_entities(chunk: str) -> dict:
|
||||||
for match in re.findall(credit_pattern, chunk):
|
for match in re.findall(credit_pattern, chunk):
|
||||||
entities.append({"text": match, "type": "统一社会信用代码"})
|
entities.append({"text": match, "type": "统一社会信用代码"})
|
||||||
return {"entities": entities} if entities else {}
|
return {"entities": entities} if entities else {}
|
||||||
|
|
||||||
def extract_case_number_entities(chunk: str) -> dict:
|
|
||||||
"""Extract case numbers and return in entity mapping format."""
|
|
||||||
# Pattern for Chinese case numbers: (2022)京 03 民终 3852 号, (2020)京0105 民初69754 号
|
|
||||||
case_pattern = r'[((]\d{4}[))][^\d]*\d+[^\d]*\d+[^\d]*号'
|
|
||||||
entities = []
|
|
||||||
for match in re.findall(case_pattern, chunk):
|
|
||||||
entities.append({"text": match, "type": "案号"})
|
|
||||||
return {"entities": entities} if entities else {}
|
|
||||||
|
|
@ -112,46 +112,6 @@ def get_ner_address_prompt(text: str) -> str:
|
||||||
return prompt.format(text=text)
|
return prompt.format(text=text)
|
||||||
|
|
||||||
|
|
||||||
def get_address_masking_prompt(address: str) -> str:
|
|
||||||
"""
|
|
||||||
Returns a prompt that generates a masked version of an address following specific rules.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
address (str): The original address to be masked
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The formatted prompt that will generate a masked address
|
|
||||||
"""
|
|
||||||
prompt = textwrap.dedent("""
|
|
||||||
你是一个专业的地址脱敏助手。请对给定的地址进行脱敏处理,遵循以下规则:
|
|
||||||
|
|
||||||
脱敏规则:
|
|
||||||
1. 保留区级以上地址(省、市、区、县)
|
|
||||||
2. 路名以大写首字母替代,例如:恒丰路 -> HF路
|
|
||||||
3. 门牌数字以**代替,例如:66号 -> **号
|
|
||||||
4. 大厦名、小区名以大写首字母替代,例如:白云大厦 -> BY大厦
|
|
||||||
5. 房间号以****代替,例如:1607室 -> ****室
|
|
||||||
|
|
||||||
示例:
|
|
||||||
- 输入:上海市静安区恒丰路66号白云大厦1607室
|
|
||||||
- 输出:上海市静安区HF路**号BY大厦****室
|
|
||||||
|
|
||||||
- 输入:北京市海淀区北小马厂6号1号楼华天大厦1306室
|
|
||||||
- 输出:北京市海淀区北小马厂**号**号楼HT大厦****室
|
|
||||||
|
|
||||||
请严格按照JSON格式输出结果:
|
|
||||||
|
|
||||||
{{
|
|
||||||
"masked_address": "脱敏后的地址"
|
|
||||||
}}
|
|
||||||
|
|
||||||
原始地址:{address}
|
|
||||||
|
|
||||||
请严格按照JSON格式输出结果。
|
|
||||||
""")
|
|
||||||
return prompt.format(address=address)
|
|
||||||
|
|
||||||
|
|
||||||
def get_ner_project_prompt(text: str) -> str:
|
def get_ner_project_prompt(text: str) -> str:
|
||||||
"""
|
"""
|
||||||
Returns a prompt that generates a mapping of original project names to their masked versions.
|
Returns a prompt that generates a mapping of original project names to their masked versions.
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ class DocumentService:
|
||||||
processor = DocumentProcessorFactory.create_processor(input_path, output_path)
|
processor = DocumentProcessorFactory.create_processor(input_path, output_path)
|
||||||
if not processor:
|
if not processor:
|
||||||
logger.error(f"Unsupported file format: {input_path}")
|
logger.error(f"Unsupported file format: {input_path}")
|
||||||
raise Exception(f"Unsupported file format: {input_path}")
|
return False
|
||||||
|
|
||||||
# Read content
|
# Read content
|
||||||
content = processor.read_content()
|
content = processor.read_content()
|
||||||
|
|
@ -27,5 +27,4 @@ class DocumentService:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing document {input_path}: {str(e)}")
|
logger.error(f"Error processing document {input_path}: {str(e)}")
|
||||||
# Re-raise the exception so the Celery task can handle it properly
|
return False
|
||||||
raise
|
|
||||||
|
|
@ -1,222 +1,72 @@
|
||||||
import requests
|
import requests
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Any, Optional, Callable, Union
|
from typing import Dict, Any
|
||||||
from ..utils.json_extractor import LLMJsonExtractor
|
|
||||||
from ..utils.llm_validator import LLMResponseValidator
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OllamaClient:
|
class OllamaClient:
|
||||||
def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3):
|
def __init__(self, model_name: str, base_url: str = "http://localhost:11434"):
|
||||||
"""Initialize Ollama client.
|
"""Initialize Ollama client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name (str): Name of the Ollama model to use
|
model_name (str): Name of the Ollama model to use
|
||||||
base_url (str): Ollama server base URL
|
host (str): Ollama server host address
|
||||||
max_retries (int): Maximum number of retries for failed requests
|
port (int): Ollama server port
|
||||||
"""
|
"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.max_retries = max_retries
|
|
||||||
self.headers = {"Content-Type": "application/json"}
|
self.headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
def generate(self,
|
def generate(self, prompt: str, strip_think: bool = True) -> str:
|
||||||
prompt: str,
|
"""Process a document using the Ollama API.
|
||||||
strip_think: bool = True,
|
|
||||||
validation_schema: Optional[Dict[str, Any]] = None,
|
|
||||||
response_type: Optional[str] = None,
|
|
||||||
return_parsed: bool = False) -> Union[str, Dict[str, Any]]:
|
|
||||||
"""Process a document using the Ollama API with optional validation and retry.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (str): The prompt to send to the model
|
document_text (str): The text content to process
|
||||||
strip_think (bool): Whether to strip thinking tags from response
|
|
||||||
validation_schema (Optional[Dict]): JSON schema for validation
|
|
||||||
response_type (Optional[str]): Type of response for validation ('entity_extraction', 'entity_linkage', etc.)
|
|
||||||
return_parsed (bool): Whether to return parsed JSON instead of raw string
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[str, Dict[str, Any]]: Response from the model (raw string or parsed JSON)
|
str: Processed text response from the model
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RequestException: If the API call fails after all retries
|
RequestException: If the API call fails
|
||||||
ValueError: If validation fails after all retries
|
|
||||||
"""
|
|
||||||
for attempt in range(self.max_retries):
|
|
||||||
try:
|
|
||||||
# Make the API call
|
|
||||||
raw_response = self._make_api_call(prompt, strip_think)
|
|
||||||
|
|
||||||
# If no validation required, return the response
|
|
||||||
if not validation_schema and not response_type and not return_parsed:
|
|
||||||
return raw_response
|
|
||||||
|
|
||||||
# Parse JSON if needed
|
|
||||||
if return_parsed or validation_schema or response_type:
|
|
||||||
parsed_response = LLMJsonExtractor.parse_raw_json_str(raw_response)
|
|
||||||
if not parsed_response:
|
|
||||||
logger.warning(f"Failed to parse JSON on attempt {attempt + 1}/{self.max_retries}")
|
|
||||||
if attempt < self.max_retries - 1:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise ValueError("Failed to parse JSON response after all retries")
|
|
||||||
|
|
||||||
# Validate if schema or response type provided
|
|
||||||
if validation_schema:
|
|
||||||
if not self._validate_with_schema(parsed_response, validation_schema):
|
|
||||||
logger.warning(f"Schema validation failed on attempt {attempt + 1}/{self.max_retries}")
|
|
||||||
if attempt < self.max_retries - 1:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise ValueError("Schema validation failed after all retries")
|
|
||||||
|
|
||||||
if response_type:
|
|
||||||
if not LLMResponseValidator.validate_response_by_type(parsed_response, response_type):
|
|
||||||
logger.warning(f"Response type validation failed on attempt {attempt + 1}/{self.max_retries}")
|
|
||||||
if attempt < self.max_retries - 1:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Response type validation failed after all retries")
|
|
||||||
|
|
||||||
# Return parsed response if requested
|
|
||||||
if return_parsed:
|
|
||||||
return parsed_response
|
|
||||||
else:
|
|
||||||
return raw_response
|
|
||||||
|
|
||||||
return raw_response
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
logger.error(f"API call failed on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
|
|
||||||
if attempt < self.max_retries - 1:
|
|
||||||
logger.info("Retrying...")
|
|
||||||
else:
|
|
||||||
logger.error("Max retries reached, raising exception")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
|
|
||||||
if attempt < self.max_retries - 1:
|
|
||||||
logger.info("Retrying...")
|
|
||||||
else:
|
|
||||||
logger.error("Max retries reached, raising exception")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# This should never be reached, but just in case
|
|
||||||
raise Exception("Unexpected error: max retries exceeded without proper exception handling")
|
|
||||||
|
|
||||||
def generate_with_validation(self,
|
|
||||||
prompt: str,
|
|
||||||
response_type: str,
|
|
||||||
strip_think: bool = True,
|
|
||||||
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
|
||||||
"""Generate response with automatic validation based on response type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): The prompt to send to the model
|
|
||||||
response_type (str): Type of response for validation
|
|
||||||
strip_think (bool): Whether to strip thinking tags from response
|
|
||||||
return_parsed (bool): Whether to return parsed JSON instead of raw string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[str, Dict[str, Any]]: Validated response from the model
|
|
||||||
"""
|
|
||||||
return self.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
strip_think=strip_think,
|
|
||||||
response_type=response_type,
|
|
||||||
return_parsed=return_parsed
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_with_schema(self,
|
|
||||||
prompt: str,
|
|
||||||
schema: Dict[str, Any],
|
|
||||||
strip_think: bool = True,
|
|
||||||
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
|
||||||
"""Generate response with custom schema validation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): The prompt to send to the model
|
|
||||||
schema (Dict): JSON schema for validation
|
|
||||||
strip_think (bool): Whether to strip thinking tags from response
|
|
||||||
return_parsed (bool): Whether to return parsed JSON instead of raw string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[str, Dict[str, Any]]: Validated response from the model
|
|
||||||
"""
|
|
||||||
return self.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
strip_think=strip_think,
|
|
||||||
validation_schema=schema,
|
|
||||||
return_parsed=return_parsed
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_api_call(self, prompt: str, strip_think: bool) -> str:
|
|
||||||
"""Make the actual API call to Ollama.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): The prompt to send
|
|
||||||
strip_think (bool): Whether to strip thinking tags
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Raw response from the API
|
|
||||||
"""
|
|
||||||
url = f"{self.base_url}/api/generate"
|
|
||||||
payload = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"prompt": prompt,
|
|
||||||
"stream": False
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(f"Sending request to Ollama API: {url}")
|
|
||||||
response = requests.post(url, json=payload, headers=self.headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
logger.debug(f"Received response from Ollama API: {result}")
|
|
||||||
|
|
||||||
if strip_think:
|
|
||||||
# Remove the "thinking" part from the response
|
|
||||||
# the response is expected to be <think>...</think>response_text
|
|
||||||
# Check if the response contains <think> tag
|
|
||||||
if "<think>" in result.get("response", ""):
|
|
||||||
# Split the response and take the part after </think>
|
|
||||||
response_parts = result["response"].split("</think>")
|
|
||||||
if len(response_parts) > 1:
|
|
||||||
# Return the part after </think>
|
|
||||||
return response_parts[1].strip()
|
|
||||||
else:
|
|
||||||
# If no closing tag, return the full response
|
|
||||||
return result.get("response", "").strip()
|
|
||||||
else:
|
|
||||||
# If no <think> tag, return the full response
|
|
||||||
return result.get("response", "").strip()
|
|
||||||
else:
|
|
||||||
# If strip_think is False, return the full response
|
|
||||||
return result.get("response", "")
|
|
||||||
|
|
||||||
def _validate_with_schema(self, response: Dict[str, Any], schema: Dict[str, Any]) -> bool:
|
|
||||||
"""Validate response against a JSON schema.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response (Dict): The parsed response to validate
|
|
||||||
schema (Dict): The JSON schema to validate against
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if valid, False otherwise
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from jsonschema import validate, ValidationError
|
url = f"{self.base_url}/api/generate"
|
||||||
validate(instance=response, schema=schema)
|
payload = {
|
||||||
logger.debug(f"Schema validation passed for response: {response}")
|
"model": self.model_name,
|
||||||
return True
|
"prompt": prompt,
|
||||||
except ValidationError as e:
|
"stream": False
|
||||||
logger.warning(f"Schema validation failed: {e}")
|
}
|
||||||
logger.warning(f"Response that failed validation: {response}")
|
|
||||||
return False
|
logger.debug(f"Sending request to Ollama API: {url}")
|
||||||
except ImportError:
|
response = requests.post(url, json=payload, headers=self.headers)
|
||||||
logger.error("jsonschema library not available for validation")
|
response.raise_for_status()
|
||||||
return False
|
|
||||||
|
result = response.json()
|
||||||
|
logger.debug(f"Received response from Ollama API: {result}")
|
||||||
|
if strip_think:
|
||||||
|
# Remove the "thinking" part from the response
|
||||||
|
# the response is expected to be <think>...</think>response_text
|
||||||
|
# Check if the response contains <think> tag
|
||||||
|
if "<think>" in result.get("response", ""):
|
||||||
|
# Split the response and take the part after </think>
|
||||||
|
response_parts = result["response"].split("</think>")
|
||||||
|
if len(response_parts) > 1:
|
||||||
|
# Return the part after </think>
|
||||||
|
return response_parts[1].strip()
|
||||||
|
else:
|
||||||
|
# If no closing tag, return the full response
|
||||||
|
return result.get("response", "").strip()
|
||||||
|
else:
|
||||||
|
# If no <think> tag, return the full response
|
||||||
|
return result.get("response", "").strip()
|
||||||
|
else:
|
||||||
|
# If strip_think is False, return the full response
|
||||||
|
return result.get("response", "")
|
||||||
|
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"Error calling Ollama API: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def get_model_info(self) -> Dict[str, Any]:
|
def get_model_info(self) -> Dict[str, Any]:
|
||||||
"""Get information about the current model.
|
"""Get information about the current model.
|
||||||
|
|
|
||||||
|
|
@ -77,66 +77,6 @@ class LLMResponseValidator:
|
||||||
"required": ["entities"]
|
"required": ["entities"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Schema for business name extraction responses
|
|
||||||
BUSINESS_NAME_EXTRACTION_SCHEMA = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"business_name": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The extracted business name (商号) from the company name"
|
|
||||||
},
|
|
||||||
"confidence": {
|
|
||||||
"type": "number",
|
|
||||||
"minimum": 0,
|
|
||||||
"maximum": 1,
|
|
||||||
"description": "Confidence level of the extraction (0-1)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["business_name"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Schema for address extraction responses
|
|
||||||
ADDRESS_EXTRACTION_SCHEMA = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"road_name": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The road name (路名) to be masked"
|
|
||||||
},
|
|
||||||
"house_number": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The house number (门牌号) to be masked"
|
|
||||||
},
|
|
||||||
"building_name": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The building name (大厦名) to be masked"
|
|
||||||
},
|
|
||||||
"community_name": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The community name (小区名) to be masked"
|
|
||||||
},
|
|
||||||
"confidence": {
|
|
||||||
"type": "number",
|
|
||||||
"minimum": 0,
|
|
||||||
"maximum": 1,
|
|
||||||
"description": "Confidence level of the extraction (0-1)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["road_name", "house_number", "building_name", "community_name"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Schema for address masking responses
|
|
||||||
ADDRESS_MASKING_SCHEMA = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"masked_address": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The masked address following the specified rules"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["masked_address"]
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_entity_extraction(cls, response: Dict[str, Any]) -> bool:
|
def validate_entity_extraction(cls, response: Dict[str, Any]) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
@ -202,66 +142,6 @@ class LLMResponseValidator:
|
||||||
logger.warning(f"Response that failed validation: {response}")
|
logger.warning(f"Response that failed validation: {response}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate_business_name_extraction(cls, response: Dict[str, Any]) -> bool:
|
|
||||||
"""
|
|
||||||
Validate business name extraction response from LLM.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: The parsed JSON response from LLM
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
validate(instance=response, schema=cls.BUSINESS_NAME_EXTRACTION_SCHEMA)
|
|
||||||
logger.debug(f"Business name extraction validation passed for response: {response}")
|
|
||||||
return True
|
|
||||||
except ValidationError as e:
|
|
||||||
logger.warning(f"Business name extraction validation failed: {e}")
|
|
||||||
logger.warning(f"Response that failed validation: {response}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate_address_extraction(cls, response: Dict[str, Any]) -> bool:
|
|
||||||
"""
|
|
||||||
Validate address extraction response from LLM.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: The parsed JSON response from LLM
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
validate(instance=response, schema=cls.ADDRESS_EXTRACTION_SCHEMA)
|
|
||||||
logger.debug(f"Address extraction validation passed for response: {response}")
|
|
||||||
return True
|
|
||||||
except ValidationError as e:
|
|
||||||
logger.warning(f"Address extraction validation failed: {e}")
|
|
||||||
logger.warning(f"Response that failed validation: {response}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate_address_masking(cls, response: Dict[str, Any]) -> bool:
|
|
||||||
"""
|
|
||||||
Validate address masking response from LLM.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: The parsed JSON response from LLM
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if valid, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
validate(instance=response, schema=cls.ADDRESS_MASKING_SCHEMA)
|
|
||||||
logger.debug(f"Address masking validation passed for response: {response}")
|
|
||||||
return True
|
|
||||||
except ValidationError as e:
|
|
||||||
logger.warning(f"Address masking validation failed: {e}")
|
|
||||||
logger.warning(f"Response that failed validation: {response}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_linkage_content(cls, response: Dict[str, Any]) -> bool:
|
def _validate_linkage_content(cls, response: Dict[str, Any]) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
@ -321,10 +201,7 @@ class LLMResponseValidator:
|
||||||
validators = {
|
validators = {
|
||||||
'entity_extraction': cls.validate_entity_extraction,
|
'entity_extraction': cls.validate_entity_extraction,
|
||||||
'entity_linkage': cls.validate_entity_linkage,
|
'entity_linkage': cls.validate_entity_linkage,
|
||||||
'regex_entity': cls.validate_regex_entity,
|
'regex_entity': cls.validate_regex_entity
|
||||||
'business_name_extraction': cls.validate_business_name_extraction,
|
|
||||||
'address_extraction': cls.validate_address_extraction,
|
|
||||||
'address_masking': cls.validate_address_masking
|
|
||||||
}
|
}
|
||||||
|
|
||||||
validator = validators.get(response_type)
|
validator = validators.get(response_type)
|
||||||
|
|
@ -355,12 +232,6 @@ class LLMResponseValidator:
|
||||||
return "Content validation failed for entity linkage"
|
return "Content validation failed for entity linkage"
|
||||||
elif response_type == 'regex_entity':
|
elif response_type == 'regex_entity':
|
||||||
validate(instance=response, schema=cls.REGEX_ENTITY_SCHEMA)
|
validate(instance=response, schema=cls.REGEX_ENTITY_SCHEMA)
|
||||||
elif response_type == 'business_name_extraction':
|
|
||||||
validate(instance=response, schema=cls.BUSINESS_NAME_EXTRACTION_SCHEMA)
|
|
||||||
elif response_type == 'address_extraction':
|
|
||||||
validate(instance=response, schema=cls.ADDRESS_EXTRACTION_SCHEMA)
|
|
||||||
elif response_type == 'address_masking':
|
|
||||||
validate(instance=response, schema=cls.ADDRESS_MASKING_SCHEMA)
|
|
||||||
else:
|
else:
|
||||||
return f"Unknown response type: {response_type}"
|
return f"Unknown response type: {response_type}"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,6 @@ def process_file(file_id: str):
|
||||||
output_path = str(settings.PROCESSED_FOLDER / output_filename)
|
output_path = str(settings.PROCESSED_FOLDER / output_filename)
|
||||||
|
|
||||||
# Process document with both input and output paths
|
# Process document with both input and output paths
|
||||||
# This will raise an exception if processing fails
|
|
||||||
process_service.process_document(file.original_path, output_path)
|
process_service.process_document(file.original_path, output_path)
|
||||||
|
|
||||||
# Update file record with processed path
|
# Update file record with processed path
|
||||||
|
|
@ -82,7 +81,6 @@ def process_file(file_id: str):
|
||||||
file.status = FileStatus.FAILED
|
file.status = FileStatus.FAILED
|
||||||
file.error_message = str(e)
|
file.error_message = str(e)
|
||||||
db.commit()
|
db.commit()
|
||||||
# Re-raise the exception to ensure Celery marks the task as failed
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
|
|
@ -1,33 +0,0 @@
|
||||||
import pytest
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add the backend directory to Python path for imports
|
|
||||||
backend_dir = Path(__file__).parent
|
|
||||||
sys.path.insert(0, str(backend_dir))
|
|
||||||
|
|
||||||
# Also add the current directory to ensure imports work
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
sys.path.insert(0, str(current_dir))
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_data():
|
|
||||||
"""Sample data fixture for testing"""
|
|
||||||
return {
|
|
||||||
"name": "test",
|
|
||||||
"value": 42,
|
|
||||||
"items": [1, 2, 3]
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_files_dir():
|
|
||||||
"""Fixture to get the test files directory"""
|
|
||||||
return Path(__file__).parent / "tests"
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_test_environment():
|
|
||||||
"""Setup test environment before each test"""
|
|
||||||
# Add any test environment setup here
|
|
||||||
yield
|
|
||||||
# Add any cleanup here
|
|
||||||
|
|
@ -7,6 +7,7 @@ services:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ./storage:/app/storage
|
- ./storage:/app/storage
|
||||||
|
- ./legal_doc_masker.db:/app/legal_doc_masker.db
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
environment:
|
||||||
|
|
@ -20,6 +21,7 @@ services:
|
||||||
command: celery -A app.services.file_service worker --loglevel=info
|
command: celery -A app.services.file_service worker --loglevel=info
|
||||||
volumes:
|
volumes:
|
||||||
- ./storage:/app/storage
|
- ./storage:/app/storage
|
||||||
|
- ./legal_doc_masker.db:/app/legal_doc_masker.db
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
environment:
|
||||||
|
|
|
||||||
|
|
@ -1,239 +0,0 @@
|
||||||
# 地址脱敏改进文档
|
|
||||||
|
|
||||||
## 问题描述
|
|
||||||
|
|
||||||
原始的地址脱敏方法使用正则表达式和拼音转换来手动处理地址组件,存在以下问题:
|
|
||||||
- 需要手动维护复杂的正则表达式模式
|
|
||||||
- 拼音转换可能失败,需要回退处理
|
|
||||||
- 难以处理复杂的地址格式
|
|
||||||
- 代码维护成本高
|
|
||||||
|
|
||||||
## 解决方案
|
|
||||||
|
|
||||||
### 1. LLM 直接生成脱敏地址
|
|
||||||
|
|
||||||
使用 LLM 直接生成脱敏后的地址,遵循指定的脱敏规则:
|
|
||||||
|
|
||||||
- **保留区级以上地址**:省、市、区、县
|
|
||||||
- **路名缩写**:以大写首字母替代,如:恒丰路 -> HF路
|
|
||||||
- **门牌号脱敏**:数字以**代替,如:66号 -> **号
|
|
||||||
- **大厦名缩写**:以大写首字母替代,如:白云大厦 -> BY大厦
|
|
||||||
- **房间号脱敏**:以****代替,如:1607室 -> ****室
|
|
||||||
|
|
||||||
### 2. 实现架构
|
|
||||||
|
|
||||||
#### 核心组件
|
|
||||||
|
|
||||||
1. **`get_address_masking_prompt()`** - 生成地址脱敏 prompt
|
|
||||||
2. **`_mask_address()`** - 主要的脱敏方法,使用 LLM
|
|
||||||
3. **`_mask_address_fallback()`** - 回退方法,使用原有逻辑
|
|
||||||
|
|
||||||
#### 调用流程
|
|
||||||
|
|
||||||
```
|
|
||||||
输入地址
|
|
||||||
↓
|
|
||||||
生成脱敏 prompt
|
|
||||||
↓
|
|
||||||
调用 Ollama LLM
|
|
||||||
↓
|
|
||||||
解析 JSON 响应
|
|
||||||
↓
|
|
||||||
返回脱敏地址
|
|
||||||
↓
|
|
||||||
失败时使用回退方法
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Prompt 设计
|
|
||||||
|
|
||||||
#### 脱敏规则说明
|
|
||||||
```
|
|
||||||
脱敏规则:
|
|
||||||
1. 保留区级以上地址(省、市、区、县)
|
|
||||||
2. 路名以大写首字母替代,例如:恒丰路 -> HF路
|
|
||||||
3. 门牌数字以**代替,例如:66号 -> **号
|
|
||||||
4. 大厦名、小区名以大写首字母替代,例如:白云大厦 -> BY大厦
|
|
||||||
5. 房间号以****代替,例如:1607室 -> ****室
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 示例展示
|
|
||||||
```
|
|
||||||
示例:
|
|
||||||
- 输入:上海市静安区恒丰路66号白云大厦1607室
|
|
||||||
- 输出:上海市静安区HF路**号BY大厦****室
|
|
||||||
|
|
||||||
- 输入:北京市海淀区北小马厂6号1号楼华天大厦1306室
|
|
||||||
- 输出:北京市海淀区北小马厂**号**号楼HT大厦****室
|
|
||||||
```
|
|
||||||
|
|
||||||
#### JSON 输出格式
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"masked_address": "脱敏后的地址"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 实现细节
|
|
||||||
|
|
||||||
### 1. 主要方法
|
|
||||||
|
|
||||||
#### `_mask_address(address: str) -> str`
|
|
||||||
```python
|
|
||||||
def _mask_address(self, address: str) -> str:
|
|
||||||
"""
|
|
||||||
对地址进行脱敏处理,使用LLM直接生成脱敏地址
|
|
||||||
"""
|
|
||||||
if not address:
|
|
||||||
return address
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 使用LLM生成脱敏地址
|
|
||||||
prompt = get_address_masking_prompt(address)
|
|
||||||
response = self.ollama_client.generate_with_validation(
|
|
||||||
prompt=prompt,
|
|
||||||
response_type='address_masking',
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if response and isinstance(response, dict) and "masked_address" in response:
|
|
||||||
return response["masked_address"]
|
|
||||||
else:
|
|
||||||
return self._mask_address_fallback(address)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error masking address with LLM: {e}")
|
|
||||||
return self._mask_address_fallback(address)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### `_mask_address_fallback(address: str) -> str`
|
|
||||||
```python
|
|
||||||
def _mask_address_fallback(self, address: str) -> str:
|
|
||||||
"""
|
|
||||||
地址脱敏的回退方法,使用原有的正则表达式和拼音转换逻辑
|
|
||||||
"""
|
|
||||||
# 原有的脱敏逻辑作为回退
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Ollama 调用模式
|
|
||||||
|
|
||||||
遵循现有的 Ollama 客户端调用模式,使用验证:
|
|
||||||
|
|
||||||
```python
|
|
||||||
response = self.ollama_client.generate_with_validation(
|
|
||||||
prompt=prompt,
|
|
||||||
response_type='address_masking',
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
- `response_type='address_masking'`:指定响应类型进行验证
|
|
||||||
- `return_parsed=True`:返回解析后的 JSON
|
|
||||||
- 自动验证响应格式是否符合 schema
|
|
||||||
|
|
||||||
## 测试结果
|
|
||||||
|
|
||||||
### 测试案例
|
|
||||||
|
|
||||||
| 原始地址 | 期望脱敏结果 |
|
|
||||||
|----------|-------------|
|
|
||||||
| 上海市静安区恒丰路66号白云大厦1607室 | 上海市静安区HF路**号BY大厦****室 |
|
|
||||||
| 北京市海淀区北小马厂6号1号楼华天大厦1306室 | 北京市海淀区北小马厂**号**号楼HT大厦****室 |
|
|
||||||
| 天津市津南区双港镇工业园区优谷产业园5号楼-1505 | 天津市津南区双港镇工业园区优谷产业园**号楼-**** |
|
|
||||||
|
|
||||||
### Prompt 验证
|
|
||||||
|
|
||||||
- ✓ 包含脱敏规则说明
|
|
||||||
- ✓ 提供具体示例
|
|
||||||
- ✓ 指定 JSON 输出格式
|
|
||||||
- ✓ 包含原始地址
|
|
||||||
- ✓ 指定输出字段名
|
|
||||||
|
|
||||||
## 优势
|
|
||||||
|
|
||||||
### 1. 智能化处理
|
|
||||||
- LLM 能够理解复杂的地址格式
|
|
||||||
- 自动处理各种地址变体
|
|
||||||
- 减少手动维护成本
|
|
||||||
|
|
||||||
### 2. 可靠性
|
|
||||||
- 回退机制确保服务可用性
|
|
||||||
- 错误处理和日志记录
|
|
||||||
- 保持向后兼容性
|
|
||||||
|
|
||||||
### 3. 可扩展性
|
|
||||||
- 易于添加新的脱敏规则
|
|
||||||
- 支持多语言地址处理
|
|
||||||
- 可配置的脱敏策略
|
|
||||||
|
|
||||||
### 4. 一致性
|
|
||||||
- 统一的脱敏标准
|
|
||||||
- 可预测的输出格式
|
|
||||||
- 便于测试和验证
|
|
||||||
|
|
||||||
## 性能影响
|
|
||||||
|
|
||||||
### 1. 延迟
|
|
||||||
- LLM 调用增加处理时间
|
|
||||||
- 网络延迟影响响应速度
|
|
||||||
- 回退机制提供快速响应
|
|
||||||
|
|
||||||
### 2. 成本
|
|
||||||
- LLM API 调用成本
|
|
||||||
- 需要稳定的网络连接
|
|
||||||
- 回退机制降低依赖风险
|
|
||||||
|
|
||||||
### 3. 准确性
|
|
||||||
- 显著提高脱敏准确性
|
|
||||||
- 减少人工错误
|
|
||||||
- 更好的地址理解能力
|
|
||||||
|
|
||||||
## 配置参数
|
|
||||||
|
|
||||||
- `response_type`: 响应类型,用于验证 (默认: 'address_masking')
|
|
||||||
- `return_parsed`: 是否返回解析后的 JSON (默认: True)
|
|
||||||
- `max_retries`: 最大重试次数 (默认: 3)
|
|
||||||
|
|
||||||
## 验证 Schema
|
|
||||||
|
|
||||||
地址脱敏响应必须符合以下 JSON schema:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"masked_address": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The masked address following the specified rules"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["masked_address"]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 使用示例
|
|
||||||
|
|
||||||
```python
|
|
||||||
from app.core.document_handlers.ner_processor import NerProcessor
|
|
||||||
|
|
||||||
processor = NerProcessor()
|
|
||||||
original_address = "上海市静安区恒丰路66号白云大厦1607室"
|
|
||||||
masked_address = processor._mask_address(original_address)
|
|
||||||
print(f"Original: {original_address}")
|
|
||||||
print(f"Masked: {masked_address}")
|
|
||||||
```
|
|
||||||
|
|
||||||
## 未来改进方向
|
|
||||||
|
|
||||||
1. **缓存机制**:缓存常见地址的脱敏结果
|
|
||||||
2. **批量处理**:支持批量地址脱敏
|
|
||||||
3. **自定义规则**:支持用户自定义脱敏规则
|
|
||||||
4. **多语言支持**:扩展到其他语言的地址处理
|
|
||||||
5. **性能优化**:异步处理和并发调用
|
|
||||||
|
|
||||||
## 相关文件
|
|
||||||
|
|
||||||
- `backend/app/core/document_handlers/ner_processor.py` - 主要实现
|
|
||||||
- `backend/app/core/prompts/masking_prompts.py` - Prompt 函数
|
|
||||||
- `backend/app/core/services/ollama_client.py` - Ollama 客户端
|
|
||||||
- `backend/app/core/utils/llm_validator.py` - 验证 schema 和验证方法
|
|
||||||
- `backend/test_validation_schema.py` - 验证 schema 测试
|
|
||||||
|
|
@ -1,255 +0,0 @@
|
||||||
# OllamaClient Enhancement Summary
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
The `OllamaClient` has been successfully enhanced to support validation and retry mechanisms while maintaining full backward compatibility.
|
|
||||||
|
|
||||||
## Key Enhancements
|
|
||||||
|
|
||||||
### 1. **Enhanced Constructor**
|
|
||||||
```python
|
|
||||||
def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3):
|
|
||||||
```
|
|
||||||
- Added `max_retries` parameter for configurable retry attempts
|
|
||||||
- Default retry count: 3 attempts
|
|
||||||
|
|
||||||
### 2. **Enhanced Generate Method**
|
|
||||||
```python
|
|
||||||
def generate(self,
|
|
||||||
prompt: str,
|
|
||||||
strip_think: bool = True,
|
|
||||||
validation_schema: Optional[Dict[str, Any]] = None,
|
|
||||||
response_type: Optional[str] = None,
|
|
||||||
return_parsed: bool = False) -> Union[str, Dict[str, Any]]:
|
|
||||||
```
|
|
||||||
|
|
||||||
**New Parameters:**
|
|
||||||
- `validation_schema`: Custom JSON schema for validation
|
|
||||||
- `response_type`: Predefined response type for validation
|
|
||||||
- `return_parsed`: Return parsed JSON instead of raw string
|
|
||||||
|
|
||||||
**Return Type:**
|
|
||||||
- `Union[str, Dict[str, Any]]`: Can return either raw string or parsed JSON
|
|
||||||
|
|
||||||
### 3. **New Convenience Methods**
|
|
||||||
|
|
||||||
#### `generate_with_validation()`
|
|
||||||
```python
|
|
||||||
def generate_with_validation(self,
|
|
||||||
prompt: str,
|
|
||||||
response_type: str,
|
|
||||||
strip_think: bool = True,
|
|
||||||
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
|
||||||
```
|
|
||||||
- Uses predefined validation schemas based on response type
|
|
||||||
- Automatically handles retries and validation
|
|
||||||
- Returns parsed JSON by default
|
|
||||||
|
|
||||||
#### `generate_with_schema()`
|
|
||||||
```python
|
|
||||||
def generate_with_schema(self,
|
|
||||||
prompt: str,
|
|
||||||
schema: Dict[str, Any],
|
|
||||||
strip_think: bool = True,
|
|
||||||
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
|
||||||
```
|
|
||||||
- Uses custom JSON schema for validation
|
|
||||||
- Automatically handles retries and validation
|
|
||||||
- Returns parsed JSON by default
|
|
||||||
|
|
||||||
### 4. **Supported Response Types**
|
|
||||||
The following response types are supported for automatic validation:
|
|
||||||
|
|
||||||
- `'entity_extraction'`: Entity extraction responses
|
|
||||||
- `'entity_linkage'`: Entity linkage responses
|
|
||||||
- `'regex_entity'`: Regex-based entity responses
|
|
||||||
- `'business_name_extraction'`: Business name extraction responses
|
|
||||||
- `'address_extraction'`: Address component extraction responses
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
### 1. **Automatic Retry Mechanism**
|
|
||||||
- Retries failed API calls up to `max_retries` times
|
|
||||||
- Retries on validation failures
|
|
||||||
- Retries on JSON parsing failures
|
|
||||||
- Configurable retry count per client instance
|
|
||||||
|
|
||||||
### 2. **Built-in Validation**
|
|
||||||
- JSON schema validation using `jsonschema` library
|
|
||||||
- Predefined schemas for common response types
|
|
||||||
- Custom schema support for specialized use cases
|
|
||||||
- Detailed validation error logging
|
|
||||||
|
|
||||||
### 3. **Automatic JSON Parsing**
|
|
||||||
- Uses `LLMJsonExtractor.parse_raw_json_str()` for robust JSON extraction
|
|
||||||
- Handles malformed JSON responses gracefully
|
|
||||||
- Returns parsed Python dictionaries when requested
|
|
||||||
|
|
||||||
### 4. **Backward Compatibility**
|
|
||||||
- All existing code continues to work without changes
|
|
||||||
- Original `generate()` method signature preserved
|
|
||||||
- Default behavior unchanged
|
|
||||||
|
|
||||||
## Usage Examples
|
|
||||||
|
|
||||||
### 1. **Basic Usage (Backward Compatible)**
|
|
||||||
```python
|
|
||||||
client = OllamaClient("llama2")
|
|
||||||
response = client.generate("Hello, world!")
|
|
||||||
# Returns: "Hello, world!"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. **With Response Type Validation**
|
|
||||||
```python
|
|
||||||
client = OllamaClient("llama2")
|
|
||||||
result = client.generate_with_validation(
|
|
||||||
prompt="Extract business name from: 上海盒马网络科技有限公司",
|
|
||||||
response_type='business_name_extraction',
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
# Returns: {"business_name": "盒马", "confidence": 0.9}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. **With Custom Schema Validation**
|
|
||||||
```python
|
|
||||||
client = OllamaClient("llama2")
|
|
||||||
custom_schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"},
|
|
||||||
"age": {"type": "number"}
|
|
||||||
},
|
|
||||||
"required": ["name", "age"]
|
|
||||||
}
|
|
||||||
|
|
||||||
result = client.generate_with_schema(
|
|
||||||
prompt="Generate person info",
|
|
||||||
schema=custom_schema,
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
# Returns: {"name": "张三", "age": 30}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. **Advanced Usage with All Options**
|
|
||||||
```python
|
|
||||||
client = OllamaClient("llama2", max_retries=5)
|
|
||||||
result = client.generate(
|
|
||||||
prompt="Complex prompt",
|
|
||||||
strip_think=True,
|
|
||||||
validation_schema=custom_schema,
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Updated Components
|
|
||||||
|
|
||||||
### 1. **Extractors**
|
|
||||||
- `BusinessNameExtractor`: Now uses `generate_with_validation()`
|
|
||||||
- `AddressExtractor`: Now uses `generate_with_validation()`
|
|
||||||
|
|
||||||
### 2. **Processors**
|
|
||||||
- `NerProcessor`: Updated to use enhanced methods
|
|
||||||
- `NerProcessorRefactored`: Updated to use enhanced methods
|
|
||||||
|
|
||||||
### 3. **Benefits in Processors**
|
|
||||||
- Simplified code: No more manual retry loops
|
|
||||||
- Automatic validation: No more manual JSON parsing
|
|
||||||
- Better error handling: Automatic fallback to regex methods
|
|
||||||
- Cleaner code: Reduced boilerplate
|
|
||||||
|
|
||||||
## Error Handling
|
|
||||||
|
|
||||||
### 1. **API Failures**
|
|
||||||
- Automatic retry on network errors
|
|
||||||
- Configurable retry count
|
|
||||||
- Detailed error logging
|
|
||||||
|
|
||||||
### 2. **Validation Failures**
|
|
||||||
- Automatic retry on schema validation failures
|
|
||||||
- Automatic retry on JSON parsing failures
|
|
||||||
- Graceful fallback to alternative methods
|
|
||||||
|
|
||||||
### 3. **Exception Types**
|
|
||||||
- `RequestException`: API call failures after all retries
|
|
||||||
- `ValueError`: Validation failures after all retries
|
|
||||||
- `Exception`: Unexpected errors
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
### 1. **Test Coverage**
|
|
||||||
- Initialization with new parameters
|
|
||||||
- Enhanced generate methods
|
|
||||||
- Backward compatibility
|
|
||||||
- Retry mechanism
|
|
||||||
- Validation failure handling
|
|
||||||
- Mock-based testing for reliability
|
|
||||||
|
|
||||||
### 2. **Run Tests**
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
python3 test_enhanced_ollama_client.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Migration Guide
|
|
||||||
|
|
||||||
### 1. **No Changes Required**
|
|
||||||
Existing code continues to work without modification:
|
|
||||||
```python
|
|
||||||
# This still works exactly the same
|
|
||||||
client = OllamaClient("llama2")
|
|
||||||
response = client.generate("prompt")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. **Optional Enhancements**
|
|
||||||
To take advantage of new features:
|
|
||||||
```python
|
|
||||||
# Old way (still works)
|
|
||||||
response = client.generate(prompt)
|
|
||||||
parsed = LLMJsonExtractor.parse_raw_json_str(response)
|
|
||||||
if LLMResponseValidator.validate_entity_extraction(parsed):
|
|
||||||
# use parsed
|
|
||||||
|
|
||||||
# New way (recommended)
|
|
||||||
parsed = client.generate_with_validation(
|
|
||||||
prompt=prompt,
|
|
||||||
response_type='entity_extraction',
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
# parsed is already validated and ready to use
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. **Benefits of Migration**
|
|
||||||
- **Reduced Code**: Eliminates manual retry loops
|
|
||||||
- **Better Reliability**: Automatic retry and validation
|
|
||||||
- **Cleaner Code**: Less boilerplate
|
|
||||||
- **Better Error Handling**: Automatic fallbacks
|
|
||||||
|
|
||||||
## Performance Impact
|
|
||||||
|
|
||||||
### 1. **Positive Impact**
|
|
||||||
- Reduced code complexity
|
|
||||||
- Better error recovery
|
|
||||||
- Automatic retry reduces manual intervention
|
|
||||||
|
|
||||||
### 2. **Minimal Overhead**
|
|
||||||
- Validation only occurs when requested
|
|
||||||
- JSON parsing only occurs when needed
|
|
||||||
- Retry mechanism only activates on failures
|
|
||||||
|
|
||||||
## Future Enhancements
|
|
||||||
|
|
||||||
### 1. **Potential Additions**
|
|
||||||
- Circuit breaker pattern for API failures
|
|
||||||
- Caching for repeated requests
|
|
||||||
- Async/await support
|
|
||||||
- Streaming response support
|
|
||||||
- Custom retry strategies
|
|
||||||
|
|
||||||
### 2. **Configuration Options**
|
|
||||||
- Per-request retry configuration
|
|
||||||
- Custom validation error handling
|
|
||||||
- Response transformation hooks
|
|
||||||
- Metrics and monitoring
|
|
||||||
|
|
||||||
## Conclusion
|
|
||||||
|
|
||||||
The enhanced `OllamaClient` provides a robust, reliable, and easy-to-use interface for LLM interactions while maintaining full backward compatibility. The new validation and retry mechanisms significantly improve the reliability of LLM-based operations in the NER processing pipeline.
|
|
||||||
|
|
@ -1,166 +0,0 @@
|
||||||
# NerProcessor Refactoring Summary
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
The `ner_processor.py` file has been successfully refactored from a monolithic 729-line class into a modular, maintainable architecture following SOLID principles.
|
|
||||||
|
|
||||||
## New Architecture
|
|
||||||
|
|
||||||
### Directory Structure
|
|
||||||
```
|
|
||||||
backend/app/core/document_handlers/
|
|
||||||
├── ner_processor.py # Original file (unchanged)
|
|
||||||
├── ner_processor_refactored.py # New refactored version
|
|
||||||
├── masker_factory.py # Factory for creating maskers
|
|
||||||
├── maskers/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── base_masker.py # Abstract base class
|
|
||||||
│ ├── name_masker.py # Chinese/English name masking
|
|
||||||
│ ├── company_masker.py # Company name masking
|
|
||||||
│ ├── address_masker.py # Address masking
|
|
||||||
│ ├── id_masker.py # ID/social credit code masking
|
|
||||||
│ └── case_masker.py # Case number masking
|
|
||||||
├── extractors/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── base_extractor.py # Abstract base class
|
|
||||||
│ ├── business_name_extractor.py # Business name extraction
|
|
||||||
│ └── address_extractor.py # Address component extraction
|
|
||||||
└── validators/ # (Placeholder for future use)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Key Components
|
|
||||||
|
|
||||||
### 1. Base Classes
|
|
||||||
- **`BaseMasker`**: Abstract base class for all maskers
|
|
||||||
- **`BaseExtractor`**: Abstract base class for all extractors
|
|
||||||
|
|
||||||
### 2. Maskers
|
|
||||||
- **`ChineseNameMasker`**: Handles Chinese name masking (surname + pinyin initials)
|
|
||||||
- **`EnglishNameMasker`**: Handles English name masking (first letter + ***)
|
|
||||||
- **`CompanyMasker`**: Handles company name masking (business name replacement)
|
|
||||||
- **`AddressMasker`**: Handles address masking (component replacement)
|
|
||||||
- **`IDMasker`**: Handles ID and social credit code masking
|
|
||||||
- **`CaseMasker`**: Handles case number masking
|
|
||||||
|
|
||||||
### 3. Extractors
|
|
||||||
- **`BusinessNameExtractor`**: Extracts business names from company names using LLM + regex fallback
|
|
||||||
- **`AddressExtractor`**: Extracts address components using LLM + regex fallback
|
|
||||||
|
|
||||||
### 4. Factory
|
|
||||||
- **`MaskerFactory`**: Creates maskers with proper dependencies
|
|
||||||
|
|
||||||
### 5. Refactored Processor
|
|
||||||
- **`NerProcessorRefactored`**: Main orchestrator using the new architecture
|
|
||||||
|
|
||||||
## Benefits Achieved
|
|
||||||
|
|
||||||
### 1. Single Responsibility Principle
|
|
||||||
- Each class has one clear responsibility
|
|
||||||
- Maskers only handle masking logic
|
|
||||||
- Extractors only handle extraction logic
|
|
||||||
- Processor only handles orchestration
|
|
||||||
|
|
||||||
### 2. Open/Closed Principle
|
|
||||||
- Easy to add new maskers without modifying existing code
|
|
||||||
- New entity types can be supported by creating new maskers
|
|
||||||
|
|
||||||
### 3. Dependency Injection
|
|
||||||
- Dependencies are injected rather than hardcoded
|
|
||||||
- Easier to test and mock
|
|
||||||
|
|
||||||
### 4. Better Testing
|
|
||||||
- Each component can be tested in isolation
|
|
||||||
- Mock dependencies easily
|
|
||||||
|
|
||||||
### 5. Code Reusability
|
|
||||||
- Maskers can be used independently
|
|
||||||
- Common functionality shared through base classes
|
|
||||||
|
|
||||||
### 6. Maintainability
|
|
||||||
- Changes to one masking rule don't affect others
|
|
||||||
- Clear separation of concerns
|
|
||||||
|
|
||||||
## Migration Strategy
|
|
||||||
|
|
||||||
### Phase 1: ✅ Complete
|
|
||||||
- Created base classes and interfaces
|
|
||||||
- Extracted all maskers
|
|
||||||
- Created extractors
|
|
||||||
- Created factory pattern
|
|
||||||
- Created refactored processor
|
|
||||||
|
|
||||||
### Phase 2: Testing (Next)
|
|
||||||
- Run validation script: `python3 validate_refactoring.py`
|
|
||||||
- Run existing tests to ensure compatibility
|
|
||||||
- Create comprehensive unit tests for each component
|
|
||||||
|
|
||||||
### Phase 3: Integration (Future)
|
|
||||||
- Replace original processor with refactored version
|
|
||||||
- Update imports throughout the codebase
|
|
||||||
- Remove old code
|
|
||||||
|
|
||||||
### Phase 4: Enhancement (Future)
|
|
||||||
- Add configuration management
|
|
||||||
- Add more extractors as needed
|
|
||||||
- Add validation components
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
### Validation Script
|
|
||||||
Run the validation script to test the refactored code:
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
python3 validate_refactoring.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### Unit Tests
|
|
||||||
Run the unit tests for the refactored components:
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
python3 -m pytest tests/test_refactored_ner_processor.py -v
|
|
||||||
```
|
|
||||||
|
|
||||||
## Current Status
|
|
||||||
|
|
||||||
✅ **Completed:**
|
|
||||||
- All maskers extracted and implemented
|
|
||||||
- All extractors created
|
|
||||||
- Factory pattern implemented
|
|
||||||
- Refactored processor created
|
|
||||||
- Validation script created
|
|
||||||
- Unit tests created
|
|
||||||
|
|
||||||
🔄 **Next Steps:**
|
|
||||||
- Test the refactored code
|
|
||||||
- Ensure all existing functionality works
|
|
||||||
- Replace original processor when ready
|
|
||||||
|
|
||||||
## File Comparison
|
|
||||||
|
|
||||||
| Metric | Original | Refactored |
|
|
||||||
|--------|----------|------------|
|
|
||||||
| Main Class Lines | 729 | ~200 |
|
|
||||||
| Number of Classes | 1 | 10+ |
|
|
||||||
| Responsibilities | Multiple | Single |
|
|
||||||
| Testability | Low | High |
|
|
||||||
| Maintainability | Low | High |
|
|
||||||
| Extensibility | Low | High |
|
|
||||||
|
|
||||||
## Backward Compatibility
|
|
||||||
|
|
||||||
The refactored code maintains full backward compatibility:
|
|
||||||
- All existing masking rules are preserved
|
|
||||||
- All existing functionality works the same
|
|
||||||
- The public API remains unchanged
|
|
||||||
- The original `ner_processor.py` is untouched
|
|
||||||
|
|
||||||
## Future Enhancements
|
|
||||||
|
|
||||||
1. **Configuration Management**: Centralized configuration for masking rules
|
|
||||||
2. **Validation Framework**: Dedicated validation components
|
|
||||||
3. **Performance Optimization**: Caching and optimization strategies
|
|
||||||
4. **Monitoring**: Metrics and logging for each component
|
|
||||||
5. **Plugin System**: Dynamic loading of new maskers and extractors
|
|
||||||
|
|
||||||
## Conclusion
|
|
||||||
|
|
||||||
The refactoring successfully transforms the monolithic `NerProcessor` into a modular, maintainable, and extensible architecture while preserving all existing functionality. The new architecture follows SOLID principles and provides a solid foundation for future enhancements.
|
|
||||||
|
|
@ -1,130 +0,0 @@
|
||||||
# 句子分块改进文档
|
|
||||||
|
|
||||||
## 问题描述
|
|
||||||
|
|
||||||
在原始的NER提取过程中,我们发现了一些实体被截断的问题,比如:
|
|
||||||
- `"丰复久信公"` (应该是 `"丰复久信营销科技有限公司"`)
|
|
||||||
- `"康达律师事"` (应该是 `"北京市康达律师事务所"`)
|
|
||||||
|
|
||||||
这些截断问题是由于原始的基于字符数量的简单分块策略导致的,该策略没有考虑实体的完整性。
|
|
||||||
|
|
||||||
## 解决方案
|
|
||||||
|
|
||||||
### 1. 句子分块策略
|
|
||||||
|
|
||||||
我们实现了基于句子的智能分块策略,主要特点:
|
|
||||||
|
|
||||||
- **自然边界分割**:使用中文句子结束符(。!?;\n)和英文句子结束符(.!?;)进行分割
|
|
||||||
- **实体完整性保护**:避免在实体名称中间进行分割
|
|
||||||
- **智能长度控制**:基于token数量而非字符数量进行分块
|
|
||||||
|
|
||||||
### 2. 实体边界安全检查
|
|
||||||
|
|
||||||
实现了 `_is_entity_boundary_safe()` 方法来检查分割点是否安全:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def _is_entity_boundary_safe(self, text: str, position: int) -> bool:
|
|
||||||
# 检查常见实体后缀
|
|
||||||
entity_suffixes = ['公', '司', '所', '院', '厅', '局', '部', '会', '团', '社', '处', '室', '楼', '号']
|
|
||||||
|
|
||||||
# 检查不完整的实体模式
|
|
||||||
if text[position-2:position+1] in ['公司', '事务所', '协会', '研究院']:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 检查地址模式
|
|
||||||
address_patterns = ['省', '市', '区', '县', '路', '街', '巷', '号', '室']
|
|
||||||
# ...
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 长句子智能分割
|
|
||||||
|
|
||||||
对于超过token限制的长句子,实现了智能分割策略:
|
|
||||||
|
|
||||||
1. **标点符号分割**:优先在逗号、分号等标点符号处分割
|
|
||||||
2. **实体边界分割**:如果标点分割不可行,在安全的实体边界处分割
|
|
||||||
3. **强制分割**:最后才使用字符级别的强制分割
|
|
||||||
|
|
||||||
## 实现细节
|
|
||||||
|
|
||||||
### 核心方法
|
|
||||||
|
|
||||||
1. **`_split_text_by_sentences()`**: 将文本按句子分割
|
|
||||||
2. **`_create_sentence_chunks()`**: 基于句子创建分块
|
|
||||||
3. **`_split_long_sentence()`**: 智能分割长句子
|
|
||||||
4. **`_is_entity_boundary_safe()`**: 检查分割点安全性
|
|
||||||
|
|
||||||
### 分块流程
|
|
||||||
|
|
||||||
```
|
|
||||||
输入文本
|
|
||||||
↓
|
|
||||||
按句子分割
|
|
||||||
↓
|
|
||||||
估算token数量
|
|
||||||
↓
|
|
||||||
创建句子分块
|
|
||||||
↓
|
|
||||||
检查实体边界
|
|
||||||
↓
|
|
||||||
输出最终分块
|
|
||||||
```
|
|
||||||
|
|
||||||
## 测试结果
|
|
||||||
|
|
||||||
### 改进前 vs 改进后
|
|
||||||
|
|
||||||
| 指标 | 改进前 | 改进后 |
|
|
||||||
|------|--------|--------|
|
|
||||||
| 截断实体数量 | 较多 | 显著减少 |
|
|
||||||
| 实体完整性 | 经常被破坏 | 得到保护 |
|
|
||||||
| 分块质量 | 基于字符 | 基于语义 |
|
|
||||||
|
|
||||||
### 测试案例
|
|
||||||
|
|
||||||
1. **"丰复久信公" 问题**:
|
|
||||||
- 改进前:`"丰复久信公"` (截断)
|
|
||||||
- 改进后:`"北京丰复久信营销科技有限公司"` (完整)
|
|
||||||
|
|
||||||
2. **长句子处理**:
|
|
||||||
- 改进前:可能在实体中间截断
|
|
||||||
- 改进后:在句子边界或安全位置分割
|
|
||||||
|
|
||||||
## 配置参数
|
|
||||||
|
|
||||||
- `max_tokens`: 每个分块的最大token数量 (默认: 400)
|
|
||||||
- `confidence_threshold`: 实体置信度阈值 (默认: 0.95)
|
|
||||||
- `sentence_pattern`: 句子分割正则表达式
|
|
||||||
|
|
||||||
## 使用示例
|
|
||||||
|
|
||||||
```python
|
|
||||||
from app.core.document_handlers.extractors.ner_extractor import NERExtractor
|
|
||||||
|
|
||||||
extractor = NERExtractor()
|
|
||||||
result = extractor.extract(long_text)
|
|
||||||
|
|
||||||
# 结果中的实体将更加完整
|
|
||||||
entities = result.get("entities", [])
|
|
||||||
for entity in entities:
|
|
||||||
print(f"{entity['text']} ({entity['type']})")
|
|
||||||
```
|
|
||||||
|
|
||||||
## 性能影响
|
|
||||||
|
|
||||||
- **内存使用**:略有增加(需要存储句子分割结果)
|
|
||||||
- **处理速度**:基本无影响(句子分割很快)
|
|
||||||
- **准确性**:显著提升(减少截断实体)
|
|
||||||
|
|
||||||
## 未来改进方向
|
|
||||||
|
|
||||||
1. **更智能的实体识别**:使用预训练模型识别实体边界
|
|
||||||
2. **动态分块大小**:根据文本复杂度调整分块大小
|
|
||||||
3. **多语言支持**:扩展到其他语言的分块策略
|
|
||||||
4. **缓存优化**:缓存句子分割结果以提高性能
|
|
||||||
|
|
||||||
## 相关文件
|
|
||||||
|
|
||||||
- `backend/app/core/document_handlers/extractors/ner_extractor.py` - 主要实现
|
|
||||||
- `backend/test_improved_chunking.py` - 测试脚本
|
|
||||||
- `backend/test_truncation_fix.py` - 截断问题测试
|
|
||||||
- `backend/test_chunking_logic.py` - 分块逻辑测试
|
|
||||||
|
|
@ -1,118 +0,0 @@
|
||||||
# Test Setup Guide
|
|
||||||
|
|
||||||
This document explains how to set up and run tests for the legal-doc-masker backend.
|
|
||||||
|
|
||||||
## Test Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
backend/
|
|
||||||
├── tests/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── test_ner_processor.py
|
|
||||||
│ ├── test1.py
|
|
||||||
│ └── test.txt
|
|
||||||
├── conftest.py
|
|
||||||
├── pytest.ini
|
|
||||||
└── run_tests.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## VS Code Configuration
|
|
||||||
|
|
||||||
The `.vscode/settings.json` file has been configured to:
|
|
||||||
|
|
||||||
1. **Set pytest as the test framework**: `"python.testing.pytestEnabled": true`
|
|
||||||
2. **Point to the correct test directory**: `"python.testing.pytestArgs": ["backend/tests"]`
|
|
||||||
3. **Set the working directory**: `"python.testing.cwd": "${workspaceFolder}/backend"`
|
|
||||||
4. **Configure Python interpreter**: Points to backend virtual environment
|
|
||||||
|
|
||||||
## Running Tests
|
|
||||||
|
|
||||||
### From VS Code Test Explorer
|
|
||||||
1. Open the Test Explorer panel (Ctrl+Shift+P → "Python: Configure Tests")
|
|
||||||
2. Select "pytest" as the test framework
|
|
||||||
3. Select "backend/tests" as the test directory
|
|
||||||
4. Tests should now appear in the Test Explorer
|
|
||||||
|
|
||||||
### From Command Line
|
|
||||||
```bash
|
|
||||||
# From the project root
|
|
||||||
cd backend
|
|
||||||
python -m pytest tests/ -v
|
|
||||||
|
|
||||||
# Or use the test runner script
|
|
||||||
python run_tests.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### From VS Code Terminal
|
|
||||||
```bash
|
|
||||||
# Make sure you're in the backend directory
|
|
||||||
cd backend
|
|
||||||
pytest tests/ -v
|
|
||||||
```
|
|
||||||
|
|
||||||
## Test Configuration
|
|
||||||
|
|
||||||
### pytest.ini
|
|
||||||
- **testpaths**: Points to the `tests` directory
|
|
||||||
- **python_files**: Looks for files starting with `test_` or ending with `_test.py`
|
|
||||||
- **python_functions**: Looks for functions starting with `test_`
|
|
||||||
- **markers**: Defines test markers for categorization
|
|
||||||
|
|
||||||
### conftest.py
|
|
||||||
- **Path setup**: Adds backend directory to Python path
|
|
||||||
- **Fixtures**: Provides common test fixtures
|
|
||||||
- **Environment setup**: Handles test environment initialization
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Tests Not Discovered
|
|
||||||
1. **Check VS Code settings**: Ensure `python.testing.pytestArgs` points to `backend/tests`
|
|
||||||
2. **Verify working directory**: Ensure `python.testing.cwd` is set to `${workspaceFolder}/backend`
|
|
||||||
3. **Check Python interpreter**: Make sure it points to the backend virtual environment
|
|
||||||
|
|
||||||
### Import Errors
|
|
||||||
1. **Check conftest.py**: Ensures backend directory is in Python path
|
|
||||||
2. **Verify __init__.py**: Tests directory should have an `__init__.py` file
|
|
||||||
3. **Check relative imports**: Use absolute imports from the backend root
|
|
||||||
|
|
||||||
### Virtual Environment Issues
|
|
||||||
1. **Create virtual environment**: `python -m venv .venv`
|
|
||||||
2. **Activate environment**:
|
|
||||||
- Windows: `.venv\Scripts\activate`
|
|
||||||
- Unix/MacOS: `source .venv/bin/activate`
|
|
||||||
3. **Install dependencies**: `pip install -r requirements.txt`
|
|
||||||
|
|
||||||
## Test Examples
|
|
||||||
|
|
||||||
### Simple Test
|
|
||||||
```python
|
|
||||||
def test_simple_assertion():
|
|
||||||
"""Simple test to verify pytest is working"""
|
|
||||||
assert 1 == 1
|
|
||||||
assert 2 + 2 == 4
|
|
||||||
```
|
|
||||||
|
|
||||||
### Test with Fixture
|
|
||||||
```python
|
|
||||||
def test_with_fixture(sample_data):
|
|
||||||
"""Test using a fixture"""
|
|
||||||
assert sample_data["name"] == "test"
|
|
||||||
assert sample_data["value"] == 42
|
|
||||||
```
|
|
||||||
|
|
||||||
### Integration Test
|
|
||||||
```python
|
|
||||||
def test_ner_processor():
|
|
||||||
"""Test NER processor functionality"""
|
|
||||||
from app.core.document_handlers.ner_processor import NerProcessor
|
|
||||||
processor = NerProcessor()
|
|
||||||
# Test implementation...
|
|
||||||
```
|
|
||||||
|
|
||||||
## Best Practices
|
|
||||||
|
|
||||||
1. **Test naming**: Use descriptive test names starting with `test_`
|
|
||||||
2. **Test isolation**: Each test should be independent
|
|
||||||
3. **Use fixtures**: For common setup and teardown
|
|
||||||
4. **Add markers**: Use `@pytest.mark.slow` for slow tests
|
|
||||||
5. **Documentation**: Add docstrings to explain test purpose
|
|
||||||
|
|
@ -0,0 +1,127 @@
|
||||||
|
[2025-07-14 14:20:19,015: INFO/ForkPoolWorker-4] Raw response from LLM: {
|
||||||
|
celery_worker-1 | "entities": []
|
||||||
|
celery_worker-1 | }
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:19,016: INFO/ForkPoolWorker-4] Parsed mapping: {'entities': []}
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:19,020: INFO/ForkPoolWorker-4] Calling ollama to generate case numbers mapping for chunk (attempt 1/3):
|
||||||
|
celery_worker-1 | 你是一个专业的法律文本实体识别助手。请从以下文本中抽取出所有需要脱敏的敏感信息,并按照指定的类别进行分类。请严格按照JSON格式输出结果。
|
||||||
|
celery_worker-1 |
|
||||||
|
celery_worker-1 | 实体类别包括:
|
||||||
|
celery_worker-1 | - 案号
|
||||||
|
celery_worker-1 |
|
||||||
|
celery_worker-1 | 待处理文本:
|
||||||
|
celery_worker-1 |
|
||||||
|
celery_worker-1 |
|
||||||
|
celery_worker-1 | 二审案件受理费450892 元,由北京丰复久信营销科技有限公司负担(已交纳)。
|
||||||
|
celery_worker-1 |
|
||||||
|
celery_worker-1 | 29. 本判决为终审判决。
|
||||||
|
celery_worker-1 |
|
||||||
|
celery_worker-1 | 审 判 长 史晓霞审 判 员 邓青菁审 判 员 李 淼二〇二二年七月七日法 官 助 理 黎 铧书 记 员 郑海兴
|
||||||
|
celery_worker-1 |
|
||||||
|
celery_worker-1 | 输出格式:
|
||||||
|
celery_worker-1 | {
|
||||||
|
celery_worker-1 | "entities": [
|
||||||
|
celery_worker-1 | {"text": "原始文本内容", "type": "案号"},
|
||||||
|
celery_worker-1 | ...
|
||||||
|
celery_worker-1 | ]
|
||||||
|
celery_worker-1 | }
|
||||||
|
celery_worker-1 |
|
||||||
|
celery_worker-1 | 请严格按照JSON格式输出结果。
|
||||||
|
celery_worker-1 |
|
||||||
|
api-1 | INFO: 192.168.65.1:60045 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:34054 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:34054 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:22084 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:31,279: INFO/ForkPoolWorker-4] Raw response from LLM: {
|
||||||
|
celery_worker-1 | "entities": []
|
||||||
|
celery_worker-1 | }
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:31,281: INFO/ForkPoolWorker-4] Parsed mapping: {'entities': []}
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:31,287: INFO/ForkPoolWorker-4] Chunk mapping: [{'entities': []}, {'entities': [{'text': '北京丰复久信营销科技有限公司', 'type': '公司名称'}]}, {'entities': []}, {'entities': []}, {'entities': []}]
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:31,288: INFO/ForkPoolWorker-4] Final chunk mappings: [{'entities': [{'text': '郭东军', 'type': '人名'}, {'text': '王欢子', 'type': '人名'}]}, {'entities': [{'text': '北京丰复久信营销科技有限公司', 'type': '公司名称'}, {'text': '丰复久信公司', 'type': '公司名称简称'}, {'text': '中研智创区块链技术有限公司', 'type': '公司名称'}, {'text': '中研智才公司', 'type': '公司名称简称'}]}, {'entities': [{'text': '北京市海淀区北小马厂6 号1 号楼华天大厦1306 室', 'type': '地址'}, {'text': '天津市津南区双港镇工业园区优谷产业园5 号楼-1505', 'type': '地址'}]}, {'entities': [{'text': '服务合同', 'type': '项目名'}]}, {'entities': [{'text': '(2022)京 03 民终 3852 号', 'type': '案号'}, {'text': '(2020)京0105 民初69754 号', 'type': '案号'}]}, {'entities': [{'text': '李圣艳', 'type': '人名'}, {'text': '闫向东', 'type': '人名'}, {'text': '李敏', 'type': '人名'}, {'text': '布兰登·斯密特', 'type': '英文人名'}]}, {'entities': [{'text': '丰复久信公司', 'type': '公司名称'}, {'text': '中研智创公司', 'type': '公司名称'}, {'text': '丰复久信', 'type': '公司名称简称'}, {'text': '中研智创', 'type': '公司名称简称'}]}, {'entities': [{'text': '上海市', 'type': '地址'}, {'text': '北京', 'type': '地址'}]}, {'entities': [{'text': '《计算机设备采购合同》', 'type': '项目名'}]}, {'entities': []}, {'entities': []}, {'entities': [{'text': '丰复久信公司', 'type': '公司名称'}, {'text': '中研智创公司', 'type': '公司名称'}]}, {'entities': []}, {'entities': [{'text': '《服务合同书》', 'type': '项目名'}]}, {'entities': []}, {'entities': []}, {'entities': [{'text': '北京丰复久信营销科技有限公司', 'type': '公司名称'}]}, {'entities': []}, {'entities': []}, {'entities': []}]
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:31,288: INFO/ForkPoolWorker-4] Duplicate entity found: {'text': '丰复久信公司', 'type': '公司名称'}
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:31,288: INFO/ForkPoolWorker-4] Duplicate entity found: {'text': '丰复久信公司', 'type': '公司名称'}
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:31,288: INFO/ForkPoolWorker-4] Duplicate entity found: {'text': '中研智创公司', 'type': '公司名称'}
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:31,288: INFO/ForkPoolWorker-4] Duplicate entity found: {'text': '北京丰复久信营销科技有限公司', 'type': '公司名称'}
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:31,288: INFO/ForkPoolWorker-4] Merged 22 unique entities
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:31,288: INFO/ForkPoolWorker-4] Unique entities: [{'text': '郭东军', 'type': '人名'}, {'text': '王欢子', 'type': '人名'}, {'text': '北京丰复久信营销科技有限公司', 'type': '公司名称'}, {'text': '丰复久信公司', 'type': '公司名称简称'}, {'text': '中研智创区块链技术有限公司', 'type': '公司名称'}, {'text': '中研智才公司', 'type': '公司名称简称'}, {'text': '北京市海淀区北小马厂6 号1 号楼华天大厦1306 室', 'type': '地址'}, {'text': '天津市津南区双港镇工业园区优谷产业园5 号楼-1505', 'type': '地址'}, {'text': '服务合同', 'type': '项目名'}, {'text': '(2022)京 03 民终 3852 号', 'type': '案号'}, {'text': '(2020)京0105 民初69754 号', 'type': '案号'}, {'text': '李圣艳', 'type': '人名'}, {'text': '闫向东', 'type': '人名'}, {'text': '李敏', 'type': '人名'}, {'text': '布兰登·斯密特', 'type': '英文人名'}, {'text': '中研智创公司', 'type': '公司名称'}, {'text': '丰复久信', 'type': '公司名称简称'}, {'text': '中研智创', 'type': '公司名称简称'}, {'text': '上海市', 'type': '地址'}, {'text': '北京', 'type': '地址'}, {'text': '《计算机设备采购合同》', 'type': '项目名'}, {'text': '《服务合同书》', 'type': '项目名'}]
|
||||||
|
celery_worker-1 | [2025-07-14 14:20:31,289: INFO/ForkPoolWorker-4] Calling ollama to generate entity linkage (attempt 1/3)
|
||||||
|
api-1 | INFO: 192.168.65.1:52168 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:61426 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:30702 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:48159 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:16860 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:21262 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:45564 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:32142 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:27769 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:21196 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,436: INFO/ForkPoolWorker-4] Raw entity linkage response from LLM: {
|
||||||
|
celery_worker-1 | "entity_groups": [
|
||||||
|
celery_worker-1 | {
|
||||||
|
celery_worker-1 | "group_id": "group_1",
|
||||||
|
celery_worker-1 | "group_type": "公司名称",
|
||||||
|
celery_worker-1 | "entities": [
|
||||||
|
celery_worker-1 | {
|
||||||
|
celery_worker-1 | "text": "北京丰复久信营销科技有限公司",
|
||||||
|
celery_worker-1 | "type": "公司名称",
|
||||||
|
celery_worker-1 | "is_primary": true
|
||||||
|
celery_worker-1 | },
|
||||||
|
celery_worker-1 | {
|
||||||
|
celery_worker-1 | "text": "丰复久信公司",
|
||||||
|
celery_worker-1 | "type": "公司名称简称",
|
||||||
|
celery_worker-1 | "is_primary": false
|
||||||
|
celery_worker-1 | },
|
||||||
|
celery_worker-1 | {
|
||||||
|
celery_worker-1 | "text": "丰复久信",
|
||||||
|
celery_worker-1 | "type": "公司名称简称",
|
||||||
|
celery_worker-1 | "is_primary": false
|
||||||
|
celery_worker-1 | }
|
||||||
|
celery_worker-1 | ]
|
||||||
|
celery_worker-1 | },
|
||||||
|
celery_worker-1 | {
|
||||||
|
celery_worker-1 | "group_id": "group_2",
|
||||||
|
celery_worker-1 | "group_type": "公司名称",
|
||||||
|
celery_worker-1 | "entities": [
|
||||||
|
celery_worker-1 | {
|
||||||
|
celery_worker-1 | "text": "中研智创区块链技术有限公司",
|
||||||
|
celery_worker-1 | "type": "公司名称",
|
||||||
|
celery_worker-1 | "is_primary": true
|
||||||
|
celery_worker-1 | },
|
||||||
|
celery_worker-1 | {
|
||||||
|
celery_worker-1 | "text": "中研智创公司",
|
||||||
|
celery_worker-1 | "type": "公司名称简称",
|
||||||
|
celery_worker-1 | "is_primary": false
|
||||||
|
celery_worker-1 | },
|
||||||
|
celery_worker-1 | {
|
||||||
|
celery_worker-1 | "text": "中研智创",
|
||||||
|
celery_worker-1 | "type": "公司名称简称",
|
||||||
|
celery_worker-1 | "is_primary": false
|
||||||
|
celery_worker-1 | }
|
||||||
|
celery_worker-1 | ]
|
||||||
|
celery_worker-1 | }
|
||||||
|
celery_worker-1 | ]
|
||||||
|
celery_worker-1 | }
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,437: INFO/ForkPoolWorker-4] Parsed entity linkage: {'entity_groups': [{'group_id': 'group_1', 'group_type': '公司名称', 'entities': [{'text': '北京丰复久信营销科技有限公司', 'type': '公司名称', 'is_primary': True}, {'text': '丰复久信公司', 'type': '公司名称简称', 'is_primary': False}, {'text': '丰复久信', 'type': '公司名称简称', 'is_primary': False}]}, {'group_id': 'group_2', 'group_type': '公司名称', 'entities': [{'text': '中研智创区块链技术有限公司', 'type': '公司名称', 'is_primary': True}, {'text': '中研智创公司', 'type': '公司名称简称', 'is_primary': False}, {'text': '中研智创', 'type': '公司名称简称', 'is_primary': False}]}]}
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,445: INFO/ForkPoolWorker-4] Successfully created entity linkage with 2 groups
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,445: INFO/ForkPoolWorker-4] Entity linkage: {'entity_groups': [{'group_id': 'group_1', 'group_type': '公司名称', 'entities': [{'text': '北京丰复久信营销科技有限公司', 'type': '公司名称', 'is_primary': True}, {'text': '丰复久信公司', 'type': '公司名称简称', 'is_primary': False}, {'text': '丰复久信', 'type': '公司名称简称', 'is_primary': False}]}, {'group_id': 'group_2', 'group_type': '公司名称', 'entities': [{'text': '中研智创区块链技术有限公司', 'type': '公司名称', 'is_primary': True}, {'text': '中研智创公司', 'type': '公司名称简称', 'is_primary': False}, {'text': '中研智创', 'type': '公司名称简称', 'is_primary': False}]}]}
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,446: INFO/ForkPoolWorker-4] Generated masked mapping for 22 entities
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,446: INFO/ForkPoolWorker-4] Combined mapping: {'郭东军': '某', '王欢子': '某甲', '北京丰复久信营销科技有限公司': '某公司', '丰复久信公司': '某公司甲', '中研智创区块链技术有限公司': '某公司乙', '中研智才公司': '某公司丙', '北京市海淀区北小马厂6 号1 号楼华天大厦1306 室': '某乙', '天津市津南区双港镇工业园区优谷产业园5 号楼-1505': '某丙', '服务合同': '某丁', '(2022)京 03 民终 3852 号': '某戊', '(2020)京0105 民初69754 号': '某己', '李圣艳': '某庚', '闫向东': '某辛', '李敏': '某壬', '布兰登·斯密特': '某癸', '中研智创公司': '某公司丁', '丰复久信': '某公司戊', '中研智创': '某公司己', '上海市': '某11', '北京': '某12', '《计算机设备采购合同》': '某13', '《服务合同书》': '某14'}
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,446: INFO/ForkPoolWorker-4] Linked entity '北京丰复久信营销科技有限公司' to '北京丰复久信营销科技有限公司' with masked name '某公司'
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,446: INFO/ForkPoolWorker-4] Linked entity '丰复久信公司' to '北京丰复久信营销科技有限公司' with masked name '某公司'
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,446: INFO/ForkPoolWorker-4] Linked entity '丰复久信' to '北京丰复久信营销科技有限公司' with masked name '某公司'
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,446: INFO/ForkPoolWorker-4] Linked entity '中研智创区块链技术有限公司' to '中研智创区块链技术有限公司' with masked name '某公司乙'
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,446: INFO/ForkPoolWorker-4] Linked entity '中研智创公司' to '中研智创区块链技术有限公司' with masked name '某公司乙'
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,446: INFO/ForkPoolWorker-4] Linked entity '中研智创' to '中研智创区块链技术有限公司' with masked name '某公司乙'
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,446: INFO/ForkPoolWorker-4] Final mapping: {'郭东军': '某', '王欢子': '某甲', '北京丰复久信营销科技有限公司': '某公司', '丰复久信公司': '某公司', '中研智创区块链技术有限公司': '某公司乙', '中研智才公司': '某公司丙', '北京市海淀区北小马厂6 号1 号楼华天大厦1306 室': '某乙', '天津市津南区双港镇工业园区优谷产业园5 号楼-1505': '某丙', '服务合同': '某丁', '(2022)京 03 民终 3852 号': '某戊', '(2020)京0105 民初69754 号': '某己', '李圣艳': '某庚', '闫向东': '某辛', '李敏': '某壬', '布兰登·斯密特': '某癸', '中研智创公司': '某公司乙', '丰复久信': '某公司', '中研智创': '某公司乙', '上海市': '某11', '北京': '某12', '《计算机设备采购合同》': '某13', '《服务合同书》': '某14'}
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,446: INFO/ForkPoolWorker-4] Successfully masked content
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,449: INFO/ForkPoolWorker-4] Successfully saved masked content to /app/storage/processed/47522ea9-c259-4304-bfe4-1d3ed6902ede.md
|
||||||
|
celery_worker-1 | [2025-07-14 14:21:21,470: INFO/ForkPoolWorker-4] Task app.services.file_service.process_file[5cfbca4c-0f6f-4c71-a66b-b22ee2d28139] succeeded in 311.847165101s: None
|
||||||
|
api-1 | INFO: 192.168.65.1:33432 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:40073 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:29550 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:61350 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:61755 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:63726 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:43446 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:45624 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:25256 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
api-1 | INFO: 192.168.65.1:43464 - "GET /api/v1/files/files HTTP/1.1" 200 OK
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
[tool:pytest]
|
|
||||||
testpaths = tests
|
|
||||||
pythonpath = .
|
|
||||||
python_files = test_*.py *_test.py
|
|
||||||
python_classes = Test*
|
|
||||||
python_functions = test_*
|
|
||||||
addopts =
|
|
||||||
-v
|
|
||||||
--tb=short
|
|
||||||
--strict-markers
|
|
||||||
--disable-warnings
|
|
||||||
markers =
|
|
||||||
slow: marks tests as slow (deselect with '-m "not slow"')
|
|
||||||
integration: marks tests as integration tests
|
|
||||||
unit: marks tests as unit tests
|
|
||||||
|
|
@ -30,11 +30,3 @@ PyPDF2>=3.0.0
|
||||||
pandas>=2.0.0
|
pandas>=2.0.0
|
||||||
# magic-pdf[full]
|
# magic-pdf[full]
|
||||||
jsonschema>=4.20.0
|
jsonschema>=4.20.0
|
||||||
|
|
||||||
# Chinese text processing
|
|
||||||
pypinyin>=0.50.0
|
|
||||||
|
|
||||||
# NER and ML dependencies
|
|
||||||
# torch is installed separately in Dockerfile for CPU optimization
|
|
||||||
transformers>=4.30.0
|
|
||||||
tokenizers>=0.13.0
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
# Tests package
|
|
||||||
|
|
@ -1,130 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Debug script to understand the position mapping issue after masking.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def find_entity_alignment(entity_text: str, original_document_text: str):
|
|
||||||
"""Simplified version of the alignment method for testing"""
|
|
||||||
clean_entity = entity_text.replace(" ", "")
|
|
||||||
doc_chars = [c for c in original_document_text if c != ' ']
|
|
||||||
|
|
||||||
for i in range(len(doc_chars) - len(clean_entity) + 1):
|
|
||||||
if doc_chars[i:i+len(clean_entity)] == list(clean_entity):
|
|
||||||
return map_char_positions_to_original(i, len(clean_entity), original_document_text)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def map_char_positions_to_original(clean_start: int, entity_length: int, original_text: str):
|
|
||||||
"""Simplified version of position mapping for testing"""
|
|
||||||
original_pos = 0
|
|
||||||
clean_pos = 0
|
|
||||||
|
|
||||||
while clean_pos < clean_start and original_pos < len(original_text):
|
|
||||||
if original_text[original_pos] != ' ':
|
|
||||||
clean_pos += 1
|
|
||||||
original_pos += 1
|
|
||||||
|
|
||||||
start_pos = original_pos
|
|
||||||
|
|
||||||
chars_found = 0
|
|
||||||
while chars_found < entity_length and original_pos < len(original_text):
|
|
||||||
if original_text[original_pos] != ' ':
|
|
||||||
chars_found += 1
|
|
||||||
original_pos += 1
|
|
||||||
|
|
||||||
end_pos = original_pos
|
|
||||||
found_text = original_text[start_pos:end_pos]
|
|
||||||
|
|
||||||
return start_pos, end_pos, found_text
|
|
||||||
|
|
||||||
def debug_position_issue():
|
|
||||||
"""Debug the position mapping issue"""
|
|
||||||
|
|
||||||
print("Debugging Position Mapping Issue")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
# Test document
|
|
||||||
original_doc = "上诉人李淼因合同纠纷,法定代表人李淼,委托代理人李淼。"
|
|
||||||
entity = "李淼"
|
|
||||||
masked_text = "李M"
|
|
||||||
|
|
||||||
print(f"Original document: '{original_doc}'")
|
|
||||||
print(f"Entity to mask: '{entity}'")
|
|
||||||
print(f"Masked text: '{masked_text}'")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# First occurrence
|
|
||||||
print("=== First Occurrence ===")
|
|
||||||
result1 = find_entity_alignment(entity, original_doc)
|
|
||||||
if result1:
|
|
||||||
start1, end1, found1 = result1
|
|
||||||
print(f"Found at positions {start1}-{end1}: '{found1}'")
|
|
||||||
|
|
||||||
# Apply first mask
|
|
||||||
masked_doc = original_doc[:start1] + masked_text + original_doc[end1:]
|
|
||||||
print(f"After first mask: '{masked_doc}'")
|
|
||||||
print(f"Length changed from {len(original_doc)} to {len(masked_doc)}")
|
|
||||||
|
|
||||||
# Try to find second occurrence in the masked document
|
|
||||||
print("\n=== Second Occurrence (in masked document) ===")
|
|
||||||
result2 = find_entity_alignment(entity, masked_doc)
|
|
||||||
if result2:
|
|
||||||
start2, end2, found2 = result2
|
|
||||||
print(f"Found at positions {start2}-{end2}: '{found2}'")
|
|
||||||
|
|
||||||
# Apply second mask
|
|
||||||
masked_doc2 = masked_doc[:start2] + masked_text + masked_doc[end2:]
|
|
||||||
print(f"After second mask: '{masked_doc2}'")
|
|
||||||
|
|
||||||
# Try to find third occurrence
|
|
||||||
print("\n=== Third Occurrence (in double-masked document) ===")
|
|
||||||
result3 = find_entity_alignment(entity, masked_doc2)
|
|
||||||
if result3:
|
|
||||||
start3, end3, found3 = result3
|
|
||||||
print(f"Found at positions {start3}-{end3}: '{found3}'")
|
|
||||||
else:
|
|
||||||
print("No third occurrence found")
|
|
||||||
else:
|
|
||||||
print("No second occurrence found")
|
|
||||||
else:
|
|
||||||
print("No first occurrence found")
|
|
||||||
|
|
||||||
def debug_infinite_loop():
|
|
||||||
"""Debug the infinite loop issue"""
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("Debugging Infinite Loop Issue")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
# Test document that causes infinite loop
|
|
||||||
original_doc = "上诉人李淼因合同纠纷,法定代表人李淼。北京丰复久信营销科技有限公司,丰复久信公司。"
|
|
||||||
entity = "丰复久信公司"
|
|
||||||
masked_text = "丰复久信公司" # Same text (no change)
|
|
||||||
|
|
||||||
print(f"Original document: '{original_doc}'")
|
|
||||||
print(f"Entity to mask: '{entity}'")
|
|
||||||
print(f"Masked text: '{masked_text}' (same as original)")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# This will cause infinite loop because we're replacing with the same text
|
|
||||||
print("=== This will cause infinite loop ===")
|
|
||||||
print("Because we're replacing '丰复久信公司' with '丰复久信公司'")
|
|
||||||
print("The document doesn't change, so we keep finding the same position")
|
|
||||||
|
|
||||||
# Show what happens
|
|
||||||
masked_doc = original_doc
|
|
||||||
for i in range(3): # Limit to 3 iterations for demo
|
|
||||||
result = find_entity_alignment(entity, masked_doc)
|
|
||||||
if result:
|
|
||||||
start, end, found = result
|
|
||||||
print(f"Iteration {i+1}: Found at positions {start}-{end}: '{found}'")
|
|
||||||
|
|
||||||
# Apply mask (but it's the same text)
|
|
||||||
masked_doc = masked_doc[:start] + masked_text + masked_doc[end:]
|
|
||||||
print(f"After mask: '{masked_doc}'")
|
|
||||||
else:
|
|
||||||
print(f"Iteration {i+1}: No occurrence found")
|
|
||||||
break
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
debug_position_issue()
|
|
||||||
debug_infinite_loop()
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
关于张三天和北京易见天树有限公司的劳动纠纷
|
||||||
|
|
@ -1,129 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test file for address masking functionality
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Add the backend directory to the Python path for imports
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
from app.core.document_handlers.ner_processor import NerProcessor
|
|
||||||
|
|
||||||
|
|
||||||
def test_address_masking():
|
|
||||||
"""Test address masking with the new rules"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test cases based on the requirements
|
|
||||||
test_cases = [
|
|
||||||
("上海市静安区恒丰路66号白云大厦1607室", "上海市静安区HF路**号BY大厦****室"),
|
|
||||||
("北京市朝阳区建国路88号SOHO现代城A座1001室", "北京市朝阳区JG路**号SOHO现代城A座****室"),
|
|
||||||
("广州市天河区珠江新城花城大道123号富力中心B座2001室", "广州市天河区珠江新城HC大道**号FL中心B座****室"),
|
|
||||||
("深圳市南山区科技园南区深南大道9988号腾讯大厦T1栋15楼", "深圳市南山区科技园南区SN大道**号TX大厦T1栋**楼"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for original_address, expected_masked in test_cases:
|
|
||||||
masked = processor._mask_address(original_address)
|
|
||||||
print(f"Original: {original_address}")
|
|
||||||
print(f"Masked: {masked}")
|
|
||||||
print(f"Expected: {expected_masked}")
|
|
||||||
print("-" * 50)
|
|
||||||
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
|
|
||||||
|
|
||||||
|
|
||||||
def test_address_component_extraction():
|
|
||||||
"""Test address component extraction"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test address component extraction
|
|
||||||
test_cases = [
|
|
||||||
("上海市静安区恒丰路66号白云大厦1607室", {
|
|
||||||
"road_name": "恒丰路",
|
|
||||||
"house_number": "66",
|
|
||||||
"building_name": "白云大厦",
|
|
||||||
"community_name": ""
|
|
||||||
}),
|
|
||||||
("北京市朝阳区建国路88号SOHO现代城A座1001室", {
|
|
||||||
"road_name": "建国路",
|
|
||||||
"house_number": "88",
|
|
||||||
"building_name": "SOHO现代城",
|
|
||||||
"community_name": ""
|
|
||||||
}),
|
|
||||||
]
|
|
||||||
|
|
||||||
for address, expected_components in test_cases:
|
|
||||||
components = processor._extract_address_components(address)
|
|
||||||
print(f"Address: {address}")
|
|
||||||
print(f"Extracted components: {components}")
|
|
||||||
print(f"Expected: {expected_components}")
|
|
||||||
print("-" * 50)
|
|
||||||
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
|
|
||||||
|
|
||||||
|
|
||||||
def test_regex_fallback():
|
|
||||||
"""Test regex fallback for address extraction"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test regex extraction (fallback method)
|
|
||||||
test_address = "上海市静安区恒丰路66号白云大厦1607室"
|
|
||||||
components = processor._extract_address_components_with_regex(test_address)
|
|
||||||
|
|
||||||
print(f"Address: {test_address}")
|
|
||||||
print(f"Regex extracted components: {components}")
|
|
||||||
|
|
||||||
# Basic validation
|
|
||||||
assert "road_name" in components
|
|
||||||
assert "house_number" in components
|
|
||||||
assert "building_name" in components
|
|
||||||
assert "community_name" in components
|
|
||||||
assert "confidence" in components
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_validation_for_address():
|
|
||||||
"""Test JSON validation for address extraction responses"""
|
|
||||||
from app.core.utils.llm_validator import LLMResponseValidator
|
|
||||||
|
|
||||||
# Test valid JSON response
|
|
||||||
valid_response = {
|
|
||||||
"road_name": "恒丰路",
|
|
||||||
"house_number": "66",
|
|
||||||
"building_name": "白云大厦",
|
|
||||||
"community_name": "",
|
|
||||||
"confidence": 0.9
|
|
||||||
}
|
|
||||||
assert LLMResponseValidator.validate_address_extraction(valid_response) == True
|
|
||||||
|
|
||||||
# Test invalid JSON response (missing required field)
|
|
||||||
invalid_response = {
|
|
||||||
"road_name": "恒丰路",
|
|
||||||
"house_number": "66",
|
|
||||||
"building_name": "白云大厦",
|
|
||||||
"confidence": 0.9
|
|
||||||
}
|
|
||||||
assert LLMResponseValidator.validate_address_extraction(invalid_response) == False
|
|
||||||
|
|
||||||
# Test invalid JSON response (wrong type)
|
|
||||||
invalid_response2 = {
|
|
||||||
"road_name": 123,
|
|
||||||
"house_number": "66",
|
|
||||||
"building_name": "白云大厦",
|
|
||||||
"community_name": "",
|
|
||||||
"confidence": 0.9
|
|
||||||
}
|
|
||||||
assert LLMResponseValidator.validate_address_extraction(invalid_response2) == False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("Testing Address Masking Functionality")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
test_regex_fallback()
|
|
||||||
print()
|
|
||||||
test_json_validation_for_address()
|
|
||||||
print()
|
|
||||||
test_address_component_extraction()
|
|
||||||
print()
|
|
||||||
test_address_masking()
|
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
def test_basic_discovery():
|
|
||||||
"""Basic test to verify pytest discovery is working"""
|
|
||||||
assert True
|
|
||||||
|
|
||||||
def test_import_works():
|
|
||||||
"""Test that we can import from the app module"""
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.ner_processor import NerProcessor
|
|
||||||
assert NerProcessor is not None
|
|
||||||
except ImportError as e:
|
|
||||||
pytest.fail(f"Failed to import NerProcessor: {e}")
|
|
||||||
|
|
||||||
def test_simple_math():
|
|
||||||
"""Simple math test"""
|
|
||||||
assert 1 + 1 == 2
|
|
||||||
assert 2 * 3 == 6
|
|
||||||
|
|
@ -1,67 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script for character-by-character alignment functionality.
|
|
||||||
This script demonstrates how the alignment handles different spacing patterns
|
|
||||||
between entity text and original document text.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'backend'))
|
|
||||||
|
|
||||||
from app.core.document_handlers.ner_processor import NerProcessor
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Test the character alignment functionality."""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
print("Testing Character-by-Character Alignment")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
# Test the alignment functionality
|
|
||||||
processor.test_character_alignment()
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("Testing Entity Masking with Alignment")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
# Test entity masking with alignment
|
|
||||||
original_document = "上诉人(原审原告):北京丰复久信营销科技有限公司,住所地北京市海淀区北小马厂6号1号楼华天大厦1306室。法定代表人:郭东军,执行董事、经理。委托诉讼代理人:周大海,北京市康达律师事务所律师。"
|
|
||||||
|
|
||||||
# Example entity mapping (from your NER results)
|
|
||||||
entity_mapping = {
|
|
||||||
"北京丰复久信营销科技有限公司": "北京JO营销科技有限公司",
|
|
||||||
"郭东军": "郭DJ",
|
|
||||||
"周大海": "周DH",
|
|
||||||
"北京市康达律师事务所": "北京市KD律师事务所"
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"Original document: {original_document}")
|
|
||||||
print(f"Entity mapping: {entity_mapping}")
|
|
||||||
|
|
||||||
# Apply masking with alignment
|
|
||||||
masked_document = processor.apply_entity_masking_with_alignment(
|
|
||||||
original_document,
|
|
||||||
entity_mapping
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Masked document: {masked_document}")
|
|
||||||
|
|
||||||
# Test with document that has spaces
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("Testing with Document Containing Spaces")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
spaced_document = "上诉人(原审原告):北京 丰复久信 营销科技 有限公司,住所地北京市海淀区北小马厂6号1号楼华天大厦1306室。法定代表人:郭 东 军,执行董事、经理。"
|
|
||||||
|
|
||||||
print(f"Spaced document: {spaced_document}")
|
|
||||||
|
|
||||||
masked_spaced_document = processor.apply_entity_masking_with_alignment(
|
|
||||||
spaced_document,
|
|
||||||
entity_mapping
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Masked spaced document: {masked_spaced_document}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,230 +0,0 @@
|
||||||
"""
|
|
||||||
Test file for the enhanced OllamaClient with validation and retry mechanisms.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
# Add the current directory to the Python path
|
|
||||||
sys.path.insert(0, os.path.dirname(__file__))
|
|
||||||
|
|
||||||
def test_ollama_client_initialization():
|
|
||||||
"""Test OllamaClient initialization with new parameters"""
|
|
||||||
from app.core.services.ollama_client import OllamaClient
|
|
||||||
|
|
||||||
# Test with default parameters
|
|
||||||
client = OllamaClient("test-model")
|
|
||||||
assert client.model_name == "test-model"
|
|
||||||
assert client.base_url == "http://localhost:11434"
|
|
||||||
assert client.max_retries == 3
|
|
||||||
|
|
||||||
# Test with custom parameters
|
|
||||||
client = OllamaClient("test-model", "http://custom:11434", 5)
|
|
||||||
assert client.model_name == "test-model"
|
|
||||||
assert client.base_url == "http://custom:11434"
|
|
||||||
assert client.max_retries == 5
|
|
||||||
|
|
||||||
print("✓ OllamaClient initialization tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_with_validation():
|
|
||||||
"""Test generate_with_validation method"""
|
|
||||||
from app.core.services.ollama_client import OllamaClient
|
|
||||||
|
|
||||||
# Mock the API response
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"response": '{"business_name": "测试公司", "confidence": 0.9}'
|
|
||||||
}
|
|
||||||
mock_response.raise_for_status.return_value = None
|
|
||||||
|
|
||||||
with patch('requests.post', return_value=mock_response):
|
|
||||||
client = OllamaClient("test-model")
|
|
||||||
|
|
||||||
# Test with business name extraction validation
|
|
||||||
result = client.generate_with_validation(
|
|
||||||
prompt="Extract business name from: 测试公司",
|
|
||||||
response_type='business_name_extraction',
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert result.get('business_name') == '测试公司'
|
|
||||||
assert result.get('confidence') == 0.9
|
|
||||||
|
|
||||||
print("✓ generate_with_validation test passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_with_schema():
|
|
||||||
"""Test generate_with_schema method"""
|
|
||||||
from app.core.services.ollama_client import OllamaClient
|
|
||||||
|
|
||||||
# Define a custom schema
|
|
||||||
custom_schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"},
|
|
||||||
"age": {"type": "number"}
|
|
||||||
},
|
|
||||||
"required": ["name", "age"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Mock the API response
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"response": '{"name": "张三", "age": 30}'
|
|
||||||
}
|
|
||||||
mock_response.raise_for_status.return_value = None
|
|
||||||
|
|
||||||
with patch('requests.post', return_value=mock_response):
|
|
||||||
client = OllamaClient("test-model")
|
|
||||||
|
|
||||||
# Test with custom schema validation
|
|
||||||
result = client.generate_with_schema(
|
|
||||||
prompt="Generate person info",
|
|
||||||
schema=custom_schema,
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert result.get('name') == '张三'
|
|
||||||
assert result.get('age') == 30
|
|
||||||
|
|
||||||
print("✓ generate_with_schema test passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_backward_compatibility():
|
|
||||||
"""Test backward compatibility with original generate method"""
|
|
||||||
from app.core.services.ollama_client import OllamaClient
|
|
||||||
|
|
||||||
# Mock the API response
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"response": "Simple text response"
|
|
||||||
}
|
|
||||||
mock_response.raise_for_status.return_value = None
|
|
||||||
|
|
||||||
with patch('requests.post', return_value=mock_response):
|
|
||||||
client = OllamaClient("test-model")
|
|
||||||
|
|
||||||
# Test original generate method (should still work)
|
|
||||||
result = client.generate("Simple prompt")
|
|
||||||
assert result == "Simple text response"
|
|
||||||
|
|
||||||
# Test with strip_think=False
|
|
||||||
result = client.generate("Simple prompt", strip_think=False)
|
|
||||||
assert result == "Simple text response"
|
|
||||||
|
|
||||||
print("✓ Backward compatibility tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_retry_mechanism():
|
|
||||||
"""Test retry mechanism for failed requests"""
|
|
||||||
from app.core.services.ollama_client import OllamaClient
|
|
||||||
import requests
|
|
||||||
|
|
||||||
# Mock failed requests followed by success
|
|
||||||
mock_failed_response = Mock()
|
|
||||||
mock_failed_response.raise_for_status.side_effect = requests.exceptions.RequestException("Connection failed")
|
|
||||||
|
|
||||||
mock_success_response = Mock()
|
|
||||||
mock_success_response.json.return_value = {
|
|
||||||
"response": "Success response"
|
|
||||||
}
|
|
||||||
mock_success_response.raise_for_status.return_value = None
|
|
||||||
|
|
||||||
with patch('requests.post', side_effect=[mock_failed_response, mock_success_response]):
|
|
||||||
client = OllamaClient("test-model", max_retries=2)
|
|
||||||
|
|
||||||
# Should retry and eventually succeed
|
|
||||||
result = client.generate("Test prompt")
|
|
||||||
assert result == "Success response"
|
|
||||||
|
|
||||||
print("✓ Retry mechanism test passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_validation_failure():
|
|
||||||
"""Test validation failure handling"""
|
|
||||||
from app.core.services.ollama_client import OllamaClient
|
|
||||||
|
|
||||||
# Mock API response with invalid JSON
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"response": "Invalid JSON response"
|
|
||||||
}
|
|
||||||
mock_response.raise_for_status.return_value = None
|
|
||||||
|
|
||||||
with patch('requests.post', return_value=mock_response):
|
|
||||||
client = OllamaClient("test-model", max_retries=2)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# This should fail validation and retry
|
|
||||||
result = client.generate_with_validation(
|
|
||||||
prompt="Test prompt",
|
|
||||||
response_type='business_name_extraction',
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
# If we get here, it means validation failed and retries were exhausted
|
|
||||||
print("✓ Validation failure handling test passed")
|
|
||||||
except ValueError as e:
|
|
||||||
# Expected behavior - validation failed after retries
|
|
||||||
assert "Failed to parse JSON response after all retries" in str(e)
|
|
||||||
print("✓ Validation failure handling test passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_enhanced_methods():
|
|
||||||
"""Test the new enhanced methods"""
|
|
||||||
from app.core.services.ollama_client import OllamaClient
|
|
||||||
|
|
||||||
# Mock the API response
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"response": '{"entities": [{"text": "张三", "type": "人名"}]}'
|
|
||||||
}
|
|
||||||
mock_response.raise_for_status.return_value = None
|
|
||||||
|
|
||||||
with patch('requests.post', return_value=mock_response):
|
|
||||||
client = OllamaClient("test-model")
|
|
||||||
|
|
||||||
# Test generate_with_validation
|
|
||||||
result = client.generate_with_validation(
|
|
||||||
prompt="Extract entities",
|
|
||||||
response_type='entity_extraction',
|
|
||||||
return_parsed=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert 'entities' in result
|
|
||||||
assert len(result['entities']) == 1
|
|
||||||
assert result['entities'][0]['text'] == '张三'
|
|
||||||
|
|
||||||
print("✓ Enhanced methods tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all tests"""
|
|
||||||
print("Testing enhanced OllamaClient...")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
try:
|
|
||||||
test_ollama_client_initialization()
|
|
||||||
test_generate_with_validation()
|
|
||||||
test_generate_with_schema()
|
|
||||||
test_backward_compatibility()
|
|
||||||
test_retry_mechanism()
|
|
||||||
test_validation_failure()
|
|
||||||
test_enhanced_methods()
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("✓ All enhanced OllamaClient tests passed!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n✗ Test failed: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,186 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Final test to verify the fix handles multiple occurrences and prevents infinite loops.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def find_entity_alignment(entity_text: str, original_document_text: str):
|
|
||||||
"""Simplified version of the alignment method for testing"""
|
|
||||||
clean_entity = entity_text.replace(" ", "")
|
|
||||||
doc_chars = [c for c in original_document_text if c != ' ']
|
|
||||||
|
|
||||||
for i in range(len(doc_chars) - len(clean_entity) + 1):
|
|
||||||
if doc_chars[i:i+len(clean_entity)] == list(clean_entity):
|
|
||||||
return map_char_positions_to_original(i, len(clean_entity), original_document_text)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def map_char_positions_to_original(clean_start: int, entity_length: int, original_text: str):
|
|
||||||
"""Simplified version of position mapping for testing"""
|
|
||||||
original_pos = 0
|
|
||||||
clean_pos = 0
|
|
||||||
|
|
||||||
while clean_pos < clean_start and original_pos < len(original_text):
|
|
||||||
if original_text[original_pos] != ' ':
|
|
||||||
clean_pos += 1
|
|
||||||
original_pos += 1
|
|
||||||
|
|
||||||
start_pos = original_pos
|
|
||||||
|
|
||||||
chars_found = 0
|
|
||||||
while chars_found < entity_length and original_pos < len(original_text):
|
|
||||||
if original_text[original_pos] != ' ':
|
|
||||||
chars_found += 1
|
|
||||||
original_pos += 1
|
|
||||||
|
|
||||||
end_pos = original_pos
|
|
||||||
found_text = original_text[start_pos:end_pos]
|
|
||||||
|
|
||||||
return start_pos, end_pos, found_text
|
|
||||||
|
|
||||||
def apply_entity_masking_with_alignment_fixed(original_document_text: str, entity_mapping: dict):
|
|
||||||
"""Fixed implementation that handles multiple occurrences and prevents infinite loops"""
|
|
||||||
masked_document = original_document_text
|
|
||||||
sorted_entities = sorted(entity_mapping.keys(), key=len, reverse=True)
|
|
||||||
|
|
||||||
for entity_text in sorted_entities:
|
|
||||||
masked_text = entity_mapping[entity_text]
|
|
||||||
|
|
||||||
# Skip if masked text is the same as original text (prevents infinite loop)
|
|
||||||
if entity_text == masked_text:
|
|
||||||
print(f"Skipping entity '{entity_text}' as masked text is identical")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Find ALL occurrences of this entity in the document
|
|
||||||
# Add safety counter to prevent infinite loops
|
|
||||||
max_iterations = 100 # Safety limit
|
|
||||||
iteration_count = 0
|
|
||||||
|
|
||||||
while iteration_count < max_iterations:
|
|
||||||
iteration_count += 1
|
|
||||||
|
|
||||||
# Find the entity in the current masked document using alignment
|
|
||||||
alignment_result = find_entity_alignment(entity_text, masked_document)
|
|
||||||
|
|
||||||
if alignment_result:
|
|
||||||
start_pos, end_pos, found_text = alignment_result
|
|
||||||
|
|
||||||
# Replace the found text with the masked version
|
|
||||||
masked_document = (
|
|
||||||
masked_document[:start_pos] +
|
|
||||||
masked_text +
|
|
||||||
masked_document[end_pos:]
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Masked entity '{entity_text}' -> '{masked_text}' at positions {start_pos}-{end_pos} (iteration {iteration_count})")
|
|
||||||
else:
|
|
||||||
# No more occurrences found for this entity, move to next entity
|
|
||||||
print(f"No more occurrences of '{entity_text}' found in document after {iteration_count} iterations")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Log warning if we hit the safety limit
|
|
||||||
if iteration_count >= max_iterations:
|
|
||||||
print(f"WARNING: Reached maximum iterations ({max_iterations}) for entity '{entity_text}', stopping to prevent infinite loop")
|
|
||||||
|
|
||||||
return masked_document
|
|
||||||
|
|
||||||
def test_final_fix():
|
|
||||||
"""Test the final fix with various scenarios"""
|
|
||||||
|
|
||||||
print("Testing Final Fix for Multiple Occurrences and Infinite Loop Prevention")
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
# Test case 1: Multiple occurrences of the same entity (should work)
|
|
||||||
print("\nTest Case 1: Multiple occurrences of same entity")
|
|
||||||
test_document_1 = "上诉人李淼因合同纠纷,法定代表人李淼,委托代理人李淼。"
|
|
||||||
entity_mapping_1 = {"李淼": "李M"}
|
|
||||||
|
|
||||||
print(f"Original: {test_document_1}")
|
|
||||||
result_1 = apply_entity_masking_with_alignment_fixed(test_document_1, entity_mapping_1)
|
|
||||||
print(f"Result: {result_1}")
|
|
||||||
|
|
||||||
remaining_1 = result_1.count("李淼")
|
|
||||||
expected_1 = "上诉人李M因合同纠纷,法定代表人李M,委托代理人李M。"
|
|
||||||
|
|
||||||
if result_1 == expected_1 and remaining_1 == 0:
|
|
||||||
print("✅ PASS: All occurrences masked correctly")
|
|
||||||
else:
|
|
||||||
print(f"❌ FAIL: Expected '{expected_1}', got '{result_1}'")
|
|
||||||
print(f" Remaining '李淼' occurrences: {remaining_1}")
|
|
||||||
|
|
||||||
# Test case 2: Entity with same masked text (should skip to prevent infinite loop)
|
|
||||||
print("\nTest Case 2: Entity with same masked text (should skip)")
|
|
||||||
test_document_2 = "上诉人李淼因合同纠纷,法定代表人李淼。北京丰复久信营销科技有限公司,丰复久信公司。"
|
|
||||||
entity_mapping_2 = {
|
|
||||||
"李淼": "李M",
|
|
||||||
"丰复久信公司": "丰复久信公司" # Same text - should be skipped
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"Original: {test_document_2}")
|
|
||||||
result_2 = apply_entity_masking_with_alignment_fixed(test_document_2, entity_mapping_2)
|
|
||||||
print(f"Result: {result_2}")
|
|
||||||
|
|
||||||
remaining_2_li = result_2.count("李淼")
|
|
||||||
remaining_2_company = result_2.count("丰复久信公司")
|
|
||||||
|
|
||||||
if remaining_2_li == 0 and remaining_2_company == 1: # Company should remain unmasked
|
|
||||||
print("✅ PASS: Infinite loop prevented, only different text masked")
|
|
||||||
else:
|
|
||||||
print(f"❌ FAIL: Remaining '李淼': {remaining_2_li}, '丰复久信公司': {remaining_2_company}")
|
|
||||||
|
|
||||||
# Test case 3: Mixed spacing scenarios
|
|
||||||
print("\nTest Case 3: Mixed spacing scenarios")
|
|
||||||
test_document_3 = "上诉人李 淼因合同纠纷,法定代表人李淼,委托代理人李 淼。"
|
|
||||||
entity_mapping_3 = {"李 淼": "李M", "李淼": "李M"}
|
|
||||||
|
|
||||||
print(f"Original: {test_document_3}")
|
|
||||||
result_3 = apply_entity_masking_with_alignment_fixed(test_document_3, entity_mapping_3)
|
|
||||||
print(f"Result: {result_3}")
|
|
||||||
|
|
||||||
remaining_3 = result_3.count("李淼") + result_3.count("李 淼")
|
|
||||||
|
|
||||||
if remaining_3 == 0:
|
|
||||||
print("✅ PASS: Mixed spacing handled correctly")
|
|
||||||
else:
|
|
||||||
print(f"❌ FAIL: Remaining occurrences: {remaining_3}")
|
|
||||||
|
|
||||||
# Test case 4: Complex document with real examples
|
|
||||||
print("\nTest Case 4: Complex document with real examples")
|
|
||||||
test_document_4 = """上诉人(原审原告):北京丰复久信营销科技有限公司,住所地北京市海淀区北小马厂6号1号楼华天大厦1306室。
|
|
||||||
法定代表人:郭东军,执行董事、经理。
|
|
||||||
委托诉讼代理人:周大海,北京市康达律师事务所律师。
|
|
||||||
委托诉讼代理人:王乃哲,北京市康达律师事务所律师。
|
|
||||||
被上诉人(原审被告):中研智创区块链技术有限公司,住所地天津市津南区双港镇工业园区优谷产业园5号楼-1505。
|
|
||||||
法定代表人:王欢子,总经理。
|
|
||||||
委托诉讼代理人:魏鑫,北京市昊衡律师事务所律师。"""
|
|
||||||
|
|
||||||
entity_mapping_4 = {
|
|
||||||
"北京丰复久信营销科技有限公司": "北京JO营销科技有限公司",
|
|
||||||
"郭东军": "郭DJ",
|
|
||||||
"周大海": "周DH",
|
|
||||||
"王乃哲": "王NZ",
|
|
||||||
"中研智创区块链技术有限公司": "中研智创区块链技术有限公司", # Same text - should be skipped
|
|
||||||
"王欢子": "王HZ",
|
|
||||||
"魏鑫": "魏X",
|
|
||||||
"北京市康达律师事务所": "北京市KD律师事务所",
|
|
||||||
"北京市昊衡律师事务所": "北京市HH律师事务所"
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"Original length: {len(test_document_4)} characters")
|
|
||||||
result_4 = apply_entity_masking_with_alignment_fixed(test_document_4, entity_mapping_4)
|
|
||||||
print(f"Result length: {len(result_4)} characters")
|
|
||||||
|
|
||||||
# Check that entities were masked correctly
|
|
||||||
unmasked_entities = []
|
|
||||||
for entity in entity_mapping_4.keys():
|
|
||||||
if entity in result_4 and entity != entity_mapping_4[entity]: # Skip if masked text is same
|
|
||||||
unmasked_entities.append(entity)
|
|
||||||
|
|
||||||
if not unmasked_entities:
|
|
||||||
print("✅ PASS: All entities masked correctly in complex document")
|
|
||||||
else:
|
|
||||||
print(f"❌ FAIL: Unmasked entities: {unmasked_entities}")
|
|
||||||
|
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print("Final Fix Verification Completed!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_final_fix()
|
|
||||||
|
|
@ -1,173 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test to verify the fix for multiple occurrence issue in apply_entity_masking_with_alignment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def find_entity_alignment(entity_text: str, original_document_text: str):
|
|
||||||
"""Simplified version of the alignment method for testing"""
|
|
||||||
clean_entity = entity_text.replace(" ", "")
|
|
||||||
doc_chars = [c for c in original_document_text if c != ' ']
|
|
||||||
|
|
||||||
for i in range(len(doc_chars) - len(clean_entity) + 1):
|
|
||||||
if doc_chars[i:i+len(clean_entity)] == list(clean_entity):
|
|
||||||
return map_char_positions_to_original(i, len(clean_entity), original_document_text)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def map_char_positions_to_original(clean_start: int, entity_length: int, original_text: str):
|
|
||||||
"""Simplified version of position mapping for testing"""
|
|
||||||
original_pos = 0
|
|
||||||
clean_pos = 0
|
|
||||||
|
|
||||||
while clean_pos < clean_start and original_pos < len(original_text):
|
|
||||||
if original_text[original_pos] != ' ':
|
|
||||||
clean_pos += 1
|
|
||||||
original_pos += 1
|
|
||||||
|
|
||||||
start_pos = original_pos
|
|
||||||
|
|
||||||
chars_found = 0
|
|
||||||
while chars_found < entity_length and original_pos < len(original_text):
|
|
||||||
if original_text[original_pos] != ' ':
|
|
||||||
chars_found += 1
|
|
||||||
original_pos += 1
|
|
||||||
|
|
||||||
end_pos = original_pos
|
|
||||||
found_text = original_text[start_pos:end_pos]
|
|
||||||
|
|
||||||
return start_pos, end_pos, found_text
|
|
||||||
|
|
||||||
def apply_entity_masking_with_alignment_fixed(original_document_text: str, entity_mapping: dict):
|
|
||||||
"""Fixed implementation that handles multiple occurrences"""
|
|
||||||
masked_document = original_document_text
|
|
||||||
sorted_entities = sorted(entity_mapping.keys(), key=len, reverse=True)
|
|
||||||
|
|
||||||
for entity_text in sorted_entities:
|
|
||||||
masked_text = entity_mapping[entity_text]
|
|
||||||
|
|
||||||
# Find ALL occurrences of this entity in the document
|
|
||||||
# We need to loop until no more matches are found
|
|
||||||
while True:
|
|
||||||
# Find the entity in the current masked document using alignment
|
|
||||||
alignment_result = find_entity_alignment(entity_text, masked_document)
|
|
||||||
|
|
||||||
if alignment_result:
|
|
||||||
start_pos, end_pos, found_text = alignment_result
|
|
||||||
|
|
||||||
# Replace the found text with the masked version
|
|
||||||
masked_document = (
|
|
||||||
masked_document[:start_pos] +
|
|
||||||
masked_text +
|
|
||||||
masked_document[end_pos:]
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Masked entity '{entity_text}' -> '{masked_text}' at positions {start_pos}-{end_pos}")
|
|
||||||
else:
|
|
||||||
# No more occurrences found for this entity, move to next entity
|
|
||||||
print(f"No more occurrences of '{entity_text}' found in document")
|
|
||||||
break
|
|
||||||
|
|
||||||
return masked_document
|
|
||||||
|
|
||||||
def test_fix_verification():
|
|
||||||
"""Test to verify the fix works correctly"""
|
|
||||||
|
|
||||||
print("Testing Fix for Multiple Occurrence Issue")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Test case 1: Multiple occurrences of the same entity
|
|
||||||
print("\nTest Case 1: Multiple occurrences of same entity")
|
|
||||||
test_document_1 = "上诉人李淼因合同纠纷,法定代表人李淼,委托代理人李淼。"
|
|
||||||
entity_mapping_1 = {"李淼": "李M"}
|
|
||||||
|
|
||||||
print(f"Original: {test_document_1}")
|
|
||||||
result_1 = apply_entity_masking_with_alignment_fixed(test_document_1, entity_mapping_1)
|
|
||||||
print(f"Result: {result_1}")
|
|
||||||
|
|
||||||
remaining_1 = result_1.count("李淼")
|
|
||||||
expected_1 = "上诉人李M因合同纠纷,法定代表人李M,委托代理人李M。"
|
|
||||||
|
|
||||||
if result_1 == expected_1 and remaining_1 == 0:
|
|
||||||
print("✅ PASS: All occurrences masked correctly")
|
|
||||||
else:
|
|
||||||
print(f"❌ FAIL: Expected '{expected_1}', got '{result_1}'")
|
|
||||||
print(f" Remaining '李淼' occurrences: {remaining_1}")
|
|
||||||
|
|
||||||
# Test case 2: Multiple entities with multiple occurrences
|
|
||||||
print("\nTest Case 2: Multiple entities with multiple occurrences")
|
|
||||||
test_document_2 = "上诉人李淼因合同纠纷,法定代表人李淼。北京丰复久信营销科技有限公司,丰复久信公司。"
|
|
||||||
entity_mapping_2 = {
|
|
||||||
"李淼": "李M",
|
|
||||||
"北京丰复久信营销科技有限公司": "北京JO营销科技有限公司",
|
|
||||||
"丰复久信公司": "丰复久信公司"
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"Original: {test_document_2}")
|
|
||||||
result_2 = apply_entity_masking_with_alignment_fixed(test_document_2, entity_mapping_2)
|
|
||||||
print(f"Result: {result_2}")
|
|
||||||
|
|
||||||
remaining_2_li = result_2.count("李淼")
|
|
||||||
remaining_2_company = result_2.count("北京丰复久信营销科技有限公司")
|
|
||||||
|
|
||||||
if remaining_2_li == 0 and remaining_2_company == 0:
|
|
||||||
print("✅ PASS: All entities masked correctly")
|
|
||||||
else:
|
|
||||||
print(f"❌ FAIL: Remaining '李淼': {remaining_2_li}, '北京丰复久信营销科技有限公司': {remaining_2_company}")
|
|
||||||
|
|
||||||
# Test case 3: Mixed spacing scenarios
|
|
||||||
print("\nTest Case 3: Mixed spacing scenarios")
|
|
||||||
test_document_3 = "上诉人李 淼因合同纠纷,法定代表人李淼,委托代理人李 淼。"
|
|
||||||
entity_mapping_3 = {"李 淼": "李M", "李淼": "李M"}
|
|
||||||
|
|
||||||
print(f"Original: {test_document_3}")
|
|
||||||
result_3 = apply_entity_masking_with_alignment_fixed(test_document_3, entity_mapping_3)
|
|
||||||
print(f"Result: {result_3}")
|
|
||||||
|
|
||||||
remaining_3 = result_3.count("李淼") + result_3.count("李 淼")
|
|
||||||
|
|
||||||
if remaining_3 == 0:
|
|
||||||
print("✅ PASS: Mixed spacing handled correctly")
|
|
||||||
else:
|
|
||||||
print(f"❌ FAIL: Remaining occurrences: {remaining_3}")
|
|
||||||
|
|
||||||
# Test case 4: Complex document with real examples
|
|
||||||
print("\nTest Case 4: Complex document with real examples")
|
|
||||||
test_document_4 = """上诉人(原审原告):北京丰复久信营销科技有限公司,住所地北京市海淀区北小马厂6号1号楼华天大厦1306室。
|
|
||||||
法定代表人:郭东军,执行董事、经理。
|
|
||||||
委托诉讼代理人:周大海,北京市康达律师事务所律师。
|
|
||||||
委托诉讼代理人:王乃哲,北京市康达律师事务所律师。
|
|
||||||
被上诉人(原审被告):中研智创区块链技术有限公司,住所地天津市津南区双港镇工业园区优谷产业园5号楼-1505。
|
|
||||||
法定代表人:王欢子,总经理。
|
|
||||||
委托诉讼代理人:魏鑫,北京市昊衡律师事务所律师。"""
|
|
||||||
|
|
||||||
entity_mapping_4 = {
|
|
||||||
"北京丰复久信营销科技有限公司": "北京JO营销科技有限公司",
|
|
||||||
"郭东军": "郭DJ",
|
|
||||||
"周大海": "周DH",
|
|
||||||
"王乃哲": "王NZ",
|
|
||||||
"中研智创区块链技术有限公司": "中研智创区块链技术有限公司",
|
|
||||||
"王欢子": "王HZ",
|
|
||||||
"魏鑫": "魏X",
|
|
||||||
"北京市康达律师事务所": "北京市KD律师事务所",
|
|
||||||
"北京市昊衡律师事务所": "北京市HH律师事务所"
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"Original length: {len(test_document_4)} characters")
|
|
||||||
result_4 = apply_entity_masking_with_alignment_fixed(test_document_4, entity_mapping_4)
|
|
||||||
print(f"Result length: {len(result_4)} characters")
|
|
||||||
|
|
||||||
# Check that all entities were masked
|
|
||||||
unmasked_entities = []
|
|
||||||
for entity in entity_mapping_4.keys():
|
|
||||||
if entity in result_4:
|
|
||||||
unmasked_entities.append(entity)
|
|
||||||
|
|
||||||
if not unmasked_entities:
|
|
||||||
print("✅ PASS: All entities masked in complex document")
|
|
||||||
else:
|
|
||||||
print(f"❌ FAIL: Unmasked entities: {unmasked_entities}")
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Fix Verification Completed!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_fix_verification()
|
|
||||||
|
|
@ -1,169 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test file for ID and social credit code masking functionality
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Add the backend directory to the Python path for imports
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
from app.core.document_handlers.ner_processor import NerProcessor
|
|
||||||
|
|
||||||
|
|
||||||
def test_id_number_masking():
|
|
||||||
"""Test ID number masking with the new rules"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test cases based on the requirements
|
|
||||||
test_cases = [
|
|
||||||
("310103198802080000", "310103XXXXXXXXXXXX"),
|
|
||||||
("110101199001011234", "110101XXXXXXXXXXXX"),
|
|
||||||
("440301199505151234", "440301XXXXXXXXXXXX"),
|
|
||||||
("320102198712345678", "320102XXXXXXXXXXXX"),
|
|
||||||
("12345", "12345"), # Edge case: too short
|
|
||||||
]
|
|
||||||
|
|
||||||
for original_id, expected_masked in test_cases:
|
|
||||||
# Create a mock entity for testing
|
|
||||||
entity = {'text': original_id, 'type': '身份证号'}
|
|
||||||
unique_entities = [entity]
|
|
||||||
linkage = {'entity_groups': []}
|
|
||||||
|
|
||||||
# Test the masking through the full pipeline
|
|
||||||
mapping = processor._generate_masked_mapping(unique_entities, linkage)
|
|
||||||
masked = mapping.get(original_id, original_id)
|
|
||||||
|
|
||||||
print(f"Original ID: {original_id}")
|
|
||||||
print(f"Masked ID: {masked}")
|
|
||||||
print(f"Expected: {expected_masked}")
|
|
||||||
print(f"Match: {masked == expected_masked}")
|
|
||||||
print("-" * 50)
|
|
||||||
|
|
||||||
|
|
||||||
def test_social_credit_code_masking():
|
|
||||||
"""Test social credit code masking with the new rules"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test cases based on the requirements
|
|
||||||
test_cases = [
|
|
||||||
("9133021276453538XT", "913302XXXXXXXXXXXX"),
|
|
||||||
("91110000100000000X", "9111000XXXXXXXXXXX"),
|
|
||||||
("914403001922038216", "9144030XXXXXXXXXXX"),
|
|
||||||
("91310000132209458G", "9131000XXXXXXXXXXX"),
|
|
||||||
("123456", "123456"), # Edge case: too short
|
|
||||||
]
|
|
||||||
|
|
||||||
for original_code, expected_masked in test_cases:
|
|
||||||
# Create a mock entity for testing
|
|
||||||
entity = {'text': original_code, 'type': '社会信用代码'}
|
|
||||||
unique_entities = [entity]
|
|
||||||
linkage = {'entity_groups': []}
|
|
||||||
|
|
||||||
# Test the masking through the full pipeline
|
|
||||||
mapping = processor._generate_masked_mapping(unique_entities, linkage)
|
|
||||||
masked = mapping.get(original_code, original_code)
|
|
||||||
|
|
||||||
print(f"Original Code: {original_code}")
|
|
||||||
print(f"Masked Code: {masked}")
|
|
||||||
print(f"Expected: {expected_masked}")
|
|
||||||
print(f"Match: {masked == expected_masked}")
|
|
||||||
print("-" * 50)
|
|
||||||
|
|
||||||
|
|
||||||
def test_edge_cases():
|
|
||||||
"""Test edge cases for ID and social credit code masking"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test edge cases
|
|
||||||
edge_cases = [
|
|
||||||
("", ""), # Empty string
|
|
||||||
("123", "123"), # Too short for ID
|
|
||||||
("123456", "123456"), # Too short for social credit code
|
|
||||||
("123456789012345678901234567890", "123456XXXXXXXXXXXXXXXXXX"), # Very long ID
|
|
||||||
]
|
|
||||||
|
|
||||||
for original, expected in edge_cases:
|
|
||||||
# Test ID number
|
|
||||||
entity_id = {'text': original, 'type': '身份证号'}
|
|
||||||
mapping_id = processor._generate_masked_mapping([entity_id], {'entity_groups': []})
|
|
||||||
masked_id = mapping_id.get(original, original)
|
|
||||||
|
|
||||||
# Test social credit code
|
|
||||||
entity_code = {'text': original, 'type': '社会信用代码'}
|
|
||||||
mapping_code = processor._generate_masked_mapping([entity_code], {'entity_groups': []})
|
|
||||||
masked_code = mapping_code.get(original, original)
|
|
||||||
|
|
||||||
print(f"Original: {original}")
|
|
||||||
print(f"ID Masked: {masked_id}")
|
|
||||||
print(f"Code Masked: {masked_code}")
|
|
||||||
print("-" * 30)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mixed_entities():
|
|
||||||
"""Test masking with mixed entity types"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Create mixed entities
|
|
||||||
entities = [
|
|
||||||
{'text': '310103198802080000', 'type': '身份证号'},
|
|
||||||
{'text': '9133021276453538XT', 'type': '社会信用代码'},
|
|
||||||
{'text': '李强', 'type': '人名'},
|
|
||||||
{'text': '上海盒马网络科技有限公司', 'type': '公司名称'},
|
|
||||||
]
|
|
||||||
|
|
||||||
linkage = {'entity_groups': []}
|
|
||||||
|
|
||||||
# Test the masking through the full pipeline
|
|
||||||
mapping = processor._generate_masked_mapping(entities, linkage)
|
|
||||||
|
|
||||||
print("Mixed Entities Test:")
|
|
||||||
print("=" * 30)
|
|
||||||
for entity in entities:
|
|
||||||
original = entity['text']
|
|
||||||
entity_type = entity['type']
|
|
||||||
masked = mapping.get(original, original)
|
|
||||||
print(f"{entity_type}: {original} -> {masked}")
|
|
||||||
|
|
||||||
def test_id_masking():
|
|
||||||
"""Test ID number and social credit code masking"""
|
|
||||||
from app.core.document_handlers.ner_processor import NerProcessor
|
|
||||||
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test ID number masking
|
|
||||||
id_entity = {'text': '310103198802080000', 'type': '身份证号'}
|
|
||||||
id_mapping = processor._generate_masked_mapping([id_entity], {'entity_groups': []})
|
|
||||||
masked_id = id_mapping.get('310103198802080000', '')
|
|
||||||
|
|
||||||
# Test social credit code masking
|
|
||||||
code_entity = {'text': '9133021276453538XT', 'type': '社会信用代码'}
|
|
||||||
code_mapping = processor._generate_masked_mapping([code_entity], {'entity_groups': []})
|
|
||||||
masked_code = code_mapping.get('9133021276453538XT', '')
|
|
||||||
|
|
||||||
# Verify the masking rules
|
|
||||||
assert masked_id.startswith('310103') # First 6 digits preserved
|
|
||||||
assert masked_id.endswith('XXXXXXXXXXXX') # Rest masked with X
|
|
||||||
assert len(masked_id) == 18 # Total length preserved
|
|
||||||
|
|
||||||
assert masked_code.startswith('913302') # First 7 digits preserved
|
|
||||||
assert masked_code.endswith('XXXXXXXXXXXX') # Rest masked with X
|
|
||||||
assert len(masked_code) == 18 # Total length preserved
|
|
||||||
|
|
||||||
print(f"ID masking: 310103198802080000 -> {masked_id}")
|
|
||||||
print(f"Code masking: 9133021276453538XT -> {masked_code}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("Testing ID and Social Credit Code Masking")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
test_id_number_masking()
|
|
||||||
print()
|
|
||||||
test_social_credit_code_masking()
|
|
||||||
print()
|
|
||||||
test_edge_cases()
|
|
||||||
print()
|
|
||||||
test_mixed_entities()
|
|
||||||
|
|
@ -1,96 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test to verify the multiple occurrence issue in apply_entity_masking_with_alignment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def find_entity_alignment(entity_text: str, original_document_text: str):
|
|
||||||
"""Simplified version of the alignment method for testing"""
|
|
||||||
clean_entity = entity_text.replace(" ", "")
|
|
||||||
doc_chars = [c for c in original_document_text if c != ' ']
|
|
||||||
|
|
||||||
for i in range(len(doc_chars) - len(clean_entity) + 1):
|
|
||||||
if doc_chars[i:i+len(clean_entity)] == list(clean_entity):
|
|
||||||
return map_char_positions_to_original(i, len(clean_entity), original_document_text)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def map_char_positions_to_original(clean_start: int, entity_length: int, original_text: str):
|
|
||||||
"""Simplified version of position mapping for testing"""
|
|
||||||
original_pos = 0
|
|
||||||
clean_pos = 0
|
|
||||||
|
|
||||||
while clean_pos < clean_start and original_pos < len(original_text):
|
|
||||||
if original_text[original_pos] != ' ':
|
|
||||||
clean_pos += 1
|
|
||||||
original_pos += 1
|
|
||||||
|
|
||||||
start_pos = original_pos
|
|
||||||
|
|
||||||
chars_found = 0
|
|
||||||
while chars_found < entity_length and original_pos < len(original_text):
|
|
||||||
if original_text[original_pos] != ' ':
|
|
||||||
chars_found += 1
|
|
||||||
original_pos += 1
|
|
||||||
|
|
||||||
end_pos = original_pos
|
|
||||||
found_text = original_text[start_pos:end_pos]
|
|
||||||
|
|
||||||
return start_pos, end_pos, found_text
|
|
||||||
|
|
||||||
def apply_entity_masking_with_alignment_current(original_document_text: str, entity_mapping: dict):
|
|
||||||
"""Current implementation with the bug"""
|
|
||||||
masked_document = original_document_text
|
|
||||||
sorted_entities = sorted(entity_mapping.keys(), key=len, reverse=True)
|
|
||||||
|
|
||||||
for entity_text in sorted_entities:
|
|
||||||
masked_text = entity_mapping[entity_text]
|
|
||||||
|
|
||||||
# Find the entity in the original document using alignment
|
|
||||||
alignment_result = find_entity_alignment(entity_text, masked_document)
|
|
||||||
|
|
||||||
if alignment_result:
|
|
||||||
start_pos, end_pos, found_text = alignment_result
|
|
||||||
|
|
||||||
# Replace the found text with the masked version
|
|
||||||
masked_document = (
|
|
||||||
masked_document[:start_pos] +
|
|
||||||
masked_text +
|
|
||||||
masked_document[end_pos:]
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Masked entity '{entity_text}' -> '{masked_text}' at positions {start_pos}-{end_pos}")
|
|
||||||
else:
|
|
||||||
print(f"Could not find entity '{entity_text}' in document for masking")
|
|
||||||
|
|
||||||
return masked_document
|
|
||||||
|
|
||||||
def test_multiple_occurrences():
|
|
||||||
"""Test the multiple occurrence issue"""
|
|
||||||
|
|
||||||
print("Testing Multiple Occurrence Issue")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
# Test document with multiple occurrences of the same entity
|
|
||||||
test_document = "上诉人李淼因合同纠纷,法定代表人李淼,委托代理人李淼。"
|
|
||||||
entity_mapping = {
|
|
||||||
"李淼": "李M"
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"Original document: {test_document}")
|
|
||||||
print(f"Entity mapping: {entity_mapping}")
|
|
||||||
print(f"Expected: All 3 occurrences of '李淼' should be masked")
|
|
||||||
|
|
||||||
# Test current implementation
|
|
||||||
result = apply_entity_masking_with_alignment_current(test_document, entity_mapping)
|
|
||||||
print(f"Current result: {result}")
|
|
||||||
|
|
||||||
# Count remaining occurrences
|
|
||||||
remaining_count = result.count("李淼")
|
|
||||||
print(f"Remaining '李淼' occurrences: {remaining_count}")
|
|
||||||
|
|
||||||
if remaining_count > 0:
|
|
||||||
print("❌ ISSUE CONFIRMED: Multiple occurrences are not being masked!")
|
|
||||||
else:
|
|
||||||
print("✅ No issue found (unexpected)")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_multiple_occurrences()
|
|
||||||
|
|
@ -1,134 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script for NER extractor integration
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
|
|
||||||
# Add the backend directory to the Python path
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend'))
|
|
||||||
|
|
||||||
from app.core.document_handlers.extractors.ner_extractor import NERExtractor
|
|
||||||
from app.core.document_handlers.ner_processor import NerProcessor
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def test_ner_extractor():
|
|
||||||
"""Test the NER extractor directly"""
|
|
||||||
print("🧪 Testing NER Extractor")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
# Sample legal text
|
|
||||||
text_to_analyze = """
|
|
||||||
上诉人(原审原告):北京丰复久信营销科技有限公司,住所地北京市海淀区北小马厂6号1号楼华天大厦1306室。
|
|
||||||
法定代表人:郭东军,执行董事、经理。
|
|
||||||
委托诉讼代理人:周大海,北京市康达律师事务所律师。
|
|
||||||
被上诉人(原审被告):中研智创区块链技术有限公司,住所地天津市津南区双港镇工业园区优谷产业园5号楼-1505。
|
|
||||||
法定代表人:王欢子,总经理。
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test NER extractor
|
|
||||||
print("1. Testing NER Extractor...")
|
|
||||||
ner_extractor = NERExtractor()
|
|
||||||
|
|
||||||
# Get model info
|
|
||||||
model_info = ner_extractor.get_model_info()
|
|
||||||
print(f" Model: {model_info['model_name']}")
|
|
||||||
print(f" Supported entities: {model_info['supported_entities']}")
|
|
||||||
|
|
||||||
# Extract entities
|
|
||||||
result = ner_extractor.extract_and_summarize(text_to_analyze)
|
|
||||||
|
|
||||||
print(f"\n2. Extraction Results:")
|
|
||||||
print(f" Total entities found: {result['total_count']}")
|
|
||||||
|
|
||||||
for entity in result['entities']:
|
|
||||||
print(f" - '{entity['text']}' ({entity['type']}) - Confidence: {entity['confidence']:.4f}")
|
|
||||||
|
|
||||||
print(f"\n3. Summary:")
|
|
||||||
for entity_type, texts in result['summary']['summary'].items():
|
|
||||||
print(f" {entity_type}: {len(texts)} entities")
|
|
||||||
for text in texts:
|
|
||||||
print(f" - {text}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ NER Extractor test failed: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_ner_processor():
|
|
||||||
"""Test the NER processor integration"""
|
|
||||||
print("\n🧪 Testing NER Processor Integration")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
# Sample legal text
|
|
||||||
text_to_analyze = """
|
|
||||||
上诉人(原审原告):北京丰复久信营销科技有限公司,住所地北京市海淀区北小马厂6号1号楼华天大厦1306室。
|
|
||||||
法定代表人:郭东军,执行董事、经理。
|
|
||||||
委托诉讼代理人:周大海,北京市康达律师事务所律师。
|
|
||||||
被上诉人(原审被告):中研智创区块链技术有限公司,住所地天津市津南区双港镇工业园区优谷产业园5号楼-1505。
|
|
||||||
法定代表人:王欢子,总经理。
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test NER processor
|
|
||||||
print("1. Testing NER Processor...")
|
|
||||||
ner_processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test NER-only extraction
|
|
||||||
print("2. Testing NER-only entity extraction...")
|
|
||||||
ner_entities = ner_processor.extract_entities_with_ner(text_to_analyze)
|
|
||||||
print(f" Extracted {len(ner_entities)} entities with NER model")
|
|
||||||
|
|
||||||
for entity in ner_entities:
|
|
||||||
print(f" - '{entity['text']}' ({entity['type']}) - Confidence: {entity['confidence']:.4f}")
|
|
||||||
|
|
||||||
# Test NER-only processing
|
|
||||||
print("\n3. Testing NER-only document processing...")
|
|
||||||
chunks = [text_to_analyze] # Single chunk for testing
|
|
||||||
mapping = ner_processor.process_ner_only(chunks)
|
|
||||||
|
|
||||||
print(f" Generated {len(mapping)} masking mappings")
|
|
||||||
for original, masked in mapping.items():
|
|
||||||
print(f" '{original}' -> '{masked}'")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ NER Processor test failed: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main test function"""
|
|
||||||
print("🧪 NER Integration Test Suite")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Test 1: NER Extractor
|
|
||||||
extractor_success = test_ner_extractor()
|
|
||||||
|
|
||||||
# Test 2: NER Processor Integration
|
|
||||||
processor_success = test_ner_processor()
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("📊 Test Summary:")
|
|
||||||
print(f" NER Extractor: {'✅' if extractor_success else '❌'}")
|
|
||||||
print(f" NER Processor: {'✅' if processor_success else '❌'}")
|
|
||||||
|
|
||||||
if extractor_success and processor_success:
|
|
||||||
print("\n🎉 All tests passed! NER integration is working correctly.")
|
|
||||||
print("\nNext steps:")
|
|
||||||
print("1. The NER extractor is ready to use in the document processing pipeline")
|
|
||||||
print("2. You can use process_ner_only() for ML-based entity extraction")
|
|
||||||
print("3. The existing process() method now includes NER extraction")
|
|
||||||
else:
|
|
||||||
print("\n⚠️ Some tests failed. Please check the error messages above.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -4,9 +4,9 @@ from app.core.document_handlers.ner_processor import NerProcessor
|
||||||
def test_generate_masked_mapping():
|
def test_generate_masked_mapping():
|
||||||
processor = NerProcessor()
|
processor = NerProcessor()
|
||||||
unique_entities = [
|
unique_entities = [
|
||||||
{'text': '李强', 'type': '人名'},
|
{'text': '李雷', 'type': '人名'},
|
||||||
{'text': '李强', 'type': '人名'}, # Duplicate to test numbering
|
{'text': '李明', 'type': '人名'},
|
||||||
{'text': '王小明', 'type': '人名'},
|
{'text': '王强', 'type': '人名'},
|
||||||
{'text': 'Acme Manufacturing Inc.', 'type': '英文公司名', 'industry': 'manufacturing'},
|
{'text': 'Acme Manufacturing Inc.', 'type': '英文公司名', 'industry': 'manufacturing'},
|
||||||
{'text': 'Google LLC', 'type': '英文公司名'},
|
{'text': 'Google LLC', 'type': '英文公司名'},
|
||||||
{'text': 'A公司', 'type': '公司名称'},
|
{'text': 'A公司', 'type': '公司名称'},
|
||||||
|
|
@ -32,23 +32,23 @@ def test_generate_masked_mapping():
|
||||||
'group_id': 'g2',
|
'group_id': 'g2',
|
||||||
'group_type': '人名',
|
'group_type': '人名',
|
||||||
'entities': [
|
'entities': [
|
||||||
{'text': '李强', 'type': '人名', 'is_primary': True},
|
{'text': '李雷', 'type': '人名', 'is_primary': True},
|
||||||
{'text': '李强', 'type': '人名', 'is_primary': False},
|
{'text': '李明', 'type': '人名', 'is_primary': False},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
mapping = processor._generate_masked_mapping(unique_entities, linkage)
|
mapping = processor._generate_masked_mapping(unique_entities, linkage)
|
||||||
# 人名 - Updated for new Chinese name masking rules
|
# 人名
|
||||||
assert mapping['李强'] == '李Q'
|
assert mapping['李雷'].startswith('李某')
|
||||||
assert mapping['王小明'] == '王XM'
|
assert mapping['李明'].startswith('李某')
|
||||||
|
assert mapping['王强'].startswith('王某')
|
||||||
# 英文公司名
|
# 英文公司名
|
||||||
assert mapping['Acme Manufacturing Inc.'] == 'MANUFACTURING'
|
assert mapping['Acme Manufacturing Inc.'] == 'MANUFACTURING'
|
||||||
assert mapping['Google LLC'] == 'COMPANY'
|
assert mapping['Google LLC'] == 'COMPANY'
|
||||||
# 公司名同组 - Updated for new company masking rules
|
# 公司名同组
|
||||||
# Note: The exact results may vary due to LLM extraction
|
assert mapping['A公司'] == mapping['B公司']
|
||||||
assert '公司' in mapping['A公司'] or mapping['A公司'] != 'A公司'
|
assert mapping['A公司'].endswith('公司')
|
||||||
assert '公司' in mapping['B公司'] or mapping['B公司'] != 'B公司'
|
|
||||||
# 英文人名
|
# 英文人名
|
||||||
assert mapping['John Smith'] == 'J*** S***'
|
assert mapping['John Smith'] == 'J*** S***'
|
||||||
assert mapping['Elizabeth Windsor'] == 'E*** W***'
|
assert mapping['Elizabeth Windsor'] == 'E*** W***'
|
||||||
|
|
@ -60,216 +60,3 @@ def test_generate_masked_mapping():
|
||||||
assert mapping['310101198802080000'] == 'XXXXXX'
|
assert mapping['310101198802080000'] == 'XXXXXX'
|
||||||
# 社会信用代码
|
# 社会信用代码
|
||||||
assert mapping['9133021276453538XT'] == 'XXXXXXXX'
|
assert mapping['9133021276453538XT'] == 'XXXXXXXX'
|
||||||
|
|
||||||
|
|
||||||
def test_chinese_name_pinyin_masking():
|
|
||||||
"""Test Chinese name masking with pinyin functionality"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test basic Chinese name masking
|
|
||||||
test_cases = [
|
|
||||||
("李强", "李Q"),
|
|
||||||
("张韶涵", "张SH"),
|
|
||||||
("张若宇", "张RY"),
|
|
||||||
("白锦程", "白JC"),
|
|
||||||
("王小明", "王XM"),
|
|
||||||
("陈志强", "陈ZQ"),
|
|
||||||
]
|
|
||||||
|
|
||||||
surname_counter = {}
|
|
||||||
|
|
||||||
for original_name, expected_masked in test_cases:
|
|
||||||
masked = processor._mask_chinese_name(original_name, surname_counter)
|
|
||||||
assert masked == expected_masked, f"Expected {expected_masked}, got {masked} for {original_name}"
|
|
||||||
|
|
||||||
# Test duplicate handling
|
|
||||||
duplicate_test_cases = [
|
|
||||||
("李强", "李Q"),
|
|
||||||
("李强", "李Q2"), # Should be numbered
|
|
||||||
("李倩", "李Q3"), # Should be numbered
|
|
||||||
("张韶涵", "张SH"),
|
|
||||||
("张韶涵", "张SH2"), # Should be numbered
|
|
||||||
("张若宇", "张RY"), # Different initials, should not be numbered
|
|
||||||
]
|
|
||||||
|
|
||||||
surname_counter = {} # Reset counter
|
|
||||||
|
|
||||||
for original_name, expected_masked in duplicate_test_cases:
|
|
||||||
masked = processor._mask_chinese_name(original_name, surname_counter)
|
|
||||||
assert masked == expected_masked, f"Expected {expected_masked}, got {masked} for {original_name}"
|
|
||||||
|
|
||||||
# Test edge cases
|
|
||||||
edge_cases = [
|
|
||||||
("", ""), # Empty string
|
|
||||||
("李", "李"), # Single character
|
|
||||||
("李强强", "李QQ"), # Multiple characters with same pinyin
|
|
||||||
]
|
|
||||||
|
|
||||||
surname_counter = {} # Reset counter
|
|
||||||
|
|
||||||
for original_name, expected_masked in edge_cases:
|
|
||||||
masked = processor._mask_chinese_name(original_name, surname_counter)
|
|
||||||
assert masked == expected_masked, f"Expected {expected_masked}, got {masked} for {original_name}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_chinese_name_integration():
|
|
||||||
"""Test Chinese name masking integrated with the full mapping process"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test Chinese names in the full mapping context
|
|
||||||
unique_entities = [
|
|
||||||
{'text': '李强', 'type': '人名'},
|
|
||||||
{'text': '张韶涵', 'type': '人名'},
|
|
||||||
{'text': '张若宇', 'type': '人名'},
|
|
||||||
{'text': '白锦程', 'type': '人名'},
|
|
||||||
{'text': '李强', 'type': '人名'}, # Duplicate
|
|
||||||
{'text': '张韶涵', 'type': '人名'}, # Duplicate
|
|
||||||
]
|
|
||||||
|
|
||||||
linkage = {
|
|
||||||
'entity_groups': [
|
|
||||||
{
|
|
||||||
'group_id': 'g1',
|
|
||||||
'group_type': '人名',
|
|
||||||
'entities': [
|
|
||||||
{'text': '李强', 'type': '人名', 'is_primary': True},
|
|
||||||
{'text': '张韶涵', 'type': '人名', 'is_primary': True},
|
|
||||||
{'text': '张若宇', 'type': '人名', 'is_primary': True},
|
|
||||||
{'text': '白锦程', 'type': '人名', 'is_primary': True},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
mapping = processor._generate_masked_mapping(unique_entities, linkage)
|
|
||||||
|
|
||||||
# Verify the mapping results
|
|
||||||
assert mapping['李强'] == '李Q'
|
|
||||||
assert mapping['张韶涵'] == '张SH'
|
|
||||||
assert mapping['张若宇'] == '张RY'
|
|
||||||
assert mapping['白锦程'] == '白JC'
|
|
||||||
|
|
||||||
# Check that duplicates are handled correctly
|
|
||||||
# The second occurrence should be numbered
|
|
||||||
assert '李Q2' in mapping.values() or '张SH2' in mapping.values()
|
|
||||||
|
|
||||||
|
|
||||||
def test_lawyer_and_judge_names():
|
|
||||||
"""Test that lawyer and judge names follow the same Chinese name rules"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test lawyer and judge names
|
|
||||||
test_entities = [
|
|
||||||
{'text': '王律师', 'type': '律师姓名'},
|
|
||||||
{'text': '李法官', 'type': '审判人员姓名'},
|
|
||||||
{'text': '张检察官', 'type': '检察官姓名'},
|
|
||||||
]
|
|
||||||
|
|
||||||
linkage = {
|
|
||||||
'entity_groups': [
|
|
||||||
{
|
|
||||||
'group_id': 'g1',
|
|
||||||
'group_type': '律师姓名',
|
|
||||||
'entities': [{'text': '王律师', 'type': '律师姓名', 'is_primary': True}]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'group_id': 'g2',
|
|
||||||
'group_type': '审判人员姓名',
|
|
||||||
'entities': [{'text': '李法官', 'type': '审判人员姓名', 'is_primary': True}]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'group_id': 'g3',
|
|
||||||
'group_type': '检察官姓名',
|
|
||||||
'entities': [{'text': '张检察官', 'type': '检察官姓名', 'is_primary': True}]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
mapping = processor._generate_masked_mapping(test_entities, linkage)
|
|
||||||
|
|
||||||
# These should follow the same Chinese name masking rules
|
|
||||||
assert mapping['王律师'] == '王L'
|
|
||||||
assert mapping['李法官'] == '李F'
|
|
||||||
assert mapping['张检察官'] == '张JC'
|
|
||||||
|
|
||||||
|
|
||||||
def test_company_name_masking():
|
|
||||||
"""Test company name masking with business name extraction"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test basic company name masking
|
|
||||||
test_cases = [
|
|
||||||
("上海盒马网络科技有限公司", "上海JO网络科技有限公司"),
|
|
||||||
("丰田通商(上海)有限公司", "HVVU(上海)有限公司"),
|
|
||||||
("雅诗兰黛(上海)商贸有限公司", "AUNF(上海)商贸有限公司"),
|
|
||||||
("北京百度网讯科技有限公司", "北京BC网讯科技有限公司"),
|
|
||||||
("腾讯科技(深圳)有限公司", "TU科技(深圳)有限公司"),
|
|
||||||
("阿里巴巴集团控股有限公司", "阿里巴巴集团控股有限公司"), # 商号可能无法正确提取
|
|
||||||
]
|
|
||||||
|
|
||||||
for original_name, expected_masked in test_cases:
|
|
||||||
masked = processor._mask_company_name(original_name)
|
|
||||||
print(f"{original_name} -> {masked} (expected: {expected_masked})")
|
|
||||||
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
|
|
||||||
|
|
||||||
|
|
||||||
def test_business_name_extraction():
|
|
||||||
"""Test business name extraction from company names"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test business name extraction
|
|
||||||
test_cases = [
|
|
||||||
("上海盒马网络科技有限公司", "盒马"),
|
|
||||||
("丰田通商(上海)有限公司", "丰田通商"),
|
|
||||||
("雅诗兰黛(上海)商贸有限公司", "雅诗兰黛"),
|
|
||||||
("北京百度网讯科技有限公司", "百度"),
|
|
||||||
("腾讯科技(深圳)有限公司", "腾讯"),
|
|
||||||
("律师事务所", "律师事务所"), # Edge case
|
|
||||||
]
|
|
||||||
|
|
||||||
for company_name, expected_business_name in test_cases:
|
|
||||||
business_name = processor._extract_business_name(company_name)
|
|
||||||
print(f"Company: {company_name} -> Business Name: {business_name} (expected: {expected_business_name})")
|
|
||||||
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_validation_for_business_name():
|
|
||||||
"""Test JSON validation for business name extraction responses"""
|
|
||||||
from app.core.utils.llm_validator import LLMResponseValidator
|
|
||||||
|
|
||||||
# Test valid JSON response
|
|
||||||
valid_response = {
|
|
||||||
"business_name": "盒马",
|
|
||||||
"confidence": 0.9
|
|
||||||
}
|
|
||||||
assert LLMResponseValidator.validate_business_name_extraction(valid_response) == True
|
|
||||||
|
|
||||||
# Test invalid JSON response (missing required field)
|
|
||||||
invalid_response = {
|
|
||||||
"confidence": 0.9
|
|
||||||
}
|
|
||||||
assert LLMResponseValidator.validate_business_name_extraction(invalid_response) == False
|
|
||||||
|
|
||||||
# Test invalid JSON response (wrong type)
|
|
||||||
invalid_response2 = {
|
|
||||||
"business_name": 123,
|
|
||||||
"confidence": 0.9
|
|
||||||
}
|
|
||||||
assert LLMResponseValidator.validate_business_name_extraction(invalid_response2) == False
|
|
||||||
|
|
||||||
|
|
||||||
def test_law_firm_masking():
|
|
||||||
"""Test law firm name masking"""
|
|
||||||
processor = NerProcessor()
|
|
||||||
|
|
||||||
# Test law firm name masking
|
|
||||||
test_cases = [
|
|
||||||
("北京大成律师事务所", "北京D律师事务所"),
|
|
||||||
("上海锦天城律师事务所", "上海JTC律师事务所"),
|
|
||||||
("广东广信君达律师事务所", "广东GXJD律师事务所"),
|
|
||||||
]
|
|
||||||
|
|
||||||
for original_name, expected_masked in test_cases:
|
|
||||||
masked = processor._mask_company_name(original_name)
|
|
||||||
print(f"{original_name} -> {masked} (expected: {expected_masked})")
|
|
||||||
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
|
|
||||||
|
|
@ -1,128 +0,0 @@
|
||||||
"""
|
|
||||||
Tests for the refactored NerProcessor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Add the backend directory to the Python path
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
|
||||||
|
|
||||||
from app.core.document_handlers.ner_processor_refactored import NerProcessorRefactored
|
|
||||||
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker, EnglishNameMasker
|
|
||||||
from app.core.document_handlers.maskers.id_masker import IDMasker
|
|
||||||
from app.core.document_handlers.maskers.case_masker import CaseMasker
|
|
||||||
|
|
||||||
|
|
||||||
def test_chinese_name_masker():
|
|
||||||
"""Test Chinese name masker"""
|
|
||||||
masker = ChineseNameMasker()
|
|
||||||
|
|
||||||
# Test basic masking
|
|
||||||
result1 = masker.mask("李强")
|
|
||||||
assert result1 == "李Q"
|
|
||||||
|
|
||||||
result2 = masker.mask("张韶涵")
|
|
||||||
assert result2 == "张SH"
|
|
||||||
|
|
||||||
result3 = masker.mask("张若宇")
|
|
||||||
assert result3 == "张RY"
|
|
||||||
|
|
||||||
result4 = masker.mask("白锦程")
|
|
||||||
assert result4 == "白JC"
|
|
||||||
|
|
||||||
# Test duplicate handling
|
|
||||||
result5 = masker.mask("李强") # Should get a number
|
|
||||||
assert result5 == "李Q2"
|
|
||||||
|
|
||||||
print(f"Chinese name masking tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_english_name_masker():
|
|
||||||
"""Test English name masker"""
|
|
||||||
masker = EnglishNameMasker()
|
|
||||||
|
|
||||||
result = masker.mask("John Smith")
|
|
||||||
assert result == "J*** S***"
|
|
||||||
|
|
||||||
result2 = masker.mask("Mary Jane Watson")
|
|
||||||
assert result2 == "M*** J*** W***"
|
|
||||||
|
|
||||||
print(f"English name masking tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_id_masker():
|
|
||||||
"""Test ID masker"""
|
|
||||||
masker = IDMasker()
|
|
||||||
|
|
||||||
# Test ID number
|
|
||||||
result1 = masker.mask("310103198802080000")
|
|
||||||
assert result1 == "310103XXXXXXXXXXXX"
|
|
||||||
assert len(result1) == 18
|
|
||||||
|
|
||||||
# Test social credit code
|
|
||||||
result2 = masker.mask("9133021276453538XT")
|
|
||||||
assert result2 == "913302XXXXXXXXXXXX"
|
|
||||||
assert len(result2) == 18
|
|
||||||
|
|
||||||
print(f"ID masking tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_case_masker():
|
|
||||||
"""Test case masker"""
|
|
||||||
masker = CaseMasker()
|
|
||||||
|
|
||||||
result1 = masker.mask("(2022)京 03 民终 3852 号")
|
|
||||||
assert "***号" in result1
|
|
||||||
|
|
||||||
result2 = masker.mask("(2020)京0105 民初69754 号")
|
|
||||||
assert "***号" in result2
|
|
||||||
|
|
||||||
print(f"Case masking tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_masker_factory():
|
|
||||||
"""Test masker factory"""
|
|
||||||
from app.core.document_handlers.masker_factory import MaskerFactory
|
|
||||||
|
|
||||||
# Test creating maskers
|
|
||||||
chinese_masker = MaskerFactory.create_masker('chinese_name')
|
|
||||||
assert isinstance(chinese_masker, ChineseNameMasker)
|
|
||||||
|
|
||||||
english_masker = MaskerFactory.create_masker('english_name')
|
|
||||||
assert isinstance(english_masker, EnglishNameMasker)
|
|
||||||
|
|
||||||
id_masker = MaskerFactory.create_masker('id')
|
|
||||||
assert isinstance(id_masker, IDMasker)
|
|
||||||
|
|
||||||
case_masker = MaskerFactory.create_masker('case')
|
|
||||||
assert isinstance(case_masker, CaseMasker)
|
|
||||||
|
|
||||||
print(f"Masker factory tests passed")
|
|
||||||
|
|
||||||
|
|
||||||
def test_refactored_processor_initialization():
|
|
||||||
"""Test that the refactored processor can be initialized"""
|
|
||||||
try:
|
|
||||||
processor = NerProcessorRefactored()
|
|
||||||
assert processor is not None
|
|
||||||
assert hasattr(processor, 'maskers')
|
|
||||||
assert len(processor.maskers) > 0
|
|
||||||
print(f"Refactored processor initialization test passed")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Refactored processor initialization failed: {e}")
|
|
||||||
# This might fail if Ollama is not running, which is expected in test environment
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("Running refactored NerProcessor tests...")
|
|
||||||
|
|
||||||
test_chinese_name_masker()
|
|
||||||
test_english_name_masker()
|
|
||||||
test_id_masker()
|
|
||||||
test_case_masker()
|
|
||||||
test_masker_factory()
|
|
||||||
test_refactored_processor_initialization()
|
|
||||||
|
|
||||||
print("All refactored NerProcessor tests completed!")
|
|
||||||
|
|
@ -1,213 +0,0 @@
|
||||||
"""
|
|
||||||
Validation script for the refactored NerProcessor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Add the current directory to the Python path
|
|
||||||
sys.path.insert(0, os.path.dirname(__file__))
|
|
||||||
|
|
||||||
def test_imports():
|
|
||||||
"""Test that all modules can be imported"""
|
|
||||||
print("Testing imports...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.maskers.base_masker import BaseMasker
|
|
||||||
print("✓ BaseMasker imported successfully")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to import BaseMasker: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker, EnglishNameMasker
|
|
||||||
print("✓ Name maskers imported successfully")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to import name maskers: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.maskers.id_masker import IDMasker
|
|
||||||
print("✓ IDMasker imported successfully")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to import IDMasker: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.maskers.case_masker import CaseMasker
|
|
||||||
print("✓ CaseMasker imported successfully")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to import CaseMasker: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.maskers.company_masker import CompanyMasker
|
|
||||||
print("✓ CompanyMasker imported successfully")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to import CompanyMasker: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.maskers.address_masker import AddressMasker
|
|
||||||
print("✓ AddressMasker imported successfully")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to import AddressMasker: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.masker_factory import MaskerFactory
|
|
||||||
print("✓ MaskerFactory imported successfully")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to import MaskerFactory: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.extractors.business_name_extractor import BusinessNameExtractor
|
|
||||||
print("✓ BusinessNameExtractor imported successfully")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to import BusinessNameExtractor: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.extractors.address_extractor import AddressExtractor
|
|
||||||
print("✓ AddressExtractor imported successfully")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to import AddressExtractor: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.ner_processor_refactored import NerProcessorRefactored
|
|
||||||
print("✓ NerProcessorRefactored imported successfully")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Failed to import NerProcessorRefactored: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def test_masker_functionality():
|
|
||||||
"""Test basic masker functionality"""
|
|
||||||
print("\nTesting masker functionality...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker
|
|
||||||
|
|
||||||
masker = ChineseNameMasker()
|
|
||||||
result = masker.mask("李强")
|
|
||||||
assert result == "李Q", f"Expected '李Q', got '{result}'"
|
|
||||||
print("✓ ChineseNameMasker works correctly")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ ChineseNameMasker test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.maskers.name_masker import EnglishNameMasker
|
|
||||||
|
|
||||||
masker = EnglishNameMasker()
|
|
||||||
result = masker.mask("John Smith")
|
|
||||||
assert result == "J*** S***", f"Expected 'J*** S***', got '{result}'"
|
|
||||||
print("✓ EnglishNameMasker works correctly")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ EnglishNameMasker test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.maskers.id_masker import IDMasker
|
|
||||||
|
|
||||||
masker = IDMasker()
|
|
||||||
result = masker.mask("310103198802080000")
|
|
||||||
assert result == "310103XXXXXXXXXXXX", f"Expected '310103XXXXXXXXXXXX', got '{result}'"
|
|
||||||
print("✓ IDMasker works correctly")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ IDMasker test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.maskers.case_masker import CaseMasker
|
|
||||||
|
|
||||||
masker = CaseMasker()
|
|
||||||
result = masker.mask("(2022)京 03 民终 3852 号")
|
|
||||||
assert "***号" in result, f"Expected '***号' in result, got '{result}'"
|
|
||||||
print("✓ CaseMasker works correctly")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ CaseMasker test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def test_factory():
|
|
||||||
"""Test masker factory"""
|
|
||||||
print("\nTesting masker factory...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.masker_factory import MaskerFactory
|
|
||||||
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker
|
|
||||||
|
|
||||||
masker = MaskerFactory.create_masker('chinese_name')
|
|
||||||
assert isinstance(masker, ChineseNameMasker), f"Expected ChineseNameMasker, got {type(masker)}"
|
|
||||||
print("✓ MaskerFactory works correctly")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ MaskerFactory test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def test_processor_initialization():
|
|
||||||
"""Test processor initialization"""
|
|
||||||
print("\nTesting processor initialization...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.document_handlers.ner_processor_refactored import NerProcessorRefactored
|
|
||||||
|
|
||||||
processor = NerProcessorRefactored()
|
|
||||||
assert processor is not None, "Processor should not be None"
|
|
||||||
assert hasattr(processor, 'maskers'), "Processor should have maskers attribute"
|
|
||||||
assert len(processor.maskers) > 0, "Processor should have at least one masker"
|
|
||||||
print("✓ NerProcessorRefactored initializes correctly")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ NerProcessorRefactored initialization failed: {e}")
|
|
||||||
# This might fail if Ollama is not running, which is expected
|
|
||||||
print(" (This is expected if Ollama is not running)")
|
|
||||||
return True # Don't fail the validation for this
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main validation function"""
|
|
||||||
print("Validating refactored NerProcessor...")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
success = True
|
|
||||||
|
|
||||||
# Test imports
|
|
||||||
if not test_imports():
|
|
||||||
success = False
|
|
||||||
|
|
||||||
# Test functionality
|
|
||||||
if not test_masker_functionality():
|
|
||||||
success = False
|
|
||||||
|
|
||||||
# Test factory
|
|
||||||
if not test_factory():
|
|
||||||
success = False
|
|
||||||
|
|
||||||
# Test processor initialization
|
|
||||||
if not test_processor_initialization():
|
|
||||||
success = False
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
if success:
|
|
||||||
print("✓ All validation tests passed!")
|
|
||||||
print("The refactored code is working correctly.")
|
|
||||||
else:
|
|
||||||
print("✗ Some validation tests failed.")
|
|
||||||
print("Please check the errors above.")
|
|
||||||
|
|
||||||
return success
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -25,29 +25,6 @@ services:
|
||||||
networks:
|
networks:
|
||||||
- app-network
|
- app-network
|
||||||
|
|
||||||
# MagicDoc API Service
|
|
||||||
magicdoc-api:
|
|
||||||
build:
|
|
||||||
context: ./magicdoc
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
platform: linux/amd64
|
|
||||||
ports:
|
|
||||||
- "8002:8000"
|
|
||||||
volumes:
|
|
||||||
- ./magicdoc/storage/uploads:/app/storage/uploads
|
|
||||||
- ./magicdoc/storage/processed:/app/storage/processed
|
|
||||||
environment:
|
|
||||||
- PYTHONUNBUFFERED=1
|
|
||||||
restart: unless-stopped
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 3
|
|
||||||
start_period: 60s
|
|
||||||
networks:
|
|
||||||
- app-network
|
|
||||||
|
|
||||||
# Backend API Service
|
# Backend API Service
|
||||||
backend-api:
|
backend-api:
|
||||||
build:
|
build:
|
||||||
|
|
@ -57,18 +34,16 @@ services:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ./backend/storage:/app/storage
|
- ./backend/storage:/app/storage
|
||||||
- huggingface_cache:/root/.cache/huggingface
|
- ./backend/legal_doc_masker.db:/app/legal_doc_masker.db
|
||||||
env_file:
|
env_file:
|
||||||
- ./backend/.env
|
- ./backend/.env
|
||||||
environment:
|
environment:
|
||||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||||
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||||
- MINERU_API_URL=http://mineru-api:8000
|
- MINERU_API_URL=http://mineru-api:8000
|
||||||
- MAGICDOC_API_URL=http://magicdoc-api:8000
|
|
||||||
depends_on:
|
depends_on:
|
||||||
- redis
|
- redis
|
||||||
- mineru-api
|
- mineru-api
|
||||||
- magicdoc-api
|
|
||||||
networks:
|
networks:
|
||||||
- app-network
|
- app-network
|
||||||
|
|
||||||
|
|
@ -80,14 +55,13 @@ services:
|
||||||
command: celery -A app.services.file_service worker --loglevel=info
|
command: celery -A app.services.file_service worker --loglevel=info
|
||||||
volumes:
|
volumes:
|
||||||
- ./backend/storage:/app/storage
|
- ./backend/storage:/app/storage
|
||||||
- huggingface_cache:/root/.cache/huggingface
|
- ./backend/legal_doc_masker.db:/app/legal_doc_masker.db
|
||||||
env_file:
|
env_file:
|
||||||
- ./backend/.env
|
- ./backend/.env
|
||||||
environment:
|
environment:
|
||||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||||
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||||
- MINERU_API_URL=http://mineru-api:8000
|
- MINERU_API_URL=http://mineru-api:8000
|
||||||
- MAGICDOC_API_URL=http://magicdoc-api:8000
|
|
||||||
depends_on:
|
depends_on:
|
||||||
- redis
|
- redis
|
||||||
- backend-api
|
- backend-api
|
||||||
|
|
@ -129,4 +103,3 @@ networks:
|
||||||
volumes:
|
volumes:
|
||||||
uploads:
|
uploads:
|
||||||
processed:
|
processed:
|
||||||
huggingface_cache:
|
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import os
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
def download_json(url):
|
||||||
|
# 下载JSON文件
|
||||||
|
response = requests.get(url)
|
||||||
|
response.raise_for_status() # 检查请求是否成功
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_modify_json(url, local_filename, modifications):
|
||||||
|
if os.path.exists(local_filename):
|
||||||
|
data = json.load(open(local_filename))
|
||||||
|
config_version = data.get('config_version', '0.0.0')
|
||||||
|
if config_version < '1.2.0':
|
||||||
|
data = download_json(url)
|
||||||
|
else:
|
||||||
|
data = download_json(url)
|
||||||
|
|
||||||
|
# 修改内容
|
||||||
|
for key, value in modifications.items():
|
||||||
|
data[key] = value
|
||||||
|
|
||||||
|
# 保存修改后的内容
|
||||||
|
with open(local_filename, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
mineru_patterns = [
|
||||||
|
# "models/Layout/LayoutLMv3/*",
|
||||||
|
"models/Layout/YOLO/*",
|
||||||
|
"models/MFD/YOLO/*",
|
||||||
|
"models/MFR/unimernet_hf_small_2503/*",
|
||||||
|
"models/OCR/paddleocr_torch/*",
|
||||||
|
# "models/TabRec/TableMaster/*",
|
||||||
|
# "models/TabRec/StructEqTable/*",
|
||||||
|
]
|
||||||
|
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
|
||||||
|
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
|
||||||
|
model_dir = model_dir + '/models'
|
||||||
|
print(f'model_dir is: {model_dir}')
|
||||||
|
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
|
||||||
|
|
||||||
|
# paddleocr_model_dir = model_dir + '/OCR/paddleocr'
|
||||||
|
# user_paddleocr_dir = os.path.expanduser('~/.paddleocr')
|
||||||
|
# if os.path.exists(user_paddleocr_dir):
|
||||||
|
# shutil.rmtree(user_paddleocr_dir)
|
||||||
|
# shutil.copytree(paddleocr_model_dir, user_paddleocr_dir)
|
||||||
|
|
||||||
|
json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json'
|
||||||
|
config_file_name = 'magic-pdf.json'
|
||||||
|
home_dir = os.path.expanduser('~')
|
||||||
|
config_file = os.path.join(home_dir, config_file_name)
|
||||||
|
|
||||||
|
json_mods = {
|
||||||
|
'models-dir': model_dir,
|
||||||
|
'layoutreader-model-dir': layoutreader_model_dir,
|
||||||
|
}
|
||||||
|
|
||||||
|
download_and_modify_json(json_url, config_file, json_mods)
|
||||||
|
print(f'The configuration file has been configured successfully, the path is: {config_file}')
|
||||||
|
|
@ -16,9 +16,8 @@ import {
|
||||||
DialogContent,
|
DialogContent,
|
||||||
DialogActions,
|
DialogActions,
|
||||||
Typography,
|
Typography,
|
||||||
Tooltip,
|
|
||||||
} from '@mui/material';
|
} from '@mui/material';
|
||||||
import { Download as DownloadIcon, Delete as DeleteIcon, Error as ErrorIcon } from '@mui/icons-material';
|
import { Download as DownloadIcon, Delete as DeleteIcon } from '@mui/icons-material';
|
||||||
import { File, FileStatus } from '../types/file';
|
import { File, FileStatus } from '../types/file';
|
||||||
import { api } from '../services/api';
|
import { api } from '../services/api';
|
||||||
|
|
||||||
|
|
@ -173,50 +172,6 @@ const FileList: React.FC<FileListProps> = ({ files, onFileStatusChange }) => {
|
||||||
color={getStatusColor(file.status) as any}
|
color={getStatusColor(file.status) as any}
|
||||||
size="small"
|
size="small"
|
||||||
/>
|
/>
|
||||||
{file.status === FileStatus.FAILED && file.error_message && (
|
|
||||||
<div style={{ marginTop: '4px' }}>
|
|
||||||
<Tooltip
|
|
||||||
title={file.error_message}
|
|
||||||
placement="top-start"
|
|
||||||
arrow
|
|
||||||
sx={{ maxWidth: '400px' }}
|
|
||||||
>
|
|
||||||
<div
|
|
||||||
style={{
|
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'flex-start',
|
|
||||||
gap: '4px',
|
|
||||||
padding: '4px 8px',
|
|
||||||
backgroundColor: '#ffebee',
|
|
||||||
borderRadius: '4px',
|
|
||||||
border: '1px solid #ffcdd2'
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<ErrorIcon
|
|
||||||
color="error"
|
|
||||||
sx={{ fontSize: '16px', marginTop: '1px', flexShrink: 0 }}
|
|
||||||
/>
|
|
||||||
<Typography
|
|
||||||
variant="caption"
|
|
||||||
color="error"
|
|
||||||
sx={{
|
|
||||||
display: 'block',
|
|
||||||
wordBreak: 'break-word',
|
|
||||||
maxWidth: '300px',
|
|
||||||
lineHeight: '1.2',
|
|
||||||
cursor: 'help',
|
|
||||||
fontWeight: 500
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{file.error_message.length > 50
|
|
||||||
? `${file.error_message.substring(0, 50)}...`
|
|
||||||
: file.error_message
|
|
||||||
}
|
|
||||||
</Typography>
|
|
||||||
</div>
|
|
||||||
</Tooltip>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell>
|
<TableCell>
|
||||||
{new Date(file.created_at).toLocaleString()}
|
{new Date(file.created_at).toLocaleString()}
|
||||||
|
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
FROM python:3.10-slim
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Install system dependencies including LibreOffice
|
|
||||||
RUN apt-get update && apt-get install -y \
|
|
||||||
build-essential \
|
|
||||||
libreoffice \
|
|
||||||
libreoffice-writer \
|
|
||||||
libreoffice-calc \
|
|
||||||
libreoffice-impress \
|
|
||||||
wget \
|
|
||||||
curl \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Copy requirements and install Python packages first
|
|
||||||
COPY requirements.txt .
|
|
||||||
RUN pip install --upgrade pip
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# Install fairy-doc after numpy and opencv are installed
|
|
||||||
RUN pip install --no-cache-dir "fairy-doc[cpu]"
|
|
||||||
|
|
||||||
# Copy the application code
|
|
||||||
COPY app/ ./app/
|
|
||||||
|
|
||||||
# Create storage directories
|
|
||||||
RUN mkdir -p storage/uploads storage/processed
|
|
||||||
|
|
||||||
# Expose the port the app runs on
|
|
||||||
EXPOSE 8000
|
|
||||||
|
|
||||||
# Health check
|
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
|
||||||
CMD curl -f http://localhost:8000/health || exit 1
|
|
||||||
|
|
||||||
# Command to run the application
|
|
||||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
||||||
|
|
@ -1,94 +0,0 @@
|
||||||
# MagicDoc API Service
|
|
||||||
|
|
||||||
A FastAPI service that provides document to markdown conversion using the Magic-Doc library. This service is designed to be compatible with the existing Mineru API interface.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- Converts DOC, DOCX, PPT, PPTX, and PDF files to markdown
|
|
||||||
- RESTful API interface compatible with Mineru API
|
|
||||||
- Docker containerization with LibreOffice dependencies
|
|
||||||
- Health check endpoint
|
|
||||||
- File upload support
|
|
||||||
|
|
||||||
## API Endpoints
|
|
||||||
|
|
||||||
### Health Check
|
|
||||||
```
|
|
||||||
GET /health
|
|
||||||
```
|
|
||||||
Returns service health status.
|
|
||||||
|
|
||||||
### File Parse
|
|
||||||
```
|
|
||||||
POST /file_parse
|
|
||||||
```
|
|
||||||
Converts uploaded document to markdown.
|
|
||||||
|
|
||||||
**Parameters:**
|
|
||||||
- `files`: File upload (required)
|
|
||||||
- `output_dir`: Output directory (default: "./output")
|
|
||||||
- `lang_list`: Language list (default: "ch")
|
|
||||||
- `backend`: Backend type (default: "pipeline")
|
|
||||||
- `parse_method`: Parse method (default: "auto")
|
|
||||||
- `formula_enable`: Enable formula processing (default: true)
|
|
||||||
- `table_enable`: Enable table processing (default: true)
|
|
||||||
- `return_md`: Return markdown (default: true)
|
|
||||||
- `return_middle_json`: Return middle JSON (default: false)
|
|
||||||
- `return_model_output`: Return model output (default: false)
|
|
||||||
- `return_content_list`: Return content list (default: false)
|
|
||||||
- `return_images`: Return images (default: false)
|
|
||||||
- `start_page_id`: Start page ID (default: 0)
|
|
||||||
- `end_page_id`: End page ID (default: 99999)
|
|
||||||
|
|
||||||
**Response:**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"markdown": "converted markdown content",
|
|
||||||
"md": "converted markdown content",
|
|
||||||
"content": "converted markdown content",
|
|
||||||
"text": "converted markdown content",
|
|
||||||
"time_cost": 1.23,
|
|
||||||
"filename": "document.docx",
|
|
||||||
"status": "success"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running with Docker
|
|
||||||
|
|
||||||
### Build and run with docker-compose
|
|
||||||
```bash
|
|
||||||
cd magicdoc
|
|
||||||
docker-compose up --build
|
|
||||||
```
|
|
||||||
|
|
||||||
The service will be available at `http://localhost:8002`
|
|
||||||
|
|
||||||
### Build and run with Docker
|
|
||||||
```bash
|
|
||||||
cd magicdoc
|
|
||||||
docker build -t magicdoc-api .
|
|
||||||
docker run -p 8002:8000 magicdoc-api
|
|
||||||
```
|
|
||||||
|
|
||||||
## Integration with Document Processors
|
|
||||||
|
|
||||||
This service is designed to be compatible with the existing document processors. To use it instead of Mineru API, update the configuration in your document processors:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# In docx_processor.py or pdf_processor.py
|
|
||||||
self.magicdoc_base_url = getattr(settings, 'MAGICDOC_API_URL', 'http://magicdoc-api:8000')
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dependencies
|
|
||||||
|
|
||||||
- Python 3.10
|
|
||||||
- LibreOffice (installed in Docker container)
|
|
||||||
- Magic-Doc library
|
|
||||||
- FastAPI
|
|
||||||
- Uvicorn
|
|
||||||
|
|
||||||
## Storage
|
|
||||||
|
|
||||||
The service creates the following directories:
|
|
||||||
- `storage/uploads/`: For uploaded files
|
|
||||||
- `storage/processed/`: For processed files
|
|
||||||
|
|
@ -1,152 +0,0 @@
|
||||||
# MagicDoc Service Setup Guide
|
|
||||||
|
|
||||||
This guide explains how to set up and use the MagicDoc API service as an alternative to the Mineru API for document processing.
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
The MagicDoc service provides a FastAPI-based REST API that converts various document formats (DOC, DOCX, PPT, PPTX, PDF) to markdown using the Magic-Doc library. It's designed to be compatible with your existing document processors.
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
### 1. Build and Run the Service
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd magicdoc
|
|
||||||
./start.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
Or manually:
|
|
||||||
```bash
|
|
||||||
cd magicdoc
|
|
||||||
docker-compose up --build -d
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Verify the Service
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Check health
|
|
||||||
curl http://localhost:8002/health
|
|
||||||
|
|
||||||
# View API documentation
|
|
||||||
open http://localhost:8002/docs
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Test with Sample Files
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd magicdoc
|
|
||||||
python test_api.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## API Compatibility
|
|
||||||
|
|
||||||
The MagicDoc API is designed to be compatible with your existing Mineru API interface:
|
|
||||||
|
|
||||||
### Endpoint: `POST /file_parse`
|
|
||||||
|
|
||||||
**Request Format:**
|
|
||||||
- File upload via multipart form data
|
|
||||||
- Same parameters as Mineru API (most are optional)
|
|
||||||
|
|
||||||
**Response Format:**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"markdown": "converted content",
|
|
||||||
"md": "converted content",
|
|
||||||
"content": "converted content",
|
|
||||||
"text": "converted content",
|
|
||||||
"time_cost": 1.23,
|
|
||||||
"filename": "document.docx",
|
|
||||||
"status": "success"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Integration with Existing Processors
|
|
||||||
|
|
||||||
To use MagicDoc instead of Mineru in your existing processors:
|
|
||||||
|
|
||||||
### 1. Update Configuration
|
|
||||||
|
|
||||||
Add to your settings:
|
|
||||||
```python
|
|
||||||
MAGICDOC_API_URL = "http://magicdoc-api:8000" # or http://localhost:8002
|
|
||||||
MAGICDOC_TIMEOUT = 300
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Modify Processors
|
|
||||||
|
|
||||||
Replace Mineru API calls with MagicDoc API calls. See `integration_example.py` for detailed examples.
|
|
||||||
|
|
||||||
### 3. Update Docker Compose
|
|
||||||
|
|
||||||
Add the MagicDoc service to your main docker-compose.yml:
|
|
||||||
```yaml
|
|
||||||
services:
|
|
||||||
magicdoc-api:
|
|
||||||
build:
|
|
||||||
context: ./magicdoc
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
ports:
|
|
||||||
- "8002:8000"
|
|
||||||
volumes:
|
|
||||||
- ./magicdoc/storage:/app/storage
|
|
||||||
environment:
|
|
||||||
- PYTHONUNBUFFERED=1
|
|
||||||
restart: unless-stopped
|
|
||||||
```
|
|
||||||
|
|
||||||
## Service Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
magicdoc/
|
|
||||||
├── app/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ └── main.py # FastAPI application
|
|
||||||
├── Dockerfile # Container definition
|
|
||||||
├── docker-compose.yml # Service orchestration
|
|
||||||
├── requirements.txt # Python dependencies
|
|
||||||
├── README.md # Service documentation
|
|
||||||
├── SETUP.md # This setup guide
|
|
||||||
├── test_api.py # API testing script
|
|
||||||
├── integration_example.py # Integration examples
|
|
||||||
└── start.sh # Startup script
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dependencies
|
|
||||||
|
|
||||||
- **Python 3.10**: Base runtime
|
|
||||||
- **LibreOffice**: Document processing (installed in container)
|
|
||||||
- **Magic-Doc**: Document conversion library
|
|
||||||
- **FastAPI**: Web framework
|
|
||||||
- **Uvicorn**: ASGI server
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Service Won't Start
|
|
||||||
1. Check Docker is running
|
|
||||||
2. Verify port 8002 is available
|
|
||||||
3. Check logs: `docker-compose logs`
|
|
||||||
|
|
||||||
### File Conversion Fails
|
|
||||||
1. Verify LibreOffice is working in container
|
|
||||||
2. Check file format is supported
|
|
||||||
3. Review API logs for errors
|
|
||||||
|
|
||||||
### Integration Issues
|
|
||||||
1. Verify API endpoint URL
|
|
||||||
2. Check network connectivity between services
|
|
||||||
3. Ensure response format compatibility
|
|
||||||
|
|
||||||
## Performance Considerations
|
|
||||||
|
|
||||||
- MagicDoc is generally faster than Mineru for simple documents
|
|
||||||
- LibreOffice dependency adds container size
|
|
||||||
- Consider caching for repeated conversions
|
|
||||||
- Monitor memory usage for large files
|
|
||||||
|
|
||||||
## Security Notes
|
|
||||||
|
|
||||||
- Service runs on internal network
|
|
||||||
- File uploads are temporary
|
|
||||||
- No persistent storage of uploaded files
|
|
||||||
- Consider adding authentication for production use
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
# MagicDoc FastAPI Application
|
|
||||||
|
|
@ -1,96 +0,0 @@
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from magic_doc.docconv import DocConverter, S3Config
|
|
||||||
import tempfile
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
app = FastAPI(title="MagicDoc API", version="1.0.0")
|
|
||||||
|
|
||||||
# Global converter instance
|
|
||||||
converter = DocConverter(s3_config=None)
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health_check():
|
|
||||||
"""Health check endpoint"""
|
|
||||||
return {"status": "healthy", "service": "magicdoc-api"}
|
|
||||||
|
|
||||||
@app.post("/file_parse")
|
|
||||||
async def parse_file(
|
|
||||||
files: UploadFile = File(...),
|
|
||||||
output_dir: str = Form("./output"),
|
|
||||||
lang_list: str = Form("ch"),
|
|
||||||
backend: str = Form("pipeline"),
|
|
||||||
parse_method: str = Form("auto"),
|
|
||||||
formula_enable: bool = Form(True),
|
|
||||||
table_enable: bool = Form(True),
|
|
||||||
return_md: bool = Form(True),
|
|
||||||
return_middle_json: bool = Form(False),
|
|
||||||
return_model_output: bool = Form(False),
|
|
||||||
return_content_list: bool = Form(False),
|
|
||||||
return_images: bool = Form(False),
|
|
||||||
start_page_id: int = Form(0),
|
|
||||||
end_page_id: int = Form(99999)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Parse document file and convert to markdown
|
|
||||||
Compatible with Mineru API interface
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"Processing file: {files.filename}")
|
|
||||||
|
|
||||||
# Create temporary file to save uploaded content
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(files.filename)[1]) as temp_file:
|
|
||||||
shutil.copyfileobj(files.file, temp_file)
|
|
||||||
temp_file_path = temp_file.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Convert file to markdown using magic-doc
|
|
||||||
markdown_content, time_cost = converter.convert(temp_file_path, conv_timeout=300)
|
|
||||||
|
|
||||||
logger.info(f"Successfully converted {files.filename} to markdown in {time_cost:.2f}s")
|
|
||||||
|
|
||||||
# Return response compatible with Mineru API
|
|
||||||
response = {
|
|
||||||
"markdown": markdown_content,
|
|
||||||
"md": markdown_content, # Alternative field name
|
|
||||||
"content": markdown_content, # Alternative field name
|
|
||||||
"text": markdown_content, # Alternative field name
|
|
||||||
"time_cost": time_cost,
|
|
||||||
"filename": files.filename,
|
|
||||||
"status": "success"
|
|
||||||
}
|
|
||||||
|
|
||||||
return JSONResponse(content=response)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up temporary file
|
|
||||||
if os.path.exists(temp_file_path):
|
|
||||||
os.unlink(temp_file_path)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing file {files.filename}: {str(e)}")
|
|
||||||
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def root():
|
|
||||||
"""Root endpoint with service information"""
|
|
||||||
return {
|
|
||||||
"service": "MagicDoc API",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "Document to Markdown conversion service using Magic-Doc",
|
|
||||||
"endpoints": {
|
|
||||||
"health": "/health",
|
|
||||||
"file_parse": "/file_parse"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
version: '3.8'
|
|
||||||
|
|
||||||
services:
|
|
||||||
magicdoc-api:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
platform: linux/amd64
|
|
||||||
ports:
|
|
||||||
- "8002:8000"
|
|
||||||
volumes:
|
|
||||||
- ./storage/uploads:/app/storage/uploads
|
|
||||||
- ./storage/processed:/app/storage/processed
|
|
||||||
environment:
|
|
||||||
- PYTHONUNBUFFERED=1
|
|
||||||
restart: unless-stopped
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 3
|
|
||||||
start_period: 60s
|
|
||||||
|
|
||||||
volumes:
|
|
||||||
uploads:
|
|
||||||
processed:
|
|
||||||
|
|
@ -1,144 +0,0 @@
|
||||||
"""
|
|
||||||
Example of how to integrate MagicDoc API with existing document processors
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Example modification for docx_processor.py
|
|
||||||
# Replace the Mineru API configuration with MagicDoc API configuration
|
|
||||||
|
|
||||||
class DocxDocumentProcessor(DocumentProcessor):
|
|
||||||
def __init__(self, input_path: str, output_path: str):
|
|
||||||
super().__init__()
|
|
||||||
self.input_path = input_path
|
|
||||||
self.output_path = output_path
|
|
||||||
self.output_dir = os.path.dirname(output_path)
|
|
||||||
self.name_without_suff = os.path.splitext(os.path.basename(input_path))[0]
|
|
||||||
|
|
||||||
# Setup work directory for temporary files
|
|
||||||
self.work_dir = os.path.join(
|
|
||||||
os.path.dirname(output_path),
|
|
||||||
".work",
|
|
||||||
os.path.splitext(os.path.basename(input_path))[0]
|
|
||||||
)
|
|
||||||
os.makedirs(self.work_dir, exist_ok=True)
|
|
||||||
|
|
||||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
|
||||||
|
|
||||||
# MagicDoc API configuration (instead of Mineru)
|
|
||||||
self.magicdoc_base_url = getattr(settings, 'MAGICDOC_API_URL', 'http://magicdoc-api:8000')
|
|
||||||
self.magicdoc_timeout = getattr(settings, 'MAGICDOC_TIMEOUT', 300) # 5 minutes timeout
|
|
||||||
|
|
||||||
def _call_magicdoc_api(self, file_path: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Call MagicDoc API to convert DOCX to markdown
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the DOCX file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
API response as dictionary or None if failed
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
url = f"{self.magicdoc_base_url}/file_parse"
|
|
||||||
|
|
||||||
with open(file_path, 'rb') as file:
|
|
||||||
files = {'files': (os.path.basename(file_path), file, 'application/vnd.openxmlformats-officedocument.wordprocessingml.document')}
|
|
||||||
|
|
||||||
# Prepare form data - simplified compared to Mineru
|
|
||||||
data = {
|
|
||||||
'output_dir': './output',
|
|
||||||
'lang_list': 'ch',
|
|
||||||
'backend': 'pipeline',
|
|
||||||
'parse_method': 'auto',
|
|
||||||
'formula_enable': True,
|
|
||||||
'table_enable': True,
|
|
||||||
'return_md': True,
|
|
||||||
'return_middle_json': False,
|
|
||||||
'return_model_output': False,
|
|
||||||
'return_content_list': False,
|
|
||||||
'return_images': False,
|
|
||||||
'start_page_id': 0,
|
|
||||||
'end_page_id': 99999
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Calling MagicDoc API for DOCX processing at {url}")
|
|
||||||
response = requests.post(
|
|
||||||
url,
|
|
||||||
files=files,
|
|
||||||
data=data,
|
|
||||||
timeout=self.magicdoc_timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
|
||||||
logger.info("Successfully received response from MagicDoc API for DOCX")
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
error_msg = f"MagicDoc API returned status code {response.status_code}: {response.text}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise Exception(error_msg)
|
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
|
||||||
error_msg = f"MagicDoc API request timed out after {self.magicdoc_timeout} seconds"
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise Exception(error_msg)
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
error_msg = f"Error calling MagicDoc API for DOCX: {str(e)}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise Exception(error_msg)
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Unexpected error calling MagicDoc API for DOCX: {str(e)}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise Exception(error_msg)
|
|
||||||
|
|
||||||
def read_content(self) -> str:
|
|
||||||
logger.info("Starting DOCX content processing with MagicDoc API")
|
|
||||||
|
|
||||||
# Call MagicDoc API to convert DOCX to markdown
|
|
||||||
magicdoc_response = self._call_magicdoc_api(self.input_path)
|
|
||||||
|
|
||||||
# Extract markdown content from the response
|
|
||||||
markdown_content = self._extract_markdown_from_response(magicdoc_response)
|
|
||||||
|
|
||||||
if not markdown_content:
|
|
||||||
raise Exception("No markdown content found in MagicDoc API response for DOCX")
|
|
||||||
|
|
||||||
logger.info(f"Successfully extracted {len(markdown_content)} characters of markdown content from DOCX")
|
|
||||||
|
|
||||||
# Save the raw markdown content to work directory for reference
|
|
||||||
md_output_path = os.path.join(self.work_dir, f"{self.name_without_suff}.md")
|
|
||||||
with open(md_output_path, 'w', encoding='utf-8') as file:
|
|
||||||
file.write(markdown_content)
|
|
||||||
|
|
||||||
logger.info(f"Saved raw markdown content from DOCX to {md_output_path}")
|
|
||||||
|
|
||||||
return markdown_content
|
|
||||||
|
|
||||||
# Configuration changes needed in settings.py:
|
|
||||||
"""
|
|
||||||
# Add these settings to your configuration
|
|
||||||
MAGICDOC_API_URL = "http://magicdoc-api:8000" # or http://localhost:8002 for local development
|
|
||||||
MAGICDOC_TIMEOUT = 300 # 5 minutes timeout
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Docker Compose integration:
|
|
||||||
"""
|
|
||||||
# Add to your main docker-compose.yml
|
|
||||||
services:
|
|
||||||
magicdoc-api:
|
|
||||||
build:
|
|
||||||
context: ./magicdoc
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
ports:
|
|
||||||
- "8002:8000"
|
|
||||||
volumes:
|
|
||||||
- ./magicdoc/storage:/app/storage
|
|
||||||
environment:
|
|
||||||
- PYTHONUNBUFFERED=1
|
|
||||||
restart: unless-stopped
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 3
|
|
||||||
start_period: 60s
|
|
||||||
"""
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
fastapi==0.104.1
|
|
||||||
uvicorn[standard]==0.24.0
|
|
||||||
python-multipart==0.0.6
|
|
||||||
# fairy-doc[cpu]==0.1.0
|
|
||||||
pydantic==2.5.0
|
|
||||||
numpy==1.24.3
|
|
||||||
opencv-python==4.8.1.78
|
|
||||||
|
|
@ -1,34 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# MagicDoc API Service Startup Script
|
|
||||||
|
|
||||||
echo "Starting MagicDoc API Service..."
|
|
||||||
|
|
||||||
# Check if Docker is running
|
|
||||||
if ! docker info > /dev/null 2>&1; then
|
|
||||||
echo "Error: Docker is not running. Please start Docker first."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Build and start the service
|
|
||||||
echo "Building and starting MagicDoc API service..."
|
|
||||||
docker-compose up --build -d
|
|
||||||
|
|
||||||
# Wait for service to be ready
|
|
||||||
echo "Waiting for service to be ready..."
|
|
||||||
sleep 10
|
|
||||||
|
|
||||||
# Check health
|
|
||||||
echo "Checking service health..."
|
|
||||||
if curl -f http://localhost:8002/health > /dev/null 2>&1; then
|
|
||||||
echo "✅ MagicDoc API service is running successfully!"
|
|
||||||
echo "🌐 Service URL: http://localhost:8002"
|
|
||||||
echo "📖 API Documentation: http://localhost:8002/docs"
|
|
||||||
echo "🔍 Health Check: http://localhost:8002/health"
|
|
||||||
else
|
|
||||||
echo "❌ Service health check failed. Check logs with: docker-compose logs"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "To stop the service, run: docker-compose down"
|
|
||||||
echo "To view logs, run: docker-compose logs -f"
|
|
||||||
|
|
@ -1,92 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test script for MagicDoc API
|
|
||||||
"""
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
def test_health_check(base_url="http://localhost:8002"):
|
|
||||||
"""Test health check endpoint"""
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{base_url}/health")
|
|
||||||
print(f"Health check status: {response.status_code}")
|
|
||||||
print(f"Response: {response.json()}")
|
|
||||||
return response.status_code == 200
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Health check failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_file_parse(base_url="http://localhost:8002", file_path=None):
|
|
||||||
"""Test file parse endpoint"""
|
|
||||||
if not file_path or not os.path.exists(file_path):
|
|
||||||
print(f"File not found: {file_path}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(file_path, 'rb') as f:
|
|
||||||
files = {'files': (os.path.basename(file_path), f, 'application/octet-stream')}
|
|
||||||
data = {
|
|
||||||
'output_dir': './output',
|
|
||||||
'lang_list': 'ch',
|
|
||||||
'backend': 'pipeline',
|
|
||||||
'parse_method': 'auto',
|
|
||||||
'formula_enable': True,
|
|
||||||
'table_enable': True,
|
|
||||||
'return_md': True,
|
|
||||||
'return_middle_json': False,
|
|
||||||
'return_model_output': False,
|
|
||||||
'return_content_list': False,
|
|
||||||
'return_images': False,
|
|
||||||
'start_page_id': 0,
|
|
||||||
'end_page_id': 99999
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(f"{base_url}/file_parse", files=files, data=data)
|
|
||||||
print(f"File parse status: {response.status_code}")
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
|
||||||
print(f"Success! Converted {len(result.get('markdown', ''))} characters")
|
|
||||||
print(f"Time cost: {result.get('time_cost', 'N/A')}s")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print(f"Error: {response.text}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"File parse failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main test function"""
|
|
||||||
print("Testing MagicDoc API...")
|
|
||||||
|
|
||||||
# Test health check
|
|
||||||
print("\n1. Testing health check...")
|
|
||||||
if not test_health_check():
|
|
||||||
print("Health check failed. Make sure the service is running.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Test file parse (if sample file exists)
|
|
||||||
print("\n2. Testing file parse...")
|
|
||||||
sample_files = [
|
|
||||||
"../sample_doc/20220707_na_decision-2.docx",
|
|
||||||
"../sample_doc/20220707_na_decision-2.pdf",
|
|
||||||
"../sample_doc/short_doc.md"
|
|
||||||
]
|
|
||||||
|
|
||||||
for sample_file in sample_files:
|
|
||||||
if os.path.exists(sample_file):
|
|
||||||
print(f"Testing with {sample_file}...")
|
|
||||||
if test_file_parse(file_path=sample_file):
|
|
||||||
print("File parse test passed!")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f"Sample file not found: {sample_file}")
|
|
||||||
|
|
||||||
print("\nTest completed!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Loading…
Reference in New Issue