129 lines
3.6 KiB
Python
129 lines
3.6 KiB
Python
"""
|
||
Tests for the refactored NerProcessor.
|
||
"""
|
||
|
||
import pytest
|
||
import sys
|
||
import os
|
||
|
||
# Add the backend directory to the Python path
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||
|
||
from app.core.document_handlers.ner_processor_refactored import NerProcessorRefactored
|
||
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker, EnglishNameMasker
|
||
from app.core.document_handlers.maskers.id_masker import IDMasker
|
||
from app.core.document_handlers.maskers.case_masker import CaseMasker
|
||
|
||
|
||
def test_chinese_name_masker():
|
||
"""Test Chinese name masker"""
|
||
masker = ChineseNameMasker()
|
||
|
||
# Test basic masking
|
||
result1 = masker.mask("李强")
|
||
assert result1 == "李Q"
|
||
|
||
result2 = masker.mask("张韶涵")
|
||
assert result2 == "张SH"
|
||
|
||
result3 = masker.mask("张若宇")
|
||
assert result3 == "张RY"
|
||
|
||
result4 = masker.mask("白锦程")
|
||
assert result4 == "白JC"
|
||
|
||
# Test duplicate handling
|
||
result5 = masker.mask("李强") # Should get a number
|
||
assert result5 == "李Q2"
|
||
|
||
print(f"Chinese name masking tests passed")
|
||
|
||
|
||
def test_english_name_masker():
|
||
"""Test English name masker"""
|
||
masker = EnglishNameMasker()
|
||
|
||
result = masker.mask("John Smith")
|
||
assert result == "J*** S***"
|
||
|
||
result2 = masker.mask("Mary Jane Watson")
|
||
assert result2 == "M*** J*** W***"
|
||
|
||
print(f"English name masking tests passed")
|
||
|
||
|
||
def test_id_masker():
|
||
"""Test ID masker"""
|
||
masker = IDMasker()
|
||
|
||
# Test ID number
|
||
result1 = masker.mask("310103198802080000")
|
||
assert result1 == "310103XXXXXXXXXXXX"
|
||
assert len(result1) == 18
|
||
|
||
# Test social credit code
|
||
result2 = masker.mask("9133021276453538XT")
|
||
assert result2 == "913302XXXXXXXXXXXX"
|
||
assert len(result2) == 18
|
||
|
||
print(f"ID masking tests passed")
|
||
|
||
|
||
def test_case_masker():
|
||
"""Test case masker"""
|
||
masker = CaseMasker()
|
||
|
||
result1 = masker.mask("(2022)京 03 民终 3852 号")
|
||
assert "***号" in result1
|
||
|
||
result2 = masker.mask("(2020)京0105 民初69754 号")
|
||
assert "***号" in result2
|
||
|
||
print(f"Case masking tests passed")
|
||
|
||
|
||
def test_masker_factory():
|
||
"""Test masker factory"""
|
||
from app.core.document_handlers.masker_factory import MaskerFactory
|
||
|
||
# Test creating maskers
|
||
chinese_masker = MaskerFactory.create_masker('chinese_name')
|
||
assert isinstance(chinese_masker, ChineseNameMasker)
|
||
|
||
english_masker = MaskerFactory.create_masker('english_name')
|
||
assert isinstance(english_masker, EnglishNameMasker)
|
||
|
||
id_masker = MaskerFactory.create_masker('id')
|
||
assert isinstance(id_masker, IDMasker)
|
||
|
||
case_masker = MaskerFactory.create_masker('case')
|
||
assert isinstance(case_masker, CaseMasker)
|
||
|
||
print(f"Masker factory tests passed")
|
||
|
||
|
||
def test_refactored_processor_initialization():
|
||
"""Test that the refactored processor can be initialized"""
|
||
try:
|
||
processor = NerProcessorRefactored()
|
||
assert processor is not None
|
||
assert hasattr(processor, 'maskers')
|
||
assert len(processor.maskers) > 0
|
||
print(f"Refactored processor initialization test passed")
|
||
except Exception as e:
|
||
print(f"Refactored processor initialization failed: {e}")
|
||
# This might fail if Ollama is not running, which is expected in test environment
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("Running refactored NerProcessor tests...")
|
||
|
||
test_chinese_name_masker()
|
||
test_english_name_masker()
|
||
test_id_masker()
|
||
test_case_masker()
|
||
test_masker_factory()
|
||
test_refactored_processor_initialization()
|
||
|
||
print("All refactored NerProcessor tests completed!")
|