WIP: 重构NER processor
This commit is contained in:
parent
1cf3c45cee
commit
1649a9328b
|
|
@ -1,43 +1,15 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from ..prompts.masking_prompts import get_ner_name_prompt, get_ner_company_prompt, get_ner_address_prompt, get_ner_project_prompt, get_ner_case_number_prompt
|
||||
import logging
|
||||
import json
|
||||
from ..services.ollama_client import OllamaClient
|
||||
from ...core.config import settings
|
||||
from ..utils.json_extractor import LLMJsonExtractor
|
||||
import re
|
||||
from .regs.entity_regex import extract_id_number_entities, extract_social_credit_code_entities
|
||||
from jsonschema import validate, ValidationError # pip install jsonschema
|
||||
|
||||
|
||||
from .ner_processor import NerProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentProcessor(ABC):
|
||||
# JSON Schema for mapping validation
|
||||
mapping_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"type": {"type": "string"}
|
||||
},
|
||||
"required": ["text", "type"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["entities"]
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||
self.max_chunk_size = 1000 # Maximum number of characters per chunk
|
||||
self.max_retries = 3 # Maximum number of retries for mapping generation
|
||||
self.ner_processor = NerProcessor()
|
||||
|
||||
@abstractmethod
|
||||
def read_content(self) -> str:
|
||||
|
|
@ -53,7 +25,6 @@ class DocumentProcessor(ABC):
|
|||
if not sentence.strip():
|
||||
continue
|
||||
|
||||
# If adding this sentence would exceed the limit, save current chunk and start new one
|
||||
if len(current_chunk) + len(sentence) > self.max_chunk_size and current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = sentence
|
||||
|
|
@ -63,240 +34,32 @@ class DocumentProcessor(ABC):
|
|||
else:
|
||||
current_chunk = sentence
|
||||
|
||||
# Add the last chunk if it's not empty
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def _validate_mapping_format(self, mapping: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate that the mapping follows the required JSON schema format.
|
||||
"""
|
||||
try:
|
||||
validate(instance=mapping, schema=self.mapping_schema)
|
||||
return True
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Mapping validation error: {e}")
|
||||
return False
|
||||
|
||||
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
|
||||
"""Process a single entity type with retry logic"""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
formatted_prompt = prompt_func(chunk)
|
||||
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
|
||||
response = self.ollama_client.generate(formatted_prompt)
|
||||
logger.info(f"Raw response from LLM: {response}")
|
||||
|
||||
# Parse the JSON response into a dictionary
|
||||
mapping = LLMJsonExtractor.parse_raw_json_str(response)
|
||||
logger.info(f"Parsed mapping: {mapping}")
|
||||
|
||||
if mapping and self._validate_mapping_format(mapping):
|
||||
return mapping
|
||||
else:
|
||||
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating {entity_type} mapping on attempt {attempt + 1}: {e}")
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.info("Retrying...")
|
||||
else:
|
||||
logger.error(f"Max retries reached for {entity_type}, returning empty mapping")
|
||||
|
||||
return {}
|
||||
|
||||
def _build_mapping(self, chunk: str) -> list[Dict[str, str]]:
|
||||
"""Build mapping for a single chunk of text with retry logic"""
|
||||
mapping_pipeline = []
|
||||
|
||||
# LLM实体
|
||||
entity_configs = [
|
||||
(get_ner_name_prompt, "people names"),
|
||||
(get_ner_company_prompt, "company names"),
|
||||
(get_ner_address_prompt, "addresses"),
|
||||
(get_ner_project_prompt, "project names"),
|
||||
(get_ner_case_number_prompt, "case numbers")
|
||||
]
|
||||
for prompt_func, entity_type in entity_configs:
|
||||
mapping = self._process_entity_type(chunk, prompt_func, entity_type)
|
||||
if mapping:
|
||||
mapping_pipeline.append(mapping)
|
||||
|
||||
# 正则实体
|
||||
regex_entity_extractors = [
|
||||
extract_id_number_entities,
|
||||
extract_social_credit_code_entities
|
||||
]
|
||||
for extractor in regex_entity_extractors:
|
||||
mapping = extractor(chunk)
|
||||
if mapping:
|
||||
mapping_pipeline.append(mapping)
|
||||
|
||||
return mapping_pipeline
|
||||
|
||||
def _apply_mapping(self, text: str, mapping: Dict[str, str]) -> str:
|
||||
"""Apply the mapping to replace sensitive information"""
|
||||
masked_text = text
|
||||
for original, masked in mapping.items():
|
||||
# Ensure masked value is a string
|
||||
if isinstance(masked, dict):
|
||||
# If it's a dict, use the first value or a default
|
||||
masked = next(iter(masked.values()), "某")
|
||||
elif not isinstance(masked, str):
|
||||
# If it's not a string, convert to string or use default
|
||||
masked = str(masked) if masked is not None else "某"
|
||||
masked_text = masked_text.replace(original, masked)
|
||||
return masked_text
|
||||
|
||||
def _get_next_suffix(self, value: str) -> str:
|
||||
"""Get the next available suffix for a value that already has a suffix"""
|
||||
# Define the sequence of suffixes
|
||||
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||
|
||||
# Check if the value already has a suffix
|
||||
for suffix in suffixes:
|
||||
if value.endswith(suffix):
|
||||
# Find the next suffix in the sequence
|
||||
current_index = suffixes.index(suffix)
|
||||
if current_index + 1 < len(suffixes):
|
||||
return value[:-1] + suffixes[current_index + 1]
|
||||
else:
|
||||
# If we've used all suffixes, start over with the first one
|
||||
return value[:-1] + suffixes[0]
|
||||
|
||||
# If no suffix found, return the value with the first suffix
|
||||
return value + '甲'
|
||||
|
||||
|
||||
|
||||
def _merge_entity_mappings(self, chunk_mappings: list[Dict[str, Any]]) -> list[Dict[str, str]]:
|
||||
"""
|
||||
Merge entity mappings from multiple chunks and remove duplicates.
|
||||
|
||||
Args:
|
||||
chunk_mappings: List of mappings returned from LLM, each containing 'entities' list
|
||||
|
||||
Returns:
|
||||
list[Dict[str, str]]: List of unique entities with text and type
|
||||
"""
|
||||
# Extract all entities from all chunks
|
||||
all_entities = []
|
||||
for mapping in chunk_mappings:
|
||||
if isinstance(mapping, dict) and 'entities' in mapping:
|
||||
entities = mapping['entities']
|
||||
if isinstance(entities, list):
|
||||
all_entities.extend(entities)
|
||||
|
||||
# Remove duplicates based on text content
|
||||
unique_entities = []
|
||||
seen_texts = set()
|
||||
|
||||
for entity in all_entities:
|
||||
if isinstance(entity, dict) and 'text' in entity:
|
||||
text = entity['text'].strip()
|
||||
if text and text not in seen_texts:
|
||||
seen_texts.add(text)
|
||||
unique_entities.append(entity)
|
||||
|
||||
logger.info(f"Merged {len(unique_entities)} unique entities")
|
||||
return unique_entities
|
||||
|
||||
def _generate_masked_mapping(self, unique_entities: list[Dict[str, str]]) -> Dict[str, str]:
|
||||
"""
|
||||
Generate masked names for unique entities.
|
||||
|
||||
Args:
|
||||
unique_entities: List of unique entities with text and type
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Mapping from original text to masked version
|
||||
"""
|
||||
entity_mapping = {}
|
||||
used_masked_names = set()
|
||||
|
||||
for entity in unique_entities:
|
||||
original_text = entity['text'].strip()
|
||||
entity_type = entity.get('type', '')
|
||||
|
||||
# Generate masked name based on entity type
|
||||
if '人名' in entity_type or '英文人名' in entity_type:
|
||||
# For person names, use 某 + suffix pattern
|
||||
base_name = '某'
|
||||
masked_name = base_name
|
||||
counter = 1
|
||||
|
||||
while masked_name in used_masked_names:
|
||||
if counter <= 10:
|
||||
# Use 甲乙丙丁... for first 10
|
||||
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||
masked_name = base_name + suffixes[counter - 1]
|
||||
else:
|
||||
# Use numbers for additional ones
|
||||
masked_name = f"{base_name}{counter}"
|
||||
counter += 1
|
||||
|
||||
elif '公司' in entity_type or 'Company' in entity_type:
|
||||
# For company names, use 某公司 + suffix pattern
|
||||
base_name = '某公司'
|
||||
masked_name = base_name
|
||||
counter = 1
|
||||
|
||||
while masked_name in used_masked_names:
|
||||
if counter <= 10:
|
||||
# Use 甲乙丙丁... for first 10
|
||||
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||
masked_name = base_name + suffixes[counter - 1]
|
||||
else:
|
||||
# Use numbers for additional ones
|
||||
masked_name = f"{base_name}{counter}"
|
||||
counter += 1
|
||||
else:
|
||||
# For other entity types, use generic pattern
|
||||
base_name = '某'
|
||||
masked_name = base_name
|
||||
counter = 1
|
||||
|
||||
while masked_name in used_masked_names:
|
||||
if counter <= 10:
|
||||
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||
masked_name = base_name + suffixes[counter - 1]
|
||||
else:
|
||||
masked_name = f"{base_name}{counter}"
|
||||
counter += 1
|
||||
|
||||
entity_mapping[original_text] = masked_name
|
||||
used_masked_names.add(masked_name)
|
||||
|
||||
logger.info(f"Generated masked mapping for {len(entity_mapping)} entities")
|
||||
return entity_mapping
|
||||
|
||||
def process_content(self, content: str) -> str:
|
||||
"""Process document content by masking sensitive information"""
|
||||
# Split content into sentences
|
||||
sentences = content.split("。")
|
||||
|
||||
# Split sentences into manageable chunks
|
||||
chunks = self._split_into_chunks(sentences)
|
||||
logger.info(f"Split content into {len(chunks)} chunks")
|
||||
|
||||
# Build mapping for each chunk
|
||||
chunk_mappings = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
|
||||
chunk_mapping = self._build_mapping(chunk)
|
||||
|
||||
logger.info(f"Chunk mapping: {chunk_mapping}")
|
||||
chunk_mappings.extend(chunk_mapping)
|
||||
|
||||
# Merge mappings, remove the duplicate ones
|
||||
unique_entities = self._merge_entity_mappings(chunk_mappings)
|
||||
final_mapping = self.ner_processor.process(chunks)
|
||||
|
||||
# Generate masked names for unique entities
|
||||
combined_mapping = self._generate_masked_mapping(unique_entities)
|
||||
|
||||
# Apply the combined mapping to the entire content
|
||||
masked_content = self._apply_mapping(content, combined_mapping)
|
||||
masked_content = self._apply_mapping(content, final_mapping)
|
||||
logger.info("Successfully masked content")
|
||||
|
||||
return masked_content
|
||||
|
|
@ -304,4 +67,4 @@ class DocumentProcessor(ABC):
|
|||
@abstractmethod
|
||||
def save_content(self, content: str) -> None:
|
||||
"""Save processed content"""
|
||||
pass
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -0,0 +1,233 @@
|
|||
from typing import Any, Dict
|
||||
from ..prompts.masking_prompts import get_ner_name_prompt, get_ner_company_prompt, get_ner_address_prompt, get_ner_project_prompt, get_ner_case_number_prompt, get_entity_linkage_prompt
|
||||
import logging
|
||||
import json
|
||||
from ..services.ollama_client import OllamaClient
|
||||
from ...core.config import settings
|
||||
from ..utils.json_extractor import LLMJsonExtractor
|
||||
from ..utils.llm_validator import LLMResponseValidator
|
||||
import re
|
||||
from .regs.entity_regex import extract_id_number_entities, extract_social_credit_code_entities
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NerProcessor:
|
||||
def __init__(self):
|
||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||
self.max_retries = 3
|
||||
|
||||
def _validate_mapping_format(self, mapping: Dict[str, Any]) -> bool:
|
||||
return LLMResponseValidator.validate_entity_extraction(mapping)
|
||||
|
||||
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
formatted_prompt = prompt_func(chunk)
|
||||
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
|
||||
response = self.ollama_client.generate(formatted_prompt)
|
||||
logger.info(f"Raw response from LLM: {response}")
|
||||
|
||||
mapping = LLMJsonExtractor.parse_raw_json_str(response)
|
||||
logger.info(f"Parsed mapping: {mapping}")
|
||||
|
||||
if mapping and self._validate_mapping_format(mapping):
|
||||
return mapping
|
||||
else:
|
||||
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating {entity_type} mapping on attempt {attempt + 1}: {e}")
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.info("Retrying...")
|
||||
else:
|
||||
logger.error(f"Max retries reached for {entity_type}, returning empty mapping")
|
||||
|
||||
return {}
|
||||
|
||||
def build_mapping(self, chunk: str) -> list[Dict[str, str]]:
|
||||
mapping_pipeline = []
|
||||
|
||||
entity_configs = [
|
||||
(get_ner_name_prompt, "people names"),
|
||||
(get_ner_company_prompt, "company names"),
|
||||
(get_ner_address_prompt, "addresses"),
|
||||
(get_ner_project_prompt, "project names"),
|
||||
(get_ner_case_number_prompt, "case numbers")
|
||||
]
|
||||
for prompt_func, entity_type in entity_configs:
|
||||
mapping = self._process_entity_type(chunk, prompt_func, entity_type)
|
||||
if mapping:
|
||||
mapping_pipeline.append(mapping)
|
||||
|
||||
regex_entity_extractors = [
|
||||
extract_id_number_entities,
|
||||
extract_social_credit_code_entities
|
||||
]
|
||||
for extractor in regex_entity_extractors:
|
||||
mapping = extractor(chunk)
|
||||
if mapping and LLMResponseValidator.validate_regex_entity(mapping):
|
||||
mapping_pipeline.append(mapping)
|
||||
elif mapping:
|
||||
logger.warning(f"Invalid regex entity mapping format: {mapping}")
|
||||
|
||||
return mapping_pipeline
|
||||
|
||||
def _merge_entity_mappings(self, chunk_mappings: list[Dict[str, Any]]) -> list[Dict[str, str]]:
|
||||
all_entities = []
|
||||
for mapping in chunk_mappings:
|
||||
if isinstance(mapping, dict) and 'entities' in mapping:
|
||||
entities = mapping['entities']
|
||||
if isinstance(entities, list):
|
||||
all_entities.extend(entities)
|
||||
|
||||
unique_entities = []
|
||||
seen_texts = set()
|
||||
|
||||
for entity in all_entities:
|
||||
if isinstance(entity, dict) and 'text' in entity:
|
||||
text = entity['text'].strip()
|
||||
if text and text not in seen_texts:
|
||||
seen_texts.add(text)
|
||||
unique_entities.append(entity)
|
||||
|
||||
logger.info(f"Merged {len(unique_entities)} unique entities")
|
||||
return unique_entities
|
||||
|
||||
def _generate_masked_mapping(self, unique_entities: list[Dict[str, str]]) -> Dict[str, str]:
|
||||
entity_mapping = {}
|
||||
used_masked_names = set()
|
||||
|
||||
for entity in unique_entities:
|
||||
original_text = entity['text'].strip()
|
||||
entity_type = entity.get('type', '')
|
||||
|
||||
if '人名' in entity_type or '英文人名' in entity_type:
|
||||
base_name = '某'
|
||||
masked_name = base_name
|
||||
counter = 1
|
||||
|
||||
while masked_name in used_masked_names:
|
||||
if counter <= 10:
|
||||
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||
masked_name = base_name + suffixes[counter - 1]
|
||||
else:
|
||||
masked_name = f"{base_name}{counter}"
|
||||
counter += 1
|
||||
|
||||
elif '公司' in entity_type or 'Company' in entity_type:
|
||||
base_name = '某公司'
|
||||
masked_name = base_name
|
||||
counter = 1
|
||||
|
||||
while masked_name in used_masked_names:
|
||||
if counter <= 10:
|
||||
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||
masked_name = base_name + suffixes[counter - 1]
|
||||
else:
|
||||
masked_name = f"{base_name}{counter}"
|
||||
counter += 1
|
||||
else:
|
||||
base_name = '某'
|
||||
masked_name = base_name
|
||||
counter = 1
|
||||
|
||||
while masked_name in used_masked_names:
|
||||
if counter <= 10:
|
||||
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||
masked_name = base_name + suffixes[counter - 1]
|
||||
else:
|
||||
masked_name = f"{base_name}{counter}"
|
||||
counter += 1
|
||||
|
||||
entity_mapping[original_text] = masked_name
|
||||
used_masked_names.add(masked_name)
|
||||
|
||||
logger.info(f"Generated masked mapping for {len(entity_mapping)} entities")
|
||||
return entity_mapping
|
||||
|
||||
def _validate_linkage_format(self, linkage: Dict[str, Any]) -> bool:
|
||||
return LLMResponseValidator.validate_entity_linkage(linkage)
|
||||
|
||||
def _create_entity_linkage(self, unique_entities: list[Dict[str, str]]) -> Dict[str, Any]:
|
||||
linkable_entities = []
|
||||
for entity in unique_entities:
|
||||
entity_type = entity.get('type', '')
|
||||
if any(keyword in entity_type for keyword in ['公司', 'Company', '人名', '英文人名']):
|
||||
linkable_entities.append(entity)
|
||||
|
||||
if not linkable_entities:
|
||||
logger.info("No linkable entities found")
|
||||
return {"entity_groups": []}
|
||||
|
||||
entities_text = "\n".join([
|
||||
f"- {entity['text']} (类型: {entity['type']})"
|
||||
for entity in linkable_entities
|
||||
])
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
formatted_prompt = get_entity_linkage_prompt(entities_text)
|
||||
logger.info(f"Calling ollama to generate entity linkage (attempt {attempt + 1}/{self.max_retries})")
|
||||
response = self.ollama_client.generate(formatted_prompt)
|
||||
logger.info(f"Raw entity linkage response from LLM: {response}")
|
||||
|
||||
linkage = LLMJsonExtractor.parse_raw_json_str(response)
|
||||
logger.info(f"Parsed entity linkage: {linkage}")
|
||||
|
||||
if linkage and self._validate_linkage_format(linkage):
|
||||
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
|
||||
return linkage
|
||||
else:
|
||||
logger.warning(f"Invalid entity linkage format received on attempt {attempt + 1}, retrying...")
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating entity linkage on attempt {attempt + 1}: {e}")
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.info("Retrying...")
|
||||
else:
|
||||
logger.error("Max retries reached for entity linkage, returning empty linkage")
|
||||
|
||||
return {"entity_groups": []}
|
||||
|
||||
def _apply_entity_linkage_to_mapping(self, entity_mapping: Dict[str, str], entity_linkage: Dict[str, Any]) -> Dict[str, str]:
|
||||
updated_mapping = entity_mapping.copy()
|
||||
|
||||
for group in entity_linkage.get('entity_groups', []):
|
||||
group_entities = group.get('entities', [])
|
||||
if not group_entities:
|
||||
continue
|
||||
|
||||
primary_entity = None
|
||||
for entity in group_entities:
|
||||
if entity.get('is_primary', False):
|
||||
primary_entity = entity
|
||||
break
|
||||
|
||||
if not primary_entity and group_entities:
|
||||
primary_entity = group_entities[0]
|
||||
|
||||
if primary_entity:
|
||||
primary_text = primary_entity['text']
|
||||
primary_masked = updated_mapping.get(primary_text)
|
||||
|
||||
if primary_masked:
|
||||
for entity in group_entities:
|
||||
entity_text = entity['text']
|
||||
if entity_text in updated_mapping:
|
||||
updated_mapping[entity_text] = primary_masked
|
||||
logger.info(f"Linked entity '{entity_text}' to '{primary_text}' with masked name '{primary_masked}'")
|
||||
|
||||
return updated_mapping
|
||||
|
||||
def process(self, chunks: list[str]) -> Dict[str, str]:
|
||||
chunk_mappings = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
|
||||
chunk_mapping = self.build_mapping(chunk)
|
||||
logger.info(f"Chunk mapping: {chunk_mapping}")
|
||||
chunk_mappings.extend(chunk_mapping)
|
||||
|
||||
unique_entities = self._merge_entity_mappings(chunk_mappings)
|
||||
entity_linkage = self._create_entity_linkage(unique_entities)
|
||||
combined_mapping = self._generate_masked_mapping(unique_entities)
|
||||
final_mapping = self._apply_entity_linkage_to_mapping(combined_mapping, entity_linkage)
|
||||
|
||||
return final_mapping
|
||||
|
|
@ -161,4 +161,65 @@ def get_ner_case_number_prompt(text: str) -> str:
|
|||
|
||||
请严格按照JSON格式输出结果。
|
||||
""")
|
||||
return prompt.format(text=text)
|
||||
return prompt.format(text=text)
|
||||
|
||||
|
||||
def get_entity_linkage_prompt(entities_text: str) -> str:
|
||||
"""
|
||||
Returns a prompt that identifies related entities and groups them together.
|
||||
|
||||
Args:
|
||||
entities_text (str): The list of entities to be analyzed for linkage
|
||||
|
||||
Returns:
|
||||
str: The formatted prompt that will generate entity linkage information
|
||||
"""
|
||||
prompt = textwrap.dedent("""
|
||||
你是一个专业的法律文本实体关联分析助手。请分析以下实体列表,识别出相互关联的实体(如全称与简称、中文名与英文名等),并将它们分组。
|
||||
|
||||
关联规则:
|
||||
1. 公司名称关联:
|
||||
- 全称与简称(如:"阿里巴巴集团控股有限公司" 与 "阿里巴巴")
|
||||
- 中文名与英文名(如:"腾讯科技有限公司" 与 "Tencent Technology Ltd.")
|
||||
- 母公司与子公司(如:"腾讯" 与 "腾讯音乐")
|
||||
|
||||
|
||||
2. 每个组中应指定一个主要实体(is_primary: true),通常是:
|
||||
- 对于公司:选择最正式的全称
|
||||
- 对于人名:选择最常用的称呼
|
||||
|
||||
待分析实体列表:
|
||||
{entities_text}
|
||||
|
||||
输出格式:
|
||||
{{
|
||||
"entity_groups": [
|
||||
{{
|
||||
"group_id": "group_1",
|
||||
"group_type": "公司名称",
|
||||
"entities": [
|
||||
{{
|
||||
"text": "阿里巴巴集团控股有限公司",
|
||||
"type": "公司名称",
|
||||
"is_primary": true
|
||||
}},
|
||||
{{
|
||||
"text": "阿里巴巴",
|
||||
"type": "公司名称简称",
|
||||
"is_primary": false
|
||||
}}
|
||||
]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
注意事项:
|
||||
1. 只对确实有关联的实体进行分组
|
||||
2. 每个实体只能属于一个组
|
||||
3. 每个组必须有且仅有一个主要实体(is_primary: true)
|
||||
4. 如果实体之间没有明显关联,不要强制分组
|
||||
5. group_type 应该是 "公司名称"
|
||||
|
||||
请严格按照JSON格式输出结果。
|
||||
""")
|
||||
return prompt.format(entities_text=entities_text)
|
||||
|
|
@ -0,0 +1,229 @@
|
|||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from jsonschema import validate, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMResponseValidator:
|
||||
"""Validator for LLM JSON responses with different schemas for different entity types"""
|
||||
|
||||
# Schema for basic entity extraction responses
|
||||
ENTITY_EXTRACTION_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"type": {"type": "string"}
|
||||
},
|
||||
"required": ["text", "type"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["entities"]
|
||||
}
|
||||
|
||||
# Schema for entity linkage responses
|
||||
ENTITY_LINKAGE_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity_groups": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"group_id": {"type": "string"},
|
||||
"group_type": {"type": "string"},
|
||||
"entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"type": {"type": "string"},
|
||||
"is_primary": {"type": "boolean"}
|
||||
},
|
||||
"required": ["text", "type", "is_primary"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["group_id", "group_type", "entities"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["entity_groups"]
|
||||
}
|
||||
|
||||
# Schema for regex-based entity extraction (from entity_regex.py)
|
||||
REGEX_ENTITY_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"type": {"type": "string"}
|
||||
},
|
||||
"required": ["text", "type"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["entities"]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate_entity_extraction(cls, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate entity extraction response from LLM.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
validate(instance=response, schema=cls.ENTITY_EXTRACTION_SCHEMA)
|
||||
return True
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Entity extraction validation error: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def validate_entity_linkage(cls, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate entity linkage response from LLM.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
validate(instance=response, schema=cls.ENTITY_LINKAGE_SCHEMA)
|
||||
return cls._validate_linkage_content(response)
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Entity linkage validation error: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def validate_regex_entity(cls, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate regex-based entity extraction response.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from regex extractors
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
validate(instance=response, schema=cls.REGEX_ENTITY_SCHEMA)
|
||||
return True
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Regex entity validation error: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _validate_linkage_content(cls, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Additional content validation for entity linkage responses.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
|
||||
Returns:
|
||||
bool: True if content is valid, False otherwise
|
||||
"""
|
||||
entity_groups = response.get('entity_groups', [])
|
||||
|
||||
for group in entity_groups:
|
||||
# Validate group type
|
||||
group_type = group.get('group_type', '')
|
||||
if group_type not in ['公司名称', '人名']:
|
||||
logger.warning(f"Invalid group_type: {group_type}")
|
||||
return False
|
||||
|
||||
# Validate entities in group
|
||||
entities = group.get('entities', [])
|
||||
if not entities:
|
||||
logger.warning("Empty entity group found")
|
||||
return False
|
||||
|
||||
# Check that exactly one entity is marked as primary
|
||||
primary_count = sum(1 for entity in entities if entity.get('is_primary', False))
|
||||
if primary_count != 1:
|
||||
logger.warning(f"Group must have exactly one primary entity, found {primary_count}")
|
||||
return False
|
||||
|
||||
# Validate entity types within group
|
||||
for entity in entities:
|
||||
entity_type = entity.get('type', '')
|
||||
if group_type == '公司名称' and not any(keyword in entity_type for keyword in ['公司', 'Company']):
|
||||
logger.warning(f"Company group contains non-company entity: {entity_type}")
|
||||
return False
|
||||
elif group_type == '人名' and not any(keyword in entity_type for keyword in ['人名', '英文人名']):
|
||||
logger.warning(f"Person group contains non-person entity: {entity_type}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def validate_response_by_type(cls, response: Dict[str, Any], response_type: str) -> bool:
|
||||
"""
|
||||
Generic validator that routes to appropriate validation method based on response type.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
response_type: Type of response ('entity_extraction', 'entity_linkage', 'regex_entity')
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
validators = {
|
||||
'entity_extraction': cls.validate_entity_extraction,
|
||||
'entity_linkage': cls.validate_entity_linkage,
|
||||
'regex_entity': cls.validate_regex_entity
|
||||
}
|
||||
|
||||
validator = validators.get(response_type)
|
||||
if not validator:
|
||||
logger.error(f"Unknown response type: {response_type}")
|
||||
return False
|
||||
|
||||
return validator(response)
|
||||
|
||||
@classmethod
|
||||
def get_validation_errors(cls, response: Dict[str, Any], response_type: str) -> Optional[str]:
|
||||
"""
|
||||
Get detailed validation errors for debugging.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
response_type: Type of response
|
||||
|
||||
Returns:
|
||||
Optional[str]: Error message or None if valid
|
||||
"""
|
||||
try:
|
||||
if response_type == 'entity_extraction':
|
||||
validate(instance=response, schema=cls.ENTITY_EXTRACTION_SCHEMA)
|
||||
elif response_type == 'entity_linkage':
|
||||
validate(instance=response, schema=cls.ENTITY_LINKAGE_SCHEMA)
|
||||
if not cls._validate_linkage_content(response):
|
||||
return "Content validation failed for entity linkage"
|
||||
elif response_type == 'regex_entity':
|
||||
validate(instance=response, schema=cls.REGEX_ENTITY_SCHEMA)
|
||||
else:
|
||||
return f"Unknown response type: {response_type}"
|
||||
|
||||
return None
|
||||
except ValidationError as e:
|
||||
return f"Schema validation error: {e}"
|
||||
Loading…
Reference in New Issue