feature-ner-keyword-detect #1
|
|
@ -1,11 +1,13 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from ..prompts.masking_prompts import get_ner_name_prompt, get_ner_company_prompt
|
from ..prompts.masking_prompts import get_ner_name_prompt, get_ner_company_prompt, get_ner_address_prompt, get_ner_project_prompt, get_ner_case_number_prompt
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from ..services.ollama_client import OllamaClient
|
from ..services.ollama_client import OllamaClient
|
||||||
from ...core.config import settings
|
from ...core.config import settings
|
||||||
from ..utils.json_extractor import LLMJsonExtractor
|
from ..utils.json_extractor import LLMJsonExtractor
|
||||||
|
import re
|
||||||
|
from .regs.entity_regex import extract_id_number_entities, extract_social_credit_code_entities
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -73,14 +75,12 @@ class DocumentProcessor(ABC):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _build_mapping(self, chunk: str) -> list[Dict[str, str]]:
|
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
|
||||||
"""Build mapping for a single chunk of text with retry logic"""
|
"""Process a single entity type with retry logic"""
|
||||||
mapping_pipeline = []
|
|
||||||
# Build people name mapping
|
|
||||||
for attempt in range(self.max_retries):
|
for attempt in range(self.max_retries):
|
||||||
try:
|
try:
|
||||||
formatted_prompt = get_ner_name_prompt(chunk)
|
formatted_prompt = prompt_func(chunk)
|
||||||
logger.info(f"Calling ollama to generate mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
|
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
|
||||||
response = self.ollama_client.generate(formatted_prompt)
|
response = self.ollama_client.generate(formatted_prompt)
|
||||||
logger.info(f"Raw response from LLM: {response}")
|
logger.info(f"Raw response from LLM: {response}")
|
||||||
|
|
||||||
|
|
@ -89,40 +89,45 @@ class DocumentProcessor(ABC):
|
||||||
logger.info(f"Parsed mapping: {mapping}")
|
logger.info(f"Parsed mapping: {mapping}")
|
||||||
|
|
||||||
if mapping and self._validate_mapping_format(mapping):
|
if mapping and self._validate_mapping_format(mapping):
|
||||||
mapping_pipeline.append(mapping)
|
return mapping
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
|
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating mapping on attempt {attempt + 1}: {e}")
|
logger.error(f"Error generating {entity_type} mapping on attempt {attempt + 1}: {e}")
|
||||||
if attempt < self.max_retries - 1:
|
if attempt < self.max_retries - 1:
|
||||||
logger.info("Retrying...")
|
logger.info("Retrying...")
|
||||||
else:
|
else:
|
||||||
logger.error("Max retries reached, returning empty mapping")
|
logger.error(f"Max retries reached for {entity_type}, returning empty mapping")
|
||||||
return {}
|
|
||||||
|
|
||||||
# Build company name mapping
|
return {}
|
||||||
for attempt in range(self.max_retries):
|
|
||||||
try:
|
def _build_mapping(self, chunk: str) -> list[Dict[str, str]]:
|
||||||
formatted_prompt = get_ner_company_prompt(chunk)
|
"""Build mapping for a single chunk of text with retry logic"""
|
||||||
logger.info(f"Calling ollama to generate mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
|
mapping_pipeline = []
|
||||||
response = self.ollama_client.generate(formatted_prompt)
|
|
||||||
logger.info(f"Raw response from LLM: {response}")
|
# LLM实体
|
||||||
mapping = LLMJsonExtractor.parse_raw_json_str(response)
|
entity_configs = [
|
||||||
logger.info(f"Parsed mapping: {mapping}")
|
(get_ner_name_prompt, "people names"),
|
||||||
|
(get_ner_company_prompt, "company names"),
|
||||||
|
(get_ner_address_prompt, "addresses"),
|
||||||
|
(get_ner_project_prompt, "project names"),
|
||||||
|
(get_ner_case_number_prompt, "case numbers")
|
||||||
|
]
|
||||||
|
for prompt_func, entity_type in entity_configs:
|
||||||
|
mapping = self._process_entity_type(chunk, prompt_func, entity_type)
|
||||||
|
if mapping:
|
||||||
|
mapping_pipeline.append(mapping)
|
||||||
|
|
||||||
|
# 正则实体
|
||||||
|
regex_entity_extractors = [
|
||||||
|
extract_id_number_entities,
|
||||||
|
extract_social_credit_code_entities
|
||||||
|
]
|
||||||
|
for extractor in regex_entity_extractors:
|
||||||
|
mapping = extractor(chunk)
|
||||||
|
if mapping:
|
||||||
|
mapping_pipeline.append(mapping)
|
||||||
|
|
||||||
if mapping and self._validate_mapping_format(mapping):
|
|
||||||
mapping_pipeline.append(mapping)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating mapping on attempt {attempt + 1}: {e}")
|
|
||||||
if attempt < self.max_retries - 1:
|
|
||||||
logger.info("Retrying...")
|
|
||||||
else:
|
|
||||||
logger.error("Max retries reached, returning empty mapping")
|
|
||||||
return {}
|
|
||||||
return mapping_pipeline
|
return mapping_pipeline
|
||||||
|
|
||||||
def _apply_mapping(self, text: str, mapping: Dict[str, str]) -> str:
|
def _apply_mapping(self, text: str, mapping: Dict[str, str]) -> str:
|
||||||
|
|
@ -158,35 +163,108 @@ class DocumentProcessor(ABC):
|
||||||
# If no suffix found, return the value with the first suffix
|
# If no suffix found, return the value with the first suffix
|
||||||
return value + '甲'
|
return value + '甲'
|
||||||
|
|
||||||
def _merge_mappings(self, existing: Dict[str, str], new: Dict[str, str]) -> Dict[str, str]:
|
|
||||||
|
|
||||||
|
def _merge_entity_mappings(self, chunk_mappings: list[Dict[str, Any]]) -> list[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Merge two mappings following the rules:
|
Merge entity mappings from multiple chunks and remove duplicates.
|
||||||
1. If key exists in existing, keep existing value
|
|
||||||
2. If value exists in existing:
|
Args:
|
||||||
- If value ends with a suffix (甲乙丙丁...), add next suffix
|
chunk_mappings: List of mappings returned from LLM, each containing 'entities' list
|
||||||
- If no suffix, add '甲'
|
|
||||||
|
Returns:
|
||||||
|
list[Dict[str, str]]: List of unique entities with text and type
|
||||||
"""
|
"""
|
||||||
result = existing.copy()
|
# Extract all entities from all chunks
|
||||||
|
all_entities = []
|
||||||
|
for mapping in chunk_mappings:
|
||||||
|
if isinstance(mapping, dict) and 'entities' in mapping:
|
||||||
|
entities = mapping['entities']
|
||||||
|
if isinstance(entities, list):
|
||||||
|
all_entities.extend(entities)
|
||||||
|
|
||||||
# Get all existing values
|
# Remove duplicates based on text content
|
||||||
existing_values = set(result.values())
|
unique_entities = []
|
||||||
|
seen_texts = set()
|
||||||
|
|
||||||
for key, value in new.items():
|
for entity in all_entities:
|
||||||
if key in result:
|
if isinstance(entity, dict) and 'text' in entity:
|
||||||
# Rule 1: Keep existing value if key exists
|
text = entity['text'].strip()
|
||||||
continue
|
if text and text not in seen_texts:
|
||||||
|
seen_texts.add(text)
|
||||||
|
unique_entities.append(entity)
|
||||||
|
|
||||||
if value in existing_values:
|
logger.info(f"Merged {len(unique_entities)} unique entities")
|
||||||
# Rule 2: Handle duplicate values
|
return unique_entities
|
||||||
new_value = self._get_next_suffix(value)
|
|
||||||
result[key] = new_value
|
def _generate_masked_mapping(self, unique_entities: list[Dict[str, str]]) -> Dict[str, str]:
|
||||||
existing_values.add(new_value)
|
"""
|
||||||
|
Generate masked names for unique entities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unique_entities: List of unique entities with text and type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, str]: Mapping from original text to masked version
|
||||||
|
"""
|
||||||
|
entity_mapping = {}
|
||||||
|
used_masked_names = set()
|
||||||
|
|
||||||
|
for entity in unique_entities:
|
||||||
|
original_text = entity['text'].strip()
|
||||||
|
entity_type = entity.get('type', '')
|
||||||
|
|
||||||
|
# Generate masked name based on entity type
|
||||||
|
if '人名' in entity_type or '英文人名' in entity_type:
|
||||||
|
# For person names, use 某 + suffix pattern
|
||||||
|
base_name = '某'
|
||||||
|
masked_name = base_name
|
||||||
|
counter = 1
|
||||||
|
|
||||||
|
while masked_name in used_masked_names:
|
||||||
|
if counter <= 10:
|
||||||
|
# Use 甲乙丙丁... for first 10
|
||||||
|
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||||
|
masked_name = base_name + suffixes[counter - 1]
|
||||||
|
else:
|
||||||
|
# Use numbers for additional ones
|
||||||
|
masked_name = f"{base_name}{counter}"
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
elif '公司' in entity_type or 'Company' in entity_type:
|
||||||
|
# For company names, use 某公司 + suffix pattern
|
||||||
|
base_name = '某公司'
|
||||||
|
masked_name = base_name
|
||||||
|
counter = 1
|
||||||
|
|
||||||
|
while masked_name in used_masked_names:
|
||||||
|
if counter <= 10:
|
||||||
|
# Use 甲乙丙丁... for first 10
|
||||||
|
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||||
|
masked_name = base_name + suffixes[counter - 1]
|
||||||
|
else:
|
||||||
|
# Use numbers for additional ones
|
||||||
|
masked_name = f"{base_name}{counter}"
|
||||||
|
counter += 1
|
||||||
else:
|
else:
|
||||||
# No conflict, add as is
|
# For other entity types, use generic pattern
|
||||||
result[key] = value
|
base_name = '某'
|
||||||
existing_values.add(value)
|
masked_name = base_name
|
||||||
|
counter = 1
|
||||||
|
|
||||||
return result
|
while masked_name in used_masked_names:
|
||||||
|
if counter <= 10:
|
||||||
|
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||||
|
masked_name = base_name + suffixes[counter - 1]
|
||||||
|
else:
|
||||||
|
masked_name = f"{base_name}{counter}"
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
entity_mapping[original_text] = masked_name
|
||||||
|
used_masked_names.add(masked_name)
|
||||||
|
|
||||||
|
logger.info(f"Generated masked mapping for {len(entity_mapping)} entities")
|
||||||
|
return entity_mapping
|
||||||
|
|
||||||
def process_content(self, content: str) -> str:
|
def process_content(self, content: str) -> str:
|
||||||
"""Process document content by masking sensitive information"""
|
"""Process document content by masking sensitive information"""
|
||||||
|
|
@ -198,21 +276,25 @@ class DocumentProcessor(ABC):
|
||||||
logger.info(f"Split content into {len(chunks)} chunks")
|
logger.info(f"Split content into {len(chunks)} chunks")
|
||||||
|
|
||||||
# Build mapping for each chunk
|
# Build mapping for each chunk
|
||||||
combined_mapping = {}
|
chunk_mappings = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
|
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
|
||||||
chunk_mapping = self._build_mapping(chunk)
|
chunk_mapping = self._build_mapping(chunk)
|
||||||
# if chunk_mapping: # Only update if we got a valid mapping
|
|
||||||
# combined_mapping = self._merge_mappings(combined_mapping, chunk_mapping)
|
logger.info(f"Chunk mapping: {chunk_mapping}")
|
||||||
# else:
|
chunk_mappings.extend(chunk_mapping)
|
||||||
# logger.warning(f"Failed to generate mapping for chunk {i+1}")
|
|
||||||
|
# Merge mappings, remove the duplicate ones
|
||||||
|
unique_entities = self._merge_entity_mappings(chunk_mappings)
|
||||||
|
|
||||||
|
# Generate masked names for unique entities
|
||||||
|
combined_mapping = self._generate_masked_mapping(unique_entities)
|
||||||
|
|
||||||
# Apply the combined mapping to the entire content
|
# Apply the combined mapping to the entire content
|
||||||
# masked_content = self._apply_mapping(content, combined_mapping)
|
masked_content = self._apply_mapping(content, combined_mapping)
|
||||||
logger.info("Successfully masked content")
|
logger.info("Successfully masked content")
|
||||||
|
|
||||||
# return masked_content
|
return masked_content
|
||||||
return ""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_content(self, content: str) -> None:
|
def save_content(self, content: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
import re
|
||||||
|
|
||||||
|
def extract_id_number_entities(chunk: str) -> dict:
|
||||||
|
"""Extract Chinese ID numbers and return in entity mapping format."""
|
||||||
|
id_pattern = r'\b\d{17}[\dXx]\b'
|
||||||
|
entities = []
|
||||||
|
for match in re.findall(id_pattern, chunk):
|
||||||
|
entities.append({"text": match, "type": "身份证号"})
|
||||||
|
return {"entities": entities} if entities else {}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_social_credit_code_entities(chunk: str) -> dict:
|
||||||
|
"""Extract social credit codes and return in entity mapping format."""
|
||||||
|
credit_pattern = r'\b[0-9A-Z]{18}\b'
|
||||||
|
entities = []
|
||||||
|
for match in re.findall(credit_pattern, chunk):
|
||||||
|
entities.append({"text": match, "type": "统一社会信用代码"})
|
||||||
|
return {"entities": entities} if entities else {}
|
||||||
|
|
@ -79,3 +79,86 @@ def get_ner_company_prompt(text: str) -> str:
|
||||||
return prompt.format(text=text)
|
return prompt.format(text=text)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ner_address_prompt(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Returns a prompt that generates a mapping of original addresses to their masked versions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The input text to be analyzed for masking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The formatted prompt that will generate a mapping dictionary
|
||||||
|
"""
|
||||||
|
prompt = textwrap.dedent("""
|
||||||
|
你是一个专业的法律文本实体识别助手。请从以下文本中抽取出所有需要脱敏的敏感信息,并按照指定的类别进行分类。请严格按照JSON格式输出结果。
|
||||||
|
|
||||||
|
实体类别包括:
|
||||||
|
- 地址
|
||||||
|
|
||||||
|
|
||||||
|
待处理文本:
|
||||||
|
{text}
|
||||||
|
|
||||||
|
输出格式:
|
||||||
|
{{
|
||||||
|
"entities": [
|
||||||
|
{{"text": "原始文本内容", "type": "地址"}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
请严格按照JSON格式输出结果。
|
||||||
|
""")
|
||||||
|
return prompt.format(text=text)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ner_project_prompt(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Returns a prompt that generates a mapping of original project names to their masked versions.
|
||||||
|
"""
|
||||||
|
prompt = textwrap.dedent("""
|
||||||
|
你是一个专业的法律文本实体识别助手。请从以下文本中抽取出所有需要脱敏的敏感信息,并按照指定的类别进行分类。请严格按照JSON格式输出结果。
|
||||||
|
|
||||||
|
实体类别包括:
|
||||||
|
- 项目名
|
||||||
|
|
||||||
|
待处理文本:
|
||||||
|
{text}
|
||||||
|
|
||||||
|
输出格式:
|
||||||
|
{{
|
||||||
|
"entities": [
|
||||||
|
{{"text": "原始文本内容", "type": "项目名"}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
请严格按照JSON格式输出结果。
|
||||||
|
""")
|
||||||
|
return prompt.format(text=text)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ner_case_number_prompt(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Returns a prompt that generates a mapping of original case numbers to their masked versions.
|
||||||
|
"""
|
||||||
|
prompt = textwrap.dedent("""
|
||||||
|
你是一个专业的法律文本实体识别助手。请从以下文本中抽取出所有需要脱敏的敏感信息,并按照指定的类别进行分类。请严格按照JSON格式输出结果。
|
||||||
|
|
||||||
|
实体类别包括:
|
||||||
|
- 案号
|
||||||
|
|
||||||
|
待处理文本:
|
||||||
|
{text}
|
||||||
|
|
||||||
|
输出格式:
|
||||||
|
{{
|
||||||
|
"entities": [
|
||||||
|
{{"text": "原始文本内容", "type": "案号"}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
请严格按照JSON格式输出结果。
|
||||||
|
""")
|
||||||
|
return prompt.format(text=text)
|
||||||
Loading…
Reference in New Issue