将题词抽象出来

This commit is contained in:
oliviamn 2025-05-06 00:13:19 +08:00
parent 815427a509
commit 7d0be5aa8a
3 changed files with 63 additions and 32 deletions

View File

@ -5,6 +5,12 @@ from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedData
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.config.enums import SupportedPdfParseMethod
from prompts.masking_prompts import get_masking_prompt
import logging
from services.ollama_client import OllamaClient
from config.settings import settings
logger = logging.getLogger(__name__)
class PdfDocumentProcessor(DocumentProcessor):
def __init__(self, input_path: str, output_path: str):
@ -29,12 +35,16 @@ class PdfDocumentProcessor(DocumentProcessor):
self.work_local_image_dir = os.path.join(self.work_dir, "images")
self.work_image_dir = os.path.basename(self.work_local_image_dir)
os.makedirs(self.work_local_image_dir, exist_ok=True)
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
def read_content(self) -> bytes:
with open(self.input_path, 'rb') as file:
return file.read()
def process_content(self, content: bytes) -> dict:
logger.info("Starting PDF content processing")
# Initialize writers
image_writer = FileBasedDataWriter(self.work_local_image_dir)
md_writer = FileBasedDataWriter(self.work_dir)
@ -42,6 +52,7 @@ class PdfDocumentProcessor(DocumentProcessor):
# Create Dataset Instance
ds = PymuDocDataset(content)
logger.info("Classifying PDF type: %s", ds.classify())
# Process based on PDF type
if ds.classify() == SupportedPdfParseMethod.OCR:
infer_result = ds.apply(doc_analyze, ocr=True)
@ -49,7 +60,8 @@ class PdfDocumentProcessor(DocumentProcessor):
else:
infer_result = ds.apply(doc_analyze, ocr=False)
pipe_result = infer_result.pipe_txt_mode(image_writer)
logger.info("Generating all outputs")
# Generate all outputs
infer_result.draw_model(os.path.join(self.work_dir, f"{self.name_without_suff}_model.pdf"))
model_inference_result = infer_result.get_infer_res()
@ -66,16 +78,21 @@ class PdfDocumentProcessor(DocumentProcessor):
middle_json = pipe_result.get_middle_json()
pipe_result.dump_middle_json(md_writer, f'{self.name_without_suff}_middle.json')
return md_content
logger.info("Masking content")
formatted_prompt = get_masking_prompt(md_content)
logger.info("Calling ollama to generate response")
response = self.ollama_client.generate(formatted_prompt)
logger.info("Response generated")
return response
return {
'markdown': md_content,
'content_list': content_list,
'middle_json': middle_json,
'model_inference': model_inference_result
}
def save_content(self, content: dict) -> None:
# Content is already saved during processing
with open(self.output_path, 'w', encoding='utf-8') as file:
def save_content(self, content: str) -> None:
# Ensure output path has .md extension
output_dir = os.path.dirname(self.output_path)
base_name = os.path.splitext(os.path.basename(self.output_path))[0]
md_output_path = os.path.join(output_dir, f"{base_name}.md")
logger.info(f"Saving masked content to: {md_output_path}")
with open(md_output_path, 'w', encoding='utf-8') as file:
file.write(content)

View File

@ -1,7 +1,7 @@
from document_handlers.document_processor import DocumentProcessor
from services.ollama_client import OllamaClient
import textwrap
import logging
from prompts.masking_prompts import get_masking_prompt
from config.settings import settings
logger = logging.getLogger(__name__)
@ -16,27 +16,8 @@ class TxtDocumentProcessor(DocumentProcessor):
return file.read()
def process_content(self, content: str) -> str:
prompt = textwrap.dedent("""
您是一位专业的法律文档脱敏专家请按照以下规则对文本进行脱敏处理
规则
1. 人名
- 两字名改为"姓+某"张三 张某
- 三字名改为"姓+某某"张三丰 张某某
2. 公司名
- 保留地理位置信息北京上海等
- 保留公司类型有限公司股份公司等
- ""替换核心名称
3. 保持原文其他部分不变
4. 确保脱敏后的文本保持原有的语言流畅性和可读性
输入文本
{text}
请直接输出脱敏后的文本无需解释或其他备注
""")
formatted_prompt = prompt.format(text=content)
formatted_prompt = get_masking_prompt(content)
response = self.ollama_client.generate(formatted_prompt)
logger.debug(f"Processed content: {response}")
return response

View File

@ -0,0 +1,33 @@
import textwrap
def get_masking_prompt(text: str) -> str:
"""
Returns the prompt for masking sensitive information in legal documents.
Args:
text (str): The input text to be masked
Returns:
str: The formatted prompt with the input text
"""
prompt = textwrap.dedent("""
您是一位专业的法律文档脱敏专家请按照以下规则对文本进行脱敏处理
规则
1. 人名
- 两字名改为"姓+某"张三 张某
- 三字名改为"姓+某某"张三丰 张某某
2. 公司名
- 保留地理位置信息北京上海等
- 保留公司类型有限公司股份公司等
- ""替换核心名称
3. 保持原文其他部分不变
4. 确保脱敏后的文本保持原有的语言流畅性和可读性
输入文本
{text}
请直接输出脱敏后的文本无需解释或其他备注
""")
return prompt.format(text=text)