WIP: 重构NER processor

This commit is contained in:
oliviamn 2025-07-10 00:14:16 +08:00
parent 1cf3c45cee
commit 1649a9328b
4 changed files with 529 additions and 243 deletions

View File

@ -1,43 +1,15 @@
from abc import ABC, abstractmethod
from typing import Any, Dict
from ..prompts.masking_prompts import get_ner_name_prompt, get_ner_company_prompt, get_ner_address_prompt, get_ner_project_prompt, get_ner_case_number_prompt
import logging
import json
from ..services.ollama_client import OllamaClient
from ...core.config import settings
from ..utils.json_extractor import LLMJsonExtractor
import re
from .regs.entity_regex import extract_id_number_entities, extract_social_credit_code_entities
from jsonschema import validate, ValidationError # pip install jsonschema
from .ner_processor import NerProcessor
logger = logging.getLogger(__name__)
class DocumentProcessor(ABC):
# JSON Schema for mapping validation
mapping_schema = {
"type": "object",
"properties": {
"entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"text": {"type": "string"},
"type": {"type": "string"}
},
"required": ["text", "type"]
}
}
},
"required": ["entities"]
}
def __init__(self):
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
self.max_chunk_size = 1000 # Maximum number of characters per chunk
self.max_retries = 3 # Maximum number of retries for mapping generation
self.ner_processor = NerProcessor()
@abstractmethod
def read_content(self) -> str:
@ -53,7 +25,6 @@ class DocumentProcessor(ABC):
if not sentence.strip():
continue
# If adding this sentence would exceed the limit, save current chunk and start new one
if len(current_chunk) + len(sentence) > self.max_chunk_size and current_chunk:
chunks.append(current_chunk)
current_chunk = sentence
@ -63,240 +34,32 @@ class DocumentProcessor(ABC):
else:
current_chunk = sentence
# Add the last chunk if it's not empty
if current_chunk:
chunks.append(current_chunk)
return chunks
def _validate_mapping_format(self, mapping: Dict[str, Any]) -> bool:
"""
Validate that the mapping follows the required JSON schema format.
"""
try:
validate(instance=mapping, schema=self.mapping_schema)
return True
except ValidationError as e:
logger.warning(f"Mapping validation error: {e}")
return False
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
"""Process a single entity type with retry logic"""
for attempt in range(self.max_retries):
try:
formatted_prompt = prompt_func(chunk)
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
response = self.ollama_client.generate(formatted_prompt)
logger.info(f"Raw response from LLM: {response}")
# Parse the JSON response into a dictionary
mapping = LLMJsonExtractor.parse_raw_json_str(response)
logger.info(f"Parsed mapping: {mapping}")
if mapping and self._validate_mapping_format(mapping):
return mapping
else:
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
except Exception as e:
logger.error(f"Error generating {entity_type} mapping on attempt {attempt + 1}: {e}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
logger.error(f"Max retries reached for {entity_type}, returning empty mapping")
return {}
def _build_mapping(self, chunk: str) -> list[Dict[str, str]]:
"""Build mapping for a single chunk of text with retry logic"""
mapping_pipeline = []
# LLM实体
entity_configs = [
(get_ner_name_prompt, "people names"),
(get_ner_company_prompt, "company names"),
(get_ner_address_prompt, "addresses"),
(get_ner_project_prompt, "project names"),
(get_ner_case_number_prompt, "case numbers")
]
for prompt_func, entity_type in entity_configs:
mapping = self._process_entity_type(chunk, prompt_func, entity_type)
if mapping:
mapping_pipeline.append(mapping)
# 正则实体
regex_entity_extractors = [
extract_id_number_entities,
extract_social_credit_code_entities
]
for extractor in regex_entity_extractors:
mapping = extractor(chunk)
if mapping:
mapping_pipeline.append(mapping)
return mapping_pipeline
def _apply_mapping(self, text: str, mapping: Dict[str, str]) -> str:
"""Apply the mapping to replace sensitive information"""
masked_text = text
for original, masked in mapping.items():
# Ensure masked value is a string
if isinstance(masked, dict):
# If it's a dict, use the first value or a default
masked = next(iter(masked.values()), "")
elif not isinstance(masked, str):
# If it's not a string, convert to string or use default
masked = str(masked) if masked is not None else ""
masked_text = masked_text.replace(original, masked)
return masked_text
def _get_next_suffix(self, value: str) -> str:
"""Get the next available suffix for a value that already has a suffix"""
# Define the sequence of suffixes
suffixes = ['', '', '', '', '', '', '', '', '', '']
# Check if the value already has a suffix
for suffix in suffixes:
if value.endswith(suffix):
# Find the next suffix in the sequence
current_index = suffixes.index(suffix)
if current_index + 1 < len(suffixes):
return value[:-1] + suffixes[current_index + 1]
else:
# If we've used all suffixes, start over with the first one
return value[:-1] + suffixes[0]
# If no suffix found, return the value with the first suffix
return value + ''
def _merge_entity_mappings(self, chunk_mappings: list[Dict[str, Any]]) -> list[Dict[str, str]]:
"""
Merge entity mappings from multiple chunks and remove duplicates.
Args:
chunk_mappings: List of mappings returned from LLM, each containing 'entities' list
Returns:
list[Dict[str, str]]: List of unique entities with text and type
"""
# Extract all entities from all chunks
all_entities = []
for mapping in chunk_mappings:
if isinstance(mapping, dict) and 'entities' in mapping:
entities = mapping['entities']
if isinstance(entities, list):
all_entities.extend(entities)
# Remove duplicates based on text content
unique_entities = []
seen_texts = set()
for entity in all_entities:
if isinstance(entity, dict) and 'text' in entity:
text = entity['text'].strip()
if text and text not in seen_texts:
seen_texts.add(text)
unique_entities.append(entity)
logger.info(f"Merged {len(unique_entities)} unique entities")
return unique_entities
def _generate_masked_mapping(self, unique_entities: list[Dict[str, str]]) -> Dict[str, str]:
"""
Generate masked names for unique entities.
Args:
unique_entities: List of unique entities with text and type
Returns:
Dict[str, str]: Mapping from original text to masked version
"""
entity_mapping = {}
used_masked_names = set()
for entity in unique_entities:
original_text = entity['text'].strip()
entity_type = entity.get('type', '')
# Generate masked name based on entity type
if '人名' in entity_type or '英文人名' in entity_type:
# For person names, use 某 + suffix pattern
base_name = ''
masked_name = base_name
counter = 1
while masked_name in used_masked_names:
if counter <= 10:
# Use 甲乙丙丁... for first 10
suffixes = ['', '', '', '', '', '', '', '', '', '']
masked_name = base_name + suffixes[counter - 1]
else:
# Use numbers for additional ones
masked_name = f"{base_name}{counter}"
counter += 1
elif '公司' in entity_type or 'Company' in entity_type:
# For company names, use 某公司 + suffix pattern
base_name = '某公司'
masked_name = base_name
counter = 1
while masked_name in used_masked_names:
if counter <= 10:
# Use 甲乙丙丁... for first 10
suffixes = ['', '', '', '', '', '', '', '', '', '']
masked_name = base_name + suffixes[counter - 1]
else:
# Use numbers for additional ones
masked_name = f"{base_name}{counter}"
counter += 1
else:
# For other entity types, use generic pattern
base_name = ''
masked_name = base_name
counter = 1
while masked_name in used_masked_names:
if counter <= 10:
suffixes = ['', '', '', '', '', '', '', '', '', '']
masked_name = base_name + suffixes[counter - 1]
else:
masked_name = f"{base_name}{counter}"
counter += 1
entity_mapping[original_text] = masked_name
used_masked_names.add(masked_name)
logger.info(f"Generated masked mapping for {len(entity_mapping)} entities")
return entity_mapping
def process_content(self, content: str) -> str:
"""Process document content by masking sensitive information"""
# Split content into sentences
sentences = content.split("")
# Split sentences into manageable chunks
chunks = self._split_into_chunks(sentences)
logger.info(f"Split content into {len(chunks)} chunks")
# Build mapping for each chunk
chunk_mappings = []
for i, chunk in enumerate(chunks):
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
chunk_mapping = self._build_mapping(chunk)
logger.info(f"Chunk mapping: {chunk_mapping}")
chunk_mappings.extend(chunk_mapping)
# Merge mappings, remove the duplicate ones
unique_entities = self._merge_entity_mappings(chunk_mappings)
final_mapping = self.ner_processor.process(chunks)
# Generate masked names for unique entities
combined_mapping = self._generate_masked_mapping(unique_entities)
# Apply the combined mapping to the entire content
masked_content = self._apply_mapping(content, combined_mapping)
masked_content = self._apply_mapping(content, final_mapping)
logger.info("Successfully masked content")
return masked_content
@ -304,4 +67,4 @@ class DocumentProcessor(ABC):
@abstractmethod
def save_content(self, content: str) -> None:
"""Save processed content"""
pass
pass

View File

@ -0,0 +1,233 @@
from typing import Any, Dict
from ..prompts.masking_prompts import get_ner_name_prompt, get_ner_company_prompt, get_ner_address_prompt, get_ner_project_prompt, get_ner_case_number_prompt, get_entity_linkage_prompt
import logging
import json
from ..services.ollama_client import OllamaClient
from ...core.config import settings
from ..utils.json_extractor import LLMJsonExtractor
from ..utils.llm_validator import LLMResponseValidator
import re
from .regs.entity_regex import extract_id_number_entities, extract_social_credit_code_entities
logger = logging.getLogger(__name__)
class NerProcessor:
def __init__(self):
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
self.max_retries = 3
def _validate_mapping_format(self, mapping: Dict[str, Any]) -> bool:
return LLMResponseValidator.validate_entity_extraction(mapping)
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
for attempt in range(self.max_retries):
try:
formatted_prompt = prompt_func(chunk)
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
response = self.ollama_client.generate(formatted_prompt)
logger.info(f"Raw response from LLM: {response}")
mapping = LLMJsonExtractor.parse_raw_json_str(response)
logger.info(f"Parsed mapping: {mapping}")
if mapping and self._validate_mapping_format(mapping):
return mapping
else:
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
except Exception as e:
logger.error(f"Error generating {entity_type} mapping on attempt {attempt + 1}: {e}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
logger.error(f"Max retries reached for {entity_type}, returning empty mapping")
return {}
def build_mapping(self, chunk: str) -> list[Dict[str, str]]:
mapping_pipeline = []
entity_configs = [
(get_ner_name_prompt, "people names"),
(get_ner_company_prompt, "company names"),
(get_ner_address_prompt, "addresses"),
(get_ner_project_prompt, "project names"),
(get_ner_case_number_prompt, "case numbers")
]
for prompt_func, entity_type in entity_configs:
mapping = self._process_entity_type(chunk, prompt_func, entity_type)
if mapping:
mapping_pipeline.append(mapping)
regex_entity_extractors = [
extract_id_number_entities,
extract_social_credit_code_entities
]
for extractor in regex_entity_extractors:
mapping = extractor(chunk)
if mapping and LLMResponseValidator.validate_regex_entity(mapping):
mapping_pipeline.append(mapping)
elif mapping:
logger.warning(f"Invalid regex entity mapping format: {mapping}")
return mapping_pipeline
def _merge_entity_mappings(self, chunk_mappings: list[Dict[str, Any]]) -> list[Dict[str, str]]:
all_entities = []
for mapping in chunk_mappings:
if isinstance(mapping, dict) and 'entities' in mapping:
entities = mapping['entities']
if isinstance(entities, list):
all_entities.extend(entities)
unique_entities = []
seen_texts = set()
for entity in all_entities:
if isinstance(entity, dict) and 'text' in entity:
text = entity['text'].strip()
if text and text not in seen_texts:
seen_texts.add(text)
unique_entities.append(entity)
logger.info(f"Merged {len(unique_entities)} unique entities")
return unique_entities
def _generate_masked_mapping(self, unique_entities: list[Dict[str, str]]) -> Dict[str, str]:
entity_mapping = {}
used_masked_names = set()
for entity in unique_entities:
original_text = entity['text'].strip()
entity_type = entity.get('type', '')
if '人名' in entity_type or '英文人名' in entity_type:
base_name = ''
masked_name = base_name
counter = 1
while masked_name in used_masked_names:
if counter <= 10:
suffixes = ['', '', '', '', '', '', '', '', '', '']
masked_name = base_name + suffixes[counter - 1]
else:
masked_name = f"{base_name}{counter}"
counter += 1
elif '公司' in entity_type or 'Company' in entity_type:
base_name = '某公司'
masked_name = base_name
counter = 1
while masked_name in used_masked_names:
if counter <= 10:
suffixes = ['', '', '', '', '', '', '', '', '', '']
masked_name = base_name + suffixes[counter - 1]
else:
masked_name = f"{base_name}{counter}"
counter += 1
else:
base_name = ''
masked_name = base_name
counter = 1
while masked_name in used_masked_names:
if counter <= 10:
suffixes = ['', '', '', '', '', '', '', '', '', '']
masked_name = base_name + suffixes[counter - 1]
else:
masked_name = f"{base_name}{counter}"
counter += 1
entity_mapping[original_text] = masked_name
used_masked_names.add(masked_name)
logger.info(f"Generated masked mapping for {len(entity_mapping)} entities")
return entity_mapping
def _validate_linkage_format(self, linkage: Dict[str, Any]) -> bool:
return LLMResponseValidator.validate_entity_linkage(linkage)
def _create_entity_linkage(self, unique_entities: list[Dict[str, str]]) -> Dict[str, Any]:
linkable_entities = []
for entity in unique_entities:
entity_type = entity.get('type', '')
if any(keyword in entity_type for keyword in ['公司', 'Company', '人名', '英文人名']):
linkable_entities.append(entity)
if not linkable_entities:
logger.info("No linkable entities found")
return {"entity_groups": []}
entities_text = "\n".join([
f"- {entity['text']} (类型: {entity['type']})"
for entity in linkable_entities
])
for attempt in range(self.max_retries):
try:
formatted_prompt = get_entity_linkage_prompt(entities_text)
logger.info(f"Calling ollama to generate entity linkage (attempt {attempt + 1}/{self.max_retries})")
response = self.ollama_client.generate(formatted_prompt)
logger.info(f"Raw entity linkage response from LLM: {response}")
linkage = LLMJsonExtractor.parse_raw_json_str(response)
logger.info(f"Parsed entity linkage: {linkage}")
if linkage and self._validate_linkage_format(linkage):
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
return linkage
else:
logger.warning(f"Invalid entity linkage format received on attempt {attempt + 1}, retrying...")
except Exception as e:
logger.error(f"Error generating entity linkage on attempt {attempt + 1}: {e}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
logger.error("Max retries reached for entity linkage, returning empty linkage")
return {"entity_groups": []}
def _apply_entity_linkage_to_mapping(self, entity_mapping: Dict[str, str], entity_linkage: Dict[str, Any]) -> Dict[str, str]:
updated_mapping = entity_mapping.copy()
for group in entity_linkage.get('entity_groups', []):
group_entities = group.get('entities', [])
if not group_entities:
continue
primary_entity = None
for entity in group_entities:
if entity.get('is_primary', False):
primary_entity = entity
break
if not primary_entity and group_entities:
primary_entity = group_entities[0]
if primary_entity:
primary_text = primary_entity['text']
primary_masked = updated_mapping.get(primary_text)
if primary_masked:
for entity in group_entities:
entity_text = entity['text']
if entity_text in updated_mapping:
updated_mapping[entity_text] = primary_masked
logger.info(f"Linked entity '{entity_text}' to '{primary_text}' with masked name '{primary_masked}'")
return updated_mapping
def process(self, chunks: list[str]) -> Dict[str, str]:
chunk_mappings = []
for i, chunk in enumerate(chunks):
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
chunk_mapping = self.build_mapping(chunk)
logger.info(f"Chunk mapping: {chunk_mapping}")
chunk_mappings.extend(chunk_mapping)
unique_entities = self._merge_entity_mappings(chunk_mappings)
entity_linkage = self._create_entity_linkage(unique_entities)
combined_mapping = self._generate_masked_mapping(unique_entities)
final_mapping = self._apply_entity_linkage_to_mapping(combined_mapping, entity_linkage)
return final_mapping

View File

@ -161,4 +161,65 @@ def get_ner_case_number_prompt(text: str) -> str:
请严格按照JSON格式输出结果
""")
return prompt.format(text=text)
return prompt.format(text=text)
def get_entity_linkage_prompt(entities_text: str) -> str:
"""
Returns a prompt that identifies related entities and groups them together.
Args:
entities_text (str): The list of entities to be analyzed for linkage
Returns:
str: The formatted prompt that will generate entity linkage information
"""
prompt = textwrap.dedent("""
你是一个专业的法律文本实体关联分析助手请分析以下实体列表识别出相互关联的实体如全称与简称中文名与英文名等并将它们分组
关联规则
1. 公司名称关联
- 全称与简称"阿里巴巴集团控股有限公司" "阿里巴巴"
- 中文名与英文名"腾讯科技有限公司" "Tencent Technology Ltd."
- 母公司与子公司"腾讯" "腾讯音乐"
2. 每个组中应指定一个主要实体is_primary: true通常是
- 对于公司选择最正式的全称
- 对于人名选择最常用的称呼
待分析实体列表:
{entities_text}
输出格式:
{{
"entity_groups": [
{{
"group_id": "group_1",
"group_type": "公司名称",
"entities": [
{{
"text": "阿里巴巴集团控股有限公司",
"type": "公司名称",
"is_primary": true
}},
{{
"text": "阿里巴巴",
"type": "公司名称简称",
"is_primary": false
}}
]
}}
]
}}
注意事项
1. 只对确实有关联的实体进行分组
2. 每个实体只能属于一个组
3. 每个组必须有且仅有一个主要实体is_primary: true
4. 如果实体之间没有明显关联不要强制分组
5. group_type 应该是 "公司名称"
请严格按照JSON格式输出结果
""")
return prompt.format(entities_text=entities_text)

View File

@ -0,0 +1,229 @@
import logging
from typing import Any, Dict, Optional
from jsonschema import validate, ValidationError
logger = logging.getLogger(__name__)
class LLMResponseValidator:
"""Validator for LLM JSON responses with different schemas for different entity types"""
# Schema for basic entity extraction responses
ENTITY_EXTRACTION_SCHEMA = {
"type": "object",
"properties": {
"entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"text": {"type": "string"},
"type": {"type": "string"}
},
"required": ["text", "type"]
}
}
},
"required": ["entities"]
}
# Schema for entity linkage responses
ENTITY_LINKAGE_SCHEMA = {
"type": "object",
"properties": {
"entity_groups": {
"type": "array",
"items": {
"type": "object",
"properties": {
"group_id": {"type": "string"},
"group_type": {"type": "string"},
"entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"text": {"type": "string"},
"type": {"type": "string"},
"is_primary": {"type": "boolean"}
},
"required": ["text", "type", "is_primary"]
}
}
},
"required": ["group_id", "group_type", "entities"]
}
}
},
"required": ["entity_groups"]
}
# Schema for regex-based entity extraction (from entity_regex.py)
REGEX_ENTITY_SCHEMA = {
"type": "object",
"properties": {
"entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"text": {"type": "string"},
"type": {"type": "string"}
},
"required": ["text", "type"]
}
}
},
"required": ["entities"]
}
@classmethod
def validate_entity_extraction(cls, response: Dict[str, Any]) -> bool:
"""
Validate entity extraction response from LLM.
Args:
response: The parsed JSON response from LLM
Returns:
bool: True if valid, False otherwise
"""
try:
validate(instance=response, schema=cls.ENTITY_EXTRACTION_SCHEMA)
return True
except ValidationError as e:
logger.warning(f"Entity extraction validation error: {e}")
return False
@classmethod
def validate_entity_linkage(cls, response: Dict[str, Any]) -> bool:
"""
Validate entity linkage response from LLM.
Args:
response: The parsed JSON response from LLM
Returns:
bool: True if valid, False otherwise
"""
try:
validate(instance=response, schema=cls.ENTITY_LINKAGE_SCHEMA)
return cls._validate_linkage_content(response)
except ValidationError as e:
logger.warning(f"Entity linkage validation error: {e}")
return False
@classmethod
def validate_regex_entity(cls, response: Dict[str, Any]) -> bool:
"""
Validate regex-based entity extraction response.
Args:
response: The parsed JSON response from regex extractors
Returns:
bool: True if valid, False otherwise
"""
try:
validate(instance=response, schema=cls.REGEX_ENTITY_SCHEMA)
return True
except ValidationError as e:
logger.warning(f"Regex entity validation error: {e}")
return False
@classmethod
def _validate_linkage_content(cls, response: Dict[str, Any]) -> bool:
"""
Additional content validation for entity linkage responses.
Args:
response: The parsed JSON response from LLM
Returns:
bool: True if content is valid, False otherwise
"""
entity_groups = response.get('entity_groups', [])
for group in entity_groups:
# Validate group type
group_type = group.get('group_type', '')
if group_type not in ['公司名称', '人名']:
logger.warning(f"Invalid group_type: {group_type}")
return False
# Validate entities in group
entities = group.get('entities', [])
if not entities:
logger.warning("Empty entity group found")
return False
# Check that exactly one entity is marked as primary
primary_count = sum(1 for entity in entities if entity.get('is_primary', False))
if primary_count != 1:
logger.warning(f"Group must have exactly one primary entity, found {primary_count}")
return False
# Validate entity types within group
for entity in entities:
entity_type = entity.get('type', '')
if group_type == '公司名称' and not any(keyword in entity_type for keyword in ['公司', 'Company']):
logger.warning(f"Company group contains non-company entity: {entity_type}")
return False
elif group_type == '人名' and not any(keyword in entity_type for keyword in ['人名', '英文人名']):
logger.warning(f"Person group contains non-person entity: {entity_type}")
return False
return True
@classmethod
def validate_response_by_type(cls, response: Dict[str, Any], response_type: str) -> bool:
"""
Generic validator that routes to appropriate validation method based on response type.
Args:
response: The parsed JSON response from LLM
response_type: Type of response ('entity_extraction', 'entity_linkage', 'regex_entity')
Returns:
bool: True if valid, False otherwise
"""
validators = {
'entity_extraction': cls.validate_entity_extraction,
'entity_linkage': cls.validate_entity_linkage,
'regex_entity': cls.validate_regex_entity
}
validator = validators.get(response_type)
if not validator:
logger.error(f"Unknown response type: {response_type}")
return False
return validator(response)
@classmethod
def get_validation_errors(cls, response: Dict[str, Any], response_type: str) -> Optional[str]:
"""
Get detailed validation errors for debugging.
Args:
response: The parsed JSON response from LLM
response_type: Type of response
Returns:
Optional[str]: Error message or None if valid
"""
try:
if response_type == 'entity_extraction':
validate(instance=response, schema=cls.ENTITY_EXTRACTION_SCHEMA)
elif response_type == 'entity_linkage':
validate(instance=response, schema=cls.ENTITY_LINKAGE_SCHEMA)
if not cls._validate_linkage_content(response):
return "Content validation failed for entity linkage"
elif response_type == 'regex_entity':
validate(instance=response, schema=cls.REGEX_ENTITY_SCHEMA)
else:
return f"Unknown response type: {response_type}"
return None
except ValidationError as e:
return f"Schema validation error: {e}"