Compare commits
No commits in common. "main" and "feature-seperate-mineru" have entirely different histories.
main
...
feature-se
|
|
@ -0,0 +1,19 @@
|
|||
# Storage paths
|
||||
OBJECT_STORAGE_PATH=/path/to/mounted/object/storage
|
||||
TARGET_DIRECTORY_PATH=/path/to/target/directory
|
||||
|
||||
# Ollama API Configuration
|
||||
OLLAMA_API_URL=https://api.ollama.com
|
||||
OLLAMA_API_KEY=your_api_key_here
|
||||
OLLAMA_MODEL=llama2
|
||||
|
||||
# Application Settings
|
||||
MONITOR_INTERVAL=5
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
LOG_FILE=app.log
|
||||
|
||||
# Optional: Additional security settings
|
||||
# MAX_FILE_SIZE=10485760 # 10MB in bytes
|
||||
# ALLOWED_FILE_TYPES=.txt,.doc,.docx,.pdf
|
||||
|
|
@ -1,206 +0,0 @@
|
|||
# Unified Docker Compose Setup
|
||||
|
||||
This project now includes a unified Docker Compose configuration that allows all services (mineru, backend, frontend) to run together and communicate using service names.
|
||||
|
||||
## Architecture
|
||||
|
||||
The unified setup includes the following services:
|
||||
|
||||
- **mineru-api**: Document processing service (port 8001)
|
||||
- **backend-api**: Main API service (port 8000)
|
||||
- **celery-worker**: Background task processor
|
||||
- **redis**: Message broker for Celery
|
||||
- **frontend**: React frontend application (port 3000)
|
||||
|
||||
## Network Configuration
|
||||
|
||||
All services are connected through a custom bridge network called `app-network`, allowing them to communicate using service names:
|
||||
|
||||
- Backend → Mineru: `http://mineru-api:8000`
|
||||
- Frontend → Backend: `http://localhost:8000/api/v1` (external access)
|
||||
- Backend → Redis: `redis://redis:6379/0`
|
||||
|
||||
## Usage
|
||||
|
||||
### Starting all services
|
||||
|
||||
```bash
|
||||
# From the root directory
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
### Starting specific services
|
||||
|
||||
```bash
|
||||
# Start only backend and mineru
|
||||
docker-compose up -d backend-api mineru-api redis
|
||||
|
||||
# Start only frontend and backend
|
||||
docker-compose up -d frontend backend-api redis
|
||||
```
|
||||
|
||||
### Stopping services
|
||||
|
||||
```bash
|
||||
# Stop all services
|
||||
docker-compose down
|
||||
|
||||
# Stop and remove volumes
|
||||
docker-compose down -v
|
||||
```
|
||||
|
||||
### Viewing logs
|
||||
|
||||
```bash
|
||||
# View all logs
|
||||
docker-compose logs -f
|
||||
|
||||
# View specific service logs
|
||||
docker-compose logs -f backend-api
|
||||
docker-compose logs -f mineru-api
|
||||
docker-compose logs -f frontend
|
||||
```
|
||||
|
||||
## Building Services
|
||||
|
||||
### Building all services
|
||||
|
||||
```bash
|
||||
# Build all services
|
||||
docker-compose build
|
||||
|
||||
# Build and start all services
|
||||
docker-compose up -d --build
|
||||
```
|
||||
|
||||
### Building individual services
|
||||
|
||||
```bash
|
||||
# Build only backend
|
||||
docker-compose build backend-api
|
||||
|
||||
# Build only frontend
|
||||
docker-compose build frontend
|
||||
|
||||
# Build only mineru
|
||||
docker-compose build mineru-api
|
||||
|
||||
# Build multiple specific services
|
||||
docker-compose build backend-api frontend celery-worker
|
||||
```
|
||||
|
||||
### Building and restarting specific services
|
||||
|
||||
```bash
|
||||
# Build and restart only backend
|
||||
docker-compose build backend-api
|
||||
docker-compose up -d backend-api
|
||||
|
||||
# Or combine in one command
|
||||
docker-compose up -d --build backend-api
|
||||
|
||||
# Build and restart backend and celery worker
|
||||
docker-compose up -d --build backend-api celery-worker
|
||||
```
|
||||
|
||||
### Force rebuild (no cache)
|
||||
|
||||
```bash
|
||||
# Force rebuild all services
|
||||
docker-compose build --no-cache
|
||||
|
||||
# Force rebuild specific service
|
||||
docker-compose build --no-cache backend-api
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
The unified setup uses environment variables from the individual service `.env` files:
|
||||
|
||||
- `./backend/.env` - Backend configuration
|
||||
- `./frontend/.env` - Frontend configuration
|
||||
- `./mineru/.env` - Mineru configuration (if exists)
|
||||
|
||||
### Key Configuration Changes
|
||||
|
||||
1. **Backend Configuration** (`backend/app/core/config.py`):
|
||||
```python
|
||||
MINERU_API_URL: str = "http://mineru-api:8000"
|
||||
```
|
||||
|
||||
2. **Frontend Configuration**:
|
||||
```javascript
|
||||
REACT_APP_API_BASE_URL=http://localhost:8000/api/v1
|
||||
```
|
||||
|
||||
## Service Dependencies
|
||||
|
||||
- `backend-api` depends on `redis` and `mineru-api`
|
||||
- `celery-worker` depends on `redis` and `backend-api`
|
||||
- `frontend` depends on `backend-api`
|
||||
|
||||
## Port Mapping
|
||||
|
||||
- **Frontend**: `http://localhost:3000`
|
||||
- **Backend API**: `http://localhost:8000`
|
||||
- **Mineru API**: `http://localhost:8001`
|
||||
- **Redis**: `localhost:6379`
|
||||
|
||||
## Health Checks
|
||||
|
||||
The mineru-api service includes a health check that verifies the service is running properly.
|
||||
|
||||
## Development vs Production
|
||||
|
||||
For development, you can still use the individual docker-compose files in each service directory. The unified setup is ideal for:
|
||||
|
||||
- Production deployments
|
||||
- End-to-end testing
|
||||
- Simplified development environment
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Service Communication Issues
|
||||
|
||||
If services can't communicate:
|
||||
|
||||
1. Check if all services are running: `docker-compose ps`
|
||||
2. Verify network connectivity: `docker network ls`
|
||||
3. Check service logs: `docker-compose logs [service-name]`
|
||||
|
||||
### Port Conflicts
|
||||
|
||||
If you get port conflicts, you can modify the port mappings in the `docker-compose.yml` file:
|
||||
|
||||
```yaml
|
||||
ports:
|
||||
- "8002:8000" # Change external port
|
||||
```
|
||||
|
||||
### Volume Issues
|
||||
|
||||
Make sure the storage directories exist:
|
||||
|
||||
```bash
|
||||
mkdir -p backend/storage
|
||||
mkdir -p mineru/storage/uploads
|
||||
mkdir -p mineru/storage/processed
|
||||
```
|
||||
|
||||
## Migration from Individual Compose Files
|
||||
|
||||
If you were previously using individual docker-compose files:
|
||||
|
||||
1. Stop all individual services:
|
||||
```bash
|
||||
cd backend && docker-compose down
|
||||
cd ../frontend && docker-compose down
|
||||
cd ../mineru && docker-compose down
|
||||
```
|
||||
|
||||
2. Start the unified setup:
|
||||
```bash
|
||||
cd .. && docker-compose up -d
|
||||
```
|
||||
|
||||
The unified setup maintains the same functionality while providing better service discovery and networking.
|
||||
|
|
@ -1,399 +0,0 @@
|
|||
# Docker Image Migration Guide
|
||||
|
||||
This guide explains how to export your built Docker images, transfer them to another environment, and run them without rebuilding.
|
||||
|
||||
## Overview
|
||||
|
||||
The migration process involves:
|
||||
1. **Export**: Save built images to tar files
|
||||
2. **Transfer**: Copy tar files to target environment
|
||||
3. **Import**: Load images on target environment
|
||||
4. **Run**: Start services with imported images
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Source Environment (where images are built)
|
||||
- Docker installed and running
|
||||
- All services built and working
|
||||
- Sufficient disk space for image export
|
||||
|
||||
### Target Environment (where images will run)
|
||||
- Docker installed and running
|
||||
- Sufficient disk space for image import
|
||||
- Network access to source environment (or USB drive)
|
||||
|
||||
## Step 1: Export Docker Images
|
||||
|
||||
### 1.1 List Current Images
|
||||
|
||||
First, check what images you have:
|
||||
|
||||
```bash
|
||||
docker images --format "table {{.Repository}}\t{{.Tag}}\t{{.ID}}\t{{.Size}}"
|
||||
```
|
||||
|
||||
You should see images like:
|
||||
- `legal-doc-masker-backend-api`
|
||||
- `legal-doc-masker-frontend`
|
||||
- `legal-doc-masker-mineru-api`
|
||||
- `redis:alpine`
|
||||
|
||||
### 1.2 Export Individual Images
|
||||
|
||||
Create a directory for exports:
|
||||
|
||||
```bash
|
||||
mkdir -p docker-images-export
|
||||
cd docker-images-export
|
||||
```
|
||||
|
||||
Export each image:
|
||||
|
||||
```bash
|
||||
# Export backend image
|
||||
docker save legal-doc-masker-backend-api:latest -o backend-api.tar
|
||||
|
||||
# Export frontend image
|
||||
docker save legal-doc-masker-frontend:latest -o frontend.tar
|
||||
|
||||
# Export mineru image
|
||||
docker save legal-doc-masker-mineru-api:latest -o mineru-api.tar
|
||||
|
||||
# Export redis image (if not using official)
|
||||
docker save redis:alpine -o redis.tar
|
||||
```
|
||||
|
||||
### 1.3 Export All Images at Once (Alternative)
|
||||
|
||||
If you want to export all images in one command:
|
||||
|
||||
```bash
|
||||
# Export all project images
|
||||
docker save \
|
||||
legal-doc-masker-backend-api:latest \
|
||||
legal-doc-masker-frontend:latest \
|
||||
legal-doc-masker-mineru-api:latest \
|
||||
redis:alpine \
|
||||
-o legal-doc-masker-all.tar
|
||||
```
|
||||
|
||||
### 1.4 Verify Export Files
|
||||
|
||||
Check the exported files:
|
||||
|
||||
```bash
|
||||
ls -lh *.tar
|
||||
```
|
||||
|
||||
You should see files like:
|
||||
- `backend-api.tar` (~200-500MB)
|
||||
- `frontend.tar` (~100-300MB)
|
||||
- `mineru-api.tar` (~1-3GB)
|
||||
- `redis.tar` (~30-50MB)
|
||||
|
||||
## Step 2: Transfer Images
|
||||
|
||||
### 2.1 Transfer via Network (SCP/RSYNC)
|
||||
|
||||
```bash
|
||||
# Transfer to remote server
|
||||
scp *.tar user@remote-server:/path/to/destination/
|
||||
|
||||
# Or using rsync (more efficient for large files)
|
||||
rsync -avz --progress *.tar user@remote-server:/path/to/destination/
|
||||
```
|
||||
|
||||
### 2.2 Transfer via USB Drive
|
||||
|
||||
```bash
|
||||
# Copy to USB drive
|
||||
cp *.tar /Volumes/USB_DRIVE/docker-images/
|
||||
|
||||
# Or create a compressed archive
|
||||
tar -czf legal-doc-masker-images.tar.gz *.tar
|
||||
cp legal-doc-masker-images.tar.gz /Volumes/USB_DRIVE/
|
||||
```
|
||||
|
||||
### 2.3 Transfer via Cloud Storage
|
||||
|
||||
```bash
|
||||
# Upload to cloud storage (example with AWS S3)
|
||||
aws s3 cp *.tar s3://your-bucket/docker-images/
|
||||
|
||||
# Or using Google Cloud Storage
|
||||
gsutil cp *.tar gs://your-bucket/docker-images/
|
||||
```
|
||||
|
||||
## Step 3: Import Images on Target Environment
|
||||
|
||||
### 3.1 Prepare Target Environment
|
||||
|
||||
```bash
|
||||
# Create directory for images
|
||||
mkdir -p docker-images-import
|
||||
cd docker-images-import
|
||||
|
||||
# Copy images from transfer method
|
||||
# (SCP, USB, or download from cloud storage)
|
||||
```
|
||||
|
||||
### 3.2 Import Individual Images
|
||||
|
||||
```bash
|
||||
# Import backend image
|
||||
docker load -i backend-api.tar
|
||||
|
||||
# Import frontend image
|
||||
docker load -i frontend.tar
|
||||
|
||||
# Import mineru image
|
||||
docker load -i mineru-api.tar
|
||||
|
||||
# Import redis image
|
||||
docker load -i redis.tar
|
||||
```
|
||||
|
||||
### 3.3 Import All Images at Once (if exported together)
|
||||
|
||||
```bash
|
||||
docker load -i legal-doc-masker-all.tar
|
||||
```
|
||||
|
||||
### 3.4 Verify Imported Images
|
||||
|
||||
```bash
|
||||
docker images --format "table {{.Repository}}\t{{.Tag}}\t{{.ID}}\t{{.Size}}"
|
||||
```
|
||||
|
||||
## Step 4: Prepare Target Environment
|
||||
|
||||
### 4.1 Copy Project Files
|
||||
|
||||
Transfer the following files to target environment:
|
||||
|
||||
```bash
|
||||
# Essential files to copy
|
||||
docker-compose.yml
|
||||
DOCKER_COMPOSE_README.md
|
||||
setup-unified-docker.sh
|
||||
|
||||
# Environment files (if they exist)
|
||||
backend/.env
|
||||
frontend/.env
|
||||
mineru/.env
|
||||
|
||||
# Storage directories (if you want to preserve data)
|
||||
backend/storage/
|
||||
mineru/storage/
|
||||
backend/legal_doc_masker.db
|
||||
```
|
||||
|
||||
### 4.2 Create Directory Structure
|
||||
|
||||
```bash
|
||||
# Create necessary directories
|
||||
mkdir -p backend/storage
|
||||
mkdir -p mineru/storage/uploads
|
||||
mkdir -p mineru/storage/processed
|
||||
```
|
||||
|
||||
## Step 5: Run Services
|
||||
|
||||
### 5.1 Start All Services
|
||||
|
||||
```bash
|
||||
# Start all services using imported images
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
### 5.2 Verify Services
|
||||
|
||||
```bash
|
||||
# Check service status
|
||||
docker-compose ps
|
||||
|
||||
# Check service logs
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
### 5.3 Test Endpoints
|
||||
|
||||
```bash
|
||||
# Test frontend
|
||||
curl -I http://localhost:3000
|
||||
|
||||
# Test backend API
|
||||
curl -I http://localhost:8000/api/v1
|
||||
|
||||
# Test mineru API
|
||||
curl -I http://localhost:8001/health
|
||||
```
|
||||
|
||||
## Automation Scripts
|
||||
|
||||
### Export Script
|
||||
|
||||
Create `export-images.sh`:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
echo "🚀 Exporting Docker Images"
|
||||
|
||||
# Create export directory
|
||||
mkdir -p docker-images-export
|
||||
cd docker-images-export
|
||||
|
||||
# Export images
|
||||
echo "📦 Exporting backend-api image..."
|
||||
docker save legal-doc-masker-backend-api:latest -o backend-api.tar
|
||||
|
||||
echo "📦 Exporting frontend image..."
|
||||
docker save legal-doc-masker-frontend:latest -o frontend.tar
|
||||
|
||||
echo "📦 Exporting mineru-api image..."
|
||||
docker save legal-doc-masker-mineru-api:latest -o mineru-api.tar
|
||||
|
||||
echo "📦 Exporting redis image..."
|
||||
docker save redis:alpine -o redis.tar
|
||||
|
||||
# Show file sizes
|
||||
echo "📊 Export complete. File sizes:"
|
||||
ls -lh *.tar
|
||||
|
||||
echo "✅ Images exported successfully!"
|
||||
```
|
||||
|
||||
### Import Script
|
||||
|
||||
Create `import-images.sh`:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
echo "🚀 Importing Docker Images"
|
||||
|
||||
# Check if tar files exist
|
||||
if [ ! -f "backend-api.tar" ]; then
|
||||
echo "❌ backend-api.tar not found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Import images
|
||||
echo "📦 Importing backend-api image..."
|
||||
docker load -i backend-api.tar
|
||||
|
||||
echo "📦 Importing frontend image..."
|
||||
docker load -i frontend.tar
|
||||
|
||||
echo "📦 Importing mineru-api image..."
|
||||
docker load -i mineru-api.tar
|
||||
|
||||
echo "📦 Importing redis image..."
|
||||
docker load -i redis.tar
|
||||
|
||||
# Verify imports
|
||||
echo "📊 Imported images:"
|
||||
docker images --format "table {{.Repository}}\t{{.Tag}}\t{{.Size}}" | grep legal-doc-masker
|
||||
|
||||
echo "✅ Images imported successfully!"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Image not found during import**
|
||||
```bash
|
||||
# Check if image exists
|
||||
docker images | grep image-name
|
||||
|
||||
# Re-export if needed
|
||||
docker save image-name:tag -o image-name.tar
|
||||
```
|
||||
|
||||
2. **Port conflicts on target environment**
|
||||
```bash
|
||||
# Check what's using the ports
|
||||
lsof -i :8000
|
||||
lsof -i :8001
|
||||
lsof -i :3000
|
||||
|
||||
# Modify docker-compose.yml if needed
|
||||
ports:
|
||||
- "8002:8000" # Change external port
|
||||
```
|
||||
|
||||
3. **Permission issues**
|
||||
```bash
|
||||
# Fix file permissions
|
||||
chmod +x setup-unified-docker.sh
|
||||
chmod +x export-images.sh
|
||||
chmod +x import-images.sh
|
||||
```
|
||||
|
||||
4. **Storage directory issues**
|
||||
```bash
|
||||
# Create directories with proper permissions
|
||||
sudo mkdir -p backend/storage
|
||||
sudo mkdir -p mineru/storage/uploads
|
||||
sudo mkdir -p mineru/storage/processed
|
||||
sudo chown -R $USER:$USER backend/storage mineru/storage
|
||||
```
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
1. **Compress images for transfer**
|
||||
```bash
|
||||
# Compress before transfer
|
||||
gzip *.tar
|
||||
|
||||
# Decompress on target
|
||||
gunzip *.tar.gz
|
||||
```
|
||||
|
||||
2. **Use parallel transfer**
|
||||
```bash
|
||||
# Transfer multiple files in parallel
|
||||
parallel scp {} user@server:/path/ ::: *.tar
|
||||
```
|
||||
|
||||
3. **Use Docker registry (alternative)**
|
||||
```bash
|
||||
# Push to registry
|
||||
docker tag legal-doc-masker-backend-api:latest your-registry/backend-api:latest
|
||||
docker push your-registry/backend-api:latest
|
||||
|
||||
# Pull on target
|
||||
docker pull your-registry/backend-api:latest
|
||||
```
|
||||
|
||||
## Complete Migration Checklist
|
||||
|
||||
- [ ] Export all Docker images
|
||||
- [ ] Transfer image files to target environment
|
||||
- [ ] Transfer project configuration files
|
||||
- [ ] Import images on target environment
|
||||
- [ ] Create necessary directories
|
||||
- [ ] Start services
|
||||
- [ ] Verify all services are running
|
||||
- [ ] Test all endpoints
|
||||
- [ ] Update any environment-specific configurations
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Secure transfer**: Use encrypted transfer methods (SCP, SFTP)
|
||||
2. **Image verification**: Verify image integrity after transfer
|
||||
3. **Environment isolation**: Ensure target environment is properly secured
|
||||
4. **Access control**: Limit access to Docker daemon on target environment
|
||||
|
||||
## Cost Optimization
|
||||
|
||||
1. **Image size**: Remove unnecessary layers before export
|
||||
2. **Compression**: Use compression for large images
|
||||
3. **Selective transfer**: Only transfer images you need
|
||||
4. **Cleanup**: Remove old images after successful migration
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# Build stage
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first to leverage Docker cache
|
||||
COPY requirements.txt .
|
||||
RUN pip wheel --no-cache-dir --no-deps --wheel-dir /app/wheels -r requirements.txt
|
||||
|
||||
# Final stage
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -m -r appuser && \
|
||||
chown appuser:appuser /app
|
||||
|
||||
# Copy wheels from builder
|
||||
COPY --from=builder /app/wheels /wheels
|
||||
COPY --from=builder /app/requirements.txt .
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache /wheels/*
|
||||
|
||||
# Copy application code
|
||||
COPY src/ ./src/
|
||||
|
||||
# Create directories for mounted volumes
|
||||
RUN mkdir -p /data/input /data/output && \
|
||||
chown -R appuser:appuser /data
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Environment variables
|
||||
ENV PYTHONPATH=/app \
|
||||
OBJECT_STORAGE_PATH=/data/input \
|
||||
TARGET_DIRECTORY_PATH=/data/output
|
||||
|
||||
# Run the application
|
||||
CMD ["python", "src/main.py"]
|
||||
|
|
@ -1,178 +0,0 @@
|
|||
# Docker Migration Quick Reference
|
||||
|
||||
## 🚀 Quick Migration Process
|
||||
|
||||
### Source Environment (Export)
|
||||
|
||||
```bash
|
||||
# 1. Build images first (if not already built)
|
||||
docker-compose build
|
||||
|
||||
# 2. Export all images
|
||||
./export-images.sh
|
||||
|
||||
# 3. Transfer files to target environment
|
||||
# Option A: SCP
|
||||
scp -r docker-images-export-*/ user@target-server:/path/to/destination/
|
||||
|
||||
# Option B: USB Drive
|
||||
cp -r docker-images-export-*/ /Volumes/USB_DRIVE/
|
||||
|
||||
# Option C: Compressed archive
|
||||
scp legal-doc-masker-images-*.tar.gz user@target-server:/path/to/destination/
|
||||
```
|
||||
|
||||
### Target Environment (Import)
|
||||
|
||||
```bash
|
||||
# 1. Copy project files
|
||||
scp docker-compose.yml user@target-server:/path/to/destination/
|
||||
scp DOCKER_COMPOSE_README.md user@target-server:/path/to/destination/
|
||||
|
||||
# 2. Import images
|
||||
./import-images.sh
|
||||
|
||||
# 3. Start services
|
||||
docker-compose up -d
|
||||
|
||||
# 4. Verify
|
||||
docker-compose ps
|
||||
```
|
||||
|
||||
## 📋 Essential Files to Transfer
|
||||
|
||||
### Required Files
|
||||
- `docker-compose.yml` - Unified compose configuration
|
||||
- `DOCKER_COMPOSE_README.md` - Documentation
|
||||
- `backend/.env` - Backend environment variables
|
||||
- `frontend/.env` - Frontend environment variables
|
||||
- `mineru/.env` - Mineru environment variables (if exists)
|
||||
|
||||
### Optional Files (for data preservation)
|
||||
- `backend/storage/` - Backend storage directory
|
||||
- `mineru/storage/` - Mineru storage directory
|
||||
- `backend/legal_doc_masker.db` - Database file
|
||||
|
||||
## 🔧 Common Commands
|
||||
|
||||
### Export Commands
|
||||
```bash
|
||||
# Manual export
|
||||
docker save legal-doc-masker-backend-api:latest -o backend-api.tar
|
||||
docker save legal-doc-masker-frontend:latest -o frontend.tar
|
||||
docker save legal-doc-masker-mineru-api:latest -o mineru-api.tar
|
||||
docker save redis:alpine -o redis.tar
|
||||
|
||||
# Compress for transfer
|
||||
tar -czf legal-doc-masker-images.tar.gz *.tar
|
||||
```
|
||||
|
||||
### Import Commands
|
||||
```bash
|
||||
# Manual import
|
||||
docker load -i backend-api.tar
|
||||
docker load -i frontend.tar
|
||||
docker load -i mineru-api.tar
|
||||
docker load -i redis.tar
|
||||
|
||||
# Extract compressed archive
|
||||
tar -xzf legal-doc-masker-images.tar.gz
|
||||
```
|
||||
|
||||
### Service Management
|
||||
```bash
|
||||
# Start all services
|
||||
docker-compose up -d
|
||||
|
||||
# Stop all services
|
||||
docker-compose down
|
||||
|
||||
# View logs
|
||||
docker-compose logs -f [service-name]
|
||||
|
||||
# Check status
|
||||
docker-compose ps
|
||||
```
|
||||
|
||||
### Building Individual Services
|
||||
```bash
|
||||
# Build specific service only
|
||||
docker-compose build backend-api
|
||||
docker-compose build frontend
|
||||
docker-compose build mineru-api
|
||||
|
||||
# Build and restart specific service
|
||||
docker-compose up -d --build backend-api
|
||||
|
||||
# Force rebuild (no cache)
|
||||
docker-compose build --no-cache backend-api
|
||||
|
||||
# Using the build script
|
||||
./build-service.sh backend-api --restart
|
||||
./build-service.sh frontend --no-cache
|
||||
./build-service.sh backend-api celery-worker
|
||||
```
|
||||
|
||||
## 🌐 Service URLs
|
||||
|
||||
After successful migration:
|
||||
- **Frontend**: http://localhost:3000
|
||||
- **Backend API**: http://localhost:8000
|
||||
- **Mineru API**: http://localhost:8001
|
||||
|
||||
## ⚠️ Troubleshooting
|
||||
|
||||
### Port Conflicts
|
||||
```bash
|
||||
# Check what's using ports
|
||||
lsof -i :8000
|
||||
lsof -i :8001
|
||||
lsof -i :3000
|
||||
|
||||
# Modify docker-compose.yml if needed
|
||||
ports:
|
||||
- "8002:8000" # Change external port
|
||||
```
|
||||
|
||||
### Permission Issues
|
||||
```bash
|
||||
# Fix script permissions
|
||||
chmod +x export-images.sh
|
||||
chmod +x import-images.sh
|
||||
chmod +x setup-unified-docker.sh
|
||||
|
||||
# Fix directory permissions
|
||||
sudo chown -R $USER:$USER backend/storage mineru/storage
|
||||
```
|
||||
|
||||
### Disk Space Issues
|
||||
```bash
|
||||
# Check available space
|
||||
df -h
|
||||
|
||||
# Clean up Docker
|
||||
docker system prune -a
|
||||
```
|
||||
|
||||
## 📊 Expected File Sizes
|
||||
|
||||
- `backend-api.tar`: ~200-500MB
|
||||
- `frontend.tar`: ~100-300MB
|
||||
- `mineru-api.tar`: ~1-3GB
|
||||
- `redis.tar`: ~30-50MB
|
||||
- `legal-doc-masker-images.tar.gz`: ~1-2GB (compressed)
|
||||
|
||||
## 🔒 Security Notes
|
||||
|
||||
1. Use encrypted transfer (SCP, SFTP) for sensitive environments
|
||||
2. Verify image integrity after transfer
|
||||
3. Update environment variables for target environment
|
||||
4. Ensure proper network security on target environment
|
||||
|
||||
## 📞 Support
|
||||
|
||||
If you encounter issues:
|
||||
1. Check the full `DOCKER_MIGRATION_GUIDE.md`
|
||||
2. Verify all required files are present
|
||||
3. Check Docker logs: `docker-compose logs -f`
|
||||
4. Ensure sufficient disk space and permissions
|
||||
25
backend/.env
25
backend/.env
|
|
@ -1,25 +0,0 @@
|
|||
# Storage paths
|
||||
OBJECT_STORAGE_PATH=/Users/tigeren/Dev/digisky/legal-doc-masker/data/doc_src
|
||||
TARGET_DIRECTORY_PATH=/Users/tigeren/Dev/digisky/legal-doc-masker/data/doc_dest
|
||||
INTERMEDIATE_DIR_PATH=/Users/tigeren/Dev/digisky/legal-doc-masker/data/doc_intermediate
|
||||
|
||||
# Ollama API Configuration
|
||||
# 3060 GPU
|
||||
# OLLAMA_API_URL=http://192.168.2.245:11434
|
||||
# Mac Mini M4
|
||||
OLLAMA_API_URL=http://192.168.2.224:11434
|
||||
|
||||
# OLLAMA_API_KEY=your_api_key_here
|
||||
# OLLAMA_MODEL=qwen3:8b
|
||||
OLLAMA_MODEL=phi4:14b
|
||||
|
||||
# Application Settings
|
||||
MONITOR_INTERVAL=5
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
LOG_FILE=app.log
|
||||
|
||||
# Optional: Additional security settings
|
||||
# MAX_FILE_SIZE=10485760 # 10MB in bytes
|
||||
# ALLOWED_FILE_TYPES=.txt,.doc,.docx,.pdf
|
||||
|
|
@ -7,31 +7,18 @@ RUN apt-get update && apt-get install -y \
|
|||
build-essential \
|
||||
libreoffice \
|
||||
wget \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
# Copy requirements first to leverage Docker cache
|
||||
COPY requirements.txt .
|
||||
RUN pip install huggingface_hub
|
||||
RUN wget https://github.com/opendatalab/MinerU/raw/master/scripts/download_models_hf.py -O download_models_hf.py
|
||||
RUN python download_models_hf.py
|
||||
|
||||
# Upgrade pip and install core dependencies
|
||||
RUN pip install --upgrade pip setuptools wheel
|
||||
|
||||
# Install PyTorch CPU version first (for better caching and smaller size)
|
||||
RUN pip install --no-cache-dir torch==2.7.0 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
# Install the rest of the requirements
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Pre-download NER model during build (larger image but faster startup)
|
||||
# RUN python -c "
|
||||
# from transformers import AutoTokenizer, AutoModelForTokenClassification
|
||||
# model_name = 'uer/roberta-base-finetuned-cluener2020-chinese'
|
||||
# print('Downloading NER model...')
|
||||
# AutoTokenizer.from_pretrained(model_name)
|
||||
# AutoModelForTokenClassification.from_pretrained(model_name)
|
||||
# print('NER model downloaded successfully')
|
||||
# "
|
||||
RUN pip install -U magic-pdf[full]
|
||||
|
||||
|
||||
# Copy the rest of the application
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
# App package
|
||||
|
|
@ -79,49 +79,21 @@ async def download_file(
|
|||
file_id: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
print(f"=== DOWNLOAD REQUEST ===")
|
||||
print(f"File ID: {file_id}")
|
||||
|
||||
file = db.query(FileModel).filter(FileModel.id == file_id).first()
|
||||
if not file:
|
||||
print(f"❌ File not found for ID: {file_id}")
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
print(f"✅ File found: {file.filename}")
|
||||
print(f"File status: {file.status}")
|
||||
print(f"Original path: {file.original_path}")
|
||||
print(f"Processed path: {file.processed_path}")
|
||||
|
||||
if file.status != FileStatus.SUCCESS:
|
||||
print(f"❌ File not ready for download. Status: {file.status}")
|
||||
raise HTTPException(status_code=400, detail="File is not ready for download")
|
||||
|
||||
if not os.path.exists(file.processed_path):
|
||||
print(f"❌ Processed file not found at: {file.processed_path}")
|
||||
raise HTTPException(status_code=404, detail="Processed file not found")
|
||||
|
||||
print(f"✅ Processed file exists at: {file.processed_path}")
|
||||
|
||||
# Get the original filename without extension and add .md extension
|
||||
original_filename = file.filename
|
||||
filename_without_ext = os.path.splitext(original_filename)[0]
|
||||
download_filename = f"{filename_without_ext}.md"
|
||||
|
||||
print(f"Original filename: {original_filename}")
|
||||
print(f"Filename without extension: {filename_without_ext}")
|
||||
print(f"Download filename: {download_filename}")
|
||||
|
||||
|
||||
response = FileResponse(
|
||||
return FileResponse(
|
||||
path=file.processed_path,
|
||||
filename=download_filename,
|
||||
media_type="text/markdown"
|
||||
filename=file.filename,
|
||||
media_type="application/octet-stream"
|
||||
)
|
||||
|
||||
print(f"Response headers: {dict(response.headers)}")
|
||||
print(f"=== END DOWNLOAD REQUEST ===")
|
||||
|
||||
return response
|
||||
|
||||
@router.websocket("/ws/status/{file_id}")
|
||||
async def websocket_endpoint(websocket: WebSocket, file_id: str, db: Session = Depends(get_db)):
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
# Core package
|
||||
|
|
@ -31,21 +31,6 @@ class Settings(BaseSettings):
|
|||
OLLAMA_API_KEY: str = ""
|
||||
OLLAMA_MODEL: str = "llama2"
|
||||
|
||||
# Mineru API settings
|
||||
MINERU_API_URL: str = "http://mineru-api:8000"
|
||||
# MINERU_API_URL: str = "http://host.docker.internal:8001"
|
||||
|
||||
MINERU_TIMEOUT: int = 300 # 5 minutes timeout
|
||||
MINERU_LANG_LIST: list = ["ch"] # Language list for parsing
|
||||
MINERU_BACKEND: str = "pipeline" # Backend to use
|
||||
MINERU_PARSE_METHOD: str = "auto" # Parse method
|
||||
MINERU_FORMULA_ENABLE: bool = True # Enable formula parsing
|
||||
MINERU_TABLE_ENABLE: bool = True # Enable table parsing
|
||||
|
||||
# MagicDoc API settings
|
||||
# MAGICDOC_API_URL: str = "http://magicdoc-api:8000"
|
||||
# MAGICDOC_TIMEOUT: int = 300 # 5 minutes timeout
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = "INFO"
|
||||
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
# Document handlers package
|
||||
|
|
@ -1,15 +1,21 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from ..prompts.masking_prompts import get_masking_mapping_prompt
|
||||
import logging
|
||||
from .ner_processor import NerProcessor
|
||||
import json
|
||||
from ..services.ollama_client import OllamaClient
|
||||
from ...core.config import settings
|
||||
from ..utils.json_extractor import LLMJsonExtractor
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentProcessor(ABC):
|
||||
|
||||
def __init__(self):
|
||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||
self.max_chunk_size = 1000 # Maximum number of characters per chunk
|
||||
self.ner_processor = NerProcessor()
|
||||
self.max_retries = 3 # Maximum number of retries for mapping generation
|
||||
|
||||
@abstractmethod
|
||||
def read_content(self) -> str:
|
||||
|
|
@ -25,6 +31,7 @@ class DocumentProcessor(ABC):
|
|||
if not sentence.strip():
|
||||
continue
|
||||
|
||||
# If adding this sentence would exceed the limit, save current chunk and start new one
|
||||
if len(current_chunk) + len(sentence) > self.max_chunk_size and current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = sentence
|
||||
|
|
@ -34,59 +41,152 @@ class DocumentProcessor(ABC):
|
|||
else:
|
||||
current_chunk = sentence
|
||||
|
||||
# Add the last chunk if it's not empty
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
logger.info(f"Split content into {len(chunks)} chunks")
|
||||
|
||||
return chunks
|
||||
|
||||
def _apply_mapping_with_alignment(self, text: str, mapping: Dict[str, str]) -> str:
|
||||
def _validate_mapping_format(self, mapping: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Apply the mapping to replace sensitive information using character-by-character alignment.
|
||||
|
||||
This method uses the new alignment-based masking to handle spacing issues
|
||||
between NER results and original document text.
|
||||
|
||||
Args:
|
||||
text: Original document text
|
||||
mapping: Dictionary mapping original entity text to masked text
|
||||
Validate that the mapping follows the required format:
|
||||
{
|
||||
"原文1": "脱敏后1",
|
||||
"原文2": "脱敏后2",
|
||||
...
|
||||
}
|
||||
"""
|
||||
if not isinstance(mapping, dict):
|
||||
logger.warning("Mapping is not a dictionary")
|
||||
return False
|
||||
|
||||
Returns:
|
||||
Masked document text
|
||||
"""
|
||||
logger.info(f"Applying entity mapping with alignment to text of length {len(text)}")
|
||||
logger.debug(f"Entity mapping: {mapping}")
|
||||
|
||||
# Use the new alignment-based masking method
|
||||
masked_text = self.ner_processor.apply_entity_masking_with_alignment(text, mapping)
|
||||
|
||||
logger.info("Successfully applied entity masking with alignment")
|
||||
return masked_text
|
||||
# Check if any key or value is not a string
|
||||
for key, value in mapping.items():
|
||||
if not isinstance(key, str) or not isinstance(value, str):
|
||||
logger.warning(f"Invalid mapping format - key or value is not a string: {key}: {value}")
|
||||
return False
|
||||
|
||||
# Check if the mapping has any nested structures
|
||||
if any(isinstance(v, (dict, list)) for v in mapping.values()):
|
||||
logger.warning("Invalid mapping format - contains nested structures")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _build_mapping(self, chunk: str) -> Dict[str, str]:
|
||||
"""Build mapping for a single chunk of text with retry logic"""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
formatted_prompt = get_masking_mapping_prompt(chunk)
|
||||
logger.info(f"Calling ollama to generate mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
|
||||
response = self.ollama_client.generate(formatted_prompt)
|
||||
logger.info(f"Raw response from LLM: {response}")
|
||||
|
||||
# Parse the JSON response into a dictionary
|
||||
mapping = LLMJsonExtractor.parse_raw_json_str(response)
|
||||
logger.info(f"Parsed mapping: {mapping}")
|
||||
|
||||
if mapping and self._validate_mapping_format(mapping):
|
||||
return mapping
|
||||
else:
|
||||
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating mapping on attempt {attempt + 1}: {e}")
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.info("Retrying...")
|
||||
else:
|
||||
logger.error("Max retries reached, returning empty mapping")
|
||||
return {}
|
||||
|
||||
def _apply_mapping(self, text: str, mapping: Dict[str, str]) -> str:
|
||||
"""Apply the mapping to replace sensitive information"""
|
||||
masked_text = text
|
||||
for original, masked in mapping.items():
|
||||
# Ensure masked value is a string
|
||||
if isinstance(masked, dict):
|
||||
# If it's a dict, use the first value or a default
|
||||
masked = next(iter(masked.values()), "某")
|
||||
elif not isinstance(masked, str):
|
||||
# If it's not a string, convert to string or use default
|
||||
masked = str(masked) if masked is not None else "某"
|
||||
masked_text = masked_text.replace(original, masked)
|
||||
return masked_text
|
||||
|
||||
def _get_next_suffix(self, value: str) -> str:
|
||||
"""Get the next available suffix for a value that already has a suffix"""
|
||||
# Define the sequence of suffixes
|
||||
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||
|
||||
# Check if the value already has a suffix
|
||||
for suffix in suffixes:
|
||||
if value.endswith(suffix):
|
||||
# Find the next suffix in the sequence
|
||||
current_index = suffixes.index(suffix)
|
||||
if current_index + 1 < len(suffixes):
|
||||
return value[:-1] + suffixes[current_index + 1]
|
||||
else:
|
||||
# If we've used all suffixes, start over with the first one
|
||||
return value[:-1] + suffixes[0]
|
||||
|
||||
# If no suffix found, return the value with the first suffix
|
||||
return value + '甲'
|
||||
|
||||
def _merge_mappings(self, existing: Dict[str, str], new: Dict[str, str]) -> Dict[str, str]:
|
||||
"""
|
||||
Legacy method for simple string replacement.
|
||||
Now delegates to the new alignment-based method.
|
||||
Merge two mappings following the rules:
|
||||
1. If key exists in existing, keep existing value
|
||||
2. If value exists in existing:
|
||||
- If value ends with a suffix (甲乙丙丁...), add next suffix
|
||||
- If no suffix, add '甲'
|
||||
"""
|
||||
return self._apply_mapping_with_alignment(text, mapping)
|
||||
result = existing.copy()
|
||||
|
||||
# Get all existing values
|
||||
existing_values = set(result.values())
|
||||
|
||||
for key, value in new.items():
|
||||
if key in result:
|
||||
# Rule 1: Keep existing value if key exists
|
||||
continue
|
||||
|
||||
if value in existing_values:
|
||||
# Rule 2: Handle duplicate values
|
||||
new_value = self._get_next_suffix(value)
|
||||
result[key] = new_value
|
||||
existing_values.add(new_value)
|
||||
else:
|
||||
# No conflict, add as is
|
||||
result[key] = value
|
||||
existing_values.add(value)
|
||||
|
||||
return result
|
||||
|
||||
def process_content(self, content: str) -> str:
|
||||
"""Process document content by masking sensitive information"""
|
||||
# Split content into sentences
|
||||
sentences = content.split("。")
|
||||
|
||||
# Split sentences into manageable chunks
|
||||
chunks = self._split_into_chunks(sentences)
|
||||
logger.info(f"Split content into {len(chunks)} chunks")
|
||||
|
||||
final_mapping = self.ner_processor.process(chunks)
|
||||
logger.info(f"Generated entity mapping with {len(final_mapping)} entities")
|
||||
# Build mapping for each chunk
|
||||
combined_mapping = {}
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
|
||||
chunk_mapping = self._build_mapping(chunk)
|
||||
if chunk_mapping: # Only update if we got a valid mapping
|
||||
combined_mapping = self._merge_mappings(combined_mapping, chunk_mapping)
|
||||
else:
|
||||
logger.warning(f"Failed to generate mapping for chunk {i+1}")
|
||||
|
||||
# Use the new alignment-based masking
|
||||
masked_content = self._apply_mapping_with_alignment(content, final_mapping)
|
||||
logger.info("Successfully masked content using character alignment")
|
||||
# Apply the combined mapping to the entire content
|
||||
masked_content = self._apply_mapping(content, combined_mapping)
|
||||
logger.info("Successfully masked content")
|
||||
|
||||
return masked_content
|
||||
|
||||
@abstractmethod
|
||||
def save_content(self, content: str) -> None:
|
||||
"""Save processed content"""
|
||||
pass
|
||||
pass
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
"""
|
||||
Extractors package for entity component extraction.
|
||||
"""
|
||||
|
||||
from .base_extractor import BaseExtractor
|
||||
from .business_name_extractor import BusinessNameExtractor
|
||||
from .address_extractor import AddressExtractor
|
||||
from .ner_extractor import NERExtractor
|
||||
|
||||
__all__ = [
|
||||
'BaseExtractor',
|
||||
'BusinessNameExtractor',
|
||||
'AddressExtractor',
|
||||
'NERExtractor'
|
||||
]
|
||||
|
|
@ -1,168 +0,0 @@
|
|||
"""
|
||||
Address extractor for address components.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from ...services.ollama_client import OllamaClient
|
||||
from ...utils.json_extractor import LLMJsonExtractor
|
||||
from ...utils.llm_validator import LLMResponseValidator
|
||||
from .base_extractor import BaseExtractor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AddressExtractor(BaseExtractor):
|
||||
"""Extractor for address components"""
|
||||
|
||||
def __init__(self, ollama_client: OllamaClient):
|
||||
self.ollama_client = ollama_client
|
||||
self._confidence = 0.5 # Default confidence for regex fallback
|
||||
|
||||
def extract(self, address: str) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Extract address components from address.
|
||||
|
||||
Args:
|
||||
address: The address to extract from
|
||||
|
||||
Returns:
|
||||
Dictionary with address components and confidence, or None if extraction fails
|
||||
"""
|
||||
if not address:
|
||||
return None
|
||||
|
||||
# Try LLM extraction first
|
||||
try:
|
||||
result = self._extract_with_llm(address)
|
||||
if result:
|
||||
self._confidence = result.get('confidence', 0.9)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM extraction failed for {address}: {e}")
|
||||
|
||||
# Fallback to regex extraction
|
||||
result = self._extract_with_regex(address)
|
||||
self._confidence = 0.5 # Lower confidence for regex
|
||||
return result
|
||||
|
||||
def _extract_with_llm(self, address: str) -> Optional[Dict[str, str]]:
|
||||
"""Extract address components using LLM"""
|
||||
prompt = f"""
|
||||
你是一个专业的地址分析助手。请从以下地址中提取需要脱敏的组件,并严格按照JSON格式返回结果。
|
||||
|
||||
地址:{address}
|
||||
|
||||
脱敏规则:
|
||||
1. 保留区级以上地址(省、市、区、县等)
|
||||
2. 路名(路名)需要脱敏:以大写首字母替代
|
||||
3. 门牌号(门牌数字)需要脱敏:以****代替
|
||||
4. 大厦名、小区名需要脱敏:以大写首字母替代
|
||||
|
||||
示例:
|
||||
- 上海市静安区恒丰路66号白云大厦1607室
|
||||
- 路名:恒丰路
|
||||
- 门牌号:66
|
||||
- 大厦名:白云大厦
|
||||
- 小区名:(空)
|
||||
|
||||
- 北京市朝阳区建国路88号SOHO现代城A座1001室
|
||||
- 路名:建国路
|
||||
- 门牌号:88
|
||||
- 大厦名:SOHO现代城
|
||||
- 小区名:(空)
|
||||
|
||||
- 广州市天河区珠江新城花城大道123号富力中心B座2001室
|
||||
- 路名:花城大道
|
||||
- 门牌号:123
|
||||
- 大厦名:富力中心
|
||||
- 小区名:(空)
|
||||
|
||||
请严格按照以下JSON格式输出,不要包含任何其他文字:
|
||||
|
||||
{{
|
||||
"road_name": "提取的路名",
|
||||
"house_number": "提取的门牌号",
|
||||
"building_name": "提取的大厦名",
|
||||
"community_name": "提取的小区名(如果没有则为空字符串)",
|
||||
"confidence": 0.9
|
||||
}}
|
||||
|
||||
注意:
|
||||
- road_name字段必须包含路名(如:恒丰路、建国路等)
|
||||
- house_number字段必须包含门牌号(如:66、88等)
|
||||
- building_name字段必须包含大厦名(如:白云大厦、SOHO现代城等)
|
||||
- community_name字段包含小区名,如果没有则为空字符串
|
||||
- confidence字段是0-1之间的数字,表示提取的置信度
|
||||
- 必须严格按照JSON格式,不要添加任何解释或额外文字
|
||||
"""
|
||||
|
||||
try:
|
||||
# Use the new enhanced generate method with validation
|
||||
parsed_response = self.ollama_client.generate_with_validation(
|
||||
prompt=prompt,
|
||||
response_type='address_extraction',
|
||||
return_parsed=True
|
||||
)
|
||||
|
||||
if parsed_response:
|
||||
logger.info(f"Successfully extracted address components: {parsed_response}")
|
||||
return parsed_response
|
||||
else:
|
||||
logger.warning(f"Failed to extract address components for: {address}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"LLM extraction failed: {e}")
|
||||
return None
|
||||
|
||||
def _extract_with_regex(self, address: str) -> Optional[Dict[str, str]]:
|
||||
"""Extract address components using regex patterns"""
|
||||
# Road name pattern: usually ends with "路", "街", "大道", etc.
|
||||
road_pattern = r'([^省市区县]+[路街大道巷弄])'
|
||||
|
||||
# House number pattern: digits + 号
|
||||
house_number_pattern = r'(\d+)号'
|
||||
|
||||
# Building name pattern: usually contains "大厦", "中心", "广场", etc.
|
||||
building_pattern = r'([^号室]+(?:大厦|中心|广场|城|楼|座))'
|
||||
|
||||
# Community name pattern: usually contains "小区", "花园", "苑", etc.
|
||||
community_pattern = r'([^号室]+(?:小区|花园|苑|园|庭))'
|
||||
|
||||
road_name = ""
|
||||
house_number = ""
|
||||
building_name = ""
|
||||
community_name = ""
|
||||
|
||||
# Extract road name
|
||||
road_match = re.search(road_pattern, address)
|
||||
if road_match:
|
||||
road_name = road_match.group(1).strip()
|
||||
|
||||
# Extract house number
|
||||
house_match = re.search(house_number_pattern, address)
|
||||
if house_match:
|
||||
house_number = house_match.group(1)
|
||||
|
||||
# Extract building name
|
||||
building_match = re.search(building_pattern, address)
|
||||
if building_match:
|
||||
building_name = building_match.group(1).strip()
|
||||
|
||||
# Extract community name
|
||||
community_match = re.search(community_pattern, address)
|
||||
if community_match:
|
||||
community_name = community_match.group(1).strip()
|
||||
|
||||
return {
|
||||
"road_name": road_name,
|
||||
"house_number": house_number,
|
||||
"building_name": building_name,
|
||||
"community_name": community_name,
|
||||
"confidence": 0.5 # Lower confidence for regex fallback
|
||||
}
|
||||
|
||||
def get_confidence(self) -> float:
|
||||
"""Return confidence level of extraction"""
|
||||
return self._confidence
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
"""
|
||||
Abstract base class for all extractors.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class BaseExtractor(ABC):
|
||||
"""Abstract base class for all extractors"""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, text: str) -> Optional[Dict[str, Any]]:
|
||||
"""Extract components from text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_confidence(self) -> float:
|
||||
"""Return confidence level of extraction"""
|
||||
pass
|
||||
|
|
@ -1,192 +0,0 @@
|
|||
"""
|
||||
Business name extractor for company names.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from ...services.ollama_client import OllamaClient
|
||||
from ...utils.json_extractor import LLMJsonExtractor
|
||||
from ...utils.llm_validator import LLMResponseValidator
|
||||
from .base_extractor import BaseExtractor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BusinessNameExtractor(BaseExtractor):
|
||||
"""Extractor for business names from company names"""
|
||||
|
||||
def __init__(self, ollama_client: OllamaClient):
|
||||
self.ollama_client = ollama_client
|
||||
self._confidence = 0.5 # Default confidence for regex fallback
|
||||
|
||||
def extract(self, company_name: str) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Extract business name from company name.
|
||||
|
||||
Args:
|
||||
company_name: The company name to extract from
|
||||
|
||||
Returns:
|
||||
Dictionary with business name and confidence, or None if extraction fails
|
||||
"""
|
||||
if not company_name:
|
||||
return None
|
||||
|
||||
# Try LLM extraction first
|
||||
try:
|
||||
result = self._extract_with_llm(company_name)
|
||||
if result:
|
||||
self._confidence = result.get('confidence', 0.9)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM extraction failed for {company_name}: {e}")
|
||||
|
||||
# Fallback to regex extraction
|
||||
result = self._extract_with_regex(company_name)
|
||||
self._confidence = 0.5 # Lower confidence for regex
|
||||
return result
|
||||
|
||||
def _extract_with_llm(self, company_name: str) -> Optional[Dict[str, str]]:
|
||||
"""Extract business name using LLM"""
|
||||
prompt = f"""
|
||||
你是一个专业的公司名称分析助手。请从以下公司名称中提取商号(企业字号),并严格按照JSON格式返回结果。
|
||||
|
||||
公司名称:{company_name}
|
||||
|
||||
商号提取规则:
|
||||
1. 公司名通常为:地域+商号+业务/行业+组织类型
|
||||
2. 也有:商号+(地域)+业务/行业+组织类型
|
||||
3. 商号是企业名称中最具识别性的部分,通常是2-4个汉字
|
||||
4. 不要包含地域、行业、组织类型等信息
|
||||
5. 律师事务所的商号通常是地域后的部分
|
||||
|
||||
示例:
|
||||
- 上海盒马网络科技有限公司 -> 盒马
|
||||
- 丰田通商(上海)有限公司 -> 丰田通商
|
||||
- 雅诗兰黛(上海)商贸有限公司 -> 雅诗兰黛
|
||||
- 北京百度网讯科技有限公司 -> 百度
|
||||
- 腾讯科技(深圳)有限公司 -> 腾讯
|
||||
- 北京大成律师事务所 -> 大成
|
||||
|
||||
请严格按照以下JSON格式输出,不要包含任何其他文字:
|
||||
|
||||
{{
|
||||
"business_name": "提取的商号",
|
||||
"confidence": 0.9
|
||||
}}
|
||||
|
||||
注意:
|
||||
- business_name字段必须包含提取的商号
|
||||
- confidence字段是0-1之间的数字,表示提取的置信度
|
||||
- 必须严格按照JSON格式,不要添加任何解释或额外文字
|
||||
"""
|
||||
|
||||
try:
|
||||
# Use the new enhanced generate method with validation
|
||||
parsed_response = self.ollama_client.generate_with_validation(
|
||||
prompt=prompt,
|
||||
response_type='business_name_extraction',
|
||||
return_parsed=True
|
||||
)
|
||||
|
||||
if parsed_response:
|
||||
business_name = parsed_response.get('business_name', '')
|
||||
# Clean business name, keep only Chinese characters
|
||||
business_name = re.sub(r'[^\u4e00-\u9fff]', '', business_name)
|
||||
logger.info(f"Successfully extracted business name: {business_name}")
|
||||
return {
|
||||
'business_name': business_name,
|
||||
'confidence': parsed_response.get('confidence', 0.9)
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Failed to extract business name for: {company_name}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"LLM extraction failed: {e}")
|
||||
return None
|
||||
|
||||
def _extract_with_regex(self, company_name: str) -> Optional[Dict[str, str]]:
|
||||
"""Extract business name using regex patterns"""
|
||||
# Handle law firms specially
|
||||
if '律师事务所' in company_name:
|
||||
return self._extract_law_firm_business_name(company_name)
|
||||
|
||||
# Common region prefixes
|
||||
region_prefixes = [
|
||||
'北京', '上海', '广州', '深圳', '杭州', '南京', '苏州', '成都', '武汉', '西安',
|
||||
'天津', '重庆', '青岛', '大连', '宁波', '厦门', '无锡', '长沙', '郑州', '济南',
|
||||
'哈尔滨', '沈阳', '长春', '石家庄', '太原', '呼和浩特', '合肥', '福州', '南昌',
|
||||
'南宁', '海口', '贵阳', '昆明', '兰州', '西宁', '银川', '乌鲁木齐', '拉萨',
|
||||
'香港', '澳门', '台湾'
|
||||
]
|
||||
|
||||
# Common organization type suffixes
|
||||
org_suffixes = [
|
||||
'有限公司', '股份有限公司', '有限责任公司', '股份公司', '集团公司', '集团',
|
||||
'科技公司', '网络公司', '信息技术公司', '软件公司', '互联网公司',
|
||||
'贸易公司', '商贸公司', '进出口公司', '物流公司', '运输公司',
|
||||
'房地产公司', '置业公司', '投资公司', '金融公司', '银行',
|
||||
'保险公司', '证券公司', '基金公司', '信托公司', '租赁公司',
|
||||
'咨询公司', '服务公司', '管理公司', '广告公司', '传媒公司',
|
||||
'教育公司', '培训公司', '医疗公司', '医药公司', '生物公司',
|
||||
'制造公司', '工业公司', '化工公司', '能源公司', '电力公司',
|
||||
'建筑公司', '工程公司', '建设公司', '开发公司', '设计公司',
|
||||
'销售公司', '营销公司', '代理公司', '经销商', '零售商',
|
||||
'连锁公司', '超市', '商场', '百货', '专卖店', '便利店'
|
||||
]
|
||||
|
||||
name = company_name
|
||||
|
||||
# Remove region prefix
|
||||
for region in region_prefixes:
|
||||
if name.startswith(region):
|
||||
name = name[len(region):].strip()
|
||||
break
|
||||
|
||||
# Remove region information in parentheses
|
||||
name = re.sub(r'[((].*?[))]', '', name).strip()
|
||||
|
||||
# Remove organization type suffix
|
||||
for suffix in org_suffixes:
|
||||
if name.endswith(suffix):
|
||||
name = name[:-len(suffix)].strip()
|
||||
break
|
||||
|
||||
# If remaining part is too long, try to extract first 2-4 characters
|
||||
if len(name) > 4:
|
||||
# Try to find a good break point
|
||||
for i in range(2, min(5, len(name))):
|
||||
if name[i] in ['网', '科', '技', '信', '息', '软', '件', '互', '联', '网', '电', '子', '商', '务']:
|
||||
name = name[:i]
|
||||
break
|
||||
|
||||
return {
|
||||
'business_name': name if name else company_name[:2],
|
||||
'confidence': 0.5
|
||||
}
|
||||
|
||||
def _extract_law_firm_business_name(self, law_firm_name: str) -> Optional[Dict[str, str]]:
|
||||
"""Extract business name from law firm names"""
|
||||
# Remove "律师事务所" suffix
|
||||
name = law_firm_name.replace('律师事务所', '').replace('分所', '').strip()
|
||||
|
||||
# Handle region information in parentheses
|
||||
name = re.sub(r'[((].*?[))]', '', name).strip()
|
||||
|
||||
# Common region prefixes
|
||||
region_prefixes = ['北京', '上海', '广州', '深圳', '杭州', '南京', '苏州', '成都', '武汉', '西安']
|
||||
|
||||
for region in region_prefixes:
|
||||
if name.startswith(region):
|
||||
name = name[len(region):].strip()
|
||||
break
|
||||
|
||||
return {
|
||||
'business_name': name,
|
||||
'confidence': 0.5
|
||||
}
|
||||
|
||||
def get_confidence(self) -> float:
|
||||
"""Return confidence level of extraction"""
|
||||
return self._confidence
|
||||
|
|
@ -1,469 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Any, Optional
|
||||
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
|
||||
from .base_extractor import BaseExtractor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NERExtractor(BaseExtractor):
|
||||
"""
|
||||
Named Entity Recognition extractor using Chinese NER model.
|
||||
Uses the uer/roberta-base-finetuned-cluener2020-chinese model for Chinese NER.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model_checkpoint = "uer/roberta-base-finetuned-cluener2020-chinese"
|
||||
self.tokenizer = None
|
||||
self.model = None
|
||||
self.ner_pipeline = None
|
||||
self._model_initialized = False
|
||||
self.confidence_threshold = 0.95
|
||||
|
||||
# Map CLUENER model labels to our desired categories
|
||||
self.label_map = {
|
||||
'company': '公司名称',
|
||||
'organization': '组织机构名',
|
||||
'name': '人名',
|
||||
'address': '地址'
|
||||
}
|
||||
|
||||
# Don't initialize the model here - use lazy loading
|
||||
|
||||
def _initialize_model(self):
|
||||
"""Initialize the NER model and pipeline"""
|
||||
try:
|
||||
logger.info(f"Loading NER model: {self.model_checkpoint}")
|
||||
|
||||
# Load the tokenizer and model
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_checkpoint)
|
||||
self.model = AutoModelForTokenClassification.from_pretrained(self.model_checkpoint)
|
||||
|
||||
# Create the NER pipeline with proper configuration
|
||||
self.ner_pipeline = pipeline(
|
||||
"ner",
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
aggregation_strategy="simple"
|
||||
)
|
||||
|
||||
# Configure the tokenizer to handle max length
|
||||
if hasattr(self.tokenizer, 'model_max_length'):
|
||||
self.tokenizer.model_max_length = 512
|
||||
|
||||
self._model_initialized = True
|
||||
logger.info("NER model loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load NER model: {str(e)}")
|
||||
raise Exception(f"NER model initialization failed: {str(e)}")
|
||||
|
||||
def _split_text_by_sentences(self, text: str) -> List[str]:
|
||||
"""
|
||||
Split text into sentences using Chinese sentence boundaries
|
||||
|
||||
Args:
|
||||
text: The text to split
|
||||
|
||||
Returns:
|
||||
List of sentences
|
||||
"""
|
||||
# Chinese sentence endings: 。!?;\n
|
||||
# Also consider English sentence endings for mixed text
|
||||
sentence_pattern = r'[。!?;\n]+|[.!?;]+'
|
||||
sentences = re.split(sentence_pattern, text)
|
||||
|
||||
# Clean up sentences and filter out empty ones
|
||||
cleaned_sentences = []
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if sentence:
|
||||
cleaned_sentences.append(sentence)
|
||||
|
||||
return cleaned_sentences
|
||||
|
||||
def _is_entity_boundary_safe(self, text: str, position: int) -> bool:
|
||||
"""
|
||||
Check if a position is safe for splitting (won't break entities)
|
||||
|
||||
Args:
|
||||
text: The text to check
|
||||
position: Position to check for safety
|
||||
|
||||
Returns:
|
||||
True if safe to split at this position
|
||||
"""
|
||||
if position <= 0 or position >= len(text):
|
||||
return True
|
||||
|
||||
# Common entity suffixes that indicate incomplete entities
|
||||
entity_suffixes = ['公', '司', '所', '院', '厅', '局', '部', '会', '团', '社', '处', '室', '楼', '号']
|
||||
|
||||
# Check if we're in the middle of a potential entity
|
||||
for suffix in entity_suffixes:
|
||||
# Look for incomplete entity patterns
|
||||
if text[position-1:position+1] in [f'公{suffix}', f'司{suffix}', f'所{suffix}']:
|
||||
return False
|
||||
|
||||
# Check for incomplete company names
|
||||
if text[position-2:position+1] in ['公司', '事务所', '协会', '研究院']:
|
||||
return False
|
||||
|
||||
# Check for incomplete address patterns
|
||||
address_patterns = ['省', '市', '区', '县', '路', '街', '巷', '号', '室']
|
||||
for pattern in address_patterns:
|
||||
if text[position-1:position+1] in [f'省{pattern}', f'市{pattern}', f'区{pattern}', f'县{pattern}']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _create_sentence_chunks(self, sentences: List[str], max_tokens: int = 400) -> List[str]:
|
||||
"""
|
||||
Create chunks from sentences while respecting token limits and entity boundaries
|
||||
|
||||
Args:
|
||||
sentences: List of sentences
|
||||
max_tokens: Maximum tokens per chunk
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
"""
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
current_token_count = 0
|
||||
|
||||
for sentence in sentences:
|
||||
# Estimate token count for this sentence
|
||||
sentence_tokens = len(self.tokenizer.tokenize(sentence))
|
||||
|
||||
# If adding this sentence would exceed the limit
|
||||
if current_token_count + sentence_tokens > max_tokens and current_chunk:
|
||||
# Check if we can split the sentence to fit better
|
||||
if sentence_tokens > max_tokens // 2: # If sentence is too long
|
||||
# Try to split the sentence at a safe boundary
|
||||
split_sentence = self._split_long_sentence(sentence, max_tokens - current_token_count)
|
||||
if split_sentence:
|
||||
# Add the first part to current chunk
|
||||
current_chunk.append(split_sentence[0])
|
||||
chunks.append(''.join(current_chunk))
|
||||
|
||||
# Start new chunk with remaining parts
|
||||
current_chunk = split_sentence[1:]
|
||||
current_token_count = sum(len(self.tokenizer.tokenize(s)) for s in current_chunk)
|
||||
else:
|
||||
# Finalize current chunk and start new one
|
||||
chunks.append(''.join(current_chunk))
|
||||
current_chunk = [sentence]
|
||||
current_token_count = sentence_tokens
|
||||
else:
|
||||
# Finalize current chunk and start new one
|
||||
chunks.append(''.join(current_chunk))
|
||||
current_chunk = [sentence]
|
||||
current_token_count = sentence_tokens
|
||||
else:
|
||||
# Add sentence to current chunk
|
||||
current_chunk.append(sentence)
|
||||
current_token_count += sentence_tokens
|
||||
|
||||
# Add the last chunk if it has content
|
||||
if current_chunk:
|
||||
chunks.append(''.join(current_chunk))
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_long_sentence(self, sentence: str, max_tokens: int) -> Optional[List[str]]:
|
||||
"""
|
||||
Split a long sentence at safe boundaries
|
||||
|
||||
Args:
|
||||
sentence: The sentence to split
|
||||
max_tokens: Maximum tokens for the first part
|
||||
|
||||
Returns:
|
||||
List of sentence parts, or None if splitting is not possible
|
||||
"""
|
||||
if len(self.tokenizer.tokenize(sentence)) <= max_tokens:
|
||||
return None
|
||||
|
||||
# Try to find safe splitting points
|
||||
# Look for punctuation marks that are safe to split at
|
||||
safe_splitters = [',', ',', ';', ';', '、', ':', ':']
|
||||
|
||||
for splitter in safe_splitters:
|
||||
if splitter in sentence:
|
||||
parts = sentence.split(splitter)
|
||||
current_part = ""
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
test_part = current_part + part + (splitter if i < len(parts) - 1 else "")
|
||||
if len(self.tokenizer.tokenize(test_part)) > max_tokens:
|
||||
if current_part:
|
||||
# Found a safe split point
|
||||
remaining = splitter.join(parts[i:])
|
||||
return [current_part, remaining]
|
||||
break
|
||||
current_part = test_part
|
||||
|
||||
# If no safe split point found, try character-based splitting with entity boundary check
|
||||
target_chars = int(max_tokens / 1.5) # Rough character estimate
|
||||
|
||||
for i in range(target_chars, len(sentence)):
|
||||
if self._is_entity_boundary_safe(sentence, i):
|
||||
part1 = sentence[:i]
|
||||
part2 = sentence[i:]
|
||||
if len(self.tokenizer.tokenize(part1)) <= max_tokens:
|
||||
return [part1, part2]
|
||||
|
||||
return None
|
||||
|
||||
def extract(self, text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract named entities from the given text
|
||||
|
||||
Args:
|
||||
text: The text to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary containing extracted entities in the format expected by the system
|
||||
"""
|
||||
try:
|
||||
if not text or not text.strip():
|
||||
logger.warning("Empty text provided for NER processing")
|
||||
return {"entities": []}
|
||||
|
||||
# Initialize model if not already done
|
||||
if not self._model_initialized:
|
||||
self._initialize_model()
|
||||
|
||||
logger.info(f"Processing text with NER (length: {len(text)} characters)")
|
||||
|
||||
# Check if text needs chunking
|
||||
if len(text) > 400: # Character-based threshold for chunking
|
||||
logger.info("Text is long, using chunking approach")
|
||||
return self._extract_with_chunking(text)
|
||||
else:
|
||||
logger.info("Text is short, processing directly")
|
||||
return self._extract_single(text)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during NER processing: {str(e)}")
|
||||
raise Exception(f"NER processing failed: {str(e)}")
|
||||
|
||||
def _extract_single(self, text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract entities from a single text chunk
|
||||
|
||||
Args:
|
||||
text: The text to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary containing extracted entities
|
||||
"""
|
||||
try:
|
||||
# Run the NER pipeline - it handles truncation automatically
|
||||
logger.info(f"Running NER pipeline with text: {text}")
|
||||
results = self.ner_pipeline(text)
|
||||
logger.info(f"NER results: {results}")
|
||||
|
||||
# Filter and process entities
|
||||
filtered_entities = []
|
||||
for entity in results:
|
||||
entity_group = entity['entity_group']
|
||||
|
||||
# Only process entities that we care about
|
||||
if entity_group in self.label_map:
|
||||
entity_type = self.label_map[entity_group]
|
||||
entity_text = entity['word']
|
||||
confidence_score = entity['score']
|
||||
|
||||
# Clean up the tokenized text (remove spaces between Chinese characters)
|
||||
cleaned_text = self._clean_tokenized_text(entity_text)
|
||||
|
||||
# Add to our list with both original and cleaned text, only add if confidence score is above threshold
|
||||
# if entity_group is 'address' or 'company', and only has characters less then 3, then filter it out
|
||||
if confidence_score > self.confidence_threshold:
|
||||
filtered_entities.append({
|
||||
"text": cleaned_text, # Clean text for display/processing
|
||||
"tokenized_text": entity_text, # Original tokenized text from model
|
||||
"type": entity_type,
|
||||
"entity_group": entity_group,
|
||||
"confidence": confidence_score
|
||||
})
|
||||
logger.info(f"Filtered entities: {filtered_entities}")
|
||||
# filter out entities that are less then 3 characters with entity_group is 'address' or 'company'
|
||||
filtered_entities = [entity for entity in filtered_entities if entity['entity_group'] not in ['address', 'company'] or len(entity['text']) > 3]
|
||||
logger.info(f"Final Filtered entities: {filtered_entities}")
|
||||
|
||||
return {
|
||||
"entities": filtered_entities,
|
||||
"total_count": len(filtered_entities)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during single NER processing: {str(e)}")
|
||||
raise Exception(f"Single NER processing failed: {str(e)}")
|
||||
|
||||
def _extract_with_chunking(self, text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract entities from long text using sentence-based chunking approach
|
||||
|
||||
Args:
|
||||
text: The text to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary containing extracted entities
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Using sentence-based chunking for text of length: {len(text)}")
|
||||
|
||||
# Split text into sentences
|
||||
sentences = self._split_text_by_sentences(text)
|
||||
logger.info(f"Split text into {len(sentences)} sentences")
|
||||
|
||||
# Create chunks from sentences
|
||||
chunks = self._create_sentence_chunks(sentences, max_tokens=400)
|
||||
logger.info(f"Created {len(chunks)} chunks from sentences")
|
||||
|
||||
all_entities = []
|
||||
|
||||
# Process each chunk
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Verify chunk won't exceed token limit
|
||||
chunk_tokens = len(self.tokenizer.tokenize(chunk))
|
||||
logger.info(f"Processing chunk {i+1}: {len(chunk)} chars, {chunk_tokens} tokens")
|
||||
|
||||
if chunk_tokens > 512:
|
||||
logger.warning(f"Chunk {i+1} has {chunk_tokens} tokens, truncating")
|
||||
# Truncate the chunk to fit within token limit
|
||||
chunk = self.tokenizer.convert_tokens_to_string(
|
||||
self.tokenizer.tokenize(chunk)[:512]
|
||||
)
|
||||
|
||||
# Extract entities from this chunk
|
||||
chunk_result = self._extract_single(chunk)
|
||||
chunk_entities = chunk_result.get("entities", [])
|
||||
|
||||
all_entities.extend(chunk_entities)
|
||||
logger.info(f"Chunk {i+1} extracted {len(chunk_entities)} entities")
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
unique_entities = []
|
||||
seen_texts = set()
|
||||
|
||||
for entity in all_entities:
|
||||
text = entity['text'].strip()
|
||||
if text and text not in seen_texts:
|
||||
seen_texts.add(text)
|
||||
unique_entities.append(entity)
|
||||
|
||||
logger.info(f"Sentence-based chunking completed: {len(all_entities)} total entities, {len(unique_entities)} unique entities")
|
||||
|
||||
return {
|
||||
"entities": unique_entities,
|
||||
"total_count": len(unique_entities)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during sentence-based chunked NER processing: {str(e)}")
|
||||
raise Exception(f"Sentence-based chunked NER processing failed: {str(e)}")
|
||||
|
||||
def _clean_tokenized_text(self, tokenized_text: str) -> str:
|
||||
"""
|
||||
Clean up tokenized text by removing spaces between Chinese characters
|
||||
|
||||
Args:
|
||||
tokenized_text: Text with spaces between characters (e.g., "北 京 市")
|
||||
|
||||
Returns:
|
||||
Cleaned text without spaces (e.g., "北京市")
|
||||
"""
|
||||
if not tokenized_text:
|
||||
return tokenized_text
|
||||
|
||||
# Remove spaces between Chinese characters
|
||||
# This handles cases like "北 京 市" -> "北京市"
|
||||
cleaned = tokenized_text.replace(" ", "")
|
||||
|
||||
# Also handle cases where there might be multiple spaces
|
||||
cleaned = " ".join(cleaned.split())
|
||||
|
||||
return cleaned
|
||||
|
||||
def get_entity_summary(self, entities: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a summary of extracted entities by type
|
||||
|
||||
Args:
|
||||
entities: List of extracted entities
|
||||
|
||||
Returns:
|
||||
Summary dictionary with counts by entity type
|
||||
"""
|
||||
summary = {}
|
||||
for entity in entities:
|
||||
entity_type = entity['type']
|
||||
if entity_type not in summary:
|
||||
summary[entity_type] = []
|
||||
summary[entity_type].append(entity['text'])
|
||||
|
||||
# Convert to count format
|
||||
summary_counts = {entity_type: len(texts) for entity_type, texts in summary.items()}
|
||||
|
||||
return {
|
||||
"summary": summary,
|
||||
"counts": summary_counts,
|
||||
"total_entities": len(entities)
|
||||
}
|
||||
|
||||
def extract_and_summarize(self, text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract entities and provide a summary in one call
|
||||
|
||||
Args:
|
||||
text: The text to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary containing entities and summary
|
||||
"""
|
||||
entities_result = self.extract(text)
|
||||
entities = entities_result.get("entities", [])
|
||||
|
||||
summary_result = self.get_entity_summary(entities)
|
||||
|
||||
return {
|
||||
"entities": entities,
|
||||
"summary": summary_result,
|
||||
"total_count": len(entities)
|
||||
}
|
||||
|
||||
def get_confidence(self) -> float:
|
||||
"""
|
||||
Return confidence level of extraction
|
||||
|
||||
Returns:
|
||||
Confidence level as a float between 0.0 and 1.0
|
||||
"""
|
||||
# NER models typically have high confidence for well-trained entities
|
||||
# This is a reasonable default confidence level for NER extraction
|
||||
return 0.85
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the NER model
|
||||
|
||||
Returns:
|
||||
Dictionary containing model information
|
||||
"""
|
||||
return {
|
||||
"model_name": self.model_checkpoint,
|
||||
"model_type": "Chinese NER",
|
||||
"supported_entities": [
|
||||
"人名 (Person Names)",
|
||||
"公司名称 (Company Names)",
|
||||
"组织机构名 (Organization Names)",
|
||||
"地址 (Addresses)"
|
||||
],
|
||||
"description": "Fine-tuned RoBERTa model for Chinese Named Entity Recognition on CLUENER2020 dataset"
|
||||
}
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
"""
|
||||
Factory for creating maskers.
|
||||
"""
|
||||
|
||||
from typing import Dict, Type, Any
|
||||
from .maskers.base_masker import BaseMasker
|
||||
from .maskers.name_masker import ChineseNameMasker, EnglishNameMasker
|
||||
from .maskers.company_masker import CompanyMasker
|
||||
from .maskers.address_masker import AddressMasker
|
||||
from .maskers.id_masker import IDMasker
|
||||
from .maskers.case_masker import CaseMasker
|
||||
from ..services.ollama_client import OllamaClient
|
||||
|
||||
|
||||
class MaskerFactory:
|
||||
"""Factory for creating maskers"""
|
||||
|
||||
_maskers: Dict[str, Type[BaseMasker]] = {
|
||||
'chinese_name': ChineseNameMasker,
|
||||
'english_name': EnglishNameMasker,
|
||||
'company': CompanyMasker,
|
||||
'address': AddressMasker,
|
||||
'id': IDMasker,
|
||||
'case': CaseMasker,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_masker(cls, masker_type: str, ollama_client: OllamaClient = None, config: Dict[str, Any] = None) -> BaseMasker:
|
||||
"""
|
||||
Create a masker of the specified type.
|
||||
|
||||
Args:
|
||||
masker_type: Type of masker to create
|
||||
ollama_client: Ollama client for LLM-based maskers
|
||||
config: Configuration for the masker
|
||||
|
||||
Returns:
|
||||
Instance of the specified masker
|
||||
|
||||
Raises:
|
||||
ValueError: If masker type is unknown
|
||||
"""
|
||||
if masker_type not in cls._maskers:
|
||||
raise ValueError(f"Unknown masker type: {masker_type}")
|
||||
|
||||
masker_class = cls._maskers[masker_type]
|
||||
|
||||
# Handle maskers that need ollama_client
|
||||
if masker_type in ['company', 'address']:
|
||||
if not ollama_client:
|
||||
raise ValueError(f"Ollama client is required for {masker_type} masker")
|
||||
return masker_class(ollama_client)
|
||||
|
||||
# Handle maskers that don't need special parameters
|
||||
return masker_class()
|
||||
|
||||
@classmethod
|
||||
def get_available_maskers(cls) -> list[str]:
|
||||
"""Get list of available masker types"""
|
||||
return list(cls._maskers.keys())
|
||||
|
||||
@classmethod
|
||||
def register_masker(cls, masker_type: str, masker_class: Type[BaseMasker]):
|
||||
"""Register a new masker type"""
|
||||
cls._maskers[masker_type] = masker_class
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
"""
|
||||
Maskers package for entity masking functionality.
|
||||
"""
|
||||
|
||||
from .base_masker import BaseMasker
|
||||
from .name_masker import ChineseNameMasker, EnglishNameMasker
|
||||
from .company_masker import CompanyMasker
|
||||
from .address_masker import AddressMasker
|
||||
from .id_masker import IDMasker
|
||||
from .case_masker import CaseMasker
|
||||
|
||||
__all__ = [
|
||||
'BaseMasker',
|
||||
'ChineseNameMasker',
|
||||
'EnglishNameMasker',
|
||||
'CompanyMasker',
|
||||
'AddressMasker',
|
||||
'IDMasker',
|
||||
'CaseMasker'
|
||||
]
|
||||
|
|
@ -1,91 +0,0 @@
|
|||
"""
|
||||
Address masker for addresses.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from pypinyin import pinyin, Style
|
||||
from ...services.ollama_client import OllamaClient
|
||||
from ..extractors.address_extractor import AddressExtractor
|
||||
from .base_masker import BaseMasker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AddressMasker(BaseMasker):
|
||||
"""Masker for addresses"""
|
||||
|
||||
def __init__(self, ollama_client: OllamaClient):
|
||||
self.extractor = AddressExtractor(ollama_client)
|
||||
|
||||
def mask(self, address: str, context: Dict[str, Any] = None) -> str:
|
||||
"""
|
||||
Mask address by replacing components with masked versions.
|
||||
|
||||
Args:
|
||||
address: The address to mask
|
||||
context: Additional context (not used for address masking)
|
||||
|
||||
Returns:
|
||||
Masked address
|
||||
"""
|
||||
if not address:
|
||||
return address
|
||||
|
||||
# Extract address components
|
||||
components = self.extractor.extract(address)
|
||||
if not components:
|
||||
return address
|
||||
|
||||
masked_address = address
|
||||
|
||||
# Replace road name
|
||||
if components.get("road_name"):
|
||||
road_name = components["road_name"]
|
||||
# Get pinyin initials for road name
|
||||
try:
|
||||
pinyin_list = pinyin(road_name, style=Style.NORMAL)
|
||||
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
|
||||
masked_address = masked_address.replace(road_name, initials + "路")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get pinyin for road name {road_name}: {e}")
|
||||
# Fallback to first character
|
||||
masked_address = masked_address.replace(road_name, road_name[0].upper() + "路")
|
||||
|
||||
# Replace house number
|
||||
if components.get("house_number"):
|
||||
house_number = components["house_number"]
|
||||
masked_address = masked_address.replace(house_number + "号", "**号")
|
||||
|
||||
# Replace building name
|
||||
if components.get("building_name"):
|
||||
building_name = components["building_name"]
|
||||
# Get pinyin initials for building name
|
||||
try:
|
||||
pinyin_list = pinyin(building_name, style=Style.NORMAL)
|
||||
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
|
||||
masked_address = masked_address.replace(building_name, initials)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get pinyin for building name {building_name}: {e}")
|
||||
# Fallback to first character
|
||||
masked_address = masked_address.replace(building_name, building_name[0].upper())
|
||||
|
||||
# Replace community name
|
||||
if components.get("community_name"):
|
||||
community_name = components["community_name"]
|
||||
# Get pinyin initials for community name
|
||||
try:
|
||||
pinyin_list = pinyin(community_name, style=Style.NORMAL)
|
||||
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
|
||||
masked_address = masked_address.replace(community_name, initials)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get pinyin for community name {community_name}: {e}")
|
||||
# Fallback to first character
|
||||
masked_address = masked_address.replace(community_name, community_name[0].upper())
|
||||
|
||||
return masked_address
|
||||
|
||||
def get_supported_types(self) -> list[str]:
|
||||
"""Return list of entity types this masker supports"""
|
||||
return ['地址']
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
"""
|
||||
Abstract base class for all maskers.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class BaseMasker(ABC):
|
||||
"""Abstract base class for all maskers"""
|
||||
|
||||
@abstractmethod
|
||||
def mask(self, text: str, context: Dict[str, Any] = None) -> str:
|
||||
"""Mask the given text according to specific rules"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_types(self) -> list[str]:
|
||||
"""Return list of entity types this masker supports"""
|
||||
pass
|
||||
|
||||
def can_mask(self, entity_type: str) -> bool:
|
||||
"""Check if this masker can handle the given entity type"""
|
||||
return entity_type in self.get_supported_types()
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
"""
|
||||
Case masker for case numbers.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any
|
||||
from .base_masker import BaseMasker
|
||||
|
||||
|
||||
class CaseMasker(BaseMasker):
|
||||
"""Masker for case numbers"""
|
||||
|
||||
def mask(self, text: str, context: Dict[str, Any] = None) -> str:
|
||||
"""
|
||||
Mask case numbers by replacing digits with ***.
|
||||
|
||||
Args:
|
||||
text: The text to mask
|
||||
context: Additional context (not used for case masking)
|
||||
|
||||
Returns:
|
||||
Masked text
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# Replace digits with *** while preserving structure
|
||||
masked = re.sub(r'(\d[\d\s]*)(号)', r'***\2', text)
|
||||
return masked
|
||||
|
||||
def get_supported_types(self) -> list[str]:
|
||||
"""Return list of entity types this masker supports"""
|
||||
return ['案号']
|
||||
|
|
@ -1,98 +0,0 @@
|
|||
"""
|
||||
Company masker for company names.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from pypinyin import pinyin, Style
|
||||
from ...services.ollama_client import OllamaClient
|
||||
from ..extractors.business_name_extractor import BusinessNameExtractor
|
||||
from .base_masker import BaseMasker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompanyMasker(BaseMasker):
|
||||
"""Masker for company names"""
|
||||
|
||||
def __init__(self, ollama_client: OllamaClient):
|
||||
self.extractor = BusinessNameExtractor(ollama_client)
|
||||
|
||||
def mask(self, company_name: str, context: Dict[str, Any] = None) -> str:
|
||||
"""
|
||||
Mask company name by replacing business name with letters.
|
||||
|
||||
Args:
|
||||
company_name: The company name to mask
|
||||
context: Additional context (not used for company masking)
|
||||
|
||||
Returns:
|
||||
Masked company name
|
||||
"""
|
||||
if not company_name:
|
||||
return company_name
|
||||
|
||||
# Extract business name
|
||||
extraction_result = self.extractor.extract(company_name)
|
||||
if not extraction_result:
|
||||
return company_name
|
||||
|
||||
business_name = extraction_result.get('business_name', '')
|
||||
if not business_name:
|
||||
return company_name
|
||||
|
||||
# Get pinyin first letter of business name
|
||||
try:
|
||||
pinyin_list = pinyin(business_name, style=Style.NORMAL)
|
||||
first_letter = pinyin_list[0][0][0].upper() if pinyin_list and pinyin_list[0] else 'A'
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get pinyin for {business_name}: {e}")
|
||||
first_letter = 'A'
|
||||
|
||||
# Calculate next two letters
|
||||
if first_letter >= 'Y':
|
||||
# If first letter is Y or Z, use X and Y
|
||||
letters = 'XY'
|
||||
elif first_letter >= 'X':
|
||||
# If first letter is X, use Y and Z
|
||||
letters = 'YZ'
|
||||
else:
|
||||
# Normal case: use next two letters
|
||||
letters = chr(ord(first_letter) + 1) + chr(ord(first_letter) + 2)
|
||||
|
||||
# Replace business name
|
||||
if business_name in company_name:
|
||||
masked_name = company_name.replace(business_name, letters)
|
||||
else:
|
||||
# Try smarter replacement
|
||||
masked_name = self._replace_business_name_in_company(company_name, business_name, letters)
|
||||
|
||||
return masked_name
|
||||
|
||||
def _replace_business_name_in_company(self, company_name: str, business_name: str, letters: str) -> str:
|
||||
"""Smart replacement of business name in company name"""
|
||||
# Try different replacement patterns
|
||||
patterns = [
|
||||
business_name,
|
||||
business_name + '(',
|
||||
business_name + '(',
|
||||
'(' + business_name + ')',
|
||||
'(' + business_name + ')',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
if pattern in company_name:
|
||||
if pattern.endswith('(') or pattern.endswith('('):
|
||||
return company_name.replace(pattern, letters + pattern[-1])
|
||||
elif pattern.startswith('(') or pattern.startswith('('):
|
||||
return company_name.replace(pattern, pattern[0] + letters + pattern[-1])
|
||||
else:
|
||||
return company_name.replace(pattern, letters)
|
||||
|
||||
# If no pattern found, return original
|
||||
return company_name
|
||||
|
||||
def get_supported_types(self) -> list[str]:
|
||||
"""Return list of entity types this masker supports"""
|
||||
return ['公司名称', 'Company', '英文公司名', 'English Company']
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
"""
|
||||
ID masker for ID numbers and social credit codes.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from .base_masker import BaseMasker
|
||||
|
||||
|
||||
class IDMasker(BaseMasker):
|
||||
"""Masker for ID numbers and social credit codes"""
|
||||
|
||||
def mask(self, text: str, context: Dict[str, Any] = None) -> str:
|
||||
"""
|
||||
Mask ID numbers and social credit codes.
|
||||
|
||||
Args:
|
||||
text: The text to mask
|
||||
context: Additional context (not used for ID masking)
|
||||
|
||||
Returns:
|
||||
Masked text
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# Determine the type based on length and format
|
||||
if len(text) == 18 and text.isdigit():
|
||||
# ID number: keep first 6 digits
|
||||
return text[:6] + 'X' * (len(text) - 6)
|
||||
elif len(text) == 18 and any(c.isalpha() for c in text):
|
||||
# Social credit code: keep first 7 digits
|
||||
return text[:7] + 'X' * (len(text) - 7)
|
||||
else:
|
||||
# Fallback for invalid formats
|
||||
return text
|
||||
|
||||
def get_supported_types(self) -> list[str]:
|
||||
"""Return list of entity types this masker supports"""
|
||||
return ['身份证号', '社会信用代码']
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
"""
|
||||
Name maskers for Chinese and English names.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from pypinyin import pinyin, Style
|
||||
from .base_masker import BaseMasker
|
||||
|
||||
|
||||
class ChineseNameMasker(BaseMasker):
|
||||
"""Masker for Chinese names"""
|
||||
|
||||
def __init__(self):
|
||||
self.surname_counter = {}
|
||||
|
||||
def mask(self, name: str, context: Dict[str, Any] = None) -> str:
|
||||
"""
|
||||
Mask Chinese names: keep surname, convert given name to pinyin initials.
|
||||
|
||||
Args:
|
||||
name: The name to mask
|
||||
context: Additional context containing surname_counter
|
||||
|
||||
Returns:
|
||||
Masked name
|
||||
"""
|
||||
if not name or len(name) < 2:
|
||||
return name
|
||||
|
||||
# Use context surname_counter if provided, otherwise use instance counter
|
||||
surname_counter = context.get('surname_counter', self.surname_counter) if context else self.surname_counter
|
||||
|
||||
surname = name[0]
|
||||
given_name = name[1:]
|
||||
|
||||
# Get pinyin initials for given name
|
||||
try:
|
||||
pinyin_list = pinyin(given_name, style=Style.NORMAL)
|
||||
initials = ''.join([p[0][0].upper() for p in pinyin_list if p and p[0]])
|
||||
except Exception:
|
||||
# Fallback to original characters if pinyin fails
|
||||
initials = given_name
|
||||
|
||||
# Initialize surname counter
|
||||
if surname not in surname_counter:
|
||||
surname_counter[surname] = {}
|
||||
|
||||
# Check for duplicate surname and initials combination
|
||||
if initials in surname_counter[surname]:
|
||||
surname_counter[surname][initials] += 1
|
||||
masked_name = f"{surname}{initials}{surname_counter[surname][initials]}"
|
||||
else:
|
||||
surname_counter[surname][initials] = 1
|
||||
masked_name = f"{surname}{initials}"
|
||||
|
||||
return masked_name
|
||||
|
||||
def get_supported_types(self) -> list[str]:
|
||||
"""Return list of entity types this masker supports"""
|
||||
return ['人名', '律师姓名', '审判人员姓名']
|
||||
|
||||
|
||||
class EnglishNameMasker(BaseMasker):
|
||||
"""Masker for English names"""
|
||||
|
||||
def mask(self, name: str, context: Dict[str, Any] = None) -> str:
|
||||
"""
|
||||
Mask English names: convert each word to first letter + ***.
|
||||
|
||||
Args:
|
||||
name: The name to mask
|
||||
context: Additional context (not used for English name masking)
|
||||
|
||||
Returns:
|
||||
Masked name
|
||||
"""
|
||||
if not name:
|
||||
return name
|
||||
|
||||
masked_parts = []
|
||||
for part in name.split():
|
||||
if part:
|
||||
masked_parts.append(part[0] + '***')
|
||||
|
||||
return ' '.join(masked_parts)
|
||||
|
||||
def get_supported_types(self) -> list[str]:
|
||||
"""Return list of entity types this masker supports"""
|
||||
return ['英文人名']
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,410 +0,0 @@
|
|||
"""
|
||||
Refactored NerProcessor using the new masker architecture.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from ..prompts.masking_prompts import (
|
||||
get_ner_name_prompt, get_ner_company_prompt, get_ner_address_prompt,
|
||||
get_ner_project_prompt, get_ner_case_number_prompt, get_entity_linkage_prompt
|
||||
)
|
||||
from ..services.ollama_client import OllamaClient
|
||||
from ...core.config import settings
|
||||
from ..utils.json_extractor import LLMJsonExtractor
|
||||
from ..utils.llm_validator import LLMResponseValidator
|
||||
from .regs.entity_regex import extract_id_number_entities, extract_social_credit_code_entities
|
||||
from .masker_factory import MaskerFactory
|
||||
from .maskers.base_masker import BaseMasker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NerProcessorRefactored:
|
||||
"""Refactored NerProcessor using the new masker architecture"""
|
||||
|
||||
def __init__(self):
|
||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||
self.max_retries = 3
|
||||
self.maskers = self._initialize_maskers()
|
||||
self.surname_counter = {} # Shared counter for Chinese names
|
||||
|
||||
def _find_entity_alignment(self, entity_text: str, original_document_text: str) -> Optional[Tuple[int, int, str]]:
|
||||
"""
|
||||
Find entity in original document using character-by-character alignment.
|
||||
|
||||
This method handles the case where the original document may have spaces
|
||||
that are not from tokenization, and the entity text may have different
|
||||
spacing patterns.
|
||||
|
||||
Args:
|
||||
entity_text: The entity text to find (may have spaces from tokenization)
|
||||
original_document_text: The original document text (may have spaces)
|
||||
|
||||
Returns:
|
||||
Tuple of (start_pos, end_pos, found_text) or None if not found
|
||||
"""
|
||||
# Remove all spaces from entity text to get clean characters
|
||||
clean_entity = entity_text.replace(" ", "")
|
||||
|
||||
# Create character lists ignoring spaces from both entity and document
|
||||
entity_chars = [c for c in clean_entity]
|
||||
doc_chars = [c for c in original_document_text if c != ' ']
|
||||
|
||||
# Find the sequence in document characters
|
||||
for i in range(len(doc_chars) - len(entity_chars) + 1):
|
||||
if doc_chars[i:i+len(entity_chars)] == entity_chars:
|
||||
# Found match, now map back to original positions
|
||||
return self._map_char_positions_to_original(i, len(entity_chars), original_document_text)
|
||||
|
||||
return None
|
||||
|
||||
def _map_char_positions_to_original(self, clean_start: int, entity_length: int, original_text: str) -> Tuple[int, int, str]:
|
||||
"""
|
||||
Map positions from clean text (without spaces) back to original text positions.
|
||||
|
||||
Args:
|
||||
clean_start: Start position in clean text (without spaces)
|
||||
entity_length: Length of entity in characters
|
||||
original_text: Original document text with spaces
|
||||
|
||||
Returns:
|
||||
Tuple of (start_pos, end_pos, found_text) in original text
|
||||
"""
|
||||
original_pos = 0
|
||||
clean_pos = 0
|
||||
|
||||
# Find the start position in original text
|
||||
while clean_pos < clean_start and original_pos < len(original_text):
|
||||
if original_text[original_pos] != ' ':
|
||||
clean_pos += 1
|
||||
original_pos += 1
|
||||
|
||||
start_pos = original_pos
|
||||
|
||||
# Find the end position by counting non-space characters
|
||||
chars_found = 0
|
||||
while chars_found < entity_length and original_pos < len(original_text):
|
||||
if original_text[original_pos] != ' ':
|
||||
chars_found += 1
|
||||
original_pos += 1
|
||||
|
||||
end_pos = original_pos
|
||||
|
||||
# Extract the actual text from the original document
|
||||
found_text = original_text[start_pos:end_pos]
|
||||
|
||||
return start_pos, end_pos, found_text
|
||||
|
||||
def apply_entity_masking_with_alignment(self, original_document_text: str, entity_mapping: Dict[str, str], mask_char: str = "*") -> str:
|
||||
"""
|
||||
Apply entity masking to original document text using character-by-character alignment.
|
||||
|
||||
This method finds each entity in the original document using alignment and
|
||||
replaces it with the corresponding masked version. It handles multiple
|
||||
occurrences of the same entity by finding all instances before moving
|
||||
to the next entity.
|
||||
|
||||
Args:
|
||||
original_document_text: The original document text to mask
|
||||
entity_mapping: Dictionary mapping original entity text to masked text
|
||||
mask_char: Character to use for masking (default: "*")
|
||||
|
||||
Returns:
|
||||
Masked document text
|
||||
"""
|
||||
masked_document = original_document_text
|
||||
|
||||
# Sort entities by length (longest first) to avoid partial matches
|
||||
sorted_entities = sorted(entity_mapping.keys(), key=len, reverse=True)
|
||||
|
||||
for entity_text in sorted_entities:
|
||||
masked_text = entity_mapping[entity_text]
|
||||
|
||||
# Skip if masked text is the same as original text (prevents infinite loop)
|
||||
if entity_text == masked_text:
|
||||
logger.debug(f"Skipping entity '{entity_text}' as masked text is identical")
|
||||
continue
|
||||
|
||||
# Find ALL occurrences of this entity in the document
|
||||
# We need to loop until no more matches are found
|
||||
# Add safety counter to prevent infinite loops
|
||||
max_iterations = 100 # Safety limit
|
||||
iteration_count = 0
|
||||
|
||||
while iteration_count < max_iterations:
|
||||
iteration_count += 1
|
||||
|
||||
# Find the entity in the current masked document using alignment
|
||||
alignment_result = self._find_entity_alignment(entity_text, masked_document)
|
||||
|
||||
if alignment_result:
|
||||
start_pos, end_pos, found_text = alignment_result
|
||||
|
||||
# Replace the found text with the masked version
|
||||
masked_document = (
|
||||
masked_document[:start_pos] +
|
||||
masked_text +
|
||||
masked_document[end_pos:]
|
||||
)
|
||||
|
||||
logger.debug(f"Masked entity '{entity_text}' -> '{masked_text}' at positions {start_pos}-{end_pos} (iteration {iteration_count})")
|
||||
else:
|
||||
# No more occurrences found for this entity, move to next entity
|
||||
logger.debug(f"No more occurrences of '{entity_text}' found in document after {iteration_count} iterations")
|
||||
break
|
||||
|
||||
# Log warning if we hit the safety limit
|
||||
if iteration_count >= max_iterations:
|
||||
logger.warning(f"Reached maximum iterations ({max_iterations}) for entity '{entity_text}', stopping to prevent infinite loop")
|
||||
|
||||
return masked_document
|
||||
|
||||
def _initialize_maskers(self) -> Dict[str, BaseMasker]:
|
||||
"""Initialize all maskers"""
|
||||
maskers = {}
|
||||
|
||||
# Create maskers that don't need ollama_client
|
||||
maskers['chinese_name'] = MaskerFactory.create_masker('chinese_name')
|
||||
maskers['english_name'] = MaskerFactory.create_masker('english_name')
|
||||
maskers['id'] = MaskerFactory.create_masker('id')
|
||||
maskers['case'] = MaskerFactory.create_masker('case')
|
||||
|
||||
# Create maskers that need ollama_client
|
||||
maskers['company'] = MaskerFactory.create_masker('company', self.ollama_client)
|
||||
maskers['address'] = MaskerFactory.create_masker('address', self.ollama_client)
|
||||
|
||||
return maskers
|
||||
|
||||
def _get_masker_for_type(self, entity_type: str) -> Optional[BaseMasker]:
|
||||
"""Get the appropriate masker for the given entity type"""
|
||||
for masker in self.maskers.values():
|
||||
if masker.can_mask(entity_type):
|
||||
return masker
|
||||
return None
|
||||
|
||||
def _validate_mapping_format(self, mapping: Dict[str, Any]) -> bool:
|
||||
"""Validate entity extraction mapping format"""
|
||||
return LLMResponseValidator.validate_entity_extraction(mapping)
|
||||
|
||||
def _process_entity_type(self, chunk: str, prompt_func, entity_type: str) -> Dict[str, str]:
|
||||
"""Process entities of a specific type using LLM"""
|
||||
try:
|
||||
formatted_prompt = prompt_func(chunk)
|
||||
logger.info(f"Calling ollama to generate {entity_type} mapping for chunk: {formatted_prompt}")
|
||||
|
||||
# Use the new enhanced generate method with validation
|
||||
mapping = self.ollama_client.generate_with_validation(
|
||||
prompt=formatted_prompt,
|
||||
response_type='entity_extraction',
|
||||
return_parsed=True
|
||||
)
|
||||
|
||||
logger.info(f"Parsed mapping: {mapping}")
|
||||
|
||||
if mapping and self._validate_mapping_format(mapping):
|
||||
return mapping
|
||||
else:
|
||||
logger.warning(f"Invalid mapping format received for {entity_type}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating {entity_type} mapping: {e}")
|
||||
return {}
|
||||
|
||||
def build_mapping(self, chunk: str) -> List[Dict[str, str]]:
|
||||
"""Build entity mappings from text chunk"""
|
||||
mapping_pipeline = []
|
||||
|
||||
# Process different entity types
|
||||
entity_configs = [
|
||||
(get_ner_name_prompt, "people names"),
|
||||
(get_ner_company_prompt, "company names"),
|
||||
(get_ner_address_prompt, "addresses"),
|
||||
(get_ner_project_prompt, "project names"),
|
||||
(get_ner_case_number_prompt, "case numbers")
|
||||
]
|
||||
|
||||
for prompt_func, entity_type in entity_configs:
|
||||
mapping = self._process_entity_type(chunk, prompt_func, entity_type)
|
||||
if mapping:
|
||||
mapping_pipeline.append(mapping)
|
||||
|
||||
# Process regex-based entities
|
||||
regex_entity_extractors = [
|
||||
extract_id_number_entities,
|
||||
extract_social_credit_code_entities
|
||||
]
|
||||
|
||||
for extractor in regex_entity_extractors:
|
||||
mapping = extractor(chunk)
|
||||
if mapping and LLMResponseValidator.validate_regex_entity(mapping):
|
||||
mapping_pipeline.append(mapping)
|
||||
elif mapping:
|
||||
logger.warning(f"Invalid regex entity mapping format: {mapping}")
|
||||
|
||||
return mapping_pipeline
|
||||
|
||||
def _merge_entity_mappings(self, chunk_mappings: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
"""Merge entity mappings from multiple chunks"""
|
||||
all_entities = []
|
||||
for mapping in chunk_mappings:
|
||||
if isinstance(mapping, dict) and 'entities' in mapping:
|
||||
entities = mapping['entities']
|
||||
if isinstance(entities, list):
|
||||
all_entities.extend(entities)
|
||||
|
||||
unique_entities = []
|
||||
seen_texts = set()
|
||||
|
||||
for entity in all_entities:
|
||||
if isinstance(entity, dict) and 'text' in entity:
|
||||
text = entity['text'].strip()
|
||||
if text and text not in seen_texts:
|
||||
seen_texts.add(text)
|
||||
unique_entities.append(entity)
|
||||
elif text and text in seen_texts:
|
||||
logger.info(f"Duplicate entity found: {entity}")
|
||||
continue
|
||||
|
||||
logger.info(f"Merged {len(unique_entities)} unique entities")
|
||||
return unique_entities
|
||||
|
||||
def _generate_masked_mapping(self, unique_entities: List[Dict[str, str]], linkage: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Generate masked mappings for entities"""
|
||||
entity_mapping = {}
|
||||
used_masked_names = set()
|
||||
group_mask_map = {}
|
||||
|
||||
# Process entity groups from linkage
|
||||
for group in linkage.get('entity_groups', []):
|
||||
group_type = group.get('group_type', '')
|
||||
entities = group.get('entities', [])
|
||||
|
||||
# Handle company groups
|
||||
if any(keyword in group_type for keyword in ['公司', 'Company']):
|
||||
for entity in entities:
|
||||
masker = self._get_masker_for_type('公司名称')
|
||||
if masker:
|
||||
masked = masker.mask(entity['text'])
|
||||
group_mask_map[entity['text']] = masked
|
||||
|
||||
# Handle name groups
|
||||
elif '人名' in group_type:
|
||||
for entity in entities:
|
||||
masker = self._get_masker_for_type('人名')
|
||||
if masker:
|
||||
context = {'surname_counter': self.surname_counter}
|
||||
masked = masker.mask(entity['text'], context)
|
||||
group_mask_map[entity['text']] = masked
|
||||
|
||||
# Handle English name groups
|
||||
elif '英文人名' in group_type:
|
||||
for entity in entities:
|
||||
masker = self._get_masker_for_type('英文人名')
|
||||
if masker:
|
||||
masked = masker.mask(entity['text'])
|
||||
group_mask_map[entity['text']] = masked
|
||||
|
||||
# Process individual entities
|
||||
for entity in unique_entities:
|
||||
text = entity['text']
|
||||
entity_type = entity.get('type', '')
|
||||
|
||||
# Check if entity is in group mapping
|
||||
if text in group_mask_map:
|
||||
entity_mapping[text] = group_mask_map[text]
|
||||
used_masked_names.add(group_mask_map[text])
|
||||
continue
|
||||
|
||||
# Get appropriate masker for entity type
|
||||
masker = self._get_masker_for_type(entity_type)
|
||||
if masker:
|
||||
# Prepare context for maskers that need it
|
||||
context = {}
|
||||
if entity_type in ['人名', '律师姓名', '审判人员姓名']:
|
||||
context['surname_counter'] = self.surname_counter
|
||||
|
||||
masked = masker.mask(text, context)
|
||||
entity_mapping[text] = masked
|
||||
used_masked_names.add(masked)
|
||||
else:
|
||||
# Fallback for unknown entity types
|
||||
base_name = '某'
|
||||
masked = base_name
|
||||
counter = 1
|
||||
while masked in used_masked_names:
|
||||
if counter <= 10:
|
||||
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||
masked = base_name + suffixes[counter - 1]
|
||||
else:
|
||||
masked = f"{base_name}{counter}"
|
||||
counter += 1
|
||||
entity_mapping[text] = masked
|
||||
used_masked_names.add(masked)
|
||||
|
||||
return entity_mapping
|
||||
|
||||
def _validate_linkage_format(self, linkage: Dict[str, Any]) -> bool:
|
||||
"""Validate entity linkage format"""
|
||||
return LLMResponseValidator.validate_entity_linkage(linkage)
|
||||
|
||||
def _create_entity_linkage(self, unique_entities: List[Dict[str, str]]) -> Dict[str, Any]:
|
||||
"""Create entity linkage information"""
|
||||
linkable_entities = []
|
||||
for entity in unique_entities:
|
||||
entity_type = entity.get('type', '')
|
||||
if any(keyword in entity_type for keyword in ['公司', 'Company', '人名', '英文人名']):
|
||||
linkable_entities.append(entity)
|
||||
|
||||
if not linkable_entities:
|
||||
logger.info("No linkable entities found")
|
||||
return {"entity_groups": []}
|
||||
|
||||
entities_text = "\n".join([
|
||||
f"- {entity['text']} (类型: {entity['type']})"
|
||||
for entity in linkable_entities
|
||||
])
|
||||
|
||||
try:
|
||||
formatted_prompt = get_entity_linkage_prompt(entities_text)
|
||||
logger.info(f"Calling ollama to generate entity linkage")
|
||||
|
||||
# Use the new enhanced generate method with validation
|
||||
linkage = self.ollama_client.generate_with_validation(
|
||||
prompt=formatted_prompt,
|
||||
response_type='entity_linkage',
|
||||
return_parsed=True
|
||||
)
|
||||
|
||||
logger.info(f"Parsed entity linkage: {linkage}")
|
||||
|
||||
if linkage and self._validate_linkage_format(linkage):
|
||||
logger.info(f"Successfully created entity linkage with {len(linkage.get('entity_groups', []))} groups")
|
||||
return linkage
|
||||
else:
|
||||
logger.warning(f"Invalid entity linkage format received")
|
||||
return {"entity_groups": []}
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating entity linkage: {e}")
|
||||
return {"entity_groups": []}
|
||||
|
||||
def process(self, chunks: List[str]) -> Dict[str, str]:
|
||||
"""Main processing method"""
|
||||
chunk_mappings = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
|
||||
chunk_mapping = self.build_mapping(chunk)
|
||||
logger.info(f"Chunk mapping: {chunk_mapping}")
|
||||
chunk_mappings.extend(chunk_mapping)
|
||||
|
||||
logger.info(f"Final chunk mappings: {chunk_mappings}")
|
||||
|
||||
unique_entities = self._merge_entity_mappings(chunk_mappings)
|
||||
logger.info(f"Unique entities: {unique_entities}")
|
||||
|
||||
entity_linkage = self._create_entity_linkage(unique_entities)
|
||||
logger.info(f"Entity linkage: {entity_linkage}")
|
||||
|
||||
combined_mapping = self._generate_masked_mapping(unique_entities, entity_linkage)
|
||||
logger.info(f"Combined mapping: {combined_mapping}")
|
||||
|
||||
return combined_mapping
|
||||
|
|
@ -1,10 +1,13 @@
|
|||
import os
|
||||
import requests
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
import docx
|
||||
from ...document_handlers.document_processor import DocumentProcessor
|
||||
from magic_pdf.data.data_reader_writer import FileBasedDataWriter
|
||||
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
||||
from magic_pdf.data.read_api import read_local_office
|
||||
import logging
|
||||
from ...services.ollama_client import OllamaClient
|
||||
from ...config import settings
|
||||
from ...prompts.masking_prompts import get_masking_mapping_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -16,195 +19,47 @@ class DocxDocumentProcessor(DocumentProcessor):
|
|||
self.output_dir = os.path.dirname(output_path)
|
||||
self.name_without_suff = os.path.splitext(os.path.basename(input_path))[0]
|
||||
|
||||
# Setup work directory for temporary files
|
||||
self.work_dir = os.path.join(
|
||||
os.path.dirname(output_path),
|
||||
".work",
|
||||
os.path.splitext(os.path.basename(input_path))[0]
|
||||
)
|
||||
os.makedirs(self.work_dir, exist_ok=True)
|
||||
# Setup output directories
|
||||
self.local_image_dir = os.path.join(self.output_dir, "images")
|
||||
self.image_dir = os.path.basename(self.local_image_dir)
|
||||
os.makedirs(self.local_image_dir, exist_ok=True)
|
||||
|
||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||
|
||||
# MagicDoc API configuration (replacing Mineru)
|
||||
self.magicdoc_base_url = getattr(settings, 'MAGICDOC_API_URL', 'http://magicdoc-api:8000')
|
||||
self.magicdoc_timeout = getattr(settings, 'MAGICDOC_TIMEOUT', 300) # 5 minutes timeout
|
||||
# MagicDoc uses simpler parameters, but we keep compatibility with existing interface
|
||||
self.magicdoc_lang_list = getattr(settings, 'MAGICDOC_LANG_LIST', 'ch')
|
||||
self.magicdoc_backend = getattr(settings, 'MAGICDOC_BACKEND', 'pipeline')
|
||||
self.magicdoc_parse_method = getattr(settings, 'MAGICDOC_PARSE_METHOD', 'auto')
|
||||
self.magicdoc_formula_enable = getattr(settings, 'MAGICDOC_FORMULA_ENABLE', True)
|
||||
self.magicdoc_table_enable = getattr(settings, 'MAGICDOC_TABLE_ENABLE', True)
|
||||
|
||||
def _call_magicdoc_api(self, file_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call MagicDoc API to convert DOCX to markdown
|
||||
|
||||
Args:
|
||||
file_path: Path to the DOCX file
|
||||
|
||||
Returns:
|
||||
API response as dictionary or None if failed
|
||||
"""
|
||||
try:
|
||||
url = f"{self.magicdoc_base_url}/file_parse"
|
||||
|
||||
with open(file_path, 'rb') as file:
|
||||
files = {'files': (os.path.basename(file_path), file, 'application/vnd.openxmlformats-officedocument.wordprocessingml.document')}
|
||||
|
||||
# Prepare form data according to MagicDoc API specification (compatible with Mineru)
|
||||
data = {
|
||||
'output_dir': './output',
|
||||
'lang_list': self.magicdoc_lang_list,
|
||||
'backend': self.magicdoc_backend,
|
||||
'parse_method': self.magicdoc_parse_method,
|
||||
'formula_enable': self.magicdoc_formula_enable,
|
||||
'table_enable': self.magicdoc_table_enable,
|
||||
'return_md': True,
|
||||
'return_middle_json': False,
|
||||
'return_model_output': False,
|
||||
'return_content_list': False,
|
||||
'return_images': False,
|
||||
'start_page_id': 0,
|
||||
'end_page_id': 99999
|
||||
}
|
||||
|
||||
logger.info(f"Calling MagicDoc API for DOCX processing at {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=self.magicdoc_timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
logger.info("Successfully received response from MagicDoc API for DOCX")
|
||||
return result
|
||||
else:
|
||||
error_msg = f"MagicDoc API returned status code {response.status_code}: {response.text}"
|
||||
logger.error(error_msg)
|
||||
# For 400 errors, include more specific information
|
||||
if response.status_code == 400:
|
||||
try:
|
||||
error_data = response.json()
|
||||
if 'error' in error_data:
|
||||
error_msg = f"MagicDoc API error: {error_data['error']}"
|
||||
except:
|
||||
pass
|
||||
raise Exception(error_msg)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
error_msg = f"MagicDoc API request timed out after {self.magicdoc_timeout} seconds"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_msg = f"Error calling MagicDoc API for DOCX: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling MagicDoc API for DOCX: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
def _extract_markdown_from_response(self, response: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract markdown content from MagicDoc API response
|
||||
|
||||
Args:
|
||||
response: MagicDoc API response dictionary
|
||||
|
||||
Returns:
|
||||
Extracted markdown content as string
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"MagicDoc API response structure for DOCX: {response}")
|
||||
|
||||
# Try different possible response formats based on MagicDoc API
|
||||
if 'markdown' in response:
|
||||
return response['markdown']
|
||||
elif 'md' in response:
|
||||
return response['md']
|
||||
elif 'content' in response:
|
||||
return response['content']
|
||||
elif 'text' in response:
|
||||
return response['text']
|
||||
elif 'result' in response and isinstance(response['result'], dict):
|
||||
result = response['result']
|
||||
if 'markdown' in result:
|
||||
return result['markdown']
|
||||
elif 'md' in result:
|
||||
return result['md']
|
||||
elif 'content' in result:
|
||||
return result['content']
|
||||
elif 'text' in result:
|
||||
return result['text']
|
||||
elif 'data' in response and isinstance(response['data'], dict):
|
||||
data = response['data']
|
||||
if 'markdown' in data:
|
||||
return data['markdown']
|
||||
elif 'md' in data:
|
||||
return data['md']
|
||||
elif 'content' in data:
|
||||
return data['content']
|
||||
elif 'text' in data:
|
||||
return data['text']
|
||||
elif isinstance(response, list) and len(response) > 0:
|
||||
# If response is a list, try to extract from first item
|
||||
first_item = response[0]
|
||||
if isinstance(first_item, dict):
|
||||
return self._extract_markdown_from_response(first_item)
|
||||
elif isinstance(first_item, str):
|
||||
return first_item
|
||||
else:
|
||||
# If no standard format found, try to extract from the response structure
|
||||
logger.warning("Could not find standard markdown field in MagicDoc response for DOCX")
|
||||
|
||||
# Return the response as string if it's simple, or empty string
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
elif isinstance(response, dict):
|
||||
# Try to find any text-like content
|
||||
for key, value in response.items():
|
||||
if isinstance(value, str) and len(value) > 100: # Likely content
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
# Recursively search in nested dictionaries
|
||||
nested_content = self._extract_markdown_from_response(value)
|
||||
if nested_content:
|
||||
return nested_content
|
||||
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting markdown from MagicDoc response for DOCX: {str(e)}")
|
||||
return ""
|
||||
|
||||
def read_content(self) -> str:
|
||||
logger.info("Starting DOCX content processing with MagicDoc API")
|
||||
|
||||
# Call MagicDoc API to convert DOCX to markdown
|
||||
# This will raise an exception if the API call fails
|
||||
magicdoc_response = self._call_magicdoc_api(self.input_path)
|
||||
|
||||
# Extract markdown content from the response
|
||||
markdown_content = self._extract_markdown_from_response(magicdoc_response)
|
||||
try:
|
||||
# Initialize writers
|
||||
image_writer = FileBasedDataWriter(self.local_image_dir)
|
||||
md_writer = FileBasedDataWriter(self.output_dir)
|
||||
|
||||
# Create Dataset Instance and process
|
||||
ds = read_local_office(self.input_path)[0]
|
||||
pipe_result = ds.apply(doc_analyze, ocr=True).pipe_txt_mode(image_writer)
|
||||
|
||||
# Generate markdown
|
||||
md_content = pipe_result.get_markdown(self.image_dir)
|
||||
pipe_result.dump_md(md_writer, f"{self.name_without_suff}.md", self.image_dir)
|
||||
|
||||
return md_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting DOCX to MD: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"MagicDoc API response: {markdown_content}")
|
||||
# def process_content(self, content: str) -> str:
|
||||
# logger.info("Processing DOCX content")
|
||||
|
||||
if not markdown_content:
|
||||
raise Exception("No markdown content found in MagicDoc API response for DOCX")
|
||||
# # Split content into sentences and apply masking
|
||||
# sentences = content.split("。")
|
||||
# final_md = ""
|
||||
# for sentence in sentences:
|
||||
# if sentence.strip(): # Only process non-empty sentences
|
||||
# formatted_prompt = get_masking_mapping_prompt(sentence)
|
||||
# logger.info("Calling ollama to generate response, prompt: %s", formatted_prompt)
|
||||
# response = self.ollama_client.generate(formatted_prompt)
|
||||
# logger.info(f"Response generated: {response}")
|
||||
# final_md += response + "。"
|
||||
|
||||
logger.info(f"Successfully extracted {len(markdown_content)} characters of markdown content from DOCX")
|
||||
|
||||
# Save the raw markdown content to work directory for reference
|
||||
md_output_path = os.path.join(self.work_dir, f"{self.name_without_suff}.md")
|
||||
with open(md_output_path, 'w', encoding='utf-8') as file:
|
||||
file.write(markdown_content)
|
||||
|
||||
logger.info(f"Saved raw markdown content from DOCX to {md_output_path}")
|
||||
|
||||
return markdown_content
|
||||
# return final_md
|
||||
|
||||
def save_content(self, content: str) -> None:
|
||||
# Ensure output path has .md extension
|
||||
|
|
@ -212,11 +67,11 @@ class DocxDocumentProcessor(DocumentProcessor):
|
|||
base_name = os.path.splitext(os.path.basename(self.output_path))[0]
|
||||
md_output_path = os.path.join(output_dir, f"{base_name}.md")
|
||||
|
||||
logger.info(f"Saving masked DOCX content to: {md_output_path}")
|
||||
logger.info(f"Saving masked content to: {md_output_path}")
|
||||
try:
|
||||
with open(md_output_path, 'w', encoding='utf-8') as file:
|
||||
file.write(content)
|
||||
logger.info(f"Successfully saved masked DOCX content to {md_output_path}")
|
||||
logger.info(f"Successfully saved content to {md_output_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving masked DOCX content: {e}")
|
||||
logger.error(f"Error saving content: {e}")
|
||||
raise
|
||||
|
|
@ -1,8 +1,12 @@
|
|||
import os
|
||||
import requests
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
import PyPDF2
|
||||
from ...document_handlers.document_processor import DocumentProcessor
|
||||
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
|
||||
from magic_pdf.data.dataset import PymuDocDataset
|
||||
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
||||
from magic_pdf.config.enums import SupportedPdfParseMethod
|
||||
from ...prompts.masking_prompts import get_masking_prompt, get_masking_mapping_prompt
|
||||
import logging
|
||||
from ...services.ollama_client import OllamaClient
|
||||
from ...config import settings
|
||||
|
||||
|
|
@ -16,192 +20,79 @@ class PdfDocumentProcessor(DocumentProcessor):
|
|||
self.output_dir = os.path.dirname(output_path)
|
||||
self.name_without_suff = os.path.splitext(os.path.basename(input_path))[0]
|
||||
|
||||
# Setup work directory for temporary files
|
||||
# Setup output directories
|
||||
self.local_image_dir = os.path.join(self.output_dir, "images")
|
||||
self.image_dir = os.path.basename(self.local_image_dir)
|
||||
os.makedirs(self.local_image_dir, exist_ok=True)
|
||||
|
||||
# Setup work directory under output directory
|
||||
self.work_dir = os.path.join(
|
||||
os.path.dirname(output_path),
|
||||
".work",
|
||||
os.path.splitext(os.path.basename(input_path))[0]
|
||||
)
|
||||
os.makedirs(self.work_dir, exist_ok=True)
|
||||
|
||||
|
||||
self.work_local_image_dir = os.path.join(self.work_dir, "images")
|
||||
self.work_image_dir = os.path.basename(self.work_local_image_dir)
|
||||
os.makedirs(self.work_local_image_dir, exist_ok=True)
|
||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||
|
||||
# Mineru API configuration
|
||||
self.mineru_base_url = getattr(settings, 'MINERU_API_URL', 'http://mineru-api:8000')
|
||||
self.mineru_timeout = getattr(settings, 'MINERU_TIMEOUT', 300) # 5 minutes timeout
|
||||
self.mineru_lang_list = getattr(settings, 'MINERU_LANG_LIST', ['ch'])
|
||||
self.mineru_backend = getattr(settings, 'MINERU_BACKEND', 'pipeline')
|
||||
self.mineru_parse_method = getattr(settings, 'MINERU_PARSE_METHOD', 'auto')
|
||||
self.mineru_formula_enable = getattr(settings, 'MINERU_FORMULA_ENABLE', True)
|
||||
self.mineru_table_enable = getattr(settings, 'MINERU_TABLE_ENABLE', True)
|
||||
|
||||
def _call_mineru_api(self, file_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call Mineru API to convert PDF to markdown
|
||||
|
||||
Args:
|
||||
file_path: Path to the PDF file
|
||||
|
||||
Returns:
|
||||
API response as dictionary or None if failed
|
||||
"""
|
||||
try:
|
||||
url = f"{self.mineru_base_url}/file_parse"
|
||||
|
||||
with open(file_path, 'rb') as file:
|
||||
files = {'files': (os.path.basename(file_path), file, 'application/pdf')}
|
||||
|
||||
# Prepare form data according to Mineru API specification
|
||||
data = {
|
||||
'output_dir': './output',
|
||||
'lang_list': self.mineru_lang_list,
|
||||
'backend': self.mineru_backend,
|
||||
'parse_method': self.mineru_parse_method,
|
||||
'formula_enable': self.mineru_formula_enable,
|
||||
'table_enable': self.mineru_table_enable,
|
||||
'return_md': True,
|
||||
'return_middle_json': False,
|
||||
'return_model_output': False,
|
||||
'return_content_list': False,
|
||||
'return_images': False,
|
||||
'start_page_id': 0,
|
||||
'end_page_id': 99999
|
||||
}
|
||||
|
||||
logger.info(f"Calling Mineru API at {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=self.mineru_timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
logger.info("Successfully received response from Mineru API")
|
||||
return result
|
||||
else:
|
||||
error_msg = f"Mineru API returned status code {response.status_code}: {response.text}"
|
||||
logger.error(error_msg)
|
||||
# For 400 errors, include more specific information
|
||||
if response.status_code == 400:
|
||||
try:
|
||||
error_data = response.json()
|
||||
if 'error' in error_data:
|
||||
error_msg = f"Mineru API error: {error_data['error']}"
|
||||
except:
|
||||
pass
|
||||
raise Exception(error_msg)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
error_msg = f"Mineru API request timed out after {self.mineru_timeout} seconds"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_msg = f"Error calling Mineru API: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Mineru API: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
def _extract_markdown_from_response(self, response: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract markdown content from Mineru API response
|
||||
|
||||
Args:
|
||||
response: Mineru API response dictionary
|
||||
|
||||
Returns:
|
||||
Extracted markdown content as string
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Mineru API response structure: {response}")
|
||||
|
||||
# Try different possible response formats based on Mineru API
|
||||
if 'markdown' in response:
|
||||
return response['markdown']
|
||||
elif 'md' in response:
|
||||
return response['md']
|
||||
elif 'content' in response:
|
||||
return response['content']
|
||||
elif 'text' in response:
|
||||
return response['text']
|
||||
elif 'result' in response and isinstance(response['result'], dict):
|
||||
result = response['result']
|
||||
if 'markdown' in result:
|
||||
return result['markdown']
|
||||
elif 'md' in result:
|
||||
return result['md']
|
||||
elif 'content' in result:
|
||||
return result['content']
|
||||
elif 'text' in result:
|
||||
return result['text']
|
||||
elif 'data' in response and isinstance(response['data'], dict):
|
||||
data = response['data']
|
||||
if 'markdown' in data:
|
||||
return data['markdown']
|
||||
elif 'md' in data:
|
||||
return data['md']
|
||||
elif 'content' in data:
|
||||
return data['content']
|
||||
elif 'text' in data:
|
||||
return data['text']
|
||||
elif isinstance(response, list) and len(response) > 0:
|
||||
# If response is a list, try to extract from first item
|
||||
first_item = response[0]
|
||||
if isinstance(first_item, dict):
|
||||
return self._extract_markdown_from_response(first_item)
|
||||
elif isinstance(first_item, str):
|
||||
return first_item
|
||||
else:
|
||||
# If no standard format found, try to extract from the response structure
|
||||
logger.warning("Could not find standard markdown field in Mineru response")
|
||||
|
||||
# Return the response as string if it's simple, or empty string
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
elif isinstance(response, dict):
|
||||
# Try to find any text-like content
|
||||
for key, value in response.items():
|
||||
if isinstance(value, str) and len(value) > 100: # Likely content
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
# Recursively search in nested dictionaries
|
||||
nested_content = self._extract_markdown_from_response(value)
|
||||
if nested_content:
|
||||
return nested_content
|
||||
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting markdown from Mineru response: {str(e)}")
|
||||
return ""
|
||||
|
||||
def read_content(self) -> str:
|
||||
logger.info("Starting PDF content processing with Mineru API")
|
||||
logger.info("Starting PDF content processing")
|
||||
|
||||
# Call Mineru API to convert PDF to markdown
|
||||
# This will raise an exception if the API call fails
|
||||
mineru_response = self._call_mineru_api(self.input_path)
|
||||
# Read the PDF file
|
||||
with open(self.input_path, 'rb') as file:
|
||||
content = file.read()
|
||||
|
||||
# Initialize writers
|
||||
image_writer = FileBasedDataWriter(self.work_local_image_dir)
|
||||
md_writer = FileBasedDataWriter(self.work_dir)
|
||||
|
||||
# Create Dataset Instance
|
||||
ds = PymuDocDataset(content)
|
||||
|
||||
# Extract markdown content from the response
|
||||
markdown_content = self._extract_markdown_from_response(mineru_response)
|
||||
logger.info("Classifying PDF type: %s", ds.classify())
|
||||
# Process based on PDF type
|
||||
if ds.classify() == SupportedPdfParseMethod.OCR:
|
||||
infer_result = ds.apply(doc_analyze, ocr=True)
|
||||
pipe_result = infer_result.pipe_ocr_mode(image_writer)
|
||||
else:
|
||||
infer_result = ds.apply(doc_analyze, ocr=False)
|
||||
pipe_result = infer_result.pipe_txt_mode(image_writer)
|
||||
|
||||
if not markdown_content:
|
||||
raise Exception("No markdown content found in Mineru API response")
|
||||
logger.info("Generating all outputs")
|
||||
# Generate all outputs
|
||||
infer_result.draw_model(os.path.join(self.work_dir, f"{self.name_without_suff}_model.pdf"))
|
||||
model_inference_result = infer_result.get_infer_res()
|
||||
|
||||
logger.info(f"Successfully extracted {len(markdown_content)} characters of markdown content")
|
||||
pipe_result.draw_layout(os.path.join(self.work_dir, f"{self.name_without_suff}_layout.pdf"))
|
||||
pipe_result.draw_span(os.path.join(self.work_dir, f"{self.name_without_suff}_spans.pdf"))
|
||||
|
||||
# Save the raw markdown content to work directory for reference
|
||||
md_output_path = os.path.join(self.work_dir, f"{self.name_without_suff}.md")
|
||||
with open(md_output_path, 'w', encoding='utf-8') as file:
|
||||
file.write(markdown_content)
|
||||
md_content = pipe_result.get_markdown(self.work_image_dir)
|
||||
pipe_result.dump_md(md_writer, f"{self.name_without_suff}.md", self.work_image_dir)
|
||||
|
||||
logger.info(f"Saved raw markdown content to {md_output_path}")
|
||||
content_list = pipe_result.get_content_list(self.work_image_dir)
|
||||
pipe_result.dump_content_list(md_writer, f"{self.name_without_suff}_content_list.json", self.work_image_dir)
|
||||
|
||||
return markdown_content
|
||||
middle_json = pipe_result.get_middle_json()
|
||||
pipe_result.dump_middle_json(md_writer, f'{self.name_without_suff}_middle.json')
|
||||
|
||||
return md_content
|
||||
|
||||
# def process_content(self, content: str) -> str:
|
||||
# logger.info("Starting content masking process")
|
||||
# sentences = content.split("。")
|
||||
# final_md = ""
|
||||
# for sentence in sentences:
|
||||
# if not sentence.strip(): # Skip empty sentences
|
||||
# continue
|
||||
# formatted_prompt = get_masking_mapping_prompt(sentence)
|
||||
# logger.info("Calling ollama to generate response, prompt: %s", formatted_prompt)
|
||||
# response = self.ollama_client.generate(formatted_prompt)
|
||||
# logger.info(f"Response generated: {response}")
|
||||
# final_md += response + "。"
|
||||
# return final_md
|
||||
|
||||
def save_content(self, content: str) -> None:
|
||||
# Ensure output path has .md extension
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from ...document_handlers.document_processor import DocumentProcessor
|
||||
from ...services.ollama_client import OllamaClient
|
||||
import logging
|
||||
# from ...prompts.masking_prompts import get_masking_prompt
|
||||
from ...prompts.masking_prompts import get_masking_prompt
|
||||
from ...config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -1,27 +0,0 @@
|
|||
import re
|
||||
|
||||
def extract_id_number_entities(chunk: str) -> dict:
|
||||
"""Extract Chinese ID numbers and return in entity mapping format."""
|
||||
id_pattern = r'\b\d{17}[\dXx]\b'
|
||||
entities = []
|
||||
for match in re.findall(id_pattern, chunk):
|
||||
entities.append({"text": match, "type": "身份证号"})
|
||||
return {"entities": entities} if entities else {}
|
||||
|
||||
|
||||
def extract_social_credit_code_entities(chunk: str) -> dict:
|
||||
"""Extract social credit codes and return in entity mapping format."""
|
||||
credit_pattern = r'\b[0-9A-Z]{18}\b'
|
||||
entities = []
|
||||
for match in re.findall(credit_pattern, chunk):
|
||||
entities.append({"text": match, "type": "统一社会信用代码"})
|
||||
return {"entities": entities} if entities else {}
|
||||
|
||||
def extract_case_number_entities(chunk: str) -> dict:
|
||||
"""Extract case numbers and return in entity mapping format."""
|
||||
# Pattern for Chinese case numbers: (2022)京 03 民终 3852 号, (2020)京0105 民初69754 号
|
||||
case_pattern = r'[((]\d{4}[))][^\d]*\d+[^\d]*\d+[^\d]*号'
|
||||
entities = []
|
||||
for match in re.findall(case_pattern, chunk):
|
||||
entities.append({"text": match, "type": "案号"})
|
||||
return {"entities": entities} if entities else {}
|
||||
|
|
@ -1,7 +1,38 @@
|
|||
import textwrap
|
||||
|
||||
def get_masking_prompt(text: str) -> str:
|
||||
"""
|
||||
Returns the prompt for masking sensitive information in legal documents.
|
||||
|
||||
Args:
|
||||
text (str): The input text to be masked
|
||||
|
||||
Returns:
|
||||
str: The formatted prompt with the input text
|
||||
"""
|
||||
prompt = textwrap.dedent("""
|
||||
您是一位专业的法律文档脱敏专家。请按照以下规则对文本进行脱敏处理:
|
||||
|
||||
def get_ner_name_prompt(text: str) -> str:
|
||||
规则:
|
||||
1. 人名:
|
||||
- 两字名改为"姓+某"(如:张三 → 张某)
|
||||
- 三字名改为"姓+某某"(如:张三丰 → 张某某)
|
||||
2. 公司名:
|
||||
- 保留地理位置信息(如:北京、上海等)
|
||||
- 保留公司类型(如:有限公司、股份公司等)
|
||||
- 用"某"替换核心名称
|
||||
3. 保持原文其他部分不变
|
||||
4. 确保脱敏后的文本保持原有的语言流畅性和可读性
|
||||
|
||||
输入文本:
|
||||
{text}
|
||||
|
||||
请直接输出脱敏后的文本,无需解释或其他备注。
|
||||
""")
|
||||
|
||||
return prompt.format(text=text)
|
||||
|
||||
def get_masking_mapping_prompt(text: str) -> str:
|
||||
"""
|
||||
Returns a prompt that generates a mapping of original names/companies to their masked versions.
|
||||
|
||||
|
|
@ -12,254 +43,39 @@ def get_ner_name_prompt(text: str) -> str:
|
|||
str: The formatted prompt that will generate a mapping dictionary
|
||||
"""
|
||||
prompt = textwrap.dedent("""
|
||||
你是一个专业的法律文本实体识别助手。请从以下文本中抽取出所有需要脱敏的敏感信息,并按照指定的类别进行分类。请严格按照JSON格式输出结果。
|
||||
您是一位专业的法律文档脱敏专家。请分析文本并生成一个脱敏映射表,遵循以下规则:
|
||||
|
||||
实体类别包括:
|
||||
- 人名 (不包括律师、法官、书记员、检察官等公职人员)
|
||||
- 英文人名
|
||||
|
||||
|
||||
待处理文本:
|
||||
{text}
|
||||
|
||||
输出格式:
|
||||
{{
|
||||
"entities": [
|
||||
{{"text": "原始文本内容", "type": "人名"}},
|
||||
{{"text": "原始文本内容", "type": "英文人名"}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
|
||||
|
||||
|
||||
请严格按照JSON格式输出结果。
|
||||
|
||||
""")
|
||||
return prompt.format(text=text)
|
||||
|
||||
|
||||
def get_ner_company_prompt(text: str) -> str:
|
||||
"""
|
||||
Returns a prompt that generates a mapping of original companies to their masked versions.
|
||||
|
||||
Args:
|
||||
text (str): The input text to be analyzed for masking
|
||||
规则:
|
||||
1. 人名映射规则:
|
||||
- 对于同一姓氏的不同人名,使用字母区分:
|
||||
* 第一个出现的用"姓+某"(如:张三 → 张某)
|
||||
* 第二个出现的用"姓+某A"(如:张四 → 张某A)
|
||||
* 第三个出现的用"姓+某B"(如:张五 → 张某B)
|
||||
依此类推
|
||||
- 三字名同样遵循此规则(如:张三丰 → 张某某,张四海 → 张某某A)
|
||||
|
||||
2. 公司名映射规则:
|
||||
- 保留地理位置信息(如:北京、上海等)
|
||||
- 保留公司类型(如:有限公司、股份公司等)
|
||||
- 用"某"替换核心名称,但保留首尾字(如:北京智慧科技有限公司 → 北京智某科技有限公司)
|
||||
- 对于多个相似公司名,使用字母区分(如:
|
||||
北京智慧科技有限公司 → 北京某科技有限公司
|
||||
北京智能科技有限公司 → 北京某科技有限公司A)
|
||||
|
||||
Returns:
|
||||
str: The formatted prompt that will generate a mapping dictionary
|
||||
"""
|
||||
prompt = textwrap.dedent("""
|
||||
你是一个专业的法律文本实体识别助手。请从以下文本中抽取出所有需要脱敏的敏感信息,并按照指定的类别进行分类。请严格按照JSON格式输出结果。
|
||||
3. 公权机关不做脱敏处理(如:公安局、法院、检察院、中国人民银行、银监会及其他未列明的公权机关)
|
||||
|
||||
实体类别包括:
|
||||
- 公司名称
|
||||
- 英文公司名称
|
||||
- Company with English name
|
||||
- 公司名称简称
|
||||
- 公司英文名称简称
|
||||
请分析以下文本,并生成一个JSON格式的映射表,包含所有需要脱敏的名称及其对应的脱敏后的形式:
|
||||
|
||||
{text}
|
||||
|
||||
待处理文本:
|
||||
{text}
|
||||
|
||||
输出格式:
|
||||
{{
|
||||
"entities": [
|
||||
{{"text": "原始文本内容", "type": "公司名称"}},
|
||||
{{"text": "原始文本内容", "type": "英文公司名称"}},
|
||||
{{"text": "原始文本内容", "type": "公司名称简称"}},
|
||||
{{"text": "原始文本内容", "type": "公司英文名称简称"}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
|
||||
请严格按照JSON格式输出结果。
|
||||
请直接输出JSON格式的映射表,格式如下:
|
||||
{{
|
||||
"原文1": "脱敏后1",
|
||||
"原文2": "脱敏后2",
|
||||
...
|
||||
}}
|
||||
如无需要输出的映射,请输出空json,如下:
|
||||
{{}}
|
||||
""")
|
||||
return prompt.format(text=text)
|
||||
|
||||
|
||||
def get_ner_address_prompt(text: str) -> str:
|
||||
"""
|
||||
Returns a prompt that generates a mapping of original addresses to their masked versions.
|
||||
|
||||
Args:
|
||||
text (str): The input text to be analyzed for masking
|
||||
|
||||
Returns:
|
||||
str: The formatted prompt that will generate a mapping dictionary
|
||||
"""
|
||||
prompt = textwrap.dedent("""
|
||||
你是一个专业的法律文本实体识别助手。请从以下文本中抽取出所有需要脱敏的敏感信息,并按照指定的类别进行分类。请严格按照JSON格式输出结果。
|
||||
|
||||
实体类别包括:
|
||||
- 地址
|
||||
|
||||
|
||||
待处理文本:
|
||||
{text}
|
||||
|
||||
输出格式:
|
||||
{{
|
||||
"entities": [
|
||||
{{"text": "原始文本内容", "type": "地址"}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
|
||||
请严格按照JSON格式输出结果。
|
||||
""")
|
||||
return prompt.format(text=text)
|
||||
|
||||
|
||||
def get_address_masking_prompt(address: str) -> str:
|
||||
"""
|
||||
Returns a prompt that generates a masked version of an address following specific rules.
|
||||
|
||||
Args:
|
||||
address (str): The original address to be masked
|
||||
|
||||
Returns:
|
||||
str: The formatted prompt that will generate a masked address
|
||||
"""
|
||||
prompt = textwrap.dedent("""
|
||||
你是一个专业的地址脱敏助手。请对给定的地址进行脱敏处理,遵循以下规则:
|
||||
|
||||
脱敏规则:
|
||||
1. 保留区级以上地址(省、市、区、县)
|
||||
2. 路名以大写首字母替代,例如:恒丰路 -> HF路
|
||||
3. 门牌数字以**代替,例如:66号 -> **号
|
||||
4. 大厦名、小区名以大写首字母替代,例如:白云大厦 -> BY大厦
|
||||
5. 房间号以****代替,例如:1607室 -> ****室
|
||||
|
||||
示例:
|
||||
- 输入:上海市静安区恒丰路66号白云大厦1607室
|
||||
- 输出:上海市静安区HF路**号BY大厦****室
|
||||
|
||||
- 输入:北京市海淀区北小马厂6号1号楼华天大厦1306室
|
||||
- 输出:北京市海淀区北小马厂**号**号楼HT大厦****室
|
||||
|
||||
请严格按照JSON格式输出结果:
|
||||
|
||||
{{
|
||||
"masked_address": "脱敏后的地址"
|
||||
}}
|
||||
|
||||
原始地址:{address}
|
||||
|
||||
请严格按照JSON格式输出结果。
|
||||
""")
|
||||
return prompt.format(address=address)
|
||||
|
||||
|
||||
def get_ner_project_prompt(text: str) -> str:
|
||||
"""
|
||||
Returns a prompt that generates a mapping of original project names to their masked versions.
|
||||
"""
|
||||
prompt = textwrap.dedent("""
|
||||
你是一个专业的法律文本实体识别助手。请从以下文本中抽取出所有需要脱敏的敏感信息,并按照指定的类别进行分类。请严格按照JSON格式输出结果。
|
||||
|
||||
实体类别包括:
|
||||
- 项目名(此处项目特指商业、工程、合同等项目)
|
||||
|
||||
待处理文本:
|
||||
{text}
|
||||
|
||||
输出格式:
|
||||
{{
|
||||
"entities": [
|
||||
{{"text": "原始文本内容", "type": "项目名"}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
|
||||
请严格按照JSON格式输出结果。
|
||||
""")
|
||||
return prompt.format(text=text)
|
||||
|
||||
|
||||
def get_ner_case_number_prompt(text: str) -> str:
|
||||
"""
|
||||
Returns a prompt that generates a mapping of original case numbers to their masked versions.
|
||||
"""
|
||||
prompt = textwrap.dedent("""
|
||||
你是一个专业的法律文本实体识别助手。请从以下文本中抽取出所有需要脱敏的敏感信息,并按照指定的类别进行分类。请严格按照JSON格式输出结果。
|
||||
|
||||
实体类别包括:
|
||||
- 案号
|
||||
|
||||
待处理文本:
|
||||
{text}
|
||||
|
||||
输出格式:
|
||||
{{
|
||||
"entities": [
|
||||
{{"text": "原始文本内容", "type": "案号"}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
|
||||
请严格按照JSON格式输出结果。
|
||||
""")
|
||||
return prompt.format(text=text)
|
||||
|
||||
|
||||
def get_entity_linkage_prompt(entities_text: str) -> str:
|
||||
"""
|
||||
Returns a prompt that identifies related entities and groups them together.
|
||||
|
||||
Args:
|
||||
entities_text (str): The list of entities to be analyzed for linkage
|
||||
|
||||
Returns:
|
||||
str: The formatted prompt that will generate entity linkage information
|
||||
"""
|
||||
prompt = textwrap.dedent("""
|
||||
你是一个专业的法律文本实体关联分析助手。请分析以下实体列表,识别出相互关联的实体(如全称与简称、中文名与英文名等),并将它们分组。
|
||||
|
||||
关联规则:
|
||||
1. 公司名称关联:
|
||||
- 全称与简称(如:"阿里巴巴集团控股有限公司" 与 "阿里巴巴")
|
||||
- 中文名与英文名(如:"腾讯科技有限公司" 与 "Tencent Technology Ltd.")
|
||||
- 母公司与子公司(如:"腾讯" 与 "腾讯音乐")
|
||||
|
||||
|
||||
2. 每个组中应指定一个主要实体(is_primary: true),通常是:
|
||||
- 对于公司:选择最正式的全称
|
||||
- 对于人名:选择最常用的称呼
|
||||
|
||||
待分析实体列表:
|
||||
{entities_text}
|
||||
|
||||
输出格式:
|
||||
{{
|
||||
"entity_groups": [
|
||||
{{
|
||||
"group_id": "group_1",
|
||||
"group_type": "公司名称",
|
||||
"entities": [
|
||||
{{
|
||||
"text": "阿里巴巴集团控股有限公司",
|
||||
"type": "公司名称",
|
||||
"is_primary": true
|
||||
}},
|
||||
{{
|
||||
"text": "阿里巴巴",
|
||||
"type": "公司名称简称",
|
||||
"is_primary": false
|
||||
}}
|
||||
]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
注意事项:
|
||||
1. 只对确实有关联的实体进行分组
|
||||
2. 每个实体只能属于一个组
|
||||
3. 每个组必须有且仅有一个主要实体(is_primary: true)
|
||||
4. 如果实体之间没有明显关联,不要强制分组
|
||||
5. group_type 应该是 "公司名称"
|
||||
|
||||
请严格按照JSON格式输出结果。
|
||||
""")
|
||||
return prompt.format(entities_text=entities_text)
|
||||
return prompt.format(text=text)
|
||||
|
|
@ -13,7 +13,7 @@ class DocumentService:
|
|||
processor = DocumentProcessorFactory.create_processor(input_path, output_path)
|
||||
if not processor:
|
||||
logger.error(f"Unsupported file format: {input_path}")
|
||||
raise Exception(f"Unsupported file format: {input_path}")
|
||||
return False
|
||||
|
||||
# Read content
|
||||
content = processor.read_content()
|
||||
|
|
@ -27,5 +27,4 @@ class DocumentService:
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document {input_path}: {str(e)}")
|
||||
# Re-raise the exception so the Celery task can handle it properly
|
||||
raise
|
||||
return False
|
||||
|
|
@ -1,222 +1,72 @@
|
|||
import requests
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Callable, Union
|
||||
from ..utils.json_extractor import LLMJsonExtractor
|
||||
from ..utils.llm_validator import LLMResponseValidator
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3):
|
||||
def __init__(self, model_name: str, base_url: str = "http://localhost:11434"):
|
||||
"""Initialize Ollama client.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the Ollama model to use
|
||||
base_url (str): Ollama server base URL
|
||||
max_retries (int): Maximum number of retries for failed requests
|
||||
host (str): Ollama server host address
|
||||
port (int): Ollama server port
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.max_retries = max_retries
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
|
||||
def generate(self,
|
||||
prompt: str,
|
||||
strip_think: bool = True,
|
||||
validation_schema: Optional[Dict[str, Any]] = None,
|
||||
response_type: Optional[str] = None,
|
||||
return_parsed: bool = False) -> Union[str, Dict[str, Any]]:
|
||||
"""Process a document using the Ollama API with optional validation and retry.
|
||||
def generate(self, prompt: str, strip_think: bool = True) -> str:
|
||||
"""Process a document using the Ollama API.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to send to the model
|
||||
strip_think (bool): Whether to strip thinking tags from response
|
||||
validation_schema (Optional[Dict]): JSON schema for validation
|
||||
response_type (Optional[str]): Type of response for validation ('entity_extraction', 'entity_linkage', etc.)
|
||||
return_parsed (bool): Whether to return parsed JSON instead of raw string
|
||||
document_text (str): The text content to process
|
||||
|
||||
Returns:
|
||||
Union[str, Dict[str, Any]]: Response from the model (raw string or parsed JSON)
|
||||
str: Processed text response from the model
|
||||
|
||||
Raises:
|
||||
RequestException: If the API call fails after all retries
|
||||
ValueError: If validation fails after all retries
|
||||
"""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# Make the API call
|
||||
raw_response = self._make_api_call(prompt, strip_think)
|
||||
|
||||
# If no validation required, return the response
|
||||
if not validation_schema and not response_type and not return_parsed:
|
||||
return raw_response
|
||||
|
||||
# Parse JSON if needed
|
||||
if return_parsed or validation_schema or response_type:
|
||||
parsed_response = LLMJsonExtractor.parse_raw_json_str(raw_response)
|
||||
if not parsed_response:
|
||||
logger.warning(f"Failed to parse JSON on attempt {attempt + 1}/{self.max_retries}")
|
||||
if attempt < self.max_retries - 1:
|
||||
continue
|
||||
else:
|
||||
raise ValueError("Failed to parse JSON response after all retries")
|
||||
|
||||
# Validate if schema or response type provided
|
||||
if validation_schema:
|
||||
if not self._validate_with_schema(parsed_response, validation_schema):
|
||||
logger.warning(f"Schema validation failed on attempt {attempt + 1}/{self.max_retries}")
|
||||
if attempt < self.max_retries - 1:
|
||||
continue
|
||||
else:
|
||||
raise ValueError("Schema validation failed after all retries")
|
||||
|
||||
if response_type:
|
||||
if not LLMResponseValidator.validate_response_by_type(parsed_response, response_type):
|
||||
logger.warning(f"Response type validation failed on attempt {attempt + 1}/{self.max_retries}")
|
||||
if attempt < self.max_retries - 1:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Response type validation failed after all retries")
|
||||
|
||||
# Return parsed response if requested
|
||||
if return_parsed:
|
||||
return parsed_response
|
||||
else:
|
||||
return raw_response
|
||||
|
||||
return raw_response
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"API call failed on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.info("Retrying...")
|
||||
else:
|
||||
logger.error("Max retries reached, raising exception")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.info("Retrying...")
|
||||
else:
|
||||
logger.error("Max retries reached, raising exception")
|
||||
raise
|
||||
|
||||
# This should never be reached, but just in case
|
||||
raise Exception("Unexpected error: max retries exceeded without proper exception handling")
|
||||
|
||||
def generate_with_validation(self,
|
||||
prompt: str,
|
||||
response_type: str,
|
||||
strip_think: bool = True,
|
||||
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
||||
"""Generate response with automatic validation based on response type.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to send to the model
|
||||
response_type (str): Type of response for validation
|
||||
strip_think (bool): Whether to strip thinking tags from response
|
||||
return_parsed (bool): Whether to return parsed JSON instead of raw string
|
||||
|
||||
Returns:
|
||||
Union[str, Dict[str, Any]]: Validated response from the model
|
||||
"""
|
||||
return self.generate(
|
||||
prompt=prompt,
|
||||
strip_think=strip_think,
|
||||
response_type=response_type,
|
||||
return_parsed=return_parsed
|
||||
)
|
||||
|
||||
def generate_with_schema(self,
|
||||
prompt: str,
|
||||
schema: Dict[str, Any],
|
||||
strip_think: bool = True,
|
||||
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
||||
"""Generate response with custom schema validation.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to send to the model
|
||||
schema (Dict): JSON schema for validation
|
||||
strip_think (bool): Whether to strip thinking tags from response
|
||||
return_parsed (bool): Whether to return parsed JSON instead of raw string
|
||||
|
||||
Returns:
|
||||
Union[str, Dict[str, Any]]: Validated response from the model
|
||||
"""
|
||||
return self.generate(
|
||||
prompt=prompt,
|
||||
strip_think=strip_think,
|
||||
validation_schema=schema,
|
||||
return_parsed=return_parsed
|
||||
)
|
||||
|
||||
def _make_api_call(self, prompt: str, strip_think: bool) -> str:
|
||||
"""Make the actual API call to Ollama.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to send
|
||||
strip_think (bool): Whether to strip thinking tags
|
||||
|
||||
Returns:
|
||||
str: Raw response from the API
|
||||
"""
|
||||
url = f"{self.base_url}/api/generate"
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
logger.debug(f"Sending request to Ollama API: {url}")
|
||||
response = requests.post(url, json=payload, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
logger.debug(f"Received response from Ollama API: {result}")
|
||||
|
||||
if strip_think:
|
||||
# Remove the "thinking" part from the response
|
||||
# the response is expected to be <think>...</think>response_text
|
||||
# Check if the response contains <think> tag
|
||||
if "<think>" in result.get("response", ""):
|
||||
# Split the response and take the part after </think>
|
||||
response_parts = result["response"].split("</think>")
|
||||
if len(response_parts) > 1:
|
||||
# Return the part after </think>
|
||||
return response_parts[1].strip()
|
||||
else:
|
||||
# If no closing tag, return the full response
|
||||
return result.get("response", "").strip()
|
||||
else:
|
||||
# If no <think> tag, return the full response
|
||||
return result.get("response", "").strip()
|
||||
else:
|
||||
# If strip_think is False, return the full response
|
||||
return result.get("response", "")
|
||||
|
||||
def _validate_with_schema(self, response: Dict[str, Any], schema: Dict[str, Any]) -> bool:
|
||||
"""Validate response against a JSON schema.
|
||||
|
||||
Args:
|
||||
response (Dict): The parsed response to validate
|
||||
schema (Dict): The JSON schema to validate against
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
RequestException: If the API call fails
|
||||
"""
|
||||
try:
|
||||
from jsonschema import validate, ValidationError
|
||||
validate(instance=response, schema=schema)
|
||||
logger.debug(f"Schema validation passed for response: {response}")
|
||||
return True
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Schema validation failed: {e}")
|
||||
logger.warning(f"Response that failed validation: {response}")
|
||||
return False
|
||||
except ImportError:
|
||||
logger.error("jsonschema library not available for validation")
|
||||
return False
|
||||
url = f"{self.base_url}/api/generate"
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
logger.debug(f"Sending request to Ollama API: {url}")
|
||||
response = requests.post(url, json=payload, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
logger.debug(f"Received response from Ollama API: {result}")
|
||||
if strip_think:
|
||||
# Remove the "thinking" part from the response
|
||||
# the response is expected to be <think>...</think>response_text
|
||||
# Check if the response contains <think> tag
|
||||
if "<think>" in result.get("response", ""):
|
||||
# Split the response and take the part after </think>
|
||||
response_parts = result["response"].split("</think>")
|
||||
if len(response_parts) > 1:
|
||||
# Return the part after </think>
|
||||
return response_parts[1].strip()
|
||||
else:
|
||||
# If no closing tag, return the full response
|
||||
return result.get("response", "").strip()
|
||||
else:
|
||||
# If no <think> tag, return the full response
|
||||
return result.get("response", "").strip()
|
||||
else:
|
||||
# If strip_think is False, return the full response
|
||||
return result.get("response", "")
|
||||
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Error calling Ollama API: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the current model.
|
||||
|
|
|
|||
|
|
@ -1,369 +0,0 @@
|
|||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from jsonschema import validate, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMResponseValidator:
|
||||
"""Validator for LLM JSON responses with different schemas for different entity types"""
|
||||
|
||||
# Schema for basic entity extraction responses
|
||||
ENTITY_EXTRACTION_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"type": {"type": "string"}
|
||||
},
|
||||
"required": ["text", "type"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["entities"]
|
||||
}
|
||||
|
||||
# Schema for entity linkage responses
|
||||
ENTITY_LINKAGE_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity_groups": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"group_id": {"type": "string"},
|
||||
"group_type": {"type": "string"},
|
||||
"entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"type": {"type": "string"},
|
||||
"is_primary": {"type": "boolean"}
|
||||
},
|
||||
"required": ["text", "type", "is_primary"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["group_id", "group_type", "entities"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["entity_groups"]
|
||||
}
|
||||
|
||||
# Schema for regex-based entity extraction (from entity_regex.py)
|
||||
REGEX_ENTITY_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entities": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"type": {"type": "string"}
|
||||
},
|
||||
"required": ["text", "type"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["entities"]
|
||||
}
|
||||
|
||||
# Schema for business name extraction responses
|
||||
BUSINESS_NAME_EXTRACTION_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"business_name": {
|
||||
"type": "string",
|
||||
"description": "The extracted business name (商号) from the company name"
|
||||
},
|
||||
"confidence": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
"description": "Confidence level of the extraction (0-1)"
|
||||
}
|
||||
},
|
||||
"required": ["business_name"]
|
||||
}
|
||||
|
||||
# Schema for address extraction responses
|
||||
ADDRESS_EXTRACTION_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"road_name": {
|
||||
"type": "string",
|
||||
"description": "The road name (路名) to be masked"
|
||||
},
|
||||
"house_number": {
|
||||
"type": "string",
|
||||
"description": "The house number (门牌号) to be masked"
|
||||
},
|
||||
"building_name": {
|
||||
"type": "string",
|
||||
"description": "The building name (大厦名) to be masked"
|
||||
},
|
||||
"community_name": {
|
||||
"type": "string",
|
||||
"description": "The community name (小区名) to be masked"
|
||||
},
|
||||
"confidence": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
"description": "Confidence level of the extraction (0-1)"
|
||||
}
|
||||
},
|
||||
"required": ["road_name", "house_number", "building_name", "community_name"]
|
||||
}
|
||||
|
||||
# Schema for address masking responses
|
||||
ADDRESS_MASKING_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"masked_address": {
|
||||
"type": "string",
|
||||
"description": "The masked address following the specified rules"
|
||||
}
|
||||
},
|
||||
"required": ["masked_address"]
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate_entity_extraction(cls, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate entity extraction response from LLM.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
validate(instance=response, schema=cls.ENTITY_EXTRACTION_SCHEMA)
|
||||
logger.debug(f"Entity extraction validation passed for response: {response}")
|
||||
return True
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Entity extraction validation failed: {e}")
|
||||
logger.warning(f"Response that failed validation: {response}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def validate_entity_linkage(cls, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate entity linkage response from LLM.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
validate(instance=response, schema=cls.ENTITY_LINKAGE_SCHEMA)
|
||||
content_valid = cls._validate_linkage_content(response)
|
||||
if content_valid:
|
||||
logger.debug(f"Entity linkage validation passed for response: {response}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Entity linkage content validation failed for response: {response}")
|
||||
return False
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Entity linkage validation failed: {e}")
|
||||
logger.warning(f"Response that failed validation: {response}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def validate_regex_entity(cls, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate regex-based entity extraction response.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from regex extractors
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
validate(instance=response, schema=cls.REGEX_ENTITY_SCHEMA)
|
||||
logger.debug(f"Regex entity validation passed for response: {response}")
|
||||
return True
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Regex entity validation failed: {e}")
|
||||
logger.warning(f"Response that failed validation: {response}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def validate_business_name_extraction(cls, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate business name extraction response from LLM.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
validate(instance=response, schema=cls.BUSINESS_NAME_EXTRACTION_SCHEMA)
|
||||
logger.debug(f"Business name extraction validation passed for response: {response}")
|
||||
return True
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Business name extraction validation failed: {e}")
|
||||
logger.warning(f"Response that failed validation: {response}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def validate_address_extraction(cls, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate address extraction response from LLM.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
validate(instance=response, schema=cls.ADDRESS_EXTRACTION_SCHEMA)
|
||||
logger.debug(f"Address extraction validation passed for response: {response}")
|
||||
return True
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Address extraction validation failed: {e}")
|
||||
logger.warning(f"Response that failed validation: {response}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def validate_address_masking(cls, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate address masking response from LLM.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
validate(instance=response, schema=cls.ADDRESS_MASKING_SCHEMA)
|
||||
logger.debug(f"Address masking validation passed for response: {response}")
|
||||
return True
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Address masking validation failed: {e}")
|
||||
logger.warning(f"Response that failed validation: {response}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _validate_linkage_content(cls, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Additional content validation for entity linkage responses.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
|
||||
Returns:
|
||||
bool: True if content is valid, False otherwise
|
||||
"""
|
||||
entity_groups = response.get('entity_groups', [])
|
||||
|
||||
for group in entity_groups:
|
||||
# Validate group type
|
||||
group_type = group.get('group_type', '')
|
||||
if group_type not in ['公司名称', '人名']:
|
||||
logger.warning(f"Invalid group_type: {group_type}")
|
||||
return False
|
||||
|
||||
# Validate entities in group
|
||||
entities = group.get('entities', [])
|
||||
if not entities:
|
||||
logger.warning("Empty entity group found")
|
||||
return False
|
||||
|
||||
# Check that exactly one entity is marked as primary
|
||||
primary_count = sum(1 for entity in entities if entity.get('is_primary', False))
|
||||
if primary_count != 1:
|
||||
logger.warning(f"Group must have exactly one primary entity, found {primary_count}")
|
||||
return False
|
||||
|
||||
# Validate entity types within group
|
||||
for entity in entities:
|
||||
entity_type = entity.get('type', '')
|
||||
if group_type == '公司名称' and not any(keyword in entity_type for keyword in ['公司', 'Company']):
|
||||
logger.warning(f"Company group contains non-company entity: {entity_type}")
|
||||
return False
|
||||
elif group_type == '人名' and not any(keyword in entity_type for keyword in ['人名', '英文人名']):
|
||||
logger.warning(f"Person group contains non-person entity: {entity_type}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def validate_response_by_type(cls, response: Dict[str, Any], response_type: str) -> bool:
|
||||
"""
|
||||
Generic validator that routes to appropriate validation method based on response type.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
response_type: Type of response ('entity_extraction', 'entity_linkage', 'regex_entity')
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
validators = {
|
||||
'entity_extraction': cls.validate_entity_extraction,
|
||||
'entity_linkage': cls.validate_entity_linkage,
|
||||
'regex_entity': cls.validate_regex_entity,
|
||||
'business_name_extraction': cls.validate_business_name_extraction,
|
||||
'address_extraction': cls.validate_address_extraction,
|
||||
'address_masking': cls.validate_address_masking
|
||||
}
|
||||
|
||||
validator = validators.get(response_type)
|
||||
if not validator:
|
||||
logger.error(f"Unknown response type: {response_type}")
|
||||
return False
|
||||
|
||||
return validator(response)
|
||||
|
||||
@classmethod
|
||||
def get_validation_errors(cls, response: Dict[str, Any], response_type: str) -> Optional[str]:
|
||||
"""
|
||||
Get detailed validation errors for debugging.
|
||||
|
||||
Args:
|
||||
response: The parsed JSON response from LLM
|
||||
response_type: Type of response
|
||||
|
||||
Returns:
|
||||
Optional[str]: Error message or None if valid
|
||||
"""
|
||||
try:
|
||||
if response_type == 'entity_extraction':
|
||||
validate(instance=response, schema=cls.ENTITY_EXTRACTION_SCHEMA)
|
||||
elif response_type == 'entity_linkage':
|
||||
validate(instance=response, schema=cls.ENTITY_LINKAGE_SCHEMA)
|
||||
if not cls._validate_linkage_content(response):
|
||||
return "Content validation failed for entity linkage"
|
||||
elif response_type == 'regex_entity':
|
||||
validate(instance=response, schema=cls.REGEX_ENTITY_SCHEMA)
|
||||
elif response_type == 'business_name_extraction':
|
||||
validate(instance=response, schema=cls.BUSINESS_NAME_EXTRACTION_SCHEMA)
|
||||
elif response_type == 'address_extraction':
|
||||
validate(instance=response, schema=cls.ADDRESS_EXTRACTION_SCHEMA)
|
||||
elif response_type == 'address_masking':
|
||||
validate(instance=response, schema=cls.ADDRESS_MASKING_SCHEMA)
|
||||
else:
|
||||
return f"Unknown response type: {response_type}"
|
||||
|
||||
return None
|
||||
except ValidationError as e:
|
||||
return f"Schema validation error: {e}"
|
||||
|
|
@ -70,7 +70,6 @@ def process_file(file_id: str):
|
|||
output_path = str(settings.PROCESSED_FOLDER / output_filename)
|
||||
|
||||
# Process document with both input and output paths
|
||||
# This will raise an exception if processing fails
|
||||
process_service.process_document(file.original_path, output_path)
|
||||
|
||||
# Update file record with processed path
|
||||
|
|
@ -82,7 +81,6 @@ def process_file(file_id: str):
|
|||
file.status = FileStatus.FAILED
|
||||
file.error_message = str(e)
|
||||
db.commit()
|
||||
# Re-raise the exception to ensure Celery marks the task as failed
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -1,33 +0,0 @@
|
|||
import pytest
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add the backend directory to Python path for imports
|
||||
backend_dir = Path(__file__).parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
# Also add the current directory to ensure imports work
|
||||
current_dir = Path(__file__).parent
|
||||
sys.path.insert(0, str(current_dir))
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data():
|
||||
"""Sample data fixture for testing"""
|
||||
return {
|
||||
"name": "test",
|
||||
"value": 42,
|
||||
"items": [1, 2, 3]
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def test_files_dir():
|
||||
"""Fixture to get the test files directory"""
|
||||
return Path(__file__).parent / "tests"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_environment():
|
||||
"""Setup test environment before each test"""
|
||||
# Add any test environment setup here
|
||||
yield
|
||||
# Add any cleanup here
|
||||
|
|
@ -7,6 +7,7 @@ services:
|
|||
- "8000:8000"
|
||||
volumes:
|
||||
- ./storage:/app/storage
|
||||
- ./legal_doc_masker.db:/app/legal_doc_masker.db
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
|
|
@ -20,6 +21,7 @@ services:
|
|||
command: celery -A app.services.file_service worker --loglevel=info
|
||||
volumes:
|
||||
- ./storage:/app/storage
|
||||
- ./legal_doc_masker.db:/app/legal_doc_masker.db
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
|
|
|
|||
|
|
@ -1,239 +0,0 @@
|
|||
# 地址脱敏改进文档
|
||||
|
||||
## 问题描述
|
||||
|
||||
原始的地址脱敏方法使用正则表达式和拼音转换来手动处理地址组件,存在以下问题:
|
||||
- 需要手动维护复杂的正则表达式模式
|
||||
- 拼音转换可能失败,需要回退处理
|
||||
- 难以处理复杂的地址格式
|
||||
- 代码维护成本高
|
||||
|
||||
## 解决方案
|
||||
|
||||
### 1. LLM 直接生成脱敏地址
|
||||
|
||||
使用 LLM 直接生成脱敏后的地址,遵循指定的脱敏规则:
|
||||
|
||||
- **保留区级以上地址**:省、市、区、县
|
||||
- **路名缩写**:以大写首字母替代,如:恒丰路 -> HF路
|
||||
- **门牌号脱敏**:数字以**代替,如:66号 -> **号
|
||||
- **大厦名缩写**:以大写首字母替代,如:白云大厦 -> BY大厦
|
||||
- **房间号脱敏**:以****代替,如:1607室 -> ****室
|
||||
|
||||
### 2. 实现架构
|
||||
|
||||
#### 核心组件
|
||||
|
||||
1. **`get_address_masking_prompt()`** - 生成地址脱敏 prompt
|
||||
2. **`_mask_address()`** - 主要的脱敏方法,使用 LLM
|
||||
3. **`_mask_address_fallback()`** - 回退方法,使用原有逻辑
|
||||
|
||||
#### 调用流程
|
||||
|
||||
```
|
||||
输入地址
|
||||
↓
|
||||
生成脱敏 prompt
|
||||
↓
|
||||
调用 Ollama LLM
|
||||
↓
|
||||
解析 JSON 响应
|
||||
↓
|
||||
返回脱敏地址
|
||||
↓
|
||||
失败时使用回退方法
|
||||
```
|
||||
|
||||
### 3. Prompt 设计
|
||||
|
||||
#### 脱敏规则说明
|
||||
```
|
||||
脱敏规则:
|
||||
1. 保留区级以上地址(省、市、区、县)
|
||||
2. 路名以大写首字母替代,例如:恒丰路 -> HF路
|
||||
3. 门牌数字以**代替,例如:66号 -> **号
|
||||
4. 大厦名、小区名以大写首字母替代,例如:白云大厦 -> BY大厦
|
||||
5. 房间号以****代替,例如:1607室 -> ****室
|
||||
```
|
||||
|
||||
#### 示例展示
|
||||
```
|
||||
示例:
|
||||
- 输入:上海市静安区恒丰路66号白云大厦1607室
|
||||
- 输出:上海市静安区HF路**号BY大厦****室
|
||||
|
||||
- 输入:北京市海淀区北小马厂6号1号楼华天大厦1306室
|
||||
- 输出:北京市海淀区北小马厂**号**号楼HT大厦****室
|
||||
```
|
||||
|
||||
#### JSON 输出格式
|
||||
```json
|
||||
{
|
||||
"masked_address": "脱敏后的地址"
|
||||
}
|
||||
```
|
||||
|
||||
## 实现细节
|
||||
|
||||
### 1. 主要方法
|
||||
|
||||
#### `_mask_address(address: str) -> str`
|
||||
```python
|
||||
def _mask_address(self, address: str) -> str:
|
||||
"""
|
||||
对地址进行脱敏处理,使用LLM直接生成脱敏地址
|
||||
"""
|
||||
if not address:
|
||||
return address
|
||||
|
||||
try:
|
||||
# 使用LLM生成脱敏地址
|
||||
prompt = get_address_masking_prompt(address)
|
||||
response = self.ollama_client.generate_with_validation(
|
||||
prompt=prompt,
|
||||
response_type='address_masking',
|
||||
return_parsed=True
|
||||
)
|
||||
|
||||
if response and isinstance(response, dict) and "masked_address" in response:
|
||||
return response["masked_address"]
|
||||
else:
|
||||
return self._mask_address_fallback(address)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error masking address with LLM: {e}")
|
||||
return self._mask_address_fallback(address)
|
||||
```
|
||||
|
||||
#### `_mask_address_fallback(address: str) -> str`
|
||||
```python
|
||||
def _mask_address_fallback(self, address: str) -> str:
|
||||
"""
|
||||
地址脱敏的回退方法,使用原有的正则表达式和拼音转换逻辑
|
||||
"""
|
||||
# 原有的脱敏逻辑作为回退
|
||||
```
|
||||
|
||||
### 2. Ollama 调用模式
|
||||
|
||||
遵循现有的 Ollama 客户端调用模式,使用验证:
|
||||
|
||||
```python
|
||||
response = self.ollama_client.generate_with_validation(
|
||||
prompt=prompt,
|
||||
response_type='address_masking',
|
||||
return_parsed=True
|
||||
)
|
||||
```
|
||||
|
||||
- `response_type='address_masking'`:指定响应类型进行验证
|
||||
- `return_parsed=True`:返回解析后的 JSON
|
||||
- 自动验证响应格式是否符合 schema
|
||||
|
||||
## 测试结果
|
||||
|
||||
### 测试案例
|
||||
|
||||
| 原始地址 | 期望脱敏结果 |
|
||||
|----------|-------------|
|
||||
| 上海市静安区恒丰路66号白云大厦1607室 | 上海市静安区HF路**号BY大厦****室 |
|
||||
| 北京市海淀区北小马厂6号1号楼华天大厦1306室 | 北京市海淀区北小马厂**号**号楼HT大厦****室 |
|
||||
| 天津市津南区双港镇工业园区优谷产业园5号楼-1505 | 天津市津南区双港镇工业园区优谷产业园**号楼-**** |
|
||||
|
||||
### Prompt 验证
|
||||
|
||||
- ✓ 包含脱敏规则说明
|
||||
- ✓ 提供具体示例
|
||||
- ✓ 指定 JSON 输出格式
|
||||
- ✓ 包含原始地址
|
||||
- ✓ 指定输出字段名
|
||||
|
||||
## 优势
|
||||
|
||||
### 1. 智能化处理
|
||||
- LLM 能够理解复杂的地址格式
|
||||
- 自动处理各种地址变体
|
||||
- 减少手动维护成本
|
||||
|
||||
### 2. 可靠性
|
||||
- 回退机制确保服务可用性
|
||||
- 错误处理和日志记录
|
||||
- 保持向后兼容性
|
||||
|
||||
### 3. 可扩展性
|
||||
- 易于添加新的脱敏规则
|
||||
- 支持多语言地址处理
|
||||
- 可配置的脱敏策略
|
||||
|
||||
### 4. 一致性
|
||||
- 统一的脱敏标准
|
||||
- 可预测的输出格式
|
||||
- 便于测试和验证
|
||||
|
||||
## 性能影响
|
||||
|
||||
### 1. 延迟
|
||||
- LLM 调用增加处理时间
|
||||
- 网络延迟影响响应速度
|
||||
- 回退机制提供快速响应
|
||||
|
||||
### 2. 成本
|
||||
- LLM API 调用成本
|
||||
- 需要稳定的网络连接
|
||||
- 回退机制降低依赖风险
|
||||
|
||||
### 3. 准确性
|
||||
- 显著提高脱敏准确性
|
||||
- 减少人工错误
|
||||
- 更好的地址理解能力
|
||||
|
||||
## 配置参数
|
||||
|
||||
- `response_type`: 响应类型,用于验证 (默认: 'address_masking')
|
||||
- `return_parsed`: 是否返回解析后的 JSON (默认: True)
|
||||
- `max_retries`: 最大重试次数 (默认: 3)
|
||||
|
||||
## 验证 Schema
|
||||
|
||||
地址脱敏响应必须符合以下 JSON schema:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"masked_address": {
|
||||
"type": "string",
|
||||
"description": "The masked address following the specified rules"
|
||||
}
|
||||
},
|
||||
"required": ["masked_address"]
|
||||
}
|
||||
```
|
||||
|
||||
## 使用示例
|
||||
|
||||
```python
|
||||
from app.core.document_handlers.ner_processor import NerProcessor
|
||||
|
||||
processor = NerProcessor()
|
||||
original_address = "上海市静安区恒丰路66号白云大厦1607室"
|
||||
masked_address = processor._mask_address(original_address)
|
||||
print(f"Original: {original_address}")
|
||||
print(f"Masked: {masked_address}")
|
||||
```
|
||||
|
||||
## 未来改进方向
|
||||
|
||||
1. **缓存机制**:缓存常见地址的脱敏结果
|
||||
2. **批量处理**:支持批量地址脱敏
|
||||
3. **自定义规则**:支持用户自定义脱敏规则
|
||||
4. **多语言支持**:扩展到其他语言的地址处理
|
||||
5. **性能优化**:异步处理和并发调用
|
||||
|
||||
## 相关文件
|
||||
|
||||
- `backend/app/core/document_handlers/ner_processor.py` - 主要实现
|
||||
- `backend/app/core/prompts/masking_prompts.py` - Prompt 函数
|
||||
- `backend/app/core/services/ollama_client.py` - Ollama 客户端
|
||||
- `backend/app/core/utils/llm_validator.py` - 验证 schema 和验证方法
|
||||
- `backend/test_validation_schema.py` - 验证 schema 测试
|
||||
|
|
@ -1,255 +0,0 @@
|
|||
# OllamaClient Enhancement Summary
|
||||
|
||||
## Overview
|
||||
The `OllamaClient` has been successfully enhanced to support validation and retry mechanisms while maintaining full backward compatibility.
|
||||
|
||||
## Key Enhancements
|
||||
|
||||
### 1. **Enhanced Constructor**
|
||||
```python
|
||||
def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3):
|
||||
```
|
||||
- Added `max_retries` parameter for configurable retry attempts
|
||||
- Default retry count: 3 attempts
|
||||
|
||||
### 2. **Enhanced Generate Method**
|
||||
```python
|
||||
def generate(self,
|
||||
prompt: str,
|
||||
strip_think: bool = True,
|
||||
validation_schema: Optional[Dict[str, Any]] = None,
|
||||
response_type: Optional[str] = None,
|
||||
return_parsed: bool = False) -> Union[str, Dict[str, Any]]:
|
||||
```
|
||||
|
||||
**New Parameters:**
|
||||
- `validation_schema`: Custom JSON schema for validation
|
||||
- `response_type`: Predefined response type for validation
|
||||
- `return_parsed`: Return parsed JSON instead of raw string
|
||||
|
||||
**Return Type:**
|
||||
- `Union[str, Dict[str, Any]]`: Can return either raw string or parsed JSON
|
||||
|
||||
### 3. **New Convenience Methods**
|
||||
|
||||
#### `generate_with_validation()`
|
||||
```python
|
||||
def generate_with_validation(self,
|
||||
prompt: str,
|
||||
response_type: str,
|
||||
strip_think: bool = True,
|
||||
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
||||
```
|
||||
- Uses predefined validation schemas based on response type
|
||||
- Automatically handles retries and validation
|
||||
- Returns parsed JSON by default
|
||||
|
||||
#### `generate_with_schema()`
|
||||
```python
|
||||
def generate_with_schema(self,
|
||||
prompt: str,
|
||||
schema: Dict[str, Any],
|
||||
strip_think: bool = True,
|
||||
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
||||
```
|
||||
- Uses custom JSON schema for validation
|
||||
- Automatically handles retries and validation
|
||||
- Returns parsed JSON by default
|
||||
|
||||
### 4. **Supported Response Types**
|
||||
The following response types are supported for automatic validation:
|
||||
|
||||
- `'entity_extraction'`: Entity extraction responses
|
||||
- `'entity_linkage'`: Entity linkage responses
|
||||
- `'regex_entity'`: Regex-based entity responses
|
||||
- `'business_name_extraction'`: Business name extraction responses
|
||||
- `'address_extraction'`: Address component extraction responses
|
||||
|
||||
## Features
|
||||
|
||||
### 1. **Automatic Retry Mechanism**
|
||||
- Retries failed API calls up to `max_retries` times
|
||||
- Retries on validation failures
|
||||
- Retries on JSON parsing failures
|
||||
- Configurable retry count per client instance
|
||||
|
||||
### 2. **Built-in Validation**
|
||||
- JSON schema validation using `jsonschema` library
|
||||
- Predefined schemas for common response types
|
||||
- Custom schema support for specialized use cases
|
||||
- Detailed validation error logging
|
||||
|
||||
### 3. **Automatic JSON Parsing**
|
||||
- Uses `LLMJsonExtractor.parse_raw_json_str()` for robust JSON extraction
|
||||
- Handles malformed JSON responses gracefully
|
||||
- Returns parsed Python dictionaries when requested
|
||||
|
||||
### 4. **Backward Compatibility**
|
||||
- All existing code continues to work without changes
|
||||
- Original `generate()` method signature preserved
|
||||
- Default behavior unchanged
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### 1. **Basic Usage (Backward Compatible)**
|
||||
```python
|
||||
client = OllamaClient("llama2")
|
||||
response = client.generate("Hello, world!")
|
||||
# Returns: "Hello, world!"
|
||||
```
|
||||
|
||||
### 2. **With Response Type Validation**
|
||||
```python
|
||||
client = OllamaClient("llama2")
|
||||
result = client.generate_with_validation(
|
||||
prompt="Extract business name from: 上海盒马网络科技有限公司",
|
||||
response_type='business_name_extraction',
|
||||
return_parsed=True
|
||||
)
|
||||
# Returns: {"business_name": "盒马", "confidence": 0.9}
|
||||
```
|
||||
|
||||
### 3. **With Custom Schema Validation**
|
||||
```python
|
||||
client = OllamaClient("llama2")
|
||||
custom_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "number"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}
|
||||
|
||||
result = client.generate_with_schema(
|
||||
prompt="Generate person info",
|
||||
schema=custom_schema,
|
||||
return_parsed=True
|
||||
)
|
||||
# Returns: {"name": "张三", "age": 30}
|
||||
```
|
||||
|
||||
### 4. **Advanced Usage with All Options**
|
||||
```python
|
||||
client = OllamaClient("llama2", max_retries=5)
|
||||
result = client.generate(
|
||||
prompt="Complex prompt",
|
||||
strip_think=True,
|
||||
validation_schema=custom_schema,
|
||||
return_parsed=True
|
||||
)
|
||||
```
|
||||
|
||||
## Updated Components
|
||||
|
||||
### 1. **Extractors**
|
||||
- `BusinessNameExtractor`: Now uses `generate_with_validation()`
|
||||
- `AddressExtractor`: Now uses `generate_with_validation()`
|
||||
|
||||
### 2. **Processors**
|
||||
- `NerProcessor`: Updated to use enhanced methods
|
||||
- `NerProcessorRefactored`: Updated to use enhanced methods
|
||||
|
||||
### 3. **Benefits in Processors**
|
||||
- Simplified code: No more manual retry loops
|
||||
- Automatic validation: No more manual JSON parsing
|
||||
- Better error handling: Automatic fallback to regex methods
|
||||
- Cleaner code: Reduced boilerplate
|
||||
|
||||
## Error Handling
|
||||
|
||||
### 1. **API Failures**
|
||||
- Automatic retry on network errors
|
||||
- Configurable retry count
|
||||
- Detailed error logging
|
||||
|
||||
### 2. **Validation Failures**
|
||||
- Automatic retry on schema validation failures
|
||||
- Automatic retry on JSON parsing failures
|
||||
- Graceful fallback to alternative methods
|
||||
|
||||
### 3. **Exception Types**
|
||||
- `RequestException`: API call failures after all retries
|
||||
- `ValueError`: Validation failures after all retries
|
||||
- `Exception`: Unexpected errors
|
||||
|
||||
## Testing
|
||||
|
||||
### 1. **Test Coverage**
|
||||
- Initialization with new parameters
|
||||
- Enhanced generate methods
|
||||
- Backward compatibility
|
||||
- Retry mechanism
|
||||
- Validation failure handling
|
||||
- Mock-based testing for reliability
|
||||
|
||||
### 2. **Run Tests**
|
||||
```bash
|
||||
cd backend
|
||||
python3 test_enhanced_ollama_client.py
|
||||
```
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### 1. **No Changes Required**
|
||||
Existing code continues to work without modification:
|
||||
```python
|
||||
# This still works exactly the same
|
||||
client = OllamaClient("llama2")
|
||||
response = client.generate("prompt")
|
||||
```
|
||||
|
||||
### 2. **Optional Enhancements**
|
||||
To take advantage of new features:
|
||||
```python
|
||||
# Old way (still works)
|
||||
response = client.generate(prompt)
|
||||
parsed = LLMJsonExtractor.parse_raw_json_str(response)
|
||||
if LLMResponseValidator.validate_entity_extraction(parsed):
|
||||
# use parsed
|
||||
|
||||
# New way (recommended)
|
||||
parsed = client.generate_with_validation(
|
||||
prompt=prompt,
|
||||
response_type='entity_extraction',
|
||||
return_parsed=True
|
||||
)
|
||||
# parsed is already validated and ready to use
|
||||
```
|
||||
|
||||
### 3. **Benefits of Migration**
|
||||
- **Reduced Code**: Eliminates manual retry loops
|
||||
- **Better Reliability**: Automatic retry and validation
|
||||
- **Cleaner Code**: Less boilerplate
|
||||
- **Better Error Handling**: Automatic fallbacks
|
||||
|
||||
## Performance Impact
|
||||
|
||||
### 1. **Positive Impact**
|
||||
- Reduced code complexity
|
||||
- Better error recovery
|
||||
- Automatic retry reduces manual intervention
|
||||
|
||||
### 2. **Minimal Overhead**
|
||||
- Validation only occurs when requested
|
||||
- JSON parsing only occurs when needed
|
||||
- Retry mechanism only activates on failures
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
### 1. **Potential Additions**
|
||||
- Circuit breaker pattern for API failures
|
||||
- Caching for repeated requests
|
||||
- Async/await support
|
||||
- Streaming response support
|
||||
- Custom retry strategies
|
||||
|
||||
### 2. **Configuration Options**
|
||||
- Per-request retry configuration
|
||||
- Custom validation error handling
|
||||
- Response transformation hooks
|
||||
- Metrics and monitoring
|
||||
|
||||
## Conclusion
|
||||
|
||||
The enhanced `OllamaClient` provides a robust, reliable, and easy-to-use interface for LLM interactions while maintaining full backward compatibility. The new validation and retry mechanisms significantly improve the reliability of LLM-based operations in the NER processing pipeline.
|
||||
|
|
@ -1,202 +0,0 @@
|
|||
# PDF Processor with Mineru API
|
||||
|
||||
## Overview
|
||||
|
||||
The PDF processor has been rewritten to use Mineru's REST API instead of the magic_pdf library. This provides better separation of concerns and allows for more flexible deployment options.
|
||||
|
||||
## Changes Made
|
||||
|
||||
### 1. Removed Dependencies
|
||||
- Removed all `magic_pdf` imports and dependencies
|
||||
- Removed `PyPDF2` direct usage (though kept in requirements for potential other uses)
|
||||
|
||||
### 2. New Implementation
|
||||
- **REST API Integration**: Uses HTTP requests to call Mineru's API
|
||||
- **Configurable Settings**: Mineru API URL and timeout are configurable
|
||||
- **Error Handling**: Comprehensive error handling for network issues, timeouts, and API errors
|
||||
- **Flexible Response Parsing**: Handles multiple possible response formats from Mineru API
|
||||
|
||||
### 3. Configuration
|
||||
|
||||
Add the following settings to your environment or `.env` file:
|
||||
|
||||
```bash
|
||||
# Mineru API Configuration
|
||||
MINERU_API_URL=http://mineru-api:8000
|
||||
MINERU_TIMEOUT=300
|
||||
MINERU_LANG_LIST=["ch"]
|
||||
MINERU_BACKEND=pipeline
|
||||
MINERU_PARSE_METHOD=auto
|
||||
MINERU_FORMULA_ENABLE=true
|
||||
MINERU_TABLE_ENABLE=true
|
||||
```
|
||||
|
||||
### 4. API Endpoint
|
||||
|
||||
The processor expects Mineru to provide a REST API endpoint at `/file_parse` that accepts PDF files via multipart form data and returns JSON with markdown content.
|
||||
|
||||
#### Expected Request Format:
|
||||
```
|
||||
POST /file_parse
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
files: [PDF file]
|
||||
output_dir: ./output
|
||||
lang_list: ["ch"]
|
||||
backend: pipeline
|
||||
parse_method: auto
|
||||
formula_enable: true
|
||||
table_enable: true
|
||||
return_md: true
|
||||
return_middle_json: false
|
||||
return_model_output: false
|
||||
return_content_list: false
|
||||
return_images: false
|
||||
start_page_id: 0
|
||||
end_page_id: 99999
|
||||
```
|
||||
|
||||
#### Expected Response Format:
|
||||
The processor can handle multiple response formats:
|
||||
|
||||
```json
|
||||
{
|
||||
"markdown": "# Document Title\n\nContent here..."
|
||||
}
|
||||
```
|
||||
|
||||
OR
|
||||
|
||||
```json
|
||||
{
|
||||
"md": "# Document Title\n\nContent here..."
|
||||
}
|
||||
```
|
||||
|
||||
OR
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "# Document Title\n\nContent here..."
|
||||
}
|
||||
```
|
||||
|
||||
OR
|
||||
|
||||
```json
|
||||
{
|
||||
"result": {
|
||||
"markdown": "# Document Title\n\nContent here..."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from app.core.document_handlers.processors.pdf_processor import PdfDocumentProcessor
|
||||
|
||||
# Create processor instance
|
||||
processor = PdfDocumentProcessor("input.pdf", "output.md")
|
||||
|
||||
# Read and convert PDF to markdown
|
||||
content = processor.read_content()
|
||||
|
||||
# Process content (apply masking)
|
||||
processed_content = processor.process_content(content)
|
||||
|
||||
# Save processed content
|
||||
processor.save_content(processed_content)
|
||||
```
|
||||
|
||||
### Through Document Service
|
||||
|
||||
```python
|
||||
from app.core.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService()
|
||||
success = service.process_document("input.pdf", "output.md")
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Run the test script to verify the implementation:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
python test_pdf_processor.py
|
||||
```
|
||||
|
||||
Make sure you have:
|
||||
1. A sample PDF file in the `sample_doc/` directory
|
||||
2. Mineru API service running and accessible
|
||||
3. Proper network connectivity between services
|
||||
|
||||
## Error Handling
|
||||
|
||||
The processor handles various error scenarios:
|
||||
|
||||
- **Network Timeouts**: Configurable timeout (default: 5 minutes)
|
||||
- **API Errors**: HTTP status code errors are logged and handled
|
||||
- **Response Parsing**: Multiple fallback strategies for extracting markdown content
|
||||
- **File Operations**: Proper error handling for file reading/writing
|
||||
|
||||
## Logging
|
||||
|
||||
The processor provides detailed logging for debugging:
|
||||
|
||||
- API call attempts and responses
|
||||
- Content extraction results
|
||||
- Error conditions and stack traces
|
||||
- Processing statistics
|
||||
|
||||
## Deployment
|
||||
|
||||
### Docker Compose
|
||||
|
||||
Ensure your Mineru service is running and accessible. The default configuration expects it at `http://mineru-api:8000`.
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Set the following environment variables in your deployment:
|
||||
|
||||
```bash
|
||||
MINERU_API_URL=http://your-mineru-service:8000
|
||||
MINERU_TIMEOUT=300
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Connection Refused**: Check if Mineru service is running and accessible
|
||||
2. **Timeout Errors**: Increase `MINERU_TIMEOUT` for large PDF files
|
||||
3. **Empty Content**: Check Mineru API response format and logs
|
||||
4. **Network Issues**: Verify network connectivity between services
|
||||
|
||||
### Debug Mode
|
||||
|
||||
Enable debug logging to see detailed API interactions:
|
||||
|
||||
```python
|
||||
import logging
|
||||
logging.getLogger('app.core.document_handlers.processors.pdf_processor').setLevel(logging.DEBUG)
|
||||
```
|
||||
|
||||
## Migration from magic_pdf
|
||||
|
||||
If you were previously using magic_pdf:
|
||||
|
||||
1. **No Code Changes Required**: The interface remains the same
|
||||
2. **Configuration Update**: Add Mineru API settings
|
||||
3. **Service Dependencies**: Ensure Mineru service is running
|
||||
4. **Testing**: Run the test script to verify functionality
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- **Timeout**: Large PDFs may require longer timeouts
|
||||
- **Memory**: The processor loads the entire PDF into memory for API calls
|
||||
- **Network**: API calls add network latency to processing time
|
||||
- **Caching**: Consider implementing caching for frequently processed documents
|
||||
|
|
@ -1,166 +0,0 @@
|
|||
# NerProcessor Refactoring Summary
|
||||
|
||||
## Overview
|
||||
The `ner_processor.py` file has been successfully refactored from a monolithic 729-line class into a modular, maintainable architecture following SOLID principles.
|
||||
|
||||
## New Architecture
|
||||
|
||||
### Directory Structure
|
||||
```
|
||||
backend/app/core/document_handlers/
|
||||
├── ner_processor.py # Original file (unchanged)
|
||||
├── ner_processor_refactored.py # New refactored version
|
||||
├── masker_factory.py # Factory for creating maskers
|
||||
├── maskers/
|
||||
│ ├── __init__.py
|
||||
│ ├── base_masker.py # Abstract base class
|
||||
│ ├── name_masker.py # Chinese/English name masking
|
||||
│ ├── company_masker.py # Company name masking
|
||||
│ ├── address_masker.py # Address masking
|
||||
│ ├── id_masker.py # ID/social credit code masking
|
||||
│ └── case_masker.py # Case number masking
|
||||
├── extractors/
|
||||
│ ├── __init__.py
|
||||
│ ├── base_extractor.py # Abstract base class
|
||||
│ ├── business_name_extractor.py # Business name extraction
|
||||
│ └── address_extractor.py # Address component extraction
|
||||
└── validators/ # (Placeholder for future use)
|
||||
```
|
||||
|
||||
## Key Components
|
||||
|
||||
### 1. Base Classes
|
||||
- **`BaseMasker`**: Abstract base class for all maskers
|
||||
- **`BaseExtractor`**: Abstract base class for all extractors
|
||||
|
||||
### 2. Maskers
|
||||
- **`ChineseNameMasker`**: Handles Chinese name masking (surname + pinyin initials)
|
||||
- **`EnglishNameMasker`**: Handles English name masking (first letter + ***)
|
||||
- **`CompanyMasker`**: Handles company name masking (business name replacement)
|
||||
- **`AddressMasker`**: Handles address masking (component replacement)
|
||||
- **`IDMasker`**: Handles ID and social credit code masking
|
||||
- **`CaseMasker`**: Handles case number masking
|
||||
|
||||
### 3. Extractors
|
||||
- **`BusinessNameExtractor`**: Extracts business names from company names using LLM + regex fallback
|
||||
- **`AddressExtractor`**: Extracts address components using LLM + regex fallback
|
||||
|
||||
### 4. Factory
|
||||
- **`MaskerFactory`**: Creates maskers with proper dependencies
|
||||
|
||||
### 5. Refactored Processor
|
||||
- **`NerProcessorRefactored`**: Main orchestrator using the new architecture
|
||||
|
||||
## Benefits Achieved
|
||||
|
||||
### 1. Single Responsibility Principle
|
||||
- Each class has one clear responsibility
|
||||
- Maskers only handle masking logic
|
||||
- Extractors only handle extraction logic
|
||||
- Processor only handles orchestration
|
||||
|
||||
### 2. Open/Closed Principle
|
||||
- Easy to add new maskers without modifying existing code
|
||||
- New entity types can be supported by creating new maskers
|
||||
|
||||
### 3. Dependency Injection
|
||||
- Dependencies are injected rather than hardcoded
|
||||
- Easier to test and mock
|
||||
|
||||
### 4. Better Testing
|
||||
- Each component can be tested in isolation
|
||||
- Mock dependencies easily
|
||||
|
||||
### 5. Code Reusability
|
||||
- Maskers can be used independently
|
||||
- Common functionality shared through base classes
|
||||
|
||||
### 6. Maintainability
|
||||
- Changes to one masking rule don't affect others
|
||||
- Clear separation of concerns
|
||||
|
||||
## Migration Strategy
|
||||
|
||||
### Phase 1: ✅ Complete
|
||||
- Created base classes and interfaces
|
||||
- Extracted all maskers
|
||||
- Created extractors
|
||||
- Created factory pattern
|
||||
- Created refactored processor
|
||||
|
||||
### Phase 2: Testing (Next)
|
||||
- Run validation script: `python3 validate_refactoring.py`
|
||||
- Run existing tests to ensure compatibility
|
||||
- Create comprehensive unit tests for each component
|
||||
|
||||
### Phase 3: Integration (Future)
|
||||
- Replace original processor with refactored version
|
||||
- Update imports throughout the codebase
|
||||
- Remove old code
|
||||
|
||||
### Phase 4: Enhancement (Future)
|
||||
- Add configuration management
|
||||
- Add more extractors as needed
|
||||
- Add validation components
|
||||
|
||||
## Testing
|
||||
|
||||
### Validation Script
|
||||
Run the validation script to test the refactored code:
|
||||
```bash
|
||||
cd backend
|
||||
python3 validate_refactoring.py
|
||||
```
|
||||
|
||||
### Unit Tests
|
||||
Run the unit tests for the refactored components:
|
||||
```bash
|
||||
cd backend
|
||||
python3 -m pytest tests/test_refactored_ner_processor.py -v
|
||||
```
|
||||
|
||||
## Current Status
|
||||
|
||||
✅ **Completed:**
|
||||
- All maskers extracted and implemented
|
||||
- All extractors created
|
||||
- Factory pattern implemented
|
||||
- Refactored processor created
|
||||
- Validation script created
|
||||
- Unit tests created
|
||||
|
||||
🔄 **Next Steps:**
|
||||
- Test the refactored code
|
||||
- Ensure all existing functionality works
|
||||
- Replace original processor when ready
|
||||
|
||||
## File Comparison
|
||||
|
||||
| Metric | Original | Refactored |
|
||||
|--------|----------|------------|
|
||||
| Main Class Lines | 729 | ~200 |
|
||||
| Number of Classes | 1 | 10+ |
|
||||
| Responsibilities | Multiple | Single |
|
||||
| Testability | Low | High |
|
||||
| Maintainability | Low | High |
|
||||
| Extensibility | Low | High |
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
The refactored code maintains full backward compatibility:
|
||||
- All existing masking rules are preserved
|
||||
- All existing functionality works the same
|
||||
- The public API remains unchanged
|
||||
- The original `ner_processor.py` is untouched
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
1. **Configuration Management**: Centralized configuration for masking rules
|
||||
2. **Validation Framework**: Dedicated validation components
|
||||
3. **Performance Optimization**: Caching and optimization strategies
|
||||
4. **Monitoring**: Metrics and logging for each component
|
||||
5. **Plugin System**: Dynamic loading of new maskers and extractors
|
||||
|
||||
## Conclusion
|
||||
|
||||
The refactoring successfully transforms the monolithic `NerProcessor` into a modular, maintainable, and extensible architecture while preserving all existing functionality. The new architecture follows SOLID principles and provides a solid foundation for future enhancements.
|
||||
|
|
@ -1,130 +0,0 @@
|
|||
# 句子分块改进文档
|
||||
|
||||
## 问题描述
|
||||
|
||||
在原始的NER提取过程中,我们发现了一些实体被截断的问题,比如:
|
||||
- `"丰复久信公"` (应该是 `"丰复久信营销科技有限公司"`)
|
||||
- `"康达律师事"` (应该是 `"北京市康达律师事务所"`)
|
||||
|
||||
这些截断问题是由于原始的基于字符数量的简单分块策略导致的,该策略没有考虑实体的完整性。
|
||||
|
||||
## 解决方案
|
||||
|
||||
### 1. 句子分块策略
|
||||
|
||||
我们实现了基于句子的智能分块策略,主要特点:
|
||||
|
||||
- **自然边界分割**:使用中文句子结束符(。!?;\n)和英文句子结束符(.!?;)进行分割
|
||||
- **实体完整性保护**:避免在实体名称中间进行分割
|
||||
- **智能长度控制**:基于token数量而非字符数量进行分块
|
||||
|
||||
### 2. 实体边界安全检查
|
||||
|
||||
实现了 `_is_entity_boundary_safe()` 方法来检查分割点是否安全:
|
||||
|
||||
```python
|
||||
def _is_entity_boundary_safe(self, text: str, position: int) -> bool:
|
||||
# 检查常见实体后缀
|
||||
entity_suffixes = ['公', '司', '所', '院', '厅', '局', '部', '会', '团', '社', '处', '室', '楼', '号']
|
||||
|
||||
# 检查不完整的实体模式
|
||||
if text[position-2:position+1] in ['公司', '事务所', '协会', '研究院']:
|
||||
return False
|
||||
|
||||
# 检查地址模式
|
||||
address_patterns = ['省', '市', '区', '县', '路', '街', '巷', '号', '室']
|
||||
# ...
|
||||
```
|
||||
|
||||
### 3. 长句子智能分割
|
||||
|
||||
对于超过token限制的长句子,实现了智能分割策略:
|
||||
|
||||
1. **标点符号分割**:优先在逗号、分号等标点符号处分割
|
||||
2. **实体边界分割**:如果标点分割不可行,在安全的实体边界处分割
|
||||
3. **强制分割**:最后才使用字符级别的强制分割
|
||||
|
||||
## 实现细节
|
||||
|
||||
### 核心方法
|
||||
|
||||
1. **`_split_text_by_sentences()`**: 将文本按句子分割
|
||||
2. **`_create_sentence_chunks()`**: 基于句子创建分块
|
||||
3. **`_split_long_sentence()`**: 智能分割长句子
|
||||
4. **`_is_entity_boundary_safe()`**: 检查分割点安全性
|
||||
|
||||
### 分块流程
|
||||
|
||||
```
|
||||
输入文本
|
||||
↓
|
||||
按句子分割
|
||||
↓
|
||||
估算token数量
|
||||
↓
|
||||
创建句子分块
|
||||
↓
|
||||
检查实体边界
|
||||
↓
|
||||
输出最终分块
|
||||
```
|
||||
|
||||
## 测试结果
|
||||
|
||||
### 改进前 vs 改进后
|
||||
|
||||
| 指标 | 改进前 | 改进后 |
|
||||
|------|--------|--------|
|
||||
| 截断实体数量 | 较多 | 显著减少 |
|
||||
| 实体完整性 | 经常被破坏 | 得到保护 |
|
||||
| 分块质量 | 基于字符 | 基于语义 |
|
||||
|
||||
### 测试案例
|
||||
|
||||
1. **"丰复久信公" 问题**:
|
||||
- 改进前:`"丰复久信公"` (截断)
|
||||
- 改进后:`"北京丰复久信营销科技有限公司"` (完整)
|
||||
|
||||
2. **长句子处理**:
|
||||
- 改进前:可能在实体中间截断
|
||||
- 改进后:在句子边界或安全位置分割
|
||||
|
||||
## 配置参数
|
||||
|
||||
- `max_tokens`: 每个分块的最大token数量 (默认: 400)
|
||||
- `confidence_threshold`: 实体置信度阈值 (默认: 0.95)
|
||||
- `sentence_pattern`: 句子分割正则表达式
|
||||
|
||||
## 使用示例
|
||||
|
||||
```python
|
||||
from app.core.document_handlers.extractors.ner_extractor import NERExtractor
|
||||
|
||||
extractor = NERExtractor()
|
||||
result = extractor.extract(long_text)
|
||||
|
||||
# 结果中的实体将更加完整
|
||||
entities = result.get("entities", [])
|
||||
for entity in entities:
|
||||
print(f"{entity['text']} ({entity['type']})")
|
||||
```
|
||||
|
||||
## 性能影响
|
||||
|
||||
- **内存使用**:略有增加(需要存储句子分割结果)
|
||||
- **处理速度**:基本无影响(句子分割很快)
|
||||
- **准确性**:显著提升(减少截断实体)
|
||||
|
||||
## 未来改进方向
|
||||
|
||||
1. **更智能的实体识别**:使用预训练模型识别实体边界
|
||||
2. **动态分块大小**:根据文本复杂度调整分块大小
|
||||
3. **多语言支持**:扩展到其他语言的分块策略
|
||||
4. **缓存优化**:缓存句子分割结果以提高性能
|
||||
|
||||
## 相关文件
|
||||
|
||||
- `backend/app/core/document_handlers/extractors/ner_extractor.py` - 主要实现
|
||||
- `backend/test_improved_chunking.py` - 测试脚本
|
||||
- `backend/test_truncation_fix.py` - 截断问题测试
|
||||
- `backend/test_chunking_logic.py` - 分块逻辑测试
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
# Test Setup Guide
|
||||
|
||||
This document explains how to set up and run tests for the legal-doc-masker backend.
|
||||
|
||||
## Test Structure
|
||||
|
||||
```
|
||||
backend/
|
||||
├── tests/
|
||||
│ ├── __init__.py
|
||||
│ ├── test_ner_processor.py
|
||||
│ ├── test1.py
|
||||
│ └── test.txt
|
||||
├── conftest.py
|
||||
├── pytest.ini
|
||||
└── run_tests.py
|
||||
```
|
||||
|
||||
## VS Code Configuration
|
||||
|
||||
The `.vscode/settings.json` file has been configured to:
|
||||
|
||||
1. **Set pytest as the test framework**: `"python.testing.pytestEnabled": true`
|
||||
2. **Point to the correct test directory**: `"python.testing.pytestArgs": ["backend/tests"]`
|
||||
3. **Set the working directory**: `"python.testing.cwd": "${workspaceFolder}/backend"`
|
||||
4. **Configure Python interpreter**: Points to backend virtual environment
|
||||
|
||||
## Running Tests
|
||||
|
||||
### From VS Code Test Explorer
|
||||
1. Open the Test Explorer panel (Ctrl+Shift+P → "Python: Configure Tests")
|
||||
2. Select "pytest" as the test framework
|
||||
3. Select "backend/tests" as the test directory
|
||||
4. Tests should now appear in the Test Explorer
|
||||
|
||||
### From Command Line
|
||||
```bash
|
||||
# From the project root
|
||||
cd backend
|
||||
python -m pytest tests/ -v
|
||||
|
||||
# Or use the test runner script
|
||||
python run_tests.py
|
||||
```
|
||||
|
||||
### From VS Code Terminal
|
||||
```bash
|
||||
# Make sure you're in the backend directory
|
||||
cd backend
|
||||
pytest tests/ -v
|
||||
```
|
||||
|
||||
## Test Configuration
|
||||
|
||||
### pytest.ini
|
||||
- **testpaths**: Points to the `tests` directory
|
||||
- **python_files**: Looks for files starting with `test_` or ending with `_test.py`
|
||||
- **python_functions**: Looks for functions starting with `test_`
|
||||
- **markers**: Defines test markers for categorization
|
||||
|
||||
### conftest.py
|
||||
- **Path setup**: Adds backend directory to Python path
|
||||
- **Fixtures**: Provides common test fixtures
|
||||
- **Environment setup**: Handles test environment initialization
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Tests Not Discovered
|
||||
1. **Check VS Code settings**: Ensure `python.testing.pytestArgs` points to `backend/tests`
|
||||
2. **Verify working directory**: Ensure `python.testing.cwd` is set to `${workspaceFolder}/backend`
|
||||
3. **Check Python interpreter**: Make sure it points to the backend virtual environment
|
||||
|
||||
### Import Errors
|
||||
1. **Check conftest.py**: Ensures backend directory is in Python path
|
||||
2. **Verify __init__.py**: Tests directory should have an `__init__.py` file
|
||||
3. **Check relative imports**: Use absolute imports from the backend root
|
||||
|
||||
### Virtual Environment Issues
|
||||
1. **Create virtual environment**: `python -m venv .venv`
|
||||
2. **Activate environment**:
|
||||
- Windows: `.venv\Scripts\activate`
|
||||
- Unix/MacOS: `source .venv/bin/activate`
|
||||
3. **Install dependencies**: `pip install -r requirements.txt`
|
||||
|
||||
## Test Examples
|
||||
|
||||
### Simple Test
|
||||
```python
|
||||
def test_simple_assertion():
|
||||
"""Simple test to verify pytest is working"""
|
||||
assert 1 == 1
|
||||
assert 2 + 2 == 4
|
||||
```
|
||||
|
||||
### Test with Fixture
|
||||
```python
|
||||
def test_with_fixture(sample_data):
|
||||
"""Test using a fixture"""
|
||||
assert sample_data["name"] == "test"
|
||||
assert sample_data["value"] == 42
|
||||
```
|
||||
|
||||
### Integration Test
|
||||
```python
|
||||
def test_ner_processor():
|
||||
"""Test NER processor functionality"""
|
||||
from app.core.document_handlers.ner_processor import NerProcessor
|
||||
processor = NerProcessor()
|
||||
# Test implementation...
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Test naming**: Use descriptive test names starting with `test_`
|
||||
2. **Test isolation**: Each test should be independent
|
||||
3. **Use fixtures**: For common setup and teardown
|
||||
4. **Add markers**: Use `@pytest.mark.slow` for slow tests
|
||||
5. **Documentation**: Add docstrings to explain test purpose
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
[tool:pytest]
|
||||
testpaths = tests
|
||||
pythonpath = .
|
||||
python_files = test_*.py *_test.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts =
|
||||
-v
|
||||
--tb=short
|
||||
--strict-markers
|
||||
--disable-warnings
|
||||
markers =
|
||||
slow: marks tests as slow (deselect with '-m "not slow"')
|
||||
integration: marks tests as integration tests
|
||||
unit: marks tests as unit tests
|
||||
|
|
@ -28,13 +28,4 @@ requests==2.28.1
|
|||
python-docx>=0.8.11
|
||||
PyPDF2>=3.0.0
|
||||
pandas>=2.0.0
|
||||
# magic-pdf[full]
|
||||
jsonschema>=4.20.0
|
||||
|
||||
# Chinese text processing
|
||||
pypinyin>=0.50.0
|
||||
|
||||
# NER and ML dependencies
|
||||
# torch is installed separately in Dockerfile for CPU optimization
|
||||
transformers>=4.30.0
|
||||
tokenizers>=0.13.0
|
||||
magic-pdf[full]
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
# Tests package
|
||||
|
|
@ -1,130 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Debug script to understand the position mapping issue after masking.
|
||||
"""
|
||||
|
||||
def find_entity_alignment(entity_text: str, original_document_text: str):
|
||||
"""Simplified version of the alignment method for testing"""
|
||||
clean_entity = entity_text.replace(" ", "")
|
||||
doc_chars = [c for c in original_document_text if c != ' ']
|
||||
|
||||
for i in range(len(doc_chars) - len(clean_entity) + 1):
|
||||
if doc_chars[i:i+len(clean_entity)] == list(clean_entity):
|
||||
return map_char_positions_to_original(i, len(clean_entity), original_document_text)
|
||||
return None
|
||||
|
||||
def map_char_positions_to_original(clean_start: int, entity_length: int, original_text: str):
|
||||
"""Simplified version of position mapping for testing"""
|
||||
original_pos = 0
|
||||
clean_pos = 0
|
||||
|
||||
while clean_pos < clean_start and original_pos < len(original_text):
|
||||
if original_text[original_pos] != ' ':
|
||||
clean_pos += 1
|
||||
original_pos += 1
|
||||
|
||||
start_pos = original_pos
|
||||
|
||||
chars_found = 0
|
||||
while chars_found < entity_length and original_pos < len(original_text):
|
||||
if original_text[original_pos] != ' ':
|
||||
chars_found += 1
|
||||
original_pos += 1
|
||||
|
||||
end_pos = original_pos
|
||||
found_text = original_text[start_pos:end_pos]
|
||||
|
||||
return start_pos, end_pos, found_text
|
||||
|
||||
def debug_position_issue():
|
||||
"""Debug the position mapping issue"""
|
||||
|
||||
print("Debugging Position Mapping Issue")
|
||||
print("=" * 50)
|
||||
|
||||
# Test document
|
||||
original_doc = "上诉人李淼因合同纠纷,法定代表人李淼,委托代理人李淼。"
|
||||
entity = "李淼"
|
||||
masked_text = "李M"
|
||||
|
||||
print(f"Original document: '{original_doc}'")
|
||||
print(f"Entity to mask: '{entity}'")
|
||||
print(f"Masked text: '{masked_text}'")
|
||||
print()
|
||||
|
||||
# First occurrence
|
||||
print("=== First Occurrence ===")
|
||||
result1 = find_entity_alignment(entity, original_doc)
|
||||
if result1:
|
||||
start1, end1, found1 = result1
|
||||
print(f"Found at positions {start1}-{end1}: '{found1}'")
|
||||
|
||||
# Apply first mask
|
||||
masked_doc = original_doc[:start1] + masked_text + original_doc[end1:]
|
||||
print(f"After first mask: '{masked_doc}'")
|
||||
print(f"Length changed from {len(original_doc)} to {len(masked_doc)}")
|
||||
|
||||
# Try to find second occurrence in the masked document
|
||||
print("\n=== Second Occurrence (in masked document) ===")
|
||||
result2 = find_entity_alignment(entity, masked_doc)
|
||||
if result2:
|
||||
start2, end2, found2 = result2
|
||||
print(f"Found at positions {start2}-{end2}: '{found2}'")
|
||||
|
||||
# Apply second mask
|
||||
masked_doc2 = masked_doc[:start2] + masked_text + masked_doc[end2:]
|
||||
print(f"After second mask: '{masked_doc2}'")
|
||||
|
||||
# Try to find third occurrence
|
||||
print("\n=== Third Occurrence (in double-masked document) ===")
|
||||
result3 = find_entity_alignment(entity, masked_doc2)
|
||||
if result3:
|
||||
start3, end3, found3 = result3
|
||||
print(f"Found at positions {start3}-{end3}: '{found3}'")
|
||||
else:
|
||||
print("No third occurrence found")
|
||||
else:
|
||||
print("No second occurrence found")
|
||||
else:
|
||||
print("No first occurrence found")
|
||||
|
||||
def debug_infinite_loop():
|
||||
"""Debug the infinite loop issue"""
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Debugging Infinite Loop Issue")
|
||||
print("=" * 50)
|
||||
|
||||
# Test document that causes infinite loop
|
||||
original_doc = "上诉人李淼因合同纠纷,法定代表人李淼。北京丰复久信营销科技有限公司,丰复久信公司。"
|
||||
entity = "丰复久信公司"
|
||||
masked_text = "丰复久信公司" # Same text (no change)
|
||||
|
||||
print(f"Original document: '{original_doc}'")
|
||||
print(f"Entity to mask: '{entity}'")
|
||||
print(f"Masked text: '{masked_text}' (same as original)")
|
||||
print()
|
||||
|
||||
# This will cause infinite loop because we're replacing with the same text
|
||||
print("=== This will cause infinite loop ===")
|
||||
print("Because we're replacing '丰复久信公司' with '丰复久信公司'")
|
||||
print("The document doesn't change, so we keep finding the same position")
|
||||
|
||||
# Show what happens
|
||||
masked_doc = original_doc
|
||||
for i in range(3): # Limit to 3 iterations for demo
|
||||
result = find_entity_alignment(entity, masked_doc)
|
||||
if result:
|
||||
start, end, found = result
|
||||
print(f"Iteration {i+1}: Found at positions {start}-{end}: '{found}'")
|
||||
|
||||
# Apply mask (but it's the same text)
|
||||
masked_doc = masked_doc[:start] + masked_text + masked_doc[end:]
|
||||
print(f"After mask: '{masked_doc}'")
|
||||
else:
|
||||
print(f"Iteration {i+1}: No occurrence found")
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
debug_position_issue()
|
||||
debug_infinite_loop()
|
||||
|
|
@ -1,129 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test file for address masking functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the backend directory to the Python path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from app.core.document_handlers.ner_processor import NerProcessor
|
||||
|
||||
|
||||
def test_address_masking():
|
||||
"""Test address masking with the new rules"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test cases based on the requirements
|
||||
test_cases = [
|
||||
("上海市静安区恒丰路66号白云大厦1607室", "上海市静安区HF路**号BY大厦****室"),
|
||||
("北京市朝阳区建国路88号SOHO现代城A座1001室", "北京市朝阳区JG路**号SOHO现代城A座****室"),
|
||||
("广州市天河区珠江新城花城大道123号富力中心B座2001室", "广州市天河区珠江新城HC大道**号FL中心B座****室"),
|
||||
("深圳市南山区科技园南区深南大道9988号腾讯大厦T1栋15楼", "深圳市南山区科技园南区SN大道**号TX大厦T1栋**楼"),
|
||||
]
|
||||
|
||||
for original_address, expected_masked in test_cases:
|
||||
masked = processor._mask_address(original_address)
|
||||
print(f"Original: {original_address}")
|
||||
print(f"Masked: {masked}")
|
||||
print(f"Expected: {expected_masked}")
|
||||
print("-" * 50)
|
||||
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
|
||||
|
||||
|
||||
def test_address_component_extraction():
|
||||
"""Test address component extraction"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test address component extraction
|
||||
test_cases = [
|
||||
("上海市静安区恒丰路66号白云大厦1607室", {
|
||||
"road_name": "恒丰路",
|
||||
"house_number": "66",
|
||||
"building_name": "白云大厦",
|
||||
"community_name": ""
|
||||
}),
|
||||
("北京市朝阳区建国路88号SOHO现代城A座1001室", {
|
||||
"road_name": "建国路",
|
||||
"house_number": "88",
|
||||
"building_name": "SOHO现代城",
|
||||
"community_name": ""
|
||||
}),
|
||||
]
|
||||
|
||||
for address, expected_components in test_cases:
|
||||
components = processor._extract_address_components(address)
|
||||
print(f"Address: {address}")
|
||||
print(f"Extracted components: {components}")
|
||||
print(f"Expected: {expected_components}")
|
||||
print("-" * 50)
|
||||
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
|
||||
|
||||
|
||||
def test_regex_fallback():
|
||||
"""Test regex fallback for address extraction"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test regex extraction (fallback method)
|
||||
test_address = "上海市静安区恒丰路66号白云大厦1607室"
|
||||
components = processor._extract_address_components_with_regex(test_address)
|
||||
|
||||
print(f"Address: {test_address}")
|
||||
print(f"Regex extracted components: {components}")
|
||||
|
||||
# Basic validation
|
||||
assert "road_name" in components
|
||||
assert "house_number" in components
|
||||
assert "building_name" in components
|
||||
assert "community_name" in components
|
||||
assert "confidence" in components
|
||||
|
||||
|
||||
def test_json_validation_for_address():
|
||||
"""Test JSON validation for address extraction responses"""
|
||||
from app.core.utils.llm_validator import LLMResponseValidator
|
||||
|
||||
# Test valid JSON response
|
||||
valid_response = {
|
||||
"road_name": "恒丰路",
|
||||
"house_number": "66",
|
||||
"building_name": "白云大厦",
|
||||
"community_name": "",
|
||||
"confidence": 0.9
|
||||
}
|
||||
assert LLMResponseValidator.validate_address_extraction(valid_response) == True
|
||||
|
||||
# Test invalid JSON response (missing required field)
|
||||
invalid_response = {
|
||||
"road_name": "恒丰路",
|
||||
"house_number": "66",
|
||||
"building_name": "白云大厦",
|
||||
"confidence": 0.9
|
||||
}
|
||||
assert LLMResponseValidator.validate_address_extraction(invalid_response) == False
|
||||
|
||||
# Test invalid JSON response (wrong type)
|
||||
invalid_response2 = {
|
||||
"road_name": 123,
|
||||
"house_number": "66",
|
||||
"building_name": "白云大厦",
|
||||
"community_name": "",
|
||||
"confidence": 0.9
|
||||
}
|
||||
assert LLMResponseValidator.validate_address_extraction(invalid_response2) == False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing Address Masking Functionality")
|
||||
print("=" * 50)
|
||||
|
||||
test_regex_fallback()
|
||||
print()
|
||||
test_json_validation_for_address()
|
||||
print()
|
||||
test_address_component_extraction()
|
||||
print()
|
||||
test_address_masking()
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
import pytest
|
||||
|
||||
def test_basic_discovery():
|
||||
"""Basic test to verify pytest discovery is working"""
|
||||
assert True
|
||||
|
||||
def test_import_works():
|
||||
"""Test that we can import from the app module"""
|
||||
try:
|
||||
from app.core.document_handlers.ner_processor import NerProcessor
|
||||
assert NerProcessor is not None
|
||||
except ImportError as e:
|
||||
pytest.fail(f"Failed to import NerProcessor: {e}")
|
||||
|
||||
def test_simple_math():
|
||||
"""Simple math test"""
|
||||
assert 1 + 1 == 2
|
||||
assert 2 * 3 == 6
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for character-by-character alignment functionality.
|
||||
This script demonstrates how the alignment handles different spacing patterns
|
||||
between entity text and original document text.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'backend'))
|
||||
|
||||
from app.core.document_handlers.ner_processor import NerProcessor
|
||||
|
||||
def main():
|
||||
"""Test the character alignment functionality."""
|
||||
processor = NerProcessor()
|
||||
|
||||
print("Testing Character-by-Character Alignment")
|
||||
print("=" * 50)
|
||||
|
||||
# Test the alignment functionality
|
||||
processor.test_character_alignment()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Testing Entity Masking with Alignment")
|
||||
print("=" * 50)
|
||||
|
||||
# Test entity masking with alignment
|
||||
original_document = "上诉人(原审原告):北京丰复久信营销科技有限公司,住所地北京市海淀区北小马厂6号1号楼华天大厦1306室。法定代表人:郭东军,执行董事、经理。委托诉讼代理人:周大海,北京市康达律师事务所律师。"
|
||||
|
||||
# Example entity mapping (from your NER results)
|
||||
entity_mapping = {
|
||||
"北京丰复久信营销科技有限公司": "北京JO营销科技有限公司",
|
||||
"郭东军": "郭DJ",
|
||||
"周大海": "周DH",
|
||||
"北京市康达律师事务所": "北京市KD律师事务所"
|
||||
}
|
||||
|
||||
print(f"Original document: {original_document}")
|
||||
print(f"Entity mapping: {entity_mapping}")
|
||||
|
||||
# Apply masking with alignment
|
||||
masked_document = processor.apply_entity_masking_with_alignment(
|
||||
original_document,
|
||||
entity_mapping
|
||||
)
|
||||
|
||||
print(f"Masked document: {masked_document}")
|
||||
|
||||
# Test with document that has spaces
|
||||
print("\n" + "=" * 50)
|
||||
print("Testing with Document Containing Spaces")
|
||||
print("=" * 50)
|
||||
|
||||
spaced_document = "上诉人(原审原告):北京 丰复久信 营销科技 有限公司,住所地北京市海淀区北小马厂6号1号楼华天大厦1306室。法定代表人:郭 东 军,执行董事、经理。"
|
||||
|
||||
print(f"Spaced document: {spaced_document}")
|
||||
|
||||
masked_spaced_document = processor.apply_entity_masking_with_alignment(
|
||||
spaced_document,
|
||||
entity_mapping
|
||||
)
|
||||
|
||||
print(f"Masked spaced document: {masked_spaced_document}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,230 +0,0 @@
|
|||
"""
|
||||
Test file for the enhanced OllamaClient with validation and retry mechanisms.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
# Add the current directory to the Python path
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
def test_ollama_client_initialization():
|
||||
"""Test OllamaClient initialization with new parameters"""
|
||||
from app.core.services.ollama_client import OllamaClient
|
||||
|
||||
# Test with default parameters
|
||||
client = OllamaClient("test-model")
|
||||
assert client.model_name == "test-model"
|
||||
assert client.base_url == "http://localhost:11434"
|
||||
assert client.max_retries == 3
|
||||
|
||||
# Test with custom parameters
|
||||
client = OllamaClient("test-model", "http://custom:11434", 5)
|
||||
assert client.model_name == "test-model"
|
||||
assert client.base_url == "http://custom:11434"
|
||||
assert client.max_retries == 5
|
||||
|
||||
print("✓ OllamaClient initialization tests passed")
|
||||
|
||||
|
||||
def test_generate_with_validation():
|
||||
"""Test generate_with_validation method"""
|
||||
from app.core.services.ollama_client import OllamaClient
|
||||
|
||||
# Mock the API response
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"response": '{"business_name": "测试公司", "confidence": 0.9}'
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response):
|
||||
client = OllamaClient("test-model")
|
||||
|
||||
# Test with business name extraction validation
|
||||
result = client.generate_with_validation(
|
||||
prompt="Extract business name from: 测试公司",
|
||||
response_type='business_name_extraction',
|
||||
return_parsed=True
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result.get('business_name') == '测试公司'
|
||||
assert result.get('confidence') == 0.9
|
||||
|
||||
print("✓ generate_with_validation test passed")
|
||||
|
||||
|
||||
def test_generate_with_schema():
|
||||
"""Test generate_with_schema method"""
|
||||
from app.core.services.ollama_client import OllamaClient
|
||||
|
||||
# Define a custom schema
|
||||
custom_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "number"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}
|
||||
|
||||
# Mock the API response
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"response": '{"name": "张三", "age": 30}'
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response):
|
||||
client = OllamaClient("test-model")
|
||||
|
||||
# Test with custom schema validation
|
||||
result = client.generate_with_schema(
|
||||
prompt="Generate person info",
|
||||
schema=custom_schema,
|
||||
return_parsed=True
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result.get('name') == '张三'
|
||||
assert result.get('age') == 30
|
||||
|
||||
print("✓ generate_with_schema test passed")
|
||||
|
||||
|
||||
def test_backward_compatibility():
|
||||
"""Test backward compatibility with original generate method"""
|
||||
from app.core.services.ollama_client import OllamaClient
|
||||
|
||||
# Mock the API response
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"response": "Simple text response"
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response):
|
||||
client = OllamaClient("test-model")
|
||||
|
||||
# Test original generate method (should still work)
|
||||
result = client.generate("Simple prompt")
|
||||
assert result == "Simple text response"
|
||||
|
||||
# Test with strip_think=False
|
||||
result = client.generate("Simple prompt", strip_think=False)
|
||||
assert result == "Simple text response"
|
||||
|
||||
print("✓ Backward compatibility tests passed")
|
||||
|
||||
|
||||
def test_retry_mechanism():
|
||||
"""Test retry mechanism for failed requests"""
|
||||
from app.core.services.ollama_client import OllamaClient
|
||||
import requests
|
||||
|
||||
# Mock failed requests followed by success
|
||||
mock_failed_response = Mock()
|
||||
mock_failed_response.raise_for_status.side_effect = requests.exceptions.RequestException("Connection failed")
|
||||
|
||||
mock_success_response = Mock()
|
||||
mock_success_response.json.return_value = {
|
||||
"response": "Success response"
|
||||
}
|
||||
mock_success_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', side_effect=[mock_failed_response, mock_success_response]):
|
||||
client = OllamaClient("test-model", max_retries=2)
|
||||
|
||||
# Should retry and eventually succeed
|
||||
result = client.generate("Test prompt")
|
||||
assert result == "Success response"
|
||||
|
||||
print("✓ Retry mechanism test passed")
|
||||
|
||||
|
||||
def test_validation_failure():
|
||||
"""Test validation failure handling"""
|
||||
from app.core.services.ollama_client import OllamaClient
|
||||
|
||||
# Mock API response with invalid JSON
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"response": "Invalid JSON response"
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response):
|
||||
client = OllamaClient("test-model", max_retries=2)
|
||||
|
||||
try:
|
||||
# This should fail validation and retry
|
||||
result = client.generate_with_validation(
|
||||
prompt="Test prompt",
|
||||
response_type='business_name_extraction',
|
||||
return_parsed=True
|
||||
)
|
||||
# If we get here, it means validation failed and retries were exhausted
|
||||
print("✓ Validation failure handling test passed")
|
||||
except ValueError as e:
|
||||
# Expected behavior - validation failed after retries
|
||||
assert "Failed to parse JSON response after all retries" in str(e)
|
||||
print("✓ Validation failure handling test passed")
|
||||
|
||||
|
||||
def test_enhanced_methods():
|
||||
"""Test the new enhanced methods"""
|
||||
from app.core.services.ollama_client import OllamaClient
|
||||
|
||||
# Mock the API response
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"response": '{"entities": [{"text": "张三", "type": "人名"}]}'
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response):
|
||||
client = OllamaClient("test-model")
|
||||
|
||||
# Test generate_with_validation
|
||||
result = client.generate_with_validation(
|
||||
prompt="Extract entities",
|
||||
response_type='entity_extraction',
|
||||
return_parsed=True
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert 'entities' in result
|
||||
assert len(result['entities']) == 1
|
||||
assert result['entities'][0]['text'] == '张三'
|
||||
|
||||
print("✓ Enhanced methods tests passed")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("Testing enhanced OllamaClient...")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
test_ollama_client_initialization()
|
||||
test_generate_with_validation()
|
||||
test_generate_with_schema()
|
||||
test_backward_compatibility()
|
||||
test_retry_mechanism()
|
||||
test_validation_failure()
|
||||
test_enhanced_methods()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("✓ All enhanced OllamaClient tests passed!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,186 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Final test to verify the fix handles multiple occurrences and prevents infinite loops.
|
||||
"""
|
||||
|
||||
def find_entity_alignment(entity_text: str, original_document_text: str):
|
||||
"""Simplified version of the alignment method for testing"""
|
||||
clean_entity = entity_text.replace(" ", "")
|
||||
doc_chars = [c for c in original_document_text if c != ' ']
|
||||
|
||||
for i in range(len(doc_chars) - len(clean_entity) + 1):
|
||||
if doc_chars[i:i+len(clean_entity)] == list(clean_entity):
|
||||
return map_char_positions_to_original(i, len(clean_entity), original_document_text)
|
||||
return None
|
||||
|
||||
def map_char_positions_to_original(clean_start: int, entity_length: int, original_text: str):
|
||||
"""Simplified version of position mapping for testing"""
|
||||
original_pos = 0
|
||||
clean_pos = 0
|
||||
|
||||
while clean_pos < clean_start and original_pos < len(original_text):
|
||||
if original_text[original_pos] != ' ':
|
||||
clean_pos += 1
|
||||
original_pos += 1
|
||||
|
||||
start_pos = original_pos
|
||||
|
||||
chars_found = 0
|
||||
while chars_found < entity_length and original_pos < len(original_text):
|
||||
if original_text[original_pos] != ' ':
|
||||
chars_found += 1
|
||||
original_pos += 1
|
||||
|
||||
end_pos = original_pos
|
||||
found_text = original_text[start_pos:end_pos]
|
||||
|
||||
return start_pos, end_pos, found_text
|
||||
|
||||
def apply_entity_masking_with_alignment_fixed(original_document_text: str, entity_mapping: dict):
|
||||
"""Fixed implementation that handles multiple occurrences and prevents infinite loops"""
|
||||
masked_document = original_document_text
|
||||
sorted_entities = sorted(entity_mapping.keys(), key=len, reverse=True)
|
||||
|
||||
for entity_text in sorted_entities:
|
||||
masked_text = entity_mapping[entity_text]
|
||||
|
||||
# Skip if masked text is the same as original text (prevents infinite loop)
|
||||
if entity_text == masked_text:
|
||||
print(f"Skipping entity '{entity_text}' as masked text is identical")
|
||||
continue
|
||||
|
||||
# Find ALL occurrences of this entity in the document
|
||||
# Add safety counter to prevent infinite loops
|
||||
max_iterations = 100 # Safety limit
|
||||
iteration_count = 0
|
||||
|
||||
while iteration_count < max_iterations:
|
||||
iteration_count += 1
|
||||
|
||||
# Find the entity in the current masked document using alignment
|
||||
alignment_result = find_entity_alignment(entity_text, masked_document)
|
||||
|
||||
if alignment_result:
|
||||
start_pos, end_pos, found_text = alignment_result
|
||||
|
||||
# Replace the found text with the masked version
|
||||
masked_document = (
|
||||
masked_document[:start_pos] +
|
||||
masked_text +
|
||||
masked_document[end_pos:]
|
||||
)
|
||||
|
||||
print(f"Masked entity '{entity_text}' -> '{masked_text}' at positions {start_pos}-{end_pos} (iteration {iteration_count})")
|
||||
else:
|
||||
# No more occurrences found for this entity, move to next entity
|
||||
print(f"No more occurrences of '{entity_text}' found in document after {iteration_count} iterations")
|
||||
break
|
||||
|
||||
# Log warning if we hit the safety limit
|
||||
if iteration_count >= max_iterations:
|
||||
print(f"WARNING: Reached maximum iterations ({max_iterations}) for entity '{entity_text}', stopping to prevent infinite loop")
|
||||
|
||||
return masked_document
|
||||
|
||||
def test_final_fix():
|
||||
"""Test the final fix with various scenarios"""
|
||||
|
||||
print("Testing Final Fix for Multiple Occurrences and Infinite Loop Prevention")
|
||||
print("=" * 70)
|
||||
|
||||
# Test case 1: Multiple occurrences of the same entity (should work)
|
||||
print("\nTest Case 1: Multiple occurrences of same entity")
|
||||
test_document_1 = "上诉人李淼因合同纠纷,法定代表人李淼,委托代理人李淼。"
|
||||
entity_mapping_1 = {"李淼": "李M"}
|
||||
|
||||
print(f"Original: {test_document_1}")
|
||||
result_1 = apply_entity_masking_with_alignment_fixed(test_document_1, entity_mapping_1)
|
||||
print(f"Result: {result_1}")
|
||||
|
||||
remaining_1 = result_1.count("李淼")
|
||||
expected_1 = "上诉人李M因合同纠纷,法定代表人李M,委托代理人李M。"
|
||||
|
||||
if result_1 == expected_1 and remaining_1 == 0:
|
||||
print("✅ PASS: All occurrences masked correctly")
|
||||
else:
|
||||
print(f"❌ FAIL: Expected '{expected_1}', got '{result_1}'")
|
||||
print(f" Remaining '李淼' occurrences: {remaining_1}")
|
||||
|
||||
# Test case 2: Entity with same masked text (should skip to prevent infinite loop)
|
||||
print("\nTest Case 2: Entity with same masked text (should skip)")
|
||||
test_document_2 = "上诉人李淼因合同纠纷,法定代表人李淼。北京丰复久信营销科技有限公司,丰复久信公司。"
|
||||
entity_mapping_2 = {
|
||||
"李淼": "李M",
|
||||
"丰复久信公司": "丰复久信公司" # Same text - should be skipped
|
||||
}
|
||||
|
||||
print(f"Original: {test_document_2}")
|
||||
result_2 = apply_entity_masking_with_alignment_fixed(test_document_2, entity_mapping_2)
|
||||
print(f"Result: {result_2}")
|
||||
|
||||
remaining_2_li = result_2.count("李淼")
|
||||
remaining_2_company = result_2.count("丰复久信公司")
|
||||
|
||||
if remaining_2_li == 0 and remaining_2_company == 1: # Company should remain unmasked
|
||||
print("✅ PASS: Infinite loop prevented, only different text masked")
|
||||
else:
|
||||
print(f"❌ FAIL: Remaining '李淼': {remaining_2_li}, '丰复久信公司': {remaining_2_company}")
|
||||
|
||||
# Test case 3: Mixed spacing scenarios
|
||||
print("\nTest Case 3: Mixed spacing scenarios")
|
||||
test_document_3 = "上诉人李 淼因合同纠纷,法定代表人李淼,委托代理人李 淼。"
|
||||
entity_mapping_3 = {"李 淼": "李M", "李淼": "李M"}
|
||||
|
||||
print(f"Original: {test_document_3}")
|
||||
result_3 = apply_entity_masking_with_alignment_fixed(test_document_3, entity_mapping_3)
|
||||
print(f"Result: {result_3}")
|
||||
|
||||
remaining_3 = result_3.count("李淼") + result_3.count("李 淼")
|
||||
|
||||
if remaining_3 == 0:
|
||||
print("✅ PASS: Mixed spacing handled correctly")
|
||||
else:
|
||||
print(f"❌ FAIL: Remaining occurrences: {remaining_3}")
|
||||
|
||||
# Test case 4: Complex document with real examples
|
||||
print("\nTest Case 4: Complex document with real examples")
|
||||
test_document_4 = """上诉人(原审原告):北京丰复久信营销科技有限公司,住所地北京市海淀区北小马厂6号1号楼华天大厦1306室。
|
||||
法定代表人:郭东军,执行董事、经理。
|
||||
委托诉讼代理人:周大海,北京市康达律师事务所律师。
|
||||
委托诉讼代理人:王乃哲,北京市康达律师事务所律师。
|
||||
被上诉人(原审被告):中研智创区块链技术有限公司,住所地天津市津南区双港镇工业园区优谷产业园5号楼-1505。
|
||||
法定代表人:王欢子,总经理。
|
||||
委托诉讼代理人:魏鑫,北京市昊衡律师事务所律师。"""
|
||||
|
||||
entity_mapping_4 = {
|
||||
"北京丰复久信营销科技有限公司": "北京JO营销科技有限公司",
|
||||
"郭东军": "郭DJ",
|
||||
"周大海": "周DH",
|
||||
"王乃哲": "王NZ",
|
||||
"中研智创区块链技术有限公司": "中研智创区块链技术有限公司", # Same text - should be skipped
|
||||
"王欢子": "王HZ",
|
||||
"魏鑫": "魏X",
|
||||
"北京市康达律师事务所": "北京市KD律师事务所",
|
||||
"北京市昊衡律师事务所": "北京市HH律师事务所"
|
||||
}
|
||||
|
||||
print(f"Original length: {len(test_document_4)} characters")
|
||||
result_4 = apply_entity_masking_with_alignment_fixed(test_document_4, entity_mapping_4)
|
||||
print(f"Result length: {len(result_4)} characters")
|
||||
|
||||
# Check that entities were masked correctly
|
||||
unmasked_entities = []
|
||||
for entity in entity_mapping_4.keys():
|
||||
if entity in result_4 and entity != entity_mapping_4[entity]: # Skip if masked text is same
|
||||
unmasked_entities.append(entity)
|
||||
|
||||
if not unmasked_entities:
|
||||
print("✅ PASS: All entities masked correctly in complex document")
|
||||
else:
|
||||
print(f"❌ FAIL: Unmasked entities: {unmasked_entities}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Final Fix Verification Completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_final_fix()
|
||||
|
|
@ -1,173 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test to verify the fix for multiple occurrence issue in apply_entity_masking_with_alignment.
|
||||
"""
|
||||
|
||||
def find_entity_alignment(entity_text: str, original_document_text: str):
|
||||
"""Simplified version of the alignment method for testing"""
|
||||
clean_entity = entity_text.replace(" ", "")
|
||||
doc_chars = [c for c in original_document_text if c != ' ']
|
||||
|
||||
for i in range(len(doc_chars) - len(clean_entity) + 1):
|
||||
if doc_chars[i:i+len(clean_entity)] == list(clean_entity):
|
||||
return map_char_positions_to_original(i, len(clean_entity), original_document_text)
|
||||
return None
|
||||
|
||||
def map_char_positions_to_original(clean_start: int, entity_length: int, original_text: str):
|
||||
"""Simplified version of position mapping for testing"""
|
||||
original_pos = 0
|
||||
clean_pos = 0
|
||||
|
||||
while clean_pos < clean_start and original_pos < len(original_text):
|
||||
if original_text[original_pos] != ' ':
|
||||
clean_pos += 1
|
||||
original_pos += 1
|
||||
|
||||
start_pos = original_pos
|
||||
|
||||
chars_found = 0
|
||||
while chars_found < entity_length and original_pos < len(original_text):
|
||||
if original_text[original_pos] != ' ':
|
||||
chars_found += 1
|
||||
original_pos += 1
|
||||
|
||||
end_pos = original_pos
|
||||
found_text = original_text[start_pos:end_pos]
|
||||
|
||||
return start_pos, end_pos, found_text
|
||||
|
||||
def apply_entity_masking_with_alignment_fixed(original_document_text: str, entity_mapping: dict):
|
||||
"""Fixed implementation that handles multiple occurrences"""
|
||||
masked_document = original_document_text
|
||||
sorted_entities = sorted(entity_mapping.keys(), key=len, reverse=True)
|
||||
|
||||
for entity_text in sorted_entities:
|
||||
masked_text = entity_mapping[entity_text]
|
||||
|
||||
# Find ALL occurrences of this entity in the document
|
||||
# We need to loop until no more matches are found
|
||||
while True:
|
||||
# Find the entity in the current masked document using alignment
|
||||
alignment_result = find_entity_alignment(entity_text, masked_document)
|
||||
|
||||
if alignment_result:
|
||||
start_pos, end_pos, found_text = alignment_result
|
||||
|
||||
# Replace the found text with the masked version
|
||||
masked_document = (
|
||||
masked_document[:start_pos] +
|
||||
masked_text +
|
||||
masked_document[end_pos:]
|
||||
)
|
||||
|
||||
print(f"Masked entity '{entity_text}' -> '{masked_text}' at positions {start_pos}-{end_pos}")
|
||||
else:
|
||||
# No more occurrences found for this entity, move to next entity
|
||||
print(f"No more occurrences of '{entity_text}' found in document")
|
||||
break
|
||||
|
||||
return masked_document
|
||||
|
||||
def test_fix_verification():
|
||||
"""Test to verify the fix works correctly"""
|
||||
|
||||
print("Testing Fix for Multiple Occurrence Issue")
|
||||
print("=" * 60)
|
||||
|
||||
# Test case 1: Multiple occurrences of the same entity
|
||||
print("\nTest Case 1: Multiple occurrences of same entity")
|
||||
test_document_1 = "上诉人李淼因合同纠纷,法定代表人李淼,委托代理人李淼。"
|
||||
entity_mapping_1 = {"李淼": "李M"}
|
||||
|
||||
print(f"Original: {test_document_1}")
|
||||
result_1 = apply_entity_masking_with_alignment_fixed(test_document_1, entity_mapping_1)
|
||||
print(f"Result: {result_1}")
|
||||
|
||||
remaining_1 = result_1.count("李淼")
|
||||
expected_1 = "上诉人李M因合同纠纷,法定代表人李M,委托代理人李M。"
|
||||
|
||||
if result_1 == expected_1 and remaining_1 == 0:
|
||||
print("✅ PASS: All occurrences masked correctly")
|
||||
else:
|
||||
print(f"❌ FAIL: Expected '{expected_1}', got '{result_1}'")
|
||||
print(f" Remaining '李淼' occurrences: {remaining_1}")
|
||||
|
||||
# Test case 2: Multiple entities with multiple occurrences
|
||||
print("\nTest Case 2: Multiple entities with multiple occurrences")
|
||||
test_document_2 = "上诉人李淼因合同纠纷,法定代表人李淼。北京丰复久信营销科技有限公司,丰复久信公司。"
|
||||
entity_mapping_2 = {
|
||||
"李淼": "李M",
|
||||
"北京丰复久信营销科技有限公司": "北京JO营销科技有限公司",
|
||||
"丰复久信公司": "丰复久信公司"
|
||||
}
|
||||
|
||||
print(f"Original: {test_document_2}")
|
||||
result_2 = apply_entity_masking_with_alignment_fixed(test_document_2, entity_mapping_2)
|
||||
print(f"Result: {result_2}")
|
||||
|
||||
remaining_2_li = result_2.count("李淼")
|
||||
remaining_2_company = result_2.count("北京丰复久信营销科技有限公司")
|
||||
|
||||
if remaining_2_li == 0 and remaining_2_company == 0:
|
||||
print("✅ PASS: All entities masked correctly")
|
||||
else:
|
||||
print(f"❌ FAIL: Remaining '李淼': {remaining_2_li}, '北京丰复久信营销科技有限公司': {remaining_2_company}")
|
||||
|
||||
# Test case 3: Mixed spacing scenarios
|
||||
print("\nTest Case 3: Mixed spacing scenarios")
|
||||
test_document_3 = "上诉人李 淼因合同纠纷,法定代表人李淼,委托代理人李 淼。"
|
||||
entity_mapping_3 = {"李 淼": "李M", "李淼": "李M"}
|
||||
|
||||
print(f"Original: {test_document_3}")
|
||||
result_3 = apply_entity_masking_with_alignment_fixed(test_document_3, entity_mapping_3)
|
||||
print(f"Result: {result_3}")
|
||||
|
||||
remaining_3 = result_3.count("李淼") + result_3.count("李 淼")
|
||||
|
||||
if remaining_3 == 0:
|
||||
print("✅ PASS: Mixed spacing handled correctly")
|
||||
else:
|
||||
print(f"❌ FAIL: Remaining occurrences: {remaining_3}")
|
||||
|
||||
# Test case 4: Complex document with real examples
|
||||
print("\nTest Case 4: Complex document with real examples")
|
||||
test_document_4 = """上诉人(原审原告):北京丰复久信营销科技有限公司,住所地北京市海淀区北小马厂6号1号楼华天大厦1306室。
|
||||
法定代表人:郭东军,执行董事、经理。
|
||||
委托诉讼代理人:周大海,北京市康达律师事务所律师。
|
||||
委托诉讼代理人:王乃哲,北京市康达律师事务所律师。
|
||||
被上诉人(原审被告):中研智创区块链技术有限公司,住所地天津市津南区双港镇工业园区优谷产业园5号楼-1505。
|
||||
法定代表人:王欢子,总经理。
|
||||
委托诉讼代理人:魏鑫,北京市昊衡律师事务所律师。"""
|
||||
|
||||
entity_mapping_4 = {
|
||||
"北京丰复久信营销科技有限公司": "北京JO营销科技有限公司",
|
||||
"郭东军": "郭DJ",
|
||||
"周大海": "周DH",
|
||||
"王乃哲": "王NZ",
|
||||
"中研智创区块链技术有限公司": "中研智创区块链技术有限公司",
|
||||
"王欢子": "王HZ",
|
||||
"魏鑫": "魏X",
|
||||
"北京市康达律师事务所": "北京市KD律师事务所",
|
||||
"北京市昊衡律师事务所": "北京市HH律师事务所"
|
||||
}
|
||||
|
||||
print(f"Original length: {len(test_document_4)} characters")
|
||||
result_4 = apply_entity_masking_with_alignment_fixed(test_document_4, entity_mapping_4)
|
||||
print(f"Result length: {len(result_4)} characters")
|
||||
|
||||
# Check that all entities were masked
|
||||
unmasked_entities = []
|
||||
for entity in entity_mapping_4.keys():
|
||||
if entity in result_4:
|
||||
unmasked_entities.append(entity)
|
||||
|
||||
if not unmasked_entities:
|
||||
print("✅ PASS: All entities masked in complex document")
|
||||
else:
|
||||
print(f"❌ FAIL: Unmasked entities: {unmasked_entities}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Fix Verification Completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fix_verification()
|
||||
|
|
@ -1,169 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test file for ID and social credit code masking functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the backend directory to the Python path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from app.core.document_handlers.ner_processor import NerProcessor
|
||||
|
||||
|
||||
def test_id_number_masking():
|
||||
"""Test ID number masking with the new rules"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test cases based on the requirements
|
||||
test_cases = [
|
||||
("310103198802080000", "310103XXXXXXXXXXXX"),
|
||||
("110101199001011234", "110101XXXXXXXXXXXX"),
|
||||
("440301199505151234", "440301XXXXXXXXXXXX"),
|
||||
("320102198712345678", "320102XXXXXXXXXXXX"),
|
||||
("12345", "12345"), # Edge case: too short
|
||||
]
|
||||
|
||||
for original_id, expected_masked in test_cases:
|
||||
# Create a mock entity for testing
|
||||
entity = {'text': original_id, 'type': '身份证号'}
|
||||
unique_entities = [entity]
|
||||
linkage = {'entity_groups': []}
|
||||
|
||||
# Test the masking through the full pipeline
|
||||
mapping = processor._generate_masked_mapping(unique_entities, linkage)
|
||||
masked = mapping.get(original_id, original_id)
|
||||
|
||||
print(f"Original ID: {original_id}")
|
||||
print(f"Masked ID: {masked}")
|
||||
print(f"Expected: {expected_masked}")
|
||||
print(f"Match: {masked == expected_masked}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def test_social_credit_code_masking():
|
||||
"""Test social credit code masking with the new rules"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test cases based on the requirements
|
||||
test_cases = [
|
||||
("9133021276453538XT", "913302XXXXXXXXXXXX"),
|
||||
("91110000100000000X", "9111000XXXXXXXXXXX"),
|
||||
("914403001922038216", "9144030XXXXXXXXXXX"),
|
||||
("91310000132209458G", "9131000XXXXXXXXXXX"),
|
||||
("123456", "123456"), # Edge case: too short
|
||||
]
|
||||
|
||||
for original_code, expected_masked in test_cases:
|
||||
# Create a mock entity for testing
|
||||
entity = {'text': original_code, 'type': '社会信用代码'}
|
||||
unique_entities = [entity]
|
||||
linkage = {'entity_groups': []}
|
||||
|
||||
# Test the masking through the full pipeline
|
||||
mapping = processor._generate_masked_mapping(unique_entities, linkage)
|
||||
masked = mapping.get(original_code, original_code)
|
||||
|
||||
print(f"Original Code: {original_code}")
|
||||
print(f"Masked Code: {masked}")
|
||||
print(f"Expected: {expected_masked}")
|
||||
print(f"Match: {masked == expected_masked}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def test_edge_cases():
|
||||
"""Test edge cases for ID and social credit code masking"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test edge cases
|
||||
edge_cases = [
|
||||
("", ""), # Empty string
|
||||
("123", "123"), # Too short for ID
|
||||
("123456", "123456"), # Too short for social credit code
|
||||
("123456789012345678901234567890", "123456XXXXXXXXXXXXXXXXXX"), # Very long ID
|
||||
]
|
||||
|
||||
for original, expected in edge_cases:
|
||||
# Test ID number
|
||||
entity_id = {'text': original, 'type': '身份证号'}
|
||||
mapping_id = processor._generate_masked_mapping([entity_id], {'entity_groups': []})
|
||||
masked_id = mapping_id.get(original, original)
|
||||
|
||||
# Test social credit code
|
||||
entity_code = {'text': original, 'type': '社会信用代码'}
|
||||
mapping_code = processor._generate_masked_mapping([entity_code], {'entity_groups': []})
|
||||
masked_code = mapping_code.get(original, original)
|
||||
|
||||
print(f"Original: {original}")
|
||||
print(f"ID Masked: {masked_id}")
|
||||
print(f"Code Masked: {masked_code}")
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def test_mixed_entities():
|
||||
"""Test masking with mixed entity types"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Create mixed entities
|
||||
entities = [
|
||||
{'text': '310103198802080000', 'type': '身份证号'},
|
||||
{'text': '9133021276453538XT', 'type': '社会信用代码'},
|
||||
{'text': '李强', 'type': '人名'},
|
||||
{'text': '上海盒马网络科技有限公司', 'type': '公司名称'},
|
||||
]
|
||||
|
||||
linkage = {'entity_groups': []}
|
||||
|
||||
# Test the masking through the full pipeline
|
||||
mapping = processor._generate_masked_mapping(entities, linkage)
|
||||
|
||||
print("Mixed Entities Test:")
|
||||
print("=" * 30)
|
||||
for entity in entities:
|
||||
original = entity['text']
|
||||
entity_type = entity['type']
|
||||
masked = mapping.get(original, original)
|
||||
print(f"{entity_type}: {original} -> {masked}")
|
||||
|
||||
def test_id_masking():
|
||||
"""Test ID number and social credit code masking"""
|
||||
from app.core.document_handlers.ner_processor import NerProcessor
|
||||
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test ID number masking
|
||||
id_entity = {'text': '310103198802080000', 'type': '身份证号'}
|
||||
id_mapping = processor._generate_masked_mapping([id_entity], {'entity_groups': []})
|
||||
masked_id = id_mapping.get('310103198802080000', '')
|
||||
|
||||
# Test social credit code masking
|
||||
code_entity = {'text': '9133021276453538XT', 'type': '社会信用代码'}
|
||||
code_mapping = processor._generate_masked_mapping([code_entity], {'entity_groups': []})
|
||||
masked_code = code_mapping.get('9133021276453538XT', '')
|
||||
|
||||
# Verify the masking rules
|
||||
assert masked_id.startswith('310103') # First 6 digits preserved
|
||||
assert masked_id.endswith('XXXXXXXXXXXX') # Rest masked with X
|
||||
assert len(masked_id) == 18 # Total length preserved
|
||||
|
||||
assert masked_code.startswith('913302') # First 7 digits preserved
|
||||
assert masked_code.endswith('XXXXXXXXXXXX') # Rest masked with X
|
||||
assert len(masked_code) == 18 # Total length preserved
|
||||
|
||||
print(f"ID masking: 310103198802080000 -> {masked_id}")
|
||||
print(f"Code masking: 9133021276453538XT -> {masked_code}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing ID and Social Credit Code Masking")
|
||||
print("=" * 50)
|
||||
|
||||
test_id_number_masking()
|
||||
print()
|
||||
test_social_credit_code_masking()
|
||||
print()
|
||||
test_edge_cases()
|
||||
print()
|
||||
test_mixed_entities()
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test to verify the multiple occurrence issue in apply_entity_masking_with_alignment.
|
||||
"""
|
||||
|
||||
def find_entity_alignment(entity_text: str, original_document_text: str):
|
||||
"""Simplified version of the alignment method for testing"""
|
||||
clean_entity = entity_text.replace(" ", "")
|
||||
doc_chars = [c for c in original_document_text if c != ' ']
|
||||
|
||||
for i in range(len(doc_chars) - len(clean_entity) + 1):
|
||||
if doc_chars[i:i+len(clean_entity)] == list(clean_entity):
|
||||
return map_char_positions_to_original(i, len(clean_entity), original_document_text)
|
||||
return None
|
||||
|
||||
def map_char_positions_to_original(clean_start: int, entity_length: int, original_text: str):
|
||||
"""Simplified version of position mapping for testing"""
|
||||
original_pos = 0
|
||||
clean_pos = 0
|
||||
|
||||
while clean_pos < clean_start and original_pos < len(original_text):
|
||||
if original_text[original_pos] != ' ':
|
||||
clean_pos += 1
|
||||
original_pos += 1
|
||||
|
||||
start_pos = original_pos
|
||||
|
||||
chars_found = 0
|
||||
while chars_found < entity_length and original_pos < len(original_text):
|
||||
if original_text[original_pos] != ' ':
|
||||
chars_found += 1
|
||||
original_pos += 1
|
||||
|
||||
end_pos = original_pos
|
||||
found_text = original_text[start_pos:end_pos]
|
||||
|
||||
return start_pos, end_pos, found_text
|
||||
|
||||
def apply_entity_masking_with_alignment_current(original_document_text: str, entity_mapping: dict):
|
||||
"""Current implementation with the bug"""
|
||||
masked_document = original_document_text
|
||||
sorted_entities = sorted(entity_mapping.keys(), key=len, reverse=True)
|
||||
|
||||
for entity_text in sorted_entities:
|
||||
masked_text = entity_mapping[entity_text]
|
||||
|
||||
# Find the entity in the original document using alignment
|
||||
alignment_result = find_entity_alignment(entity_text, masked_document)
|
||||
|
||||
if alignment_result:
|
||||
start_pos, end_pos, found_text = alignment_result
|
||||
|
||||
# Replace the found text with the masked version
|
||||
masked_document = (
|
||||
masked_document[:start_pos] +
|
||||
masked_text +
|
||||
masked_document[end_pos:]
|
||||
)
|
||||
|
||||
print(f"Masked entity '{entity_text}' -> '{masked_text}' at positions {start_pos}-{end_pos}")
|
||||
else:
|
||||
print(f"Could not find entity '{entity_text}' in document for masking")
|
||||
|
||||
return masked_document
|
||||
|
||||
def test_multiple_occurrences():
|
||||
"""Test the multiple occurrence issue"""
|
||||
|
||||
print("Testing Multiple Occurrence Issue")
|
||||
print("=" * 50)
|
||||
|
||||
# Test document with multiple occurrences of the same entity
|
||||
test_document = "上诉人李淼因合同纠纷,法定代表人李淼,委托代理人李淼。"
|
||||
entity_mapping = {
|
||||
"李淼": "李M"
|
||||
}
|
||||
|
||||
print(f"Original document: {test_document}")
|
||||
print(f"Entity mapping: {entity_mapping}")
|
||||
print(f"Expected: All 3 occurrences of '李淼' should be masked")
|
||||
|
||||
# Test current implementation
|
||||
result = apply_entity_masking_with_alignment_current(test_document, entity_mapping)
|
||||
print(f"Current result: {result}")
|
||||
|
||||
# Count remaining occurrences
|
||||
remaining_count = result.count("李淼")
|
||||
print(f"Remaining '李淼' occurrences: {remaining_count}")
|
||||
|
||||
if remaining_count > 0:
|
||||
print("❌ ISSUE CONFIRMED: Multiple occurrences are not being masked!")
|
||||
else:
|
||||
print("✅ No issue found (unexpected)")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_multiple_occurrences()
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for NER extractor integration
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
# Add the backend directory to the Python path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend'))
|
||||
|
||||
from app.core.document_handlers.extractors.ner_extractor import NERExtractor
|
||||
from app.core.document_handlers.ner_processor import NerProcessor
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_ner_extractor():
|
||||
"""Test the NER extractor directly"""
|
||||
print("🧪 Testing NER Extractor")
|
||||
print("=" * 50)
|
||||
|
||||
# Sample legal text
|
||||
text_to_analyze = """
|
||||
上诉人(原审原告):北京丰复久信营销科技有限公司,住所地北京市海淀区北小马厂6号1号楼华天大厦1306室。
|
||||
法定代表人:郭东军,执行董事、经理。
|
||||
委托诉讼代理人:周大海,北京市康达律师事务所律师。
|
||||
被上诉人(原审被告):中研智创区块链技术有限公司,住所地天津市津南区双港镇工业园区优谷产业园5号楼-1505。
|
||||
法定代表人:王欢子,总经理。
|
||||
"""
|
||||
|
||||
try:
|
||||
# Test NER extractor
|
||||
print("1. Testing NER Extractor...")
|
||||
ner_extractor = NERExtractor()
|
||||
|
||||
# Get model info
|
||||
model_info = ner_extractor.get_model_info()
|
||||
print(f" Model: {model_info['model_name']}")
|
||||
print(f" Supported entities: {model_info['supported_entities']}")
|
||||
|
||||
# Extract entities
|
||||
result = ner_extractor.extract_and_summarize(text_to_analyze)
|
||||
|
||||
print(f"\n2. Extraction Results:")
|
||||
print(f" Total entities found: {result['total_count']}")
|
||||
|
||||
for entity in result['entities']:
|
||||
print(f" - '{entity['text']}' ({entity['type']}) - Confidence: {entity['confidence']:.4f}")
|
||||
|
||||
print(f"\n3. Summary:")
|
||||
for entity_type, texts in result['summary']['summary'].items():
|
||||
print(f" {entity_type}: {len(texts)} entities")
|
||||
for text in texts:
|
||||
print(f" - {text}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ NER Extractor test failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_ner_processor():
|
||||
"""Test the NER processor integration"""
|
||||
print("\n🧪 Testing NER Processor Integration")
|
||||
print("=" * 50)
|
||||
|
||||
# Sample legal text
|
||||
text_to_analyze = """
|
||||
上诉人(原审原告):北京丰复久信营销科技有限公司,住所地北京市海淀区北小马厂6号1号楼华天大厦1306室。
|
||||
法定代表人:郭东军,执行董事、经理。
|
||||
委托诉讼代理人:周大海,北京市康达律师事务所律师。
|
||||
被上诉人(原审被告):中研智创区块链技术有限公司,住所地天津市津南区双港镇工业园区优谷产业园5号楼-1505。
|
||||
法定代表人:王欢子,总经理。
|
||||
"""
|
||||
|
||||
try:
|
||||
# Test NER processor
|
||||
print("1. Testing NER Processor...")
|
||||
ner_processor = NerProcessor()
|
||||
|
||||
# Test NER-only extraction
|
||||
print("2. Testing NER-only entity extraction...")
|
||||
ner_entities = ner_processor.extract_entities_with_ner(text_to_analyze)
|
||||
print(f" Extracted {len(ner_entities)} entities with NER model")
|
||||
|
||||
for entity in ner_entities:
|
||||
print(f" - '{entity['text']}' ({entity['type']}) - Confidence: {entity['confidence']:.4f}")
|
||||
|
||||
# Test NER-only processing
|
||||
print("\n3. Testing NER-only document processing...")
|
||||
chunks = [text_to_analyze] # Single chunk for testing
|
||||
mapping = ner_processor.process_ner_only(chunks)
|
||||
|
||||
print(f" Generated {len(mapping)} masking mappings")
|
||||
for original, masked in mapping.items():
|
||||
print(f" '{original}' -> '{masked}'")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ NER Processor test failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
print("🧪 NER Integration Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
# Test 1: NER Extractor
|
||||
extractor_success = test_ner_extractor()
|
||||
|
||||
# Test 2: NER Processor Integration
|
||||
processor_success = test_ner_processor()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 Test Summary:")
|
||||
print(f" NER Extractor: {'✅' if extractor_success else '❌'}")
|
||||
print(f" NER Processor: {'✅' if processor_success else '❌'}")
|
||||
|
||||
if extractor_success and processor_success:
|
||||
print("\n🎉 All tests passed! NER integration is working correctly.")
|
||||
print("\nNext steps:")
|
||||
print("1. The NER extractor is ready to use in the document processing pipeline")
|
||||
print("2. You can use process_ner_only() for ML-based entity extraction")
|
||||
print("3. The existing process() method now includes NER extraction")
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the error messages above.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,275 +0,0 @@
|
|||
import pytest
|
||||
from app.core.document_handlers.ner_processor import NerProcessor
|
||||
|
||||
def test_generate_masked_mapping():
|
||||
processor = NerProcessor()
|
||||
unique_entities = [
|
||||
{'text': '李强', 'type': '人名'},
|
||||
{'text': '李强', 'type': '人名'}, # Duplicate to test numbering
|
||||
{'text': '王小明', 'type': '人名'},
|
||||
{'text': 'Acme Manufacturing Inc.', 'type': '英文公司名', 'industry': 'manufacturing'},
|
||||
{'text': 'Google LLC', 'type': '英文公司名'},
|
||||
{'text': 'A公司', 'type': '公司名称'},
|
||||
{'text': 'B公司', 'type': '公司名称'},
|
||||
{'text': 'John Smith', 'type': '英文人名'},
|
||||
{'text': 'Elizabeth Windsor', 'type': '英文人名'},
|
||||
{'text': '华梦龙光伏项目', 'type': '项目名'},
|
||||
{'text': '案号12345', 'type': '案号'},
|
||||
{'text': '310101198802080000', 'type': '身份证号'},
|
||||
{'text': '9133021276453538XT', 'type': '社会信用代码'},
|
||||
]
|
||||
linkage = {
|
||||
'entity_groups': [
|
||||
{
|
||||
'group_id': 'g1',
|
||||
'group_type': '公司名称',
|
||||
'entities': [
|
||||
{'text': 'A公司', 'type': '公司名称', 'is_primary': True},
|
||||
{'text': 'B公司', 'type': '公司名称', 'is_primary': False},
|
||||
]
|
||||
},
|
||||
{
|
||||
'group_id': 'g2',
|
||||
'group_type': '人名',
|
||||
'entities': [
|
||||
{'text': '李强', 'type': '人名', 'is_primary': True},
|
||||
{'text': '李强', 'type': '人名', 'is_primary': False},
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
mapping = processor._generate_masked_mapping(unique_entities, linkage)
|
||||
# 人名 - Updated for new Chinese name masking rules
|
||||
assert mapping['李强'] == '李Q'
|
||||
assert mapping['王小明'] == '王XM'
|
||||
# 英文公司名
|
||||
assert mapping['Acme Manufacturing Inc.'] == 'MANUFACTURING'
|
||||
assert mapping['Google LLC'] == 'COMPANY'
|
||||
# 公司名同组 - Updated for new company masking rules
|
||||
# Note: The exact results may vary due to LLM extraction
|
||||
assert '公司' in mapping['A公司'] or mapping['A公司'] != 'A公司'
|
||||
assert '公司' in mapping['B公司'] or mapping['B公司'] != 'B公司'
|
||||
# 英文人名
|
||||
assert mapping['John Smith'] == 'J*** S***'
|
||||
assert mapping['Elizabeth Windsor'] == 'E*** W***'
|
||||
# 项目名
|
||||
assert mapping['华梦龙光伏项目'].endswith('项目')
|
||||
# 案号
|
||||
assert mapping['案号12345'] == '***'
|
||||
# 身份证号
|
||||
assert mapping['310101198802080000'] == 'XXXXXX'
|
||||
# 社会信用代码
|
||||
assert mapping['9133021276453538XT'] == 'XXXXXXXX'
|
||||
|
||||
|
||||
def test_chinese_name_pinyin_masking():
|
||||
"""Test Chinese name masking with pinyin functionality"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test basic Chinese name masking
|
||||
test_cases = [
|
||||
("李强", "李Q"),
|
||||
("张韶涵", "张SH"),
|
||||
("张若宇", "张RY"),
|
||||
("白锦程", "白JC"),
|
||||
("王小明", "王XM"),
|
||||
("陈志强", "陈ZQ"),
|
||||
]
|
||||
|
||||
surname_counter = {}
|
||||
|
||||
for original_name, expected_masked in test_cases:
|
||||
masked = processor._mask_chinese_name(original_name, surname_counter)
|
||||
assert masked == expected_masked, f"Expected {expected_masked}, got {masked} for {original_name}"
|
||||
|
||||
# Test duplicate handling
|
||||
duplicate_test_cases = [
|
||||
("李强", "李Q"),
|
||||
("李强", "李Q2"), # Should be numbered
|
||||
("李倩", "李Q3"), # Should be numbered
|
||||
("张韶涵", "张SH"),
|
||||
("张韶涵", "张SH2"), # Should be numbered
|
||||
("张若宇", "张RY"), # Different initials, should not be numbered
|
||||
]
|
||||
|
||||
surname_counter = {} # Reset counter
|
||||
|
||||
for original_name, expected_masked in duplicate_test_cases:
|
||||
masked = processor._mask_chinese_name(original_name, surname_counter)
|
||||
assert masked == expected_masked, f"Expected {expected_masked}, got {masked} for {original_name}"
|
||||
|
||||
# Test edge cases
|
||||
edge_cases = [
|
||||
("", ""), # Empty string
|
||||
("李", "李"), # Single character
|
||||
("李强强", "李QQ"), # Multiple characters with same pinyin
|
||||
]
|
||||
|
||||
surname_counter = {} # Reset counter
|
||||
|
||||
for original_name, expected_masked in edge_cases:
|
||||
masked = processor._mask_chinese_name(original_name, surname_counter)
|
||||
assert masked == expected_masked, f"Expected {expected_masked}, got {masked} for {original_name}"
|
||||
|
||||
|
||||
def test_chinese_name_integration():
|
||||
"""Test Chinese name masking integrated with the full mapping process"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test Chinese names in the full mapping context
|
||||
unique_entities = [
|
||||
{'text': '李强', 'type': '人名'},
|
||||
{'text': '张韶涵', 'type': '人名'},
|
||||
{'text': '张若宇', 'type': '人名'},
|
||||
{'text': '白锦程', 'type': '人名'},
|
||||
{'text': '李强', 'type': '人名'}, # Duplicate
|
||||
{'text': '张韶涵', 'type': '人名'}, # Duplicate
|
||||
]
|
||||
|
||||
linkage = {
|
||||
'entity_groups': [
|
||||
{
|
||||
'group_id': 'g1',
|
||||
'group_type': '人名',
|
||||
'entities': [
|
||||
{'text': '李强', 'type': '人名', 'is_primary': True},
|
||||
{'text': '张韶涵', 'type': '人名', 'is_primary': True},
|
||||
{'text': '张若宇', 'type': '人名', 'is_primary': True},
|
||||
{'text': '白锦程', 'type': '人名', 'is_primary': True},
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
mapping = processor._generate_masked_mapping(unique_entities, linkage)
|
||||
|
||||
# Verify the mapping results
|
||||
assert mapping['李强'] == '李Q'
|
||||
assert mapping['张韶涵'] == '张SH'
|
||||
assert mapping['张若宇'] == '张RY'
|
||||
assert mapping['白锦程'] == '白JC'
|
||||
|
||||
# Check that duplicates are handled correctly
|
||||
# The second occurrence should be numbered
|
||||
assert '李Q2' in mapping.values() or '张SH2' in mapping.values()
|
||||
|
||||
|
||||
def test_lawyer_and_judge_names():
|
||||
"""Test that lawyer and judge names follow the same Chinese name rules"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test lawyer and judge names
|
||||
test_entities = [
|
||||
{'text': '王律师', 'type': '律师姓名'},
|
||||
{'text': '李法官', 'type': '审判人员姓名'},
|
||||
{'text': '张检察官', 'type': '检察官姓名'},
|
||||
]
|
||||
|
||||
linkage = {
|
||||
'entity_groups': [
|
||||
{
|
||||
'group_id': 'g1',
|
||||
'group_type': '律师姓名',
|
||||
'entities': [{'text': '王律师', 'type': '律师姓名', 'is_primary': True}]
|
||||
},
|
||||
{
|
||||
'group_id': 'g2',
|
||||
'group_type': '审判人员姓名',
|
||||
'entities': [{'text': '李法官', 'type': '审判人员姓名', 'is_primary': True}]
|
||||
},
|
||||
{
|
||||
'group_id': 'g3',
|
||||
'group_type': '检察官姓名',
|
||||
'entities': [{'text': '张检察官', 'type': '检察官姓名', 'is_primary': True}]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
mapping = processor._generate_masked_mapping(test_entities, linkage)
|
||||
|
||||
# These should follow the same Chinese name masking rules
|
||||
assert mapping['王律师'] == '王L'
|
||||
assert mapping['李法官'] == '李F'
|
||||
assert mapping['张检察官'] == '张JC'
|
||||
|
||||
|
||||
def test_company_name_masking():
|
||||
"""Test company name masking with business name extraction"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test basic company name masking
|
||||
test_cases = [
|
||||
("上海盒马网络科技有限公司", "上海JO网络科技有限公司"),
|
||||
("丰田通商(上海)有限公司", "HVVU(上海)有限公司"),
|
||||
("雅诗兰黛(上海)商贸有限公司", "AUNF(上海)商贸有限公司"),
|
||||
("北京百度网讯科技有限公司", "北京BC网讯科技有限公司"),
|
||||
("腾讯科技(深圳)有限公司", "TU科技(深圳)有限公司"),
|
||||
("阿里巴巴集团控股有限公司", "阿里巴巴集团控股有限公司"), # 商号可能无法正确提取
|
||||
]
|
||||
|
||||
for original_name, expected_masked in test_cases:
|
||||
masked = processor._mask_company_name(original_name)
|
||||
print(f"{original_name} -> {masked} (expected: {expected_masked})")
|
||||
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
|
||||
|
||||
|
||||
def test_business_name_extraction():
|
||||
"""Test business name extraction from company names"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test business name extraction
|
||||
test_cases = [
|
||||
("上海盒马网络科技有限公司", "盒马"),
|
||||
("丰田通商(上海)有限公司", "丰田通商"),
|
||||
("雅诗兰黛(上海)商贸有限公司", "雅诗兰黛"),
|
||||
("北京百度网讯科技有限公司", "百度"),
|
||||
("腾讯科技(深圳)有限公司", "腾讯"),
|
||||
("律师事务所", "律师事务所"), # Edge case
|
||||
]
|
||||
|
||||
for company_name, expected_business_name in test_cases:
|
||||
business_name = processor._extract_business_name(company_name)
|
||||
print(f"Company: {company_name} -> Business Name: {business_name} (expected: {expected_business_name})")
|
||||
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
|
||||
|
||||
|
||||
def test_json_validation_for_business_name():
|
||||
"""Test JSON validation for business name extraction responses"""
|
||||
from app.core.utils.llm_validator import LLMResponseValidator
|
||||
|
||||
# Test valid JSON response
|
||||
valid_response = {
|
||||
"business_name": "盒马",
|
||||
"confidence": 0.9
|
||||
}
|
||||
assert LLMResponseValidator.validate_business_name_extraction(valid_response) == True
|
||||
|
||||
# Test invalid JSON response (missing required field)
|
||||
invalid_response = {
|
||||
"confidence": 0.9
|
||||
}
|
||||
assert LLMResponseValidator.validate_business_name_extraction(invalid_response) == False
|
||||
|
||||
# Test invalid JSON response (wrong type)
|
||||
invalid_response2 = {
|
||||
"business_name": 123,
|
||||
"confidence": 0.9
|
||||
}
|
||||
assert LLMResponseValidator.validate_business_name_extraction(invalid_response2) == False
|
||||
|
||||
|
||||
def test_law_firm_masking():
|
||||
"""Test law firm name masking"""
|
||||
processor = NerProcessor()
|
||||
|
||||
# Test law firm name masking
|
||||
test_cases = [
|
||||
("北京大成律师事务所", "北京D律师事务所"),
|
||||
("上海锦天城律师事务所", "上海JTC律师事务所"),
|
||||
("广东广信君达律师事务所", "广东GXJD律师事务所"),
|
||||
]
|
||||
|
||||
for original_name, expected_masked in test_cases:
|
||||
masked = processor._mask_company_name(original_name)
|
||||
print(f"{original_name} -> {masked} (expected: {expected_masked})")
|
||||
# Note: The exact results may vary due to LLM extraction, so we'll just print for verification
|
||||
|
|
@ -1,128 +0,0 @@
|
|||
"""
|
||||
Tests for the refactored NerProcessor.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the backend directory to the Python path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from app.core.document_handlers.ner_processor_refactored import NerProcessorRefactored
|
||||
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker, EnglishNameMasker
|
||||
from app.core.document_handlers.maskers.id_masker import IDMasker
|
||||
from app.core.document_handlers.maskers.case_masker import CaseMasker
|
||||
|
||||
|
||||
def test_chinese_name_masker():
|
||||
"""Test Chinese name masker"""
|
||||
masker = ChineseNameMasker()
|
||||
|
||||
# Test basic masking
|
||||
result1 = masker.mask("李强")
|
||||
assert result1 == "李Q"
|
||||
|
||||
result2 = masker.mask("张韶涵")
|
||||
assert result2 == "张SH"
|
||||
|
||||
result3 = masker.mask("张若宇")
|
||||
assert result3 == "张RY"
|
||||
|
||||
result4 = masker.mask("白锦程")
|
||||
assert result4 == "白JC"
|
||||
|
||||
# Test duplicate handling
|
||||
result5 = masker.mask("李强") # Should get a number
|
||||
assert result5 == "李Q2"
|
||||
|
||||
print(f"Chinese name masking tests passed")
|
||||
|
||||
|
||||
def test_english_name_masker():
|
||||
"""Test English name masker"""
|
||||
masker = EnglishNameMasker()
|
||||
|
||||
result = masker.mask("John Smith")
|
||||
assert result == "J*** S***"
|
||||
|
||||
result2 = masker.mask("Mary Jane Watson")
|
||||
assert result2 == "M*** J*** W***"
|
||||
|
||||
print(f"English name masking tests passed")
|
||||
|
||||
|
||||
def test_id_masker():
|
||||
"""Test ID masker"""
|
||||
masker = IDMasker()
|
||||
|
||||
# Test ID number
|
||||
result1 = masker.mask("310103198802080000")
|
||||
assert result1 == "310103XXXXXXXXXXXX"
|
||||
assert len(result1) == 18
|
||||
|
||||
# Test social credit code
|
||||
result2 = masker.mask("9133021276453538XT")
|
||||
assert result2 == "913302XXXXXXXXXXXX"
|
||||
assert len(result2) == 18
|
||||
|
||||
print(f"ID masking tests passed")
|
||||
|
||||
|
||||
def test_case_masker():
|
||||
"""Test case masker"""
|
||||
masker = CaseMasker()
|
||||
|
||||
result1 = masker.mask("(2022)京 03 民终 3852 号")
|
||||
assert "***号" in result1
|
||||
|
||||
result2 = masker.mask("(2020)京0105 民初69754 号")
|
||||
assert "***号" in result2
|
||||
|
||||
print(f"Case masking tests passed")
|
||||
|
||||
|
||||
def test_masker_factory():
|
||||
"""Test masker factory"""
|
||||
from app.core.document_handlers.masker_factory import MaskerFactory
|
||||
|
||||
# Test creating maskers
|
||||
chinese_masker = MaskerFactory.create_masker('chinese_name')
|
||||
assert isinstance(chinese_masker, ChineseNameMasker)
|
||||
|
||||
english_masker = MaskerFactory.create_masker('english_name')
|
||||
assert isinstance(english_masker, EnglishNameMasker)
|
||||
|
||||
id_masker = MaskerFactory.create_masker('id')
|
||||
assert isinstance(id_masker, IDMasker)
|
||||
|
||||
case_masker = MaskerFactory.create_masker('case')
|
||||
assert isinstance(case_masker, CaseMasker)
|
||||
|
||||
print(f"Masker factory tests passed")
|
||||
|
||||
|
||||
def test_refactored_processor_initialization():
|
||||
"""Test that the refactored processor can be initialized"""
|
||||
try:
|
||||
processor = NerProcessorRefactored()
|
||||
assert processor is not None
|
||||
assert hasattr(processor, 'maskers')
|
||||
assert len(processor.maskers) > 0
|
||||
print(f"Refactored processor initialization test passed")
|
||||
except Exception as e:
|
||||
print(f"Refactored processor initialization failed: {e}")
|
||||
# This might fail if Ollama is not running, which is expected in test environment
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Running refactored NerProcessor tests...")
|
||||
|
||||
test_chinese_name_masker()
|
||||
test_english_name_masker()
|
||||
test_id_masker()
|
||||
test_case_masker()
|
||||
test_masker_factory()
|
||||
test_refactored_processor_initialization()
|
||||
|
||||
print("All refactored NerProcessor tests completed!")
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
"""
|
||||
Validation script for the refactored NerProcessor.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the current directory to the Python path
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
def test_imports():
|
||||
"""Test that all modules can be imported"""
|
||||
print("Testing imports...")
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.maskers.base_masker import BaseMasker
|
||||
print("✓ BaseMasker imported successfully")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to import BaseMasker: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker, EnglishNameMasker
|
||||
print("✓ Name maskers imported successfully")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to import name maskers: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.maskers.id_masker import IDMasker
|
||||
print("✓ IDMasker imported successfully")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to import IDMasker: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.maskers.case_masker import CaseMasker
|
||||
print("✓ CaseMasker imported successfully")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to import CaseMasker: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.maskers.company_masker import CompanyMasker
|
||||
print("✓ CompanyMasker imported successfully")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to import CompanyMasker: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.maskers.address_masker import AddressMasker
|
||||
print("✓ AddressMasker imported successfully")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to import AddressMasker: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.masker_factory import MaskerFactory
|
||||
print("✓ MaskerFactory imported successfully")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to import MaskerFactory: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.extractors.business_name_extractor import BusinessNameExtractor
|
||||
print("✓ BusinessNameExtractor imported successfully")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to import BusinessNameExtractor: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.extractors.address_extractor import AddressExtractor
|
||||
print("✓ AddressExtractor imported successfully")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to import AddressExtractor: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.ner_processor_refactored import NerProcessorRefactored
|
||||
print("✓ NerProcessorRefactored imported successfully")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to import NerProcessorRefactored: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_masker_functionality():
|
||||
"""Test basic masker functionality"""
|
||||
print("\nTesting masker functionality...")
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker
|
||||
|
||||
masker = ChineseNameMasker()
|
||||
result = masker.mask("李强")
|
||||
assert result == "李Q", f"Expected '李Q', got '{result}'"
|
||||
print("✓ ChineseNameMasker works correctly")
|
||||
except Exception as e:
|
||||
print(f"✗ ChineseNameMasker test failed: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.maskers.name_masker import EnglishNameMasker
|
||||
|
||||
masker = EnglishNameMasker()
|
||||
result = masker.mask("John Smith")
|
||||
assert result == "J*** S***", f"Expected 'J*** S***', got '{result}'"
|
||||
print("✓ EnglishNameMasker works correctly")
|
||||
except Exception as e:
|
||||
print(f"✗ EnglishNameMasker test failed: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.maskers.id_masker import IDMasker
|
||||
|
||||
masker = IDMasker()
|
||||
result = masker.mask("310103198802080000")
|
||||
assert result == "310103XXXXXXXXXXXX", f"Expected '310103XXXXXXXXXXXX', got '{result}'"
|
||||
print("✓ IDMasker works correctly")
|
||||
except Exception as e:
|
||||
print(f"✗ IDMasker test failed: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.maskers.case_masker import CaseMasker
|
||||
|
||||
masker = CaseMasker()
|
||||
result = masker.mask("(2022)京 03 民终 3852 号")
|
||||
assert "***号" in result, f"Expected '***号' in result, got '{result}'"
|
||||
print("✓ CaseMasker works correctly")
|
||||
except Exception as e:
|
||||
print(f"✗ CaseMasker test failed: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_factory():
|
||||
"""Test masker factory"""
|
||||
print("\nTesting masker factory...")
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.masker_factory import MaskerFactory
|
||||
from app.core.document_handlers.maskers.name_masker import ChineseNameMasker
|
||||
|
||||
masker = MaskerFactory.create_masker('chinese_name')
|
||||
assert isinstance(masker, ChineseNameMasker), f"Expected ChineseNameMasker, got {type(masker)}"
|
||||
print("✓ MaskerFactory works correctly")
|
||||
except Exception as e:
|
||||
print(f"✗ MaskerFactory test failed: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_processor_initialization():
|
||||
"""Test processor initialization"""
|
||||
print("\nTesting processor initialization...")
|
||||
|
||||
try:
|
||||
from app.core.document_handlers.ner_processor_refactored import NerProcessorRefactored
|
||||
|
||||
processor = NerProcessorRefactored()
|
||||
assert processor is not None, "Processor should not be None"
|
||||
assert hasattr(processor, 'maskers'), "Processor should have maskers attribute"
|
||||
assert len(processor.maskers) > 0, "Processor should have at least one masker"
|
||||
print("✓ NerProcessorRefactored initializes correctly")
|
||||
except Exception as e:
|
||||
print(f"✗ NerProcessorRefactored initialization failed: {e}")
|
||||
# This might fail if Ollama is not running, which is expected
|
||||
print(" (This is expected if Ollama is not running)")
|
||||
return True # Don't fail the validation for this
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Main validation function"""
|
||||
print("Validating refactored NerProcessor...")
|
||||
print("=" * 50)
|
||||
|
||||
success = True
|
||||
|
||||
# Test imports
|
||||
if not test_imports():
|
||||
success = False
|
||||
|
||||
# Test functionality
|
||||
if not test_masker_functionality():
|
||||
success = False
|
||||
|
||||
# Test factory
|
||||
if not test_factory():
|
||||
success = False
|
||||
|
||||
# Test processor initialization
|
||||
if not test_processor_initialization():
|
||||
success = False
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
if success:
|
||||
print("✓ All validation tests passed!")
|
||||
print("The refactored code is working correctly.")
|
||||
else:
|
||||
print("✗ Some validation tests failed.")
|
||||
print("Please check the errors above.")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,132 +0,0 @@
|
|||
version: '3.8'
|
||||
|
||||
services:
|
||||
# Mineru API Service
|
||||
mineru-api:
|
||||
build:
|
||||
context: ./mineru
|
||||
dockerfile: Dockerfile
|
||||
platform: linux/arm64
|
||||
ports:
|
||||
- "8001:8000"
|
||||
volumes:
|
||||
- ./mineru/storage/uploads:/app/storage/uploads
|
||||
- ./mineru/storage/processed:/app/storage/processed
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
- MINERU_MODEL_SOURCE=local
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
# MagicDoc API Service
|
||||
magicdoc-api:
|
||||
build:
|
||||
context: ./magicdoc
|
||||
dockerfile: Dockerfile
|
||||
platform: linux/amd64
|
||||
ports:
|
||||
- "8002:8000"
|
||||
volumes:
|
||||
- ./magicdoc/storage/uploads:/app/storage/uploads
|
||||
- ./magicdoc/storage/processed:/app/storage/processed
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
# Backend API Service
|
||||
backend-api:
|
||||
build:
|
||||
context: ./backend
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- ./backend/storage:/app/storage
|
||||
- huggingface_cache:/root/.cache/huggingface
|
||||
env_file:
|
||||
- ./backend/.env
|
||||
environment:
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||
- MINERU_API_URL=http://mineru-api:8000
|
||||
- MAGICDOC_API_URL=http://magicdoc-api:8000
|
||||
depends_on:
|
||||
- redis
|
||||
- mineru-api
|
||||
- magicdoc-api
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
# Celery Worker
|
||||
celery-worker:
|
||||
build:
|
||||
context: ./backend
|
||||
dockerfile: Dockerfile
|
||||
command: celery -A app.services.file_service worker --loglevel=info
|
||||
volumes:
|
||||
- ./backend/storage:/app/storage
|
||||
- huggingface_cache:/root/.cache/huggingface
|
||||
env_file:
|
||||
- ./backend/.env
|
||||
environment:
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||
- MINERU_API_URL=http://mineru-api:8000
|
||||
- MAGICDOC_API_URL=http://magicdoc-api:8000
|
||||
depends_on:
|
||||
- redis
|
||||
- backend-api
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
# Redis Service
|
||||
redis:
|
||||
image: redis:alpine
|
||||
ports:
|
||||
- "6379:6379"
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
# Frontend Service
|
||||
frontend:
|
||||
build:
|
||||
context: ./frontend
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
- REACT_APP_API_BASE_URL=http://localhost:8000/api/v1
|
||||
ports:
|
||||
- "3000:80"
|
||||
env_file:
|
||||
- ./frontend/.env
|
||||
environment:
|
||||
- NODE_ENV=production
|
||||
- REACT_APP_API_BASE_URL=http://localhost:8000/api/v1
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
- backend-api
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
networks:
|
||||
app-network:
|
||||
driver: bridge
|
||||
|
||||
volumes:
|
||||
uploads:
|
||||
processed:
|
||||
huggingface_cache:
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import json
|
||||
import shutil
|
||||
import os
|
||||
|
||||
import requests
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
def download_json(url):
|
||||
# 下载JSON文件
|
||||
response = requests.get(url)
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
return response.json()
|
||||
|
||||
|
||||
def download_and_modify_json(url, local_filename, modifications):
|
||||
if os.path.exists(local_filename):
|
||||
data = json.load(open(local_filename))
|
||||
config_version = data.get('config_version', '0.0.0')
|
||||
if config_version < '1.2.0':
|
||||
data = download_json(url)
|
||||
else:
|
||||
data = download_json(url)
|
||||
|
||||
# 修改内容
|
||||
for key, value in modifications.items():
|
||||
data[key] = value
|
||||
|
||||
# 保存修改后的内容
|
||||
with open(local_filename, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mineru_patterns = [
|
||||
# "models/Layout/LayoutLMv3/*",
|
||||
"models/Layout/YOLO/*",
|
||||
"models/MFD/YOLO/*",
|
||||
"models/MFR/unimernet_hf_small_2503/*",
|
||||
"models/OCR/paddleocr_torch/*",
|
||||
# "models/TabRec/TableMaster/*",
|
||||
# "models/TabRec/StructEqTable/*",
|
||||
]
|
||||
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
|
||||
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
|
||||
model_dir = model_dir + '/models'
|
||||
print(f'model_dir is: {model_dir}')
|
||||
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
|
||||
|
||||
# paddleocr_model_dir = model_dir + '/OCR/paddleocr'
|
||||
# user_paddleocr_dir = os.path.expanduser('~/.paddleocr')
|
||||
# if os.path.exists(user_paddleocr_dir):
|
||||
# shutil.rmtree(user_paddleocr_dir)
|
||||
# shutil.copytree(paddleocr_model_dir, user_paddleocr_dir)
|
||||
|
||||
json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json'
|
||||
config_file_name = 'magic-pdf.json'
|
||||
home_dir = os.path.expanduser('~')
|
||||
config_file = os.path.join(home_dir, config_file_name)
|
||||
|
||||
json_mods = {
|
||||
'models-dir': model_dir,
|
||||
'layoutreader-model-dir': layoutreader_model_dir,
|
||||
}
|
||||
|
||||
download_and_modify_json(json_url, config_file, json_mods)
|
||||
print(f'The configuration file has been configured successfully, the path is: {config_file}')
|
||||
168
export-images.sh
168
export-images.sh
|
|
@ -1,168 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Docker Image Export Script
|
||||
# Exports all project Docker images for migration to another environment
|
||||
|
||||
set -e
|
||||
|
||||
echo "🚀 Legal Document Masker - Docker Image Export"
|
||||
echo "=============================================="
|
||||
|
||||
# Function to check if Docker is running
|
||||
check_docker() {
|
||||
if ! docker info > /dev/null 2>&1; then
|
||||
echo "❌ Docker is not running. Please start Docker and try again."
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Docker is running"
|
||||
}
|
||||
|
||||
# Function to check if images exist
|
||||
check_images() {
|
||||
echo "🔍 Checking for required images..."
|
||||
|
||||
local missing_images=()
|
||||
|
||||
if ! docker images | grep -q "legal-doc-masker-backend-api"; then
|
||||
missing_images+=("legal-doc-masker-backend-api")
|
||||
fi
|
||||
|
||||
if ! docker images | grep -q "legal-doc-masker-frontend"; then
|
||||
missing_images+=("legal-doc-masker-frontend")
|
||||
fi
|
||||
|
||||
if ! docker images | grep -q "legal-doc-masker-mineru-api"; then
|
||||
missing_images+=("legal-doc-masker-mineru-api")
|
||||
fi
|
||||
|
||||
if ! docker images | grep -q "redis:alpine"; then
|
||||
missing_images+=("redis:alpine")
|
||||
fi
|
||||
|
||||
if [ ${#missing_images[@]} -ne 0 ]; then
|
||||
echo "❌ Missing images: ${missing_images[*]}"
|
||||
echo "Please build the images first using: docker-compose build"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ All required images found"
|
||||
}
|
||||
|
||||
# Function to create export directory
|
||||
create_export_dir() {
|
||||
local export_dir="docker-images-export-$(date +%Y%m%d-%H%M%S)"
|
||||
mkdir -p "$export_dir"
|
||||
cd "$export_dir"
|
||||
echo "📁 Created export directory: $export_dir"
|
||||
echo "$export_dir"
|
||||
}
|
||||
|
||||
# Function to export images
|
||||
export_images() {
|
||||
local export_dir="$1"
|
||||
|
||||
echo "📦 Exporting Docker images..."
|
||||
|
||||
# Export backend image
|
||||
echo " 📦 Exporting backend-api image..."
|
||||
docker save legal-doc-masker-backend-api:latest -o backend-api.tar
|
||||
|
||||
# Export frontend image
|
||||
echo " 📦 Exporting frontend image..."
|
||||
docker save legal-doc-masker-frontend:latest -o frontend.tar
|
||||
|
||||
# Export mineru image
|
||||
echo " 📦 Exporting mineru-api image..."
|
||||
docker save legal-doc-masker-mineru-api:latest -o mineru-api.tar
|
||||
|
||||
# Export redis image
|
||||
echo " 📦 Exporting redis image..."
|
||||
docker save redis:alpine -o redis.tar
|
||||
|
||||
echo "✅ All images exported successfully!"
|
||||
}
|
||||
|
||||
# Function to show export summary
|
||||
show_summary() {
|
||||
echo ""
|
||||
echo "📊 Export Summary:"
|
||||
echo "=================="
|
||||
ls -lh *.tar
|
||||
|
||||
echo ""
|
||||
echo "📋 Files to transfer:"
|
||||
echo "===================="
|
||||
for file in *.tar; do
|
||||
echo " - $file"
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "💾 Total size: $(du -sh . | cut -f1)"
|
||||
}
|
||||
|
||||
# Function to create compressed archive
|
||||
create_archive() {
|
||||
echo ""
|
||||
echo "🗜️ Creating compressed archive..."
|
||||
|
||||
local archive_name="legal-doc-masker-images-$(date +%Y%m%d-%H%M%S).tar.gz"
|
||||
tar -czf "$archive_name" *.tar
|
||||
|
||||
echo "✅ Created archive: $archive_name"
|
||||
echo "📊 Archive size: $(du -sh "$archive_name" | cut -f1)"
|
||||
|
||||
echo ""
|
||||
echo "📋 Transfer options:"
|
||||
echo "==================="
|
||||
echo "1. Transfer individual .tar files"
|
||||
echo "2. Transfer compressed archive: $archive_name"
|
||||
}
|
||||
|
||||
# Function to show transfer instructions
|
||||
show_transfer_instructions() {
|
||||
echo ""
|
||||
echo "📤 Transfer Instructions:"
|
||||
echo "========================"
|
||||
echo ""
|
||||
echo "Option 1: Transfer individual files"
|
||||
echo "-----------------------------------"
|
||||
echo "scp *.tar user@target-server:/path/to/destination/"
|
||||
echo ""
|
||||
echo "Option 2: Transfer compressed archive"
|
||||
echo "-------------------------------------"
|
||||
echo "scp legal-doc-masker-images-*.tar.gz user@target-server:/path/to/destination/"
|
||||
echo ""
|
||||
echo "Option 3: USB Drive"
|
||||
echo "-------------------"
|
||||
echo "cp *.tar /Volumes/USB_DRIVE/docker-images/"
|
||||
echo "cp legal-doc-masker-images-*.tar.gz /Volumes/USB_DRIVE/"
|
||||
echo ""
|
||||
echo "Option 4: Cloud Storage"
|
||||
echo "----------------------"
|
||||
echo "aws s3 cp *.tar s3://your-bucket/docker-images/"
|
||||
echo "aws s3 cp legal-doc-masker-images-*.tar.gz s3://your-bucket/docker-images/"
|
||||
}
|
||||
|
||||
# Main execution
|
||||
main() {
|
||||
check_docker
|
||||
check_images
|
||||
|
||||
local export_dir=$(create_export_dir)
|
||||
export_images "$export_dir"
|
||||
show_summary
|
||||
create_archive
|
||||
show_transfer_instructions
|
||||
|
||||
echo ""
|
||||
echo "🎉 Export completed successfully!"
|
||||
echo "📁 Export location: $(pwd)"
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo "1. Transfer the files to your target environment"
|
||||
echo "2. Use import-images.sh on the target environment"
|
||||
echo "3. Copy docker-compose.yml and other config files"
|
||||
}
|
||||
|
||||
# Run main function
|
||||
main "$@"
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
# REACT_APP_API_BASE_URL=http://192.168.2.203:8000/api/v1
|
||||
REACT_APP_API_BASE_URL=http://localhost:8000/api/v1
|
||||
|
|
@ -16,9 +16,8 @@ import {
|
|||
DialogContent,
|
||||
DialogActions,
|
||||
Typography,
|
||||
Tooltip,
|
||||
} from '@mui/material';
|
||||
import { Download as DownloadIcon, Delete as DeleteIcon, Error as ErrorIcon } from '@mui/icons-material';
|
||||
import { Download as DownloadIcon, Delete as DeleteIcon } from '@mui/icons-material';
|
||||
import { File, FileStatus } from '../types/file';
|
||||
import { api } from '../services/api';
|
||||
|
||||
|
|
@ -48,37 +47,15 @@ const FileList: React.FC<FileListProps> = ({ files, onFileStatusChange }) => {
|
|||
|
||||
const handleDownload = async (fileId: string) => {
|
||||
try {
|
||||
console.log('=== FRONTEND DOWNLOAD START ===');
|
||||
console.log('File ID:', fileId);
|
||||
|
||||
const file = files.find((f) => f.id === fileId);
|
||||
console.log('File object:', file);
|
||||
|
||||
const blob = await api.downloadFile(fileId);
|
||||
console.log('Blob received:', blob);
|
||||
console.log('Blob type:', blob.type);
|
||||
console.log('Blob size:', blob.size);
|
||||
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
|
||||
// Match backend behavior: change extension to .md
|
||||
const originalFilename = file?.filename || 'downloaded-file';
|
||||
const filenameWithoutExt = originalFilename.replace(/\.[^/.]+$/, ''); // Remove extension
|
||||
const downloadFilename = `${filenameWithoutExt}.md`;
|
||||
|
||||
console.log('Original filename:', originalFilename);
|
||||
console.log('Filename without extension:', filenameWithoutExt);
|
||||
console.log('Download filename:', downloadFilename);
|
||||
|
||||
a.download = downloadFilename;
|
||||
a.download = files.find((f) => f.id === fileId)?.filename || 'downloaded-file';
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
window.URL.revokeObjectURL(url);
|
||||
document.body.removeChild(a);
|
||||
|
||||
console.log('=== FRONTEND DOWNLOAD END ===');
|
||||
} catch (error) {
|
||||
console.error('Error downloading file:', error);
|
||||
}
|
||||
|
|
@ -173,50 +150,6 @@ const FileList: React.FC<FileListProps> = ({ files, onFileStatusChange }) => {
|
|||
color={getStatusColor(file.status) as any}
|
||||
size="small"
|
||||
/>
|
||||
{file.status === FileStatus.FAILED && file.error_message && (
|
||||
<div style={{ marginTop: '4px' }}>
|
||||
<Tooltip
|
||||
title={file.error_message}
|
||||
placement="top-start"
|
||||
arrow
|
||||
sx={{ maxWidth: '400px' }}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
alignItems: 'flex-start',
|
||||
gap: '4px',
|
||||
padding: '4px 8px',
|
||||
backgroundColor: '#ffebee',
|
||||
borderRadius: '4px',
|
||||
border: '1px solid #ffcdd2'
|
||||
}}
|
||||
>
|
||||
<ErrorIcon
|
||||
color="error"
|
||||
sx={{ fontSize: '16px', marginTop: '1px', flexShrink: 0 }}
|
||||
/>
|
||||
<Typography
|
||||
variant="caption"
|
||||
color="error"
|
||||
sx={{
|
||||
display: 'block',
|
||||
wordBreak: 'break-word',
|
||||
maxWidth: '300px',
|
||||
lineHeight: '1.2',
|
||||
cursor: 'help',
|
||||
fontWeight: 500
|
||||
}}
|
||||
>
|
||||
{file.error_message.length > 50
|
||||
? `${file.error_message.substring(0, 50)}...`
|
||||
: file.error_message
|
||||
}
|
||||
</Typography>
|
||||
</div>
|
||||
</Tooltip>
|
||||
</div>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{new Date(file.created_at).toLocaleString()}
|
||||
|
|
|
|||
232
import-images.sh
232
import-images.sh
|
|
@ -1,232 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Docker Image Import Script
|
||||
# Imports Docker images on target environment for migration
|
||||
|
||||
set -e
|
||||
|
||||
echo "🚀 Legal Document Masker - Docker Image Import"
|
||||
echo "=============================================="
|
||||
|
||||
# Function to check if Docker is running
|
||||
check_docker() {
|
||||
if ! docker info > /dev/null 2>&1; then
|
||||
echo "❌ Docker is not running. Please start Docker and try again."
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Docker is running"
|
||||
}
|
||||
|
||||
# Function to check for tar files
|
||||
check_tar_files() {
|
||||
echo "🔍 Checking for Docker image files..."
|
||||
|
||||
local missing_files=()
|
||||
|
||||
if [ ! -f "backend-api.tar" ]; then
|
||||
missing_files+=("backend-api.tar")
|
||||
fi
|
||||
|
||||
if [ ! -f "frontend.tar" ]; then
|
||||
missing_files+=("frontend.tar")
|
||||
fi
|
||||
|
||||
if [ ! -f "mineru-api.tar" ]; then
|
||||
missing_files+=("mineru-api.tar")
|
||||
fi
|
||||
|
||||
if [ ! -f "redis.tar" ]; then
|
||||
missing_files+=("redis.tar")
|
||||
fi
|
||||
|
||||
if [ ${#missing_files[@]} -ne 0 ]; then
|
||||
echo "❌ Missing files: ${missing_files[*]}"
|
||||
echo ""
|
||||
echo "Please ensure all .tar files are in the current directory."
|
||||
echo "If you have a compressed archive, extract it first:"
|
||||
echo " tar -xzf legal-doc-masker-images-*.tar.gz"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ All required files found"
|
||||
}
|
||||
|
||||
# Function to check available disk space
|
||||
check_disk_space() {
|
||||
echo "💾 Checking available disk space..."
|
||||
|
||||
local required_space=0
|
||||
for file in *.tar; do
|
||||
local file_size=$(stat -f%z "$file" 2>/dev/null || stat -c%s "$file" 2>/dev/null || echo 0)
|
||||
required_space=$((required_space + file_size))
|
||||
done
|
||||
|
||||
local available_space=$(df . | awk 'NR==2 {print $4}')
|
||||
available_space=$((available_space * 1024)) # Convert to bytes
|
||||
|
||||
if [ $required_space -gt $available_space ]; then
|
||||
echo "❌ Insufficient disk space"
|
||||
echo "Required: $(numfmt --to=iec $required_space)"
|
||||
echo "Available: $(numfmt --to=iec $available_space)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ Sufficient disk space available"
|
||||
}
|
||||
|
||||
# Function to import images
|
||||
import_images() {
|
||||
echo "📦 Importing Docker images..."
|
||||
|
||||
# Import backend image
|
||||
echo " 📦 Importing backend-api image..."
|
||||
docker load -i backend-api.tar
|
||||
|
||||
# Import frontend image
|
||||
echo " 📦 Importing frontend image..."
|
||||
docker load -i frontend.tar
|
||||
|
||||
# Import mineru image
|
||||
echo " 📦 Importing mineru-api image..."
|
||||
docker load -i mineru-api.tar
|
||||
|
||||
# Import redis image
|
||||
echo " 📦 Importing redis image..."
|
||||
docker load -i redis.tar
|
||||
|
||||
echo "✅ All images imported successfully!"
|
||||
}
|
||||
|
||||
# Function to verify imported images
|
||||
verify_images() {
|
||||
echo "🔍 Verifying imported images..."
|
||||
|
||||
local missing_images=()
|
||||
|
||||
if ! docker images | grep -q "legal-doc-masker-backend-api"; then
|
||||
missing_images+=("legal-doc-masker-backend-api")
|
||||
fi
|
||||
|
||||
if ! docker images | grep -q "legal-doc-masker-frontend"; then
|
||||
missing_images+=("legal-doc-masker-frontend")
|
||||
fi
|
||||
|
||||
if ! docker images | grep -q "legal-doc-masker-mineru-api"; then
|
||||
missing_images+=("legal-doc-masker-mineru-api")
|
||||
fi
|
||||
|
||||
if ! docker images | grep -q "redis:alpine"; then
|
||||
missing_images+=("redis:alpine")
|
||||
fi
|
||||
|
||||
if [ ${#missing_images[@]} -ne 0 ]; then
|
||||
echo "❌ Missing imported images: ${missing_images[*]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ All images verified successfully!"
|
||||
}
|
||||
|
||||
# Function to show imported images
|
||||
show_imported_images() {
|
||||
echo ""
|
||||
echo "📊 Imported Images:"
|
||||
echo "==================="
|
||||
docker images --format "table {{.Repository}}\t{{.Tag}}\t{{.Size}}" | grep legal-doc-masker
|
||||
docker images --format "table {{.Repository}}\t{{.Tag}}\t{{.Size}}" | grep redis
|
||||
}
|
||||
|
||||
# Function to create necessary directories
|
||||
create_directories() {
|
||||
echo ""
|
||||
echo "📁 Creating necessary directories..."
|
||||
|
||||
mkdir -p backend/storage
|
||||
mkdir -p mineru/storage/uploads
|
||||
mkdir -p mineru/storage/processed
|
||||
|
||||
echo "✅ Directories created"
|
||||
}
|
||||
|
||||
# Function to check for required files
|
||||
check_required_files() {
|
||||
echo ""
|
||||
echo "🔍 Checking for required configuration files..."
|
||||
|
||||
local missing_files=()
|
||||
|
||||
if [ ! -f "docker-compose.yml" ]; then
|
||||
missing_files+=("docker-compose.yml")
|
||||
fi
|
||||
|
||||
if [ ! -f "DOCKER_COMPOSE_README.md" ]; then
|
||||
missing_files+=("DOCKER_COMPOSE_README.md")
|
||||
fi
|
||||
|
||||
if [ ${#missing_files[@]} -ne 0 ]; then
|
||||
echo "⚠️ Missing files: ${missing_files[*]}"
|
||||
echo "Please copy these files from the source environment:"
|
||||
echo " - docker-compose.yml"
|
||||
echo " - DOCKER_COMPOSE_README.md"
|
||||
echo " - backend/.env (if exists)"
|
||||
echo " - frontend/.env (if exists)"
|
||||
echo " - mineru/.env (if exists)"
|
||||
else
|
||||
echo "✅ All required configuration files found"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to show next steps
|
||||
show_next_steps() {
|
||||
echo ""
|
||||
echo "🎉 Import completed successfully!"
|
||||
echo ""
|
||||
echo "📋 Next Steps:"
|
||||
echo "=============="
|
||||
echo ""
|
||||
echo "1. Copy configuration files (if not already present):"
|
||||
echo " - docker-compose.yml"
|
||||
echo " - backend/.env"
|
||||
echo " - frontend/.env"
|
||||
echo " - mineru/.env"
|
||||
echo ""
|
||||
echo "2. Start the services:"
|
||||
echo " docker-compose up -d"
|
||||
echo ""
|
||||
echo "3. Verify services are running:"
|
||||
echo " docker-compose ps"
|
||||
echo ""
|
||||
echo "4. Test the endpoints:"
|
||||
echo " - Frontend: http://localhost:3000"
|
||||
echo " - Backend API: http://localhost:8000"
|
||||
echo " - Mineru API: http://localhost:8001"
|
||||
echo ""
|
||||
echo "5. View logs if needed:"
|
||||
echo " docker-compose logs -f [service-name]"
|
||||
}
|
||||
|
||||
# Function to handle compressed archive
|
||||
handle_compressed_archive() {
|
||||
if ls legal-doc-masker-images-*.tar.gz 1> /dev/null 2>&1; then
|
||||
echo "🗜️ Found compressed archive, extracting..."
|
||||
tar -xzf legal-doc-masker-images-*.tar.gz
|
||||
echo "✅ Archive extracted"
|
||||
fi
|
||||
}
|
||||
|
||||
# Main execution
|
||||
main() {
|
||||
check_docker
|
||||
handle_compressed_archive
|
||||
check_tar_files
|
||||
check_disk_space
|
||||
import_images
|
||||
verify_images
|
||||
show_imported_images
|
||||
create_directories
|
||||
check_required_files
|
||||
show_next_steps
|
||||
}
|
||||
|
||||
# Run main function
|
||||
main "$@"
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
FROM python:3.10-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies including LibreOffice
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
libreoffice \
|
||||
libreoffice-writer \
|
||||
libreoffice-calc \
|
||||
libreoffice-impress \
|
||||
wget \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements and install Python packages first
|
||||
COPY requirements.txt .
|
||||
RUN pip install --upgrade pip
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Install fairy-doc after numpy and opencv are installed
|
||||
RUN pip install --no-cache-dir "fairy-doc[cpu]"
|
||||
|
||||
# Copy the application code
|
||||
COPY app/ ./app/
|
||||
|
||||
# Create storage directories
|
||||
RUN mkdir -p storage/uploads storage/processed
|
||||
|
||||
# Expose the port the app runs on
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# Command to run the application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
# MagicDoc API Service
|
||||
|
||||
A FastAPI service that provides document to markdown conversion using the Magic-Doc library. This service is designed to be compatible with the existing Mineru API interface.
|
||||
|
||||
## Features
|
||||
|
||||
- Converts DOC, DOCX, PPT, PPTX, and PDF files to markdown
|
||||
- RESTful API interface compatible with Mineru API
|
||||
- Docker containerization with LibreOffice dependencies
|
||||
- Health check endpoint
|
||||
- File upload support
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Health Check
|
||||
```
|
||||
GET /health
|
||||
```
|
||||
Returns service health status.
|
||||
|
||||
### File Parse
|
||||
```
|
||||
POST /file_parse
|
||||
```
|
||||
Converts uploaded document to markdown.
|
||||
|
||||
**Parameters:**
|
||||
- `files`: File upload (required)
|
||||
- `output_dir`: Output directory (default: "./output")
|
||||
- `lang_list`: Language list (default: "ch")
|
||||
- `backend`: Backend type (default: "pipeline")
|
||||
- `parse_method`: Parse method (default: "auto")
|
||||
- `formula_enable`: Enable formula processing (default: true)
|
||||
- `table_enable`: Enable table processing (default: true)
|
||||
- `return_md`: Return markdown (default: true)
|
||||
- `return_middle_json`: Return middle JSON (default: false)
|
||||
- `return_model_output`: Return model output (default: false)
|
||||
- `return_content_list`: Return content list (default: false)
|
||||
- `return_images`: Return images (default: false)
|
||||
- `start_page_id`: Start page ID (default: 0)
|
||||
- `end_page_id`: End page ID (default: 99999)
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"markdown": "converted markdown content",
|
||||
"md": "converted markdown content",
|
||||
"content": "converted markdown content",
|
||||
"text": "converted markdown content",
|
||||
"time_cost": 1.23,
|
||||
"filename": "document.docx",
|
||||
"status": "success"
|
||||
}
|
||||
```
|
||||
|
||||
## Running with Docker
|
||||
|
||||
### Build and run with docker-compose
|
||||
```bash
|
||||
cd magicdoc
|
||||
docker-compose up --build
|
||||
```
|
||||
|
||||
The service will be available at `http://localhost:8002`
|
||||
|
||||
### Build and run with Docker
|
||||
```bash
|
||||
cd magicdoc
|
||||
docker build -t magicdoc-api .
|
||||
docker run -p 8002:8000 magicdoc-api
|
||||
```
|
||||
|
||||
## Integration with Document Processors
|
||||
|
||||
This service is designed to be compatible with the existing document processors. To use it instead of Mineru API, update the configuration in your document processors:
|
||||
|
||||
```python
|
||||
# In docx_processor.py or pdf_processor.py
|
||||
self.magicdoc_base_url = getattr(settings, 'MAGICDOC_API_URL', 'http://magicdoc-api:8000')
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
- Python 3.10
|
||||
- LibreOffice (installed in Docker container)
|
||||
- Magic-Doc library
|
||||
- FastAPI
|
||||
- Uvicorn
|
||||
|
||||
## Storage
|
||||
|
||||
The service creates the following directories:
|
||||
- `storage/uploads/`: For uploaded files
|
||||
- `storage/processed/`: For processed files
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
# MagicDoc Service Setup Guide
|
||||
|
||||
This guide explains how to set up and use the MagicDoc API service as an alternative to the Mineru API for document processing.
|
||||
|
||||
## Overview
|
||||
|
||||
The MagicDoc service provides a FastAPI-based REST API that converts various document formats (DOC, DOCX, PPT, PPTX, PDF) to markdown using the Magic-Doc library. It's designed to be compatible with your existing document processors.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Build and Run the Service
|
||||
|
||||
```bash
|
||||
cd magicdoc
|
||||
./start.sh
|
||||
```
|
||||
|
||||
Or manually:
|
||||
```bash
|
||||
cd magicdoc
|
||||
docker-compose up --build -d
|
||||
```
|
||||
|
||||
### 2. Verify the Service
|
||||
|
||||
```bash
|
||||
# Check health
|
||||
curl http://localhost:8002/health
|
||||
|
||||
# View API documentation
|
||||
open http://localhost:8002/docs
|
||||
```
|
||||
|
||||
### 3. Test with Sample Files
|
||||
|
||||
```bash
|
||||
cd magicdoc
|
||||
python test_api.py
|
||||
```
|
||||
|
||||
## API Compatibility
|
||||
|
||||
The MagicDoc API is designed to be compatible with your existing Mineru API interface:
|
||||
|
||||
### Endpoint: `POST /file_parse`
|
||||
|
||||
**Request Format:**
|
||||
- File upload via multipart form data
|
||||
- Same parameters as Mineru API (most are optional)
|
||||
|
||||
**Response Format:**
|
||||
```json
|
||||
{
|
||||
"markdown": "converted content",
|
||||
"md": "converted content",
|
||||
"content": "converted content",
|
||||
"text": "converted content",
|
||||
"time_cost": 1.23,
|
||||
"filename": "document.docx",
|
||||
"status": "success"
|
||||
}
|
||||
```
|
||||
|
||||
## Integration with Existing Processors
|
||||
|
||||
To use MagicDoc instead of Mineru in your existing processors:
|
||||
|
||||
### 1. Update Configuration
|
||||
|
||||
Add to your settings:
|
||||
```python
|
||||
MAGICDOC_API_URL = "http://magicdoc-api:8000" # or http://localhost:8002
|
||||
MAGICDOC_TIMEOUT = 300
|
||||
```
|
||||
|
||||
### 2. Modify Processors
|
||||
|
||||
Replace Mineru API calls with MagicDoc API calls. See `integration_example.py` for detailed examples.
|
||||
|
||||
### 3. Update Docker Compose
|
||||
|
||||
Add the MagicDoc service to your main docker-compose.yml:
|
||||
```yaml
|
||||
services:
|
||||
magicdoc-api:
|
||||
build:
|
||||
context: ./magicdoc
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8002:8000"
|
||||
volumes:
|
||||
- ./magicdoc/storage:/app/storage
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
restart: unless-stopped
|
||||
```
|
||||
|
||||
## Service Architecture
|
||||
|
||||
```
|
||||
magicdoc/
|
||||
├── app/
|
||||
│ ├── __init__.py
|
||||
│ └── main.py # FastAPI application
|
||||
├── Dockerfile # Container definition
|
||||
├── docker-compose.yml # Service orchestration
|
||||
├── requirements.txt # Python dependencies
|
||||
├── README.md # Service documentation
|
||||
├── SETUP.md # This setup guide
|
||||
├── test_api.py # API testing script
|
||||
├── integration_example.py # Integration examples
|
||||
└── start.sh # Startup script
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
- **Python 3.10**: Base runtime
|
||||
- **LibreOffice**: Document processing (installed in container)
|
||||
- **Magic-Doc**: Document conversion library
|
||||
- **FastAPI**: Web framework
|
||||
- **Uvicorn**: ASGI server
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Service Won't Start
|
||||
1. Check Docker is running
|
||||
2. Verify port 8002 is available
|
||||
3. Check logs: `docker-compose logs`
|
||||
|
||||
### File Conversion Fails
|
||||
1. Verify LibreOffice is working in container
|
||||
2. Check file format is supported
|
||||
3. Review API logs for errors
|
||||
|
||||
### Integration Issues
|
||||
1. Verify API endpoint URL
|
||||
2. Check network connectivity between services
|
||||
3. Ensure response format compatibility
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- MagicDoc is generally faster than Mineru for simple documents
|
||||
- LibreOffice dependency adds container size
|
||||
- Consider caching for repeated conversions
|
||||
- Monitor memory usage for large files
|
||||
|
||||
## Security Notes
|
||||
|
||||
- Service runs on internal network
|
||||
- File uploads are temporary
|
||||
- No persistent storage of uploaded files
|
||||
- Consider adding authentication for production use
|
||||
|
|
@ -1 +0,0 @@
|
|||
# MagicDoc FastAPI Application
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
import os
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from magic_doc.docconv import DocConverter, S3Config
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="MagicDoc API", version="1.0.0")
|
||||
|
||||
# Global converter instance
|
||||
converter = DocConverter(s3_config=None)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy", "service": "magicdoc-api"}
|
||||
|
||||
@app.post("/file_parse")
|
||||
async def parse_file(
|
||||
files: UploadFile = File(...),
|
||||
output_dir: str = Form("./output"),
|
||||
lang_list: str = Form("ch"),
|
||||
backend: str = Form("pipeline"),
|
||||
parse_method: str = Form("auto"),
|
||||
formula_enable: bool = Form(True),
|
||||
table_enable: bool = Form(True),
|
||||
return_md: bool = Form(True),
|
||||
return_middle_json: bool = Form(False),
|
||||
return_model_output: bool = Form(False),
|
||||
return_content_list: bool = Form(False),
|
||||
return_images: bool = Form(False),
|
||||
start_page_id: int = Form(0),
|
||||
end_page_id: int = Form(99999)
|
||||
):
|
||||
"""
|
||||
Parse document file and convert to markdown
|
||||
Compatible with Mineru API interface
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Processing file: {files.filename}")
|
||||
|
||||
# Create temporary file to save uploaded content
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(files.filename)[1]) as temp_file:
|
||||
shutil.copyfileobj(files.file, temp_file)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Convert file to markdown using magic-doc
|
||||
markdown_content, time_cost = converter.convert(temp_file_path, conv_timeout=300)
|
||||
|
||||
logger.info(f"Successfully converted {files.filename} to markdown in {time_cost:.2f}s")
|
||||
|
||||
# Return response compatible with Mineru API
|
||||
response = {
|
||||
"markdown": markdown_content,
|
||||
"md": markdown_content, # Alternative field name
|
||||
"content": markdown_content, # Alternative field name
|
||||
"text": markdown_content, # Alternative field name
|
||||
"time_cost": time_cost,
|
||||
"filename": files.filename,
|
||||
"status": "success"
|
||||
}
|
||||
|
||||
return JSONResponse(content=response)
|
||||
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
if os.path.exists(temp_file_path):
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file {files.filename}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with service information"""
|
||||
return {
|
||||
"service": "MagicDoc API",
|
||||
"version": "1.0.0",
|
||||
"description": "Document to Markdown conversion service using Magic-Doc",
|
||||
"endpoints": {
|
||||
"health": "/health",
|
||||
"file_parse": "/file_parse"
|
||||
}
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
version: '3.8'
|
||||
|
||||
services:
|
||||
magicdoc-api:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
platform: linux/amd64
|
||||
ports:
|
||||
- "8002:8000"
|
||||
volumes:
|
||||
- ./storage/uploads:/app/storage/uploads
|
||||
- ./storage/processed:/app/storage/processed
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
|
||||
volumes:
|
||||
uploads:
|
||||
processed:
|
||||
|
|
@ -1,144 +0,0 @@
|
|||
"""
|
||||
Example of how to integrate MagicDoc API with existing document processors
|
||||
"""
|
||||
|
||||
# Example modification for docx_processor.py
|
||||
# Replace the Mineru API configuration with MagicDoc API configuration
|
||||
|
||||
class DocxDocumentProcessor(DocumentProcessor):
|
||||
def __init__(self, input_path: str, output_path: str):
|
||||
super().__init__()
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
self.output_dir = os.path.dirname(output_path)
|
||||
self.name_without_suff = os.path.splitext(os.path.basename(input_path))[0]
|
||||
|
||||
# Setup work directory for temporary files
|
||||
self.work_dir = os.path.join(
|
||||
os.path.dirname(output_path),
|
||||
".work",
|
||||
os.path.splitext(os.path.basename(input_path))[0]
|
||||
)
|
||||
os.makedirs(self.work_dir, exist_ok=True)
|
||||
|
||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||
|
||||
# MagicDoc API configuration (instead of Mineru)
|
||||
self.magicdoc_base_url = getattr(settings, 'MAGICDOC_API_URL', 'http://magicdoc-api:8000')
|
||||
self.magicdoc_timeout = getattr(settings, 'MAGICDOC_TIMEOUT', 300) # 5 minutes timeout
|
||||
|
||||
def _call_magicdoc_api(self, file_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call MagicDoc API to convert DOCX to markdown
|
||||
|
||||
Args:
|
||||
file_path: Path to the DOCX file
|
||||
|
||||
Returns:
|
||||
API response as dictionary or None if failed
|
||||
"""
|
||||
try:
|
||||
url = f"{self.magicdoc_base_url}/file_parse"
|
||||
|
||||
with open(file_path, 'rb') as file:
|
||||
files = {'files': (os.path.basename(file_path), file, 'application/vnd.openxmlformats-officedocument.wordprocessingml.document')}
|
||||
|
||||
# Prepare form data - simplified compared to Mineru
|
||||
data = {
|
||||
'output_dir': './output',
|
||||
'lang_list': 'ch',
|
||||
'backend': 'pipeline',
|
||||
'parse_method': 'auto',
|
||||
'formula_enable': True,
|
||||
'table_enable': True,
|
||||
'return_md': True,
|
||||
'return_middle_json': False,
|
||||
'return_model_output': False,
|
||||
'return_content_list': False,
|
||||
'return_images': False,
|
||||
'start_page_id': 0,
|
||||
'end_page_id': 99999
|
||||
}
|
||||
|
||||
logger.info(f"Calling MagicDoc API for DOCX processing at {url}")
|
||||
response = requests.post(
|
||||
url,
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=self.magicdoc_timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
logger.info("Successfully received response from MagicDoc API for DOCX")
|
||||
return result
|
||||
else:
|
||||
error_msg = f"MagicDoc API returned status code {response.status_code}: {response.text}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
error_msg = f"MagicDoc API request timed out after {self.magicdoc_timeout} seconds"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_msg = f"Error calling MagicDoc API for DOCX: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling MagicDoc API for DOCX: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
def read_content(self) -> str:
|
||||
logger.info("Starting DOCX content processing with MagicDoc API")
|
||||
|
||||
# Call MagicDoc API to convert DOCX to markdown
|
||||
magicdoc_response = self._call_magicdoc_api(self.input_path)
|
||||
|
||||
# Extract markdown content from the response
|
||||
markdown_content = self._extract_markdown_from_response(magicdoc_response)
|
||||
|
||||
if not markdown_content:
|
||||
raise Exception("No markdown content found in MagicDoc API response for DOCX")
|
||||
|
||||
logger.info(f"Successfully extracted {len(markdown_content)} characters of markdown content from DOCX")
|
||||
|
||||
# Save the raw markdown content to work directory for reference
|
||||
md_output_path = os.path.join(self.work_dir, f"{self.name_without_suff}.md")
|
||||
with open(md_output_path, 'w', encoding='utf-8') as file:
|
||||
file.write(markdown_content)
|
||||
|
||||
logger.info(f"Saved raw markdown content from DOCX to {md_output_path}")
|
||||
|
||||
return markdown_content
|
||||
|
||||
# Configuration changes needed in settings.py:
|
||||
"""
|
||||
# Add these settings to your configuration
|
||||
MAGICDOC_API_URL = "http://magicdoc-api:8000" # or http://localhost:8002 for local development
|
||||
MAGICDOC_TIMEOUT = 300 # 5 minutes timeout
|
||||
"""
|
||||
|
||||
# Docker Compose integration:
|
||||
"""
|
||||
# Add to your main docker-compose.yml
|
||||
services:
|
||||
magicdoc-api:
|
||||
build:
|
||||
context: ./magicdoc
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8002:8000"
|
||||
volumes:
|
||||
- ./magicdoc/storage:/app/storage
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
"""
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
python-multipart==0.0.6
|
||||
# fairy-doc[cpu]==0.1.0
|
||||
pydantic==2.5.0
|
||||
numpy==1.24.3
|
||||
opencv-python==4.8.1.78
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# MagicDoc API Service Startup Script
|
||||
|
||||
echo "Starting MagicDoc API Service..."
|
||||
|
||||
# Check if Docker is running
|
||||
if ! docker info > /dev/null 2>&1; then
|
||||
echo "Error: Docker is not running. Please start Docker first."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Build and start the service
|
||||
echo "Building and starting MagicDoc API service..."
|
||||
docker-compose up --build -d
|
||||
|
||||
# Wait for service to be ready
|
||||
echo "Waiting for service to be ready..."
|
||||
sleep 10
|
||||
|
||||
# Check health
|
||||
echo "Checking service health..."
|
||||
if curl -f http://localhost:8002/health > /dev/null 2>&1; then
|
||||
echo "✅ MagicDoc API service is running successfully!"
|
||||
echo "🌐 Service URL: http://localhost:8002"
|
||||
echo "📖 API Documentation: http://localhost:8002/docs"
|
||||
echo "🔍 Health Check: http://localhost:8002/health"
|
||||
else
|
||||
echo "❌ Service health check failed. Check logs with: docker-compose logs"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "To stop the service, run: docker-compose down"
|
||||
echo "To view logs, run: docker-compose logs -f"
|
||||
|
|
@ -1,92 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for MagicDoc API
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
|
||||
def test_health_check(base_url="http://localhost:8002"):
|
||||
"""Test health check endpoint"""
|
||||
try:
|
||||
response = requests.get(f"{base_url}/health")
|
||||
print(f"Health check status: {response.status_code}")
|
||||
print(f"Response: {response.json()}")
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
print(f"Health check failed: {e}")
|
||||
return False
|
||||
|
||||
def test_file_parse(base_url="http://localhost:8002", file_path=None):
|
||||
"""Test file parse endpoint"""
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
print(f"File not found: {file_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
files = {'files': (os.path.basename(file_path), f, 'application/octet-stream')}
|
||||
data = {
|
||||
'output_dir': './output',
|
||||
'lang_list': 'ch',
|
||||
'backend': 'pipeline',
|
||||
'parse_method': 'auto',
|
||||
'formula_enable': True,
|
||||
'table_enable': True,
|
||||
'return_md': True,
|
||||
'return_middle_json': False,
|
||||
'return_model_output': False,
|
||||
'return_content_list': False,
|
||||
'return_images': False,
|
||||
'start_page_id': 0,
|
||||
'end_page_id': 99999
|
||||
}
|
||||
|
||||
response = requests.post(f"{base_url}/file_parse", files=files, data=data)
|
||||
print(f"File parse status: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print(f"Success! Converted {len(result.get('markdown', ''))} characters")
|
||||
print(f"Time cost: {result.get('time_cost', 'N/A')}s")
|
||||
return True
|
||||
else:
|
||||
print(f"Error: {response.text}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"File parse failed: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
print("Testing MagicDoc API...")
|
||||
|
||||
# Test health check
|
||||
print("\n1. Testing health check...")
|
||||
if not test_health_check():
|
||||
print("Health check failed. Make sure the service is running.")
|
||||
return
|
||||
|
||||
# Test file parse (if sample file exists)
|
||||
print("\n2. Testing file parse...")
|
||||
sample_files = [
|
||||
"../sample_doc/20220707_na_decision-2.docx",
|
||||
"../sample_doc/20220707_na_decision-2.pdf",
|
||||
"../sample_doc/short_doc.md"
|
||||
]
|
||||
|
||||
for sample_file in sample_files:
|
||||
if os.path.exists(sample_file):
|
||||
print(f"Testing with {sample_file}...")
|
||||
if test_file_parse(file_path=sample_file):
|
||||
print("File parse test passed!")
|
||||
break
|
||||
else:
|
||||
print(f"Sample file not found: {sample_file}")
|
||||
|
||||
print("\nTest completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
FROM python:3.12-slim
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
|
@ -8,39 +8,27 @@ RUN apt-get update && apt-get install -y \
|
|||
libreoffice \
|
||||
wget \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
RUN pip install --upgrade pip
|
||||
RUN pip install uv
|
||||
|
||||
# Configure uv and install mineru
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
RUN uv pip install --system -U "mineru[core]"
|
||||
|
||||
|
||||
|
||||
# Copy requirements first to leverage Docker cache
|
||||
# COPY requirements.txt .
|
||||
# RUN pip install huggingface_hub
|
||||
# RUN wget https://github.com/opendatalab/MinerU/raw/master/scripts/download_models_hf.py -O download_models_hf.py
|
||||
# RUN wget https://raw.githubusercontent.com/opendatalab/MinerU/refs/heads/release-1.3.1/scripts/download_models_hf.py -O download_models_hf.py
|
||||
|
||||
# RUN python download_models_hf.py
|
||||
RUN mineru-models-download -s modelscope -m pipeline
|
||||
COPY requirements.txt .
|
||||
RUN pip install huggingface_hub
|
||||
RUN wget https://github.com/opendatalab/MinerU/raw/master/scripts/download_models_hf.py -O download_models_hf.py
|
||||
RUN python download_models_hf.py
|
||||
|
||||
|
||||
|
||||
|
||||
# RUN pip install --no-cache-dir -r requirements.txt
|
||||
# RUN pip install -U magic-pdf[full]
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
RUN pip install -U magic-pdf[full]
|
||||
|
||||
|
||||
# Copy the rest of the application
|
||||
# COPY . .
|
||||
COPY . .
|
||||
|
||||
# Create storage directories
|
||||
# RUN mkdir -p storage/uploads storage/processed
|
||||
RUN mkdir -p storage/uploads storage/processed
|
||||
|
||||
# Expose the port the app runs on
|
||||
EXPOSE 8000
|
||||
|
||||
# Command to run the application
|
||||
CMD ["mineru-api", "--host", "0.0.0.0", "--port", "8000"]
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
|
|
@ -0,0 +1,201 @@
|
|||
# Mineru API Documentation
|
||||
|
||||
This document describes the FastAPI interface for the Mineru document parsing service.
|
||||
|
||||
## Overview
|
||||
|
||||
The Mineru API provides endpoints for parsing documents (PDFs, images) using advanced OCR and layout analysis. It supports both pipeline and VLM backends for different use cases.
|
||||
|
||||
## Base URL
|
||||
|
||||
```
|
||||
http://localhost:8000/api/v1/mineru
|
||||
```
|
||||
|
||||
## Endpoints
|
||||
|
||||
### 1. Health Check
|
||||
|
||||
**GET** `/health`
|
||||
|
||||
Check if the Mineru service is running.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"status": "healthy",
|
||||
"service": "mineru"
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Parse Document
|
||||
|
||||
**POST** `/parse`
|
||||
|
||||
Parse a document using Mineru's advanced parsing capabilities.
|
||||
|
||||
**Parameters:**
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `file` | File | Required | The document file to parse (PDF, PNG, JPEG, JPG) |
|
||||
| `lang` | string | "ch" | Language option ('ch', 'en', 'korean', 'japan', etc.) |
|
||||
| `backend` | string | "pipeline" | Backend for parsing ('pipeline', 'vlm-transformers', 'vlm-sglang-engine', 'vlm-sglang-client') |
|
||||
| `method` | string | "auto" | Method for parsing ('auto', 'txt', 'ocr') |
|
||||
| `server_url` | string | null | Server URL for vlm-sglang-client backend |
|
||||
| `start_page_id` | int | 0 | Start page ID for parsing |
|
||||
| `end_page_id` | int | null | End page ID for parsing |
|
||||
| `formula_enable` | boolean | true | Enable formula parsing |
|
||||
| `table_enable` | boolean | true | Enable table parsing |
|
||||
| `draw_layout_bbox` | boolean | true | Whether to draw layout bounding boxes |
|
||||
| `draw_span_bbox` | boolean | true | Whether to draw span bounding boxes |
|
||||
| `dump_md` | boolean | true | Whether to dump markdown files |
|
||||
| `dump_middle_json` | boolean | true | Whether to dump middle JSON files |
|
||||
| `dump_model_output` | boolean | true | Whether to dump model output files |
|
||||
| `dump_orig_pdf` | boolean | true | Whether to dump original PDF files |
|
||||
| `dump_content_list` | boolean | true | Whether to dump content list files |
|
||||
| `make_md_mode` | string | "MM_MD" | The mode for making markdown content |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"file_name": "document_name",
|
||||
"outputs": {
|
||||
"markdown": "/path/to/document_name.md",
|
||||
"middle_json": "/path/to/document_name_middle.json",
|
||||
"model_output": "/path/to/document_name_model.json",
|
||||
"content_list": "/path/to/document_name_content_list.json",
|
||||
"original_pdf": "/path/to/document_name_origin.pdf",
|
||||
"layout_pdf": "/path/to/document_name_layout.pdf",
|
||||
"span_pdf": "/path/to/document_name_span.pdf"
|
||||
},
|
||||
"output_directory": "/path/to/output/directory"
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Download Processed File
|
||||
|
||||
**GET** `/download/{file_path}`
|
||||
|
||||
Download a processed file from the Mineru output directory.
|
||||
|
||||
**Parameters:**
|
||||
- `file_path`: Path to the file relative to the mineru output directory
|
||||
|
||||
**Response:** File download
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Python Example
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
# Parse a document
|
||||
with open('document.pdf', 'rb') as f:
|
||||
files = {'file': ('document.pdf', f, 'application/pdf')}
|
||||
params = {
|
||||
'lang': 'ch',
|
||||
'backend': 'pipeline',
|
||||
'method': 'auto',
|
||||
'formula_enable': True,
|
||||
'table_enable': True
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://localhost:8000/api/v1/mineru/parse',
|
||||
files=files,
|
||||
params=params
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print(f"Parsed successfully: {result['file_name']}")
|
||||
|
||||
# Download the markdown file
|
||||
md_path = result['outputs']['markdown']
|
||||
download_response = requests.get(
|
||||
f'http://localhost:8000/api/v1/mineru/download/{md_path}'
|
||||
)
|
||||
|
||||
with open('output.md', 'wb') as f:
|
||||
f.write(download_response.content)
|
||||
```
|
||||
|
||||
### cURL Example
|
||||
|
||||
```bash
|
||||
# Parse a document
|
||||
curl -X POST "http://localhost:8000/api/v1/mineru/parse" \
|
||||
-F "file=@document.pdf" \
|
||||
-F "lang=ch" \
|
||||
-F "backend=pipeline" \
|
||||
-F "method=auto"
|
||||
|
||||
# Download a processed file
|
||||
curl -X GET "http://localhost:8000/api/v1/mineru/download/path/to/file.md" \
|
||||
-o downloaded_file.md
|
||||
```
|
||||
|
||||
## Backend Options
|
||||
|
||||
### Pipeline Backend
|
||||
- **Use case**: General purpose, more robust
|
||||
- **Advantages**: Better for complex layouts, supports multiple languages
|
||||
- **Command**: `backend=pipeline`
|
||||
|
||||
### VLM Backends
|
||||
- **vlm-transformers**: General purpose VLM
|
||||
- **vlm-sglang-engine**: Faster engine-based approach
|
||||
- **vlm-sglang-client**: Fastest client-based approach (requires server_url)
|
||||
|
||||
## Language Support
|
||||
|
||||
Supported languages for the pipeline backend:
|
||||
- `ch`: Chinese (Simplified)
|
||||
- `en`: English
|
||||
- `korean`: Korean
|
||||
- `japan`: Japanese
|
||||
- `chinese_cht`: Chinese (Traditional)
|
||||
- `ta`: Tamil
|
||||
- `te`: Telugu
|
||||
- `ka`: Kannada
|
||||
|
||||
## Output Files
|
||||
|
||||
The API generates various output files depending on the parameters:
|
||||
|
||||
1. **Markdown** (`.md`): Structured text content
|
||||
2. **Middle JSON** (`.json`): Intermediate parsing results
|
||||
3. **Model Output** (`.json` or `.txt`): Raw model predictions
|
||||
4. **Content List** (`.json`): Structured content list
|
||||
5. **Original PDF**: Copy of the input file
|
||||
6. **Layout PDF**: PDF with layout bounding boxes
|
||||
7. **Span PDF**: PDF with span bounding boxes
|
||||
|
||||
## Error Handling
|
||||
|
||||
The API returns appropriate HTTP status codes:
|
||||
|
||||
- `200`: Success
|
||||
- `400`: Bad request (invalid parameters, unsupported file type)
|
||||
- `404`: File not found
|
||||
- `500`: Internal server error
|
||||
|
||||
Error responses include a detail message explaining the issue.
|
||||
|
||||
## Testing
|
||||
|
||||
Use the provided test script to verify the API:
|
||||
|
||||
```bash
|
||||
python test_mineru_api.py
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The API creates unique output directories for each request to avoid conflicts
|
||||
- Temporary files are automatically cleaned up after processing
|
||||
- File downloads are restricted to the processed folder for security
|
||||
- Large files may take time to process depending on the backend and document complexity
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
# Legal Document Masker API
|
||||
|
||||
This is the backend API for the Legal Document Masking system. It provides endpoints for file upload, processing status tracking, and file download.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.8+
|
||||
- Redis (for Celery)
|
||||
|
||||
## File Storage
|
||||
|
||||
Files are stored in the following structure:
|
||||
```
|
||||
backend/
|
||||
├── storage/
|
||||
│ ├── uploads/ # Original uploaded files
|
||||
│ └── processed/ # Masked/processed files
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
### Option 1: Local Development
|
||||
|
||||
1. Create a virtual environment:
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Set up environment variables:
|
||||
Create a `.env` file in the backend directory with the following variables:
|
||||
```env
|
||||
SECRET_KEY=your-secret-key-here
|
||||
```
|
||||
|
||||
The database (SQLite) will be automatically created when you first run the application.
|
||||
|
||||
4. Start Redis (required for Celery):
|
||||
```bash
|
||||
redis-server
|
||||
```
|
||||
|
||||
5. Start Celery worker:
|
||||
```bash
|
||||
celery -A app.services.file_service worker --loglevel=info
|
||||
```
|
||||
|
||||
6. Start the FastAPI server:
|
||||
```bash
|
||||
uvicorn app.main:app --reload
|
||||
```
|
||||
|
||||
### Option 2: Docker Deployment
|
||||
|
||||
1. Build and start the services:
|
||||
```bash
|
||||
docker-compose up --build
|
||||
```
|
||||
|
||||
This will start:
|
||||
- FastAPI server on port 8000
|
||||
- Celery worker for background processing
|
||||
- Redis for task queue
|
||||
|
||||
## API Documentation
|
||||
|
||||
Once the server is running, you can access:
|
||||
- Swagger UI: `http://localhost:8000/docs`
|
||||
- ReDoc: `http://localhost:8000/redoc`
|
||||
|
||||
## API Endpoints
|
||||
|
||||
- `POST /api/v1/files/upload` - Upload a new file
|
||||
- `GET /api/v1/files` - List all files
|
||||
- `GET /api/v1/files/{file_id}` - Get file details
|
||||
- `GET /api/v1/files/{file_id}/download` - Download processed file
|
||||
- `WS /api/v1/files/ws/status/{file_id}` - WebSocket for real-time status updates
|
||||
|
||||
## Development
|
||||
|
||||
### Running Tests
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
### Code Style
|
||||
The project uses Black for code formatting:
|
||||
```bash
|
||||
black .
|
||||
```
|
||||
|
||||
### Docker Commands
|
||||
|
||||
- Start services: `docker-compose up`
|
||||
- Start in background: `docker-compose up -d`
|
||||
- Stop services: `docker-compose down`
|
||||
- View logs: `docker-compose logs -f`
|
||||
- Rebuild: `docker-compose up --build`
|
||||
|
|
@ -0,0 +1,329 @@
|
|||
from fastapi import APIRouter, HTTPException, UploadFile, File, BackgroundTasks
|
||||
from fastapi.responses import FileResponse
|
||||
from typing import List, Optional
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
import json
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
from loguru import logger
|
||||
|
||||
from ...core.config import settings
|
||||
|
||||
# Import mineru functions
|
||||
from mineru.cli.common import convert_pdf_bytes_to_bytes_by_pypdfium2, prepare_env, read_fn
|
||||
from mineru.data.data_reader_writer import FileBasedDataWriter
|
||||
from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
|
||||
from mineru.utils.enum_class import MakeMode
|
||||
from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
|
||||
from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
|
||||
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
|
||||
from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
|
||||
from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class MineruParseRequest:
|
||||
def __init__(
|
||||
self,
|
||||
lang: str = "ch",
|
||||
backend: str = "pipeline",
|
||||
method: str = "auto",
|
||||
server_url: Optional[str] = None,
|
||||
start_page_id: int = 0,
|
||||
end_page_id: Optional[int] = None,
|
||||
formula_enable: bool = True,
|
||||
table_enable: bool = True,
|
||||
draw_layout_bbox: bool = True,
|
||||
draw_span_bbox: bool = True,
|
||||
dump_md: bool = True,
|
||||
dump_middle_json: bool = True,
|
||||
dump_model_output: bool = True,
|
||||
dump_orig_pdf: bool = True,
|
||||
dump_content_list: bool = True,
|
||||
make_md_mode: str = "MM_MD"
|
||||
):
|
||||
self.lang = lang
|
||||
self.backend = backend
|
||||
self.method = method
|
||||
self.server_url = server_url
|
||||
self.start_page_id = start_page_id
|
||||
self.end_page_id = end_page_id
|
||||
self.formula_enable = formula_enable
|
||||
self.table_enable = table_enable
|
||||
self.draw_layout_bbox = draw_layout_bbox
|
||||
self.draw_span_bbox = draw_span_bbox
|
||||
self.dump_md = dump_md
|
||||
self.dump_middle_json = dump_middle_json
|
||||
self.dump_model_output = dump_model_output
|
||||
self.dump_orig_pdf = dump_orig_pdf
|
||||
self.dump_content_list = dump_content_list
|
||||
self.make_md_mode = MakeMode.MM_MD if make_md_mode == "MM_MD" else MakeMode.CONTENT_LIST
|
||||
|
||||
async def process_mineru_document(
|
||||
file: UploadFile,
|
||||
request: MineruParseRequest,
|
||||
output_dir: Path
|
||||
) -> dict:
|
||||
"""Process a single document using Mineru"""
|
||||
try:
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
|
||||
# Create temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as temp_file:
|
||||
temp_file.write(content)
|
||||
temp_file_path = Path(temp_file.name)
|
||||
|
||||
try:
|
||||
# Prepare environment
|
||||
file_name = Path(file.filename).stem
|
||||
local_image_dir, local_md_dir = prepare_env(output_dir, file_name, request.method)
|
||||
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
|
||||
|
||||
# Convert PDF bytes if needed
|
||||
if request.backend == "pipeline":
|
||||
new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(
|
||||
content, request.start_page_id, request.end_page_id
|
||||
)
|
||||
|
||||
# Analyze document
|
||||
infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = pipeline_doc_analyze(
|
||||
[new_pdf_bytes], [request.lang],
|
||||
parse_method=request.method,
|
||||
formula_enable=request.formula_enable,
|
||||
table_enable=request.table_enable
|
||||
)
|
||||
|
||||
# Process results
|
||||
model_list = infer_results[0]
|
||||
images_list = all_image_lists[0]
|
||||
pdf_doc = all_pdf_docs[0]
|
||||
_lang = lang_list[0]
|
||||
_ocr_enable = ocr_enabled_list[0]
|
||||
|
||||
middle_json = pipeline_result_to_middle_json(
|
||||
model_list, images_list, pdf_doc, image_writer, _lang, _ocr_enable, request.formula_enable
|
||||
)
|
||||
|
||||
pdf_info = middle_json["pdf_info"]
|
||||
|
||||
# Generate outputs
|
||||
outputs = {}
|
||||
|
||||
if request.draw_layout_bbox:
|
||||
draw_layout_bbox(pdf_info, new_pdf_bytes, local_md_dir, f"{file_name}_layout.pdf")
|
||||
outputs["layout_pdf"] = str(local_md_dir / f"{file_name}_layout.pdf")
|
||||
|
||||
if request.draw_span_bbox:
|
||||
draw_span_bbox(pdf_info, new_pdf_bytes, local_md_dir, f"{file_name}_span.pdf")
|
||||
outputs["span_pdf"] = str(local_md_dir / f"{file_name}_span.pdf")
|
||||
|
||||
if request.dump_orig_pdf:
|
||||
md_writer.write(f"{file_name}_origin.pdf", new_pdf_bytes)
|
||||
outputs["original_pdf"] = str(local_md_dir / f"{file_name}_origin.pdf")
|
||||
|
||||
if request.dump_md:
|
||||
image_dir = str(os.path.basename(local_image_dir))
|
||||
md_content_str = pipeline_union_make(pdf_info, request.make_md_mode, image_dir)
|
||||
md_writer.write_string(f"{file_name}.md", md_content_str)
|
||||
outputs["markdown"] = str(local_md_dir / f"{file_name}.md")
|
||||
|
||||
if request.dump_content_list:
|
||||
image_dir = str(os.path.basename(local_image_dir))
|
||||
content_list = pipeline_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
|
||||
md_writer.write_string(
|
||||
f"{file_name}_content_list.json",
|
||||
json.dumps(content_list, ensure_ascii=False, indent=4)
|
||||
)
|
||||
outputs["content_list"] = str(local_md_dir / f"{file_name}_content_list.json")
|
||||
|
||||
if request.dump_middle_json:
|
||||
md_writer.write_string(
|
||||
f"{file_name}_middle.json",
|
||||
json.dumps(middle_json, ensure_ascii=False, indent=4)
|
||||
)
|
||||
outputs["middle_json"] = str(local_md_dir / f"{file_name}_middle.json")
|
||||
|
||||
if request.dump_model_output:
|
||||
md_writer.write_string(
|
||||
f"{file_name}_model.json",
|
||||
json.dumps(model_list, ensure_ascii=False, indent=4)
|
||||
)
|
||||
outputs["model_output"] = str(local_md_dir / f"{file_name}_model.json")
|
||||
|
||||
else:
|
||||
# VLM backend
|
||||
if request.backend.startswith("vlm-"):
|
||||
backend = request.backend[4:]
|
||||
|
||||
middle_json, infer_result = vlm_doc_analyze(
|
||||
content, image_writer=image_writer,
|
||||
backend=backend, server_url=request.server_url
|
||||
)
|
||||
|
||||
pdf_info = middle_json["pdf_info"]
|
||||
|
||||
# Generate outputs for VLM
|
||||
outputs = {}
|
||||
|
||||
if request.draw_layout_bbox:
|
||||
draw_layout_bbox(pdf_info, content, local_md_dir, f"{file_name}_layout.pdf")
|
||||
outputs["layout_pdf"] = str(local_md_dir / f"{file_name}_layout.pdf")
|
||||
|
||||
if request.dump_orig_pdf:
|
||||
md_writer.write(f"{file_name}_origin.pdf", content)
|
||||
outputs["original_pdf"] = str(local_md_dir / f"{file_name}_origin.pdf")
|
||||
|
||||
if request.dump_md:
|
||||
image_dir = str(os.path.basename(local_image_dir))
|
||||
md_content_str = vlm_union_make(pdf_info, request.make_md_mode, image_dir)
|
||||
md_writer.write_string(f"{file_name}.md", md_content_str)
|
||||
outputs["markdown"] = str(local_md_dir / f"{file_name}.md")
|
||||
|
||||
if request.dump_content_list:
|
||||
image_dir = str(os.path.basename(local_image_dir))
|
||||
content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
|
||||
md_writer.write_string(
|
||||
f"{file_name}_content_list.json",
|
||||
json.dumps(content_list, ensure_ascii=False, indent=4)
|
||||
)
|
||||
outputs["content_list"] = str(local_md_dir / f"{file_name}_content_list.json")
|
||||
|
||||
if request.dump_middle_json:
|
||||
md_writer.write_string(
|
||||
f"{file_name}_middle.json",
|
||||
json.dumps(middle_json, ensure_ascii=False, indent=4)
|
||||
)
|
||||
outputs["middle_json"] = str(local_md_dir / f"{file_name}_middle.json")
|
||||
|
||||
if request.dump_model_output:
|
||||
model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
|
||||
md_writer.write_string(f"{file_name}_model_output.txt", model_output)
|
||||
outputs["model_output"] = str(local_md_dir / f"{file_name}_model_output.txt")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_name": file_name,
|
||||
"outputs": outputs,
|
||||
"output_directory": str(local_md_dir)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
if temp_file_path.exists():
|
||||
temp_file_path.unlink()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing document: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
|
||||
|
||||
@router.post("/parse")
|
||||
async def parse_document(
|
||||
file: UploadFile = File(...),
|
||||
lang: str = "ch",
|
||||
backend: str = "pipeline",
|
||||
method: str = "auto",
|
||||
server_url: Optional[str] = None,
|
||||
start_page_id: int = 0,
|
||||
end_page_id: Optional[int] = None,
|
||||
formula_enable: bool = True,
|
||||
table_enable: bool = True,
|
||||
draw_layout_bbox: bool = True,
|
||||
draw_span_bbox: bool = True,
|
||||
dump_md: bool = True,
|
||||
dump_middle_json: bool = True,
|
||||
dump_model_output: bool = True,
|
||||
dump_orig_pdf: bool = True,
|
||||
dump_content_list: bool = True,
|
||||
make_md_mode: str = "MM_MD"
|
||||
):
|
||||
"""
|
||||
Parse a document using Mineru API
|
||||
|
||||
Parameters:
|
||||
- file: The document file to parse (PDF, image, etc.)
|
||||
- lang: Language option (default: 'ch')
|
||||
- backend: Backend for parsing ('pipeline', 'vlm-transformers', 'vlm-sglang-engine', 'vlm-sglang-client')
|
||||
- method: Method for parsing ('auto', 'txt', 'ocr')
|
||||
- server_url: Server URL for vlm-sglang-client backend
|
||||
- start_page_id: Start page ID for parsing
|
||||
- end_page_id: End page ID for parsing
|
||||
- formula_enable: Enable formula parsing
|
||||
- table_enable: Enable table parsing
|
||||
- draw_layout_bbox: Whether to draw layout bounding boxes
|
||||
- draw_span_bbox: Whether to draw span bounding boxes
|
||||
- dump_md: Whether to dump markdown files
|
||||
- dump_middle_json: Whether to dump middle JSON files
|
||||
- dump_model_output: Whether to dump model output files
|
||||
- dump_orig_pdf: Whether to dump original PDF files
|
||||
- dump_content_list: Whether to dump content list files
|
||||
- make_md_mode: The mode for making markdown content
|
||||
"""
|
||||
|
||||
# Validate file type
|
||||
allowed_extensions = {".pdf", ".png", ".jpeg", ".jpg"}
|
||||
file_extension = Path(file.filename).suffix.lower()
|
||||
if file_extension not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File type not allowed. Allowed types: {', '.join(allowed_extensions)}"
|
||||
)
|
||||
|
||||
# Create request object
|
||||
request = MineruParseRequest(
|
||||
lang=lang,
|
||||
backend=backend,
|
||||
method=method,
|
||||
server_url=server_url,
|
||||
start_page_id=start_page_id,
|
||||
end_page_id=end_page_id,
|
||||
formula_enable=formula_enable,
|
||||
table_enable=table_enable,
|
||||
draw_layout_bbox=draw_layout_bbox,
|
||||
draw_span_bbox=draw_span_bbox,
|
||||
dump_md=dump_md,
|
||||
dump_middle_json=dump_middle_json,
|
||||
dump_model_output=dump_model_output,
|
||||
dump_orig_pdf=dump_orig_pdf,
|
||||
dump_content_list=dump_content_list,
|
||||
make_md_mode=make_md_mode
|
||||
)
|
||||
|
||||
# Create output directory
|
||||
output_dir = settings.PROCESSED_FOLDER / "mineru" / str(uuid.uuid4())
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process document
|
||||
result = await process_mineru_document(file, request, output_dir)
|
||||
|
||||
return result
|
||||
|
||||
@router.get("/download/{file_path:path}")
|
||||
async def download_processed_file(file_path: str):
|
||||
"""Download a processed file from the mineru output directory"""
|
||||
try:
|
||||
# Construct the full path
|
||||
full_path = settings.PROCESSED_FOLDER / "mineru" / file_path
|
||||
|
||||
# Security check: ensure the path is within the processed folder
|
||||
if not str(full_path).startswith(str(settings.PROCESSED_FOLDER)):
|
||||
raise HTTPException(status_code=400, detail="Invalid file path")
|
||||
|
||||
if not full_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return FileResponse(
|
||||
path=str(full_path),
|
||||
filename=full_path.name,
|
||||
media_type="application/octet-stream"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error downloading file: {str(e)}")
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint for mineru service"""
|
||||
return {"status": "healthy", "service": "mineru"}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# API Settings
|
||||
API_V1_STR: str = "/api/v1"
|
||||
PROJECT_NAME: str = "Legal Document Masker API"
|
||||
|
||||
# Security
|
||||
SECRET_KEY: str = "your-secret-key-here" # Change in production
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
|
||||
|
||||
# Database
|
||||
BASE_DIR: Path = Path(__file__).parent.parent.parent
|
||||
DATABASE_URL: str = f"sqlite:///{BASE_DIR}/storage/legal_doc_masker.db"
|
||||
|
||||
# File Storage
|
||||
UPLOAD_FOLDER: Path = BASE_DIR / "storage" / "uploads"
|
||||
PROCESSED_FOLDER: Path = BASE_DIR / "storage" / "processed"
|
||||
MAX_FILE_SIZE: int = 50 * 1024 * 1024 # 50MB
|
||||
ALLOWED_EXTENSIONS: set = {"pdf", "docx", "doc", "md"}
|
||||
|
||||
# Celery
|
||||
CELERY_BROKER_URL: str = "redis://redis:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://redis:6379/0"
|
||||
|
||||
# Ollama API settings
|
||||
OLLAMA_API_URL: str = "https://api.ollama.com"
|
||||
OLLAMA_API_KEY: str = ""
|
||||
OLLAMA_MODEL: str = "llama2"
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = "INFO"
|
||||
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
LOG_DATE_FORMAT: str = "%Y-%m-%d %H:%M:%S"
|
||||
LOG_FILE: str = "app.log"
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "allow"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Create storage directories if they don't exist
|
||||
self.UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True)
|
||||
self.PROCESSED_FOLDER.mkdir(parents=True, exist_ok=True)
|
||||
# Create storage directory for database
|
||||
(self.BASE_DIR / "storage").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
settings = Settings()
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
import logging.config
|
||||
# from config.settings import settings
|
||||
from .settings import settings
|
||||
|
||||
LOGGING_CONFIG = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"standard": {
|
||||
"format": settings.LOG_FORMAT,
|
||||
"datefmt": settings.LOG_DATE_FORMAT
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "standard",
|
||||
"level": settings.LOG_LEVEL,
|
||||
"stream": "ext://sys.stdout"
|
||||
},
|
||||
"file": {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": "standard",
|
||||
"level": settings.LOG_LEVEL,
|
||||
"filename": settings.LOG_FILE,
|
||||
"mode": "a",
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
"": { # root logger
|
||||
"handlers": ["console", "file"],
|
||||
"level": settings.LOG_LEVEL,
|
||||
"propagate": True
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def setup_logging():
|
||||
"""Initialize logging configuration"""
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from .config import settings
|
||||
|
||||
# Create SQLite engine with check_same_thread=False for FastAPI
|
||||
engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
connect_args={"check_same_thread": False}
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
class Document:
|
||||
def __init__(self, file_path):
|
||||
self.file_path = file_path
|
||||
self.content = ""
|
||||
|
||||
def load(self):
|
||||
with open(self.file_path, 'r') as file:
|
||||
self.content = file.read()
|
||||
|
||||
def save(self, target_path):
|
||||
with open(target_path, 'w') as file:
|
||||
file.write(self.content)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
from .document_processor import DocumentProcessor
|
||||
from .processors import (
|
||||
TxtDocumentProcessor,
|
||||
DocxDocumentProcessor,
|
||||
PdfDocumentProcessor,
|
||||
MarkdownDocumentProcessor
|
||||
)
|
||||
|
||||
class DocumentProcessorFactory:
|
||||
@staticmethod
|
||||
def create_processor(input_path: str, output_path: str) -> Optional[DocumentProcessor]:
|
||||
file_extension = os.path.splitext(input_path)[1].lower()
|
||||
|
||||
processors = {
|
||||
'.txt': TxtDocumentProcessor,
|
||||
'.docx': DocxDocumentProcessor,
|
||||
'.doc': DocxDocumentProcessor,
|
||||
'.pdf': PdfDocumentProcessor,
|
||||
'.md': MarkdownDocumentProcessor,
|
||||
'.markdown': MarkdownDocumentProcessor
|
||||
}
|
||||
|
||||
processor_class = processors.get(file_extension)
|
||||
if processor_class:
|
||||
return processor_class(input_path, output_path)
|
||||
return None
|
||||
|
|
@ -0,0 +1,192 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from ..prompts.masking_prompts import get_masking_mapping_prompt
|
||||
import logging
|
||||
import json
|
||||
from ..services.ollama_client import OllamaClient
|
||||
from ...core.config import settings
|
||||
from ..utils.json_extractor import LLMJsonExtractor
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentProcessor(ABC):
|
||||
def __init__(self):
|
||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||
self.max_chunk_size = 1000 # Maximum number of characters per chunk
|
||||
self.max_retries = 3 # Maximum number of retries for mapping generation
|
||||
|
||||
@abstractmethod
|
||||
def read_content(self) -> str:
|
||||
"""Read document content"""
|
||||
pass
|
||||
|
||||
def _split_into_chunks(self, sentences: list[str]) -> list[str]:
|
||||
"""Split sentences into chunks that don't exceed max_chunk_size"""
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
for sentence in sentences:
|
||||
if not sentence.strip():
|
||||
continue
|
||||
|
||||
# If adding this sentence would exceed the limit, save current chunk and start new one
|
||||
if len(current_chunk) + len(sentence) > self.max_chunk_size and current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = sentence
|
||||
else:
|
||||
if current_chunk:
|
||||
current_chunk += "。" + sentence
|
||||
else:
|
||||
current_chunk = sentence
|
||||
|
||||
# Add the last chunk if it's not empty
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def _validate_mapping_format(self, mapping: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate that the mapping follows the required format:
|
||||
{
|
||||
"原文1": "脱敏后1",
|
||||
"原文2": "脱敏后2",
|
||||
...
|
||||
}
|
||||
"""
|
||||
if not isinstance(mapping, dict):
|
||||
logger.warning("Mapping is not a dictionary")
|
||||
return False
|
||||
|
||||
# Check if any key or value is not a string
|
||||
for key, value in mapping.items():
|
||||
if not isinstance(key, str) or not isinstance(value, str):
|
||||
logger.warning(f"Invalid mapping format - key or value is not a string: {key}: {value}")
|
||||
return False
|
||||
|
||||
# Check if the mapping has any nested structures
|
||||
if any(isinstance(v, (dict, list)) for v in mapping.values()):
|
||||
logger.warning("Invalid mapping format - contains nested structures")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _build_mapping(self, chunk: str) -> Dict[str, str]:
|
||||
"""Build mapping for a single chunk of text with retry logic"""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
formatted_prompt = get_masking_mapping_prompt(chunk)
|
||||
logger.info(f"Calling ollama to generate mapping for chunk (attempt {attempt + 1}/{self.max_retries}): {formatted_prompt}")
|
||||
response = self.ollama_client.generate(formatted_prompt)
|
||||
logger.info(f"Raw response from LLM: {response}")
|
||||
|
||||
# Parse the JSON response into a dictionary
|
||||
mapping = LLMJsonExtractor.parse_raw_json_str(response)
|
||||
logger.info(f"Parsed mapping: {mapping}")
|
||||
|
||||
if mapping and self._validate_mapping_format(mapping):
|
||||
return mapping
|
||||
else:
|
||||
logger.warning(f"Invalid mapping format received on attempt {attempt + 1}, retrying...")
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating mapping on attempt {attempt + 1}: {e}")
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.info("Retrying...")
|
||||
else:
|
||||
logger.error("Max retries reached, returning empty mapping")
|
||||
return {}
|
||||
|
||||
def _apply_mapping(self, text: str, mapping: Dict[str, str]) -> str:
|
||||
"""Apply the mapping to replace sensitive information"""
|
||||
masked_text = text
|
||||
for original, masked in mapping.items():
|
||||
# Ensure masked value is a string
|
||||
if isinstance(masked, dict):
|
||||
# If it's a dict, use the first value or a default
|
||||
masked = next(iter(masked.values()), "某")
|
||||
elif not isinstance(masked, str):
|
||||
# If it's not a string, convert to string or use default
|
||||
masked = str(masked) if masked is not None else "某"
|
||||
masked_text = masked_text.replace(original, masked)
|
||||
return masked_text
|
||||
|
||||
def _get_next_suffix(self, value: str) -> str:
|
||||
"""Get the next available suffix for a value that already has a suffix"""
|
||||
# Define the sequence of suffixes
|
||||
suffixes = ['甲', '乙', '丙', '丁', '戊', '己', '庚', '辛', '壬', '癸']
|
||||
|
||||
# Check if the value already has a suffix
|
||||
for suffix in suffixes:
|
||||
if value.endswith(suffix):
|
||||
# Find the next suffix in the sequence
|
||||
current_index = suffixes.index(suffix)
|
||||
if current_index + 1 < len(suffixes):
|
||||
return value[:-1] + suffixes[current_index + 1]
|
||||
else:
|
||||
# If we've used all suffixes, start over with the first one
|
||||
return value[:-1] + suffixes[0]
|
||||
|
||||
# If no suffix found, return the value with the first suffix
|
||||
return value + '甲'
|
||||
|
||||
def _merge_mappings(self, existing: Dict[str, str], new: Dict[str, str]) -> Dict[str, str]:
|
||||
"""
|
||||
Merge two mappings following the rules:
|
||||
1. If key exists in existing, keep existing value
|
||||
2. If value exists in existing:
|
||||
- If value ends with a suffix (甲乙丙丁...), add next suffix
|
||||
- If no suffix, add '甲'
|
||||
"""
|
||||
result = existing.copy()
|
||||
|
||||
# Get all existing values
|
||||
existing_values = set(result.values())
|
||||
|
||||
for key, value in new.items():
|
||||
if key in result:
|
||||
# Rule 1: Keep existing value if key exists
|
||||
continue
|
||||
|
||||
if value in existing_values:
|
||||
# Rule 2: Handle duplicate values
|
||||
new_value = self._get_next_suffix(value)
|
||||
result[key] = new_value
|
||||
existing_values.add(new_value)
|
||||
else:
|
||||
# No conflict, add as is
|
||||
result[key] = value
|
||||
existing_values.add(value)
|
||||
|
||||
return result
|
||||
|
||||
def process_content(self, content: str) -> str:
|
||||
"""Process document content by masking sensitive information"""
|
||||
# Split content into sentences
|
||||
sentences = content.split("。")
|
||||
|
||||
# Split sentences into manageable chunks
|
||||
chunks = self._split_into_chunks(sentences)
|
||||
logger.info(f"Split content into {len(chunks)} chunks")
|
||||
|
||||
# Build mapping for each chunk
|
||||
combined_mapping = {}
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
|
||||
chunk_mapping = self._build_mapping(chunk)
|
||||
if chunk_mapping: # Only update if we got a valid mapping
|
||||
combined_mapping = self._merge_mappings(combined_mapping, chunk_mapping)
|
||||
else:
|
||||
logger.warning(f"Failed to generate mapping for chunk {i+1}")
|
||||
|
||||
# Apply the combined mapping to the entire content
|
||||
masked_content = self._apply_mapping(content, combined_mapping)
|
||||
logger.info("Successfully masked content")
|
||||
|
||||
return masked_content
|
||||
|
||||
@abstractmethod
|
||||
def save_content(self, content: str) -> None:
|
||||
"""Save processed content"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from .txt_processor import TxtDocumentProcessor
|
||||
from .docx_processor import DocxDocumentProcessor
|
||||
from .pdf_processor import PdfDocumentProcessor
|
||||
from .md_processor import MarkdownDocumentProcessor
|
||||
|
||||
__all__ = ['TxtDocumentProcessor', 'DocxDocumentProcessor', 'PdfDocumentProcessor', 'MarkdownDocumentProcessor']
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
import os
|
||||
import docx
|
||||
from ...document_handlers.document_processor import DocumentProcessor
|
||||
from magic_pdf.data.data_reader_writer import FileBasedDataWriter
|
||||
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
||||
from magic_pdf.data.read_api import read_local_office
|
||||
import logging
|
||||
from ...services.ollama_client import OllamaClient
|
||||
from ...config import settings
|
||||
from ...prompts.masking_prompts import get_masking_mapping_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocxDocumentProcessor(DocumentProcessor):
|
||||
def __init__(self, input_path: str, output_path: str):
|
||||
super().__init__() # Call parent class's __init__
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
self.output_dir = os.path.dirname(output_path)
|
||||
self.name_without_suff = os.path.splitext(os.path.basename(input_path))[0]
|
||||
|
||||
# Setup output directories
|
||||
self.local_image_dir = os.path.join(self.output_dir, "images")
|
||||
self.image_dir = os.path.basename(self.local_image_dir)
|
||||
os.makedirs(self.local_image_dir, exist_ok=True)
|
||||
|
||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||
|
||||
def read_content(self) -> str:
|
||||
try:
|
||||
# Initialize writers
|
||||
image_writer = FileBasedDataWriter(self.local_image_dir)
|
||||
md_writer = FileBasedDataWriter(self.output_dir)
|
||||
|
||||
# Create Dataset Instance and process
|
||||
ds = read_local_office(self.input_path)[0]
|
||||
pipe_result = ds.apply(doc_analyze, ocr=True).pipe_txt_mode(image_writer)
|
||||
|
||||
# Generate markdown
|
||||
md_content = pipe_result.get_markdown(self.image_dir)
|
||||
pipe_result.dump_md(md_writer, f"{self.name_without_suff}.md", self.image_dir)
|
||||
|
||||
return md_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting DOCX to MD: {e}")
|
||||
raise
|
||||
|
||||
# def process_content(self, content: str) -> str:
|
||||
# logger.info("Processing DOCX content")
|
||||
|
||||
# # Split content into sentences and apply masking
|
||||
# sentences = content.split("。")
|
||||
# final_md = ""
|
||||
# for sentence in sentences:
|
||||
# if sentence.strip(): # Only process non-empty sentences
|
||||
# formatted_prompt = get_masking_mapping_prompt(sentence)
|
||||
# logger.info("Calling ollama to generate response, prompt: %s", formatted_prompt)
|
||||
# response = self.ollama_client.generate(formatted_prompt)
|
||||
# logger.info(f"Response generated: {response}")
|
||||
# final_md += response + "。"
|
||||
|
||||
# return final_md
|
||||
|
||||
def save_content(self, content: str) -> None:
|
||||
# Ensure output path has .md extension
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
base_name = os.path.splitext(os.path.basename(self.output_path))[0]
|
||||
md_output_path = os.path.join(output_dir, f"{base_name}.md")
|
||||
|
||||
logger.info(f"Saving masked content to: {md_output_path}")
|
||||
try:
|
||||
with open(md_output_path, 'w', encoding='utf-8') as file:
|
||||
file.write(content)
|
||||
logger.info(f"Successfully saved content to {md_output_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving content: {e}")
|
||||
raise
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
import os
|
||||
from ...document_handlers.document_processor import DocumentProcessor
|
||||
from ...services.ollama_client import OllamaClient
|
||||
import logging
|
||||
from ...config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MarkdownDocumentProcessor(DocumentProcessor):
|
||||
def __init__(self, input_path: str, output_path: str):
|
||||
super().__init__() # Call parent class's __init__
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||
|
||||
def read_content(self) -> str:
|
||||
"""Read markdown content from file"""
|
||||
try:
|
||||
with open(self.input_path, 'r', encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
logger.info(f"Successfully read markdown content from {self.input_path}")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading markdown file {self.input_path}: {e}")
|
||||
raise
|
||||
|
||||
def save_content(self, content: str) -> None:
|
||||
"""Save processed markdown content"""
|
||||
try:
|
||||
# Ensure output directory exists
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
with open(self.output_path, 'w', encoding='utf-8') as file:
|
||||
file.write(content)
|
||||
logger.info(f"Successfully saved masked content to {self.output_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving content to {self.output_path}: {e}")
|
||||
raise
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
import os
|
||||
import PyPDF2
|
||||
from ...document_handlers.document_processor import DocumentProcessor
|
||||
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
|
||||
from magic_pdf.data.dataset import PymuDocDataset
|
||||
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
||||
from magic_pdf.config.enums import SupportedPdfParseMethod
|
||||
from ...prompts.masking_prompts import get_masking_prompt, get_masking_mapping_prompt
|
||||
import logging
|
||||
from ...services.ollama_client import OllamaClient
|
||||
from ...config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PdfDocumentProcessor(DocumentProcessor):
|
||||
def __init__(self, input_path: str, output_path: str):
|
||||
super().__init__() # Call parent class's __init__
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
self.output_dir = os.path.dirname(output_path)
|
||||
self.name_without_suff = os.path.splitext(os.path.basename(input_path))[0]
|
||||
|
||||
# Setup output directories
|
||||
self.local_image_dir = os.path.join(self.output_dir, "images")
|
||||
self.image_dir = os.path.basename(self.local_image_dir)
|
||||
os.makedirs(self.local_image_dir, exist_ok=True)
|
||||
|
||||
# Setup work directory under output directory
|
||||
self.work_dir = os.path.join(
|
||||
os.path.dirname(output_path),
|
||||
".work",
|
||||
os.path.splitext(os.path.basename(input_path))[0]
|
||||
)
|
||||
os.makedirs(self.work_dir, exist_ok=True)
|
||||
|
||||
self.work_local_image_dir = os.path.join(self.work_dir, "images")
|
||||
self.work_image_dir = os.path.basename(self.work_local_image_dir)
|
||||
os.makedirs(self.work_local_image_dir, exist_ok=True)
|
||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||
|
||||
def read_content(self) -> str:
|
||||
logger.info("Starting PDF content processing")
|
||||
|
||||
# Read the PDF file
|
||||
with open(self.input_path, 'rb') as file:
|
||||
content = file.read()
|
||||
|
||||
# Initialize writers
|
||||
image_writer = FileBasedDataWriter(self.work_local_image_dir)
|
||||
md_writer = FileBasedDataWriter(self.work_dir)
|
||||
|
||||
# Create Dataset Instance
|
||||
ds = PymuDocDataset(content)
|
||||
|
||||
logger.info("Classifying PDF type: %s", ds.classify())
|
||||
# Process based on PDF type
|
||||
if ds.classify() == SupportedPdfParseMethod.OCR:
|
||||
infer_result = ds.apply(doc_analyze, ocr=True)
|
||||
pipe_result = infer_result.pipe_ocr_mode(image_writer)
|
||||
else:
|
||||
infer_result = ds.apply(doc_analyze, ocr=False)
|
||||
pipe_result = infer_result.pipe_txt_mode(image_writer)
|
||||
|
||||
logger.info("Generating all outputs")
|
||||
# Generate all outputs
|
||||
infer_result.draw_model(os.path.join(self.work_dir, f"{self.name_without_suff}_model.pdf"))
|
||||
model_inference_result = infer_result.get_infer_res()
|
||||
|
||||
pipe_result.draw_layout(os.path.join(self.work_dir, f"{self.name_without_suff}_layout.pdf"))
|
||||
pipe_result.draw_span(os.path.join(self.work_dir, f"{self.name_without_suff}_spans.pdf"))
|
||||
|
||||
md_content = pipe_result.get_markdown(self.work_image_dir)
|
||||
pipe_result.dump_md(md_writer, f"{self.name_without_suff}.md", self.work_image_dir)
|
||||
|
||||
content_list = pipe_result.get_content_list(self.work_image_dir)
|
||||
pipe_result.dump_content_list(md_writer, f"{self.name_without_suff}_content_list.json", self.work_image_dir)
|
||||
|
||||
middle_json = pipe_result.get_middle_json()
|
||||
pipe_result.dump_middle_json(md_writer, f'{self.name_without_suff}_middle.json')
|
||||
|
||||
return md_content
|
||||
|
||||
# def process_content(self, content: str) -> str:
|
||||
# logger.info("Starting content masking process")
|
||||
# sentences = content.split("。")
|
||||
# final_md = ""
|
||||
# for sentence in sentences:
|
||||
# if not sentence.strip(): # Skip empty sentences
|
||||
# continue
|
||||
# formatted_prompt = get_masking_mapping_prompt(sentence)
|
||||
# logger.info("Calling ollama to generate response, prompt: %s", formatted_prompt)
|
||||
# response = self.ollama_client.generate(formatted_prompt)
|
||||
# logger.info(f"Response generated: {response}")
|
||||
# final_md += response + "。"
|
||||
# return final_md
|
||||
|
||||
def save_content(self, content: str) -> None:
|
||||
# Ensure output path has .md extension
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
base_name = os.path.splitext(os.path.basename(self.output_path))[0]
|
||||
md_output_path = os.path.join(output_dir, f"{base_name}.md")
|
||||
|
||||
logger.info(f"Saving masked content to: {md_output_path}")
|
||||
with open(md_output_path, 'w', encoding='utf-8') as file:
|
||||
file.write(content)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from ...document_handlers.document_processor import DocumentProcessor
|
||||
from ...services.ollama_client import OllamaClient
|
||||
import logging
|
||||
from ...prompts.masking_prompts import get_masking_prompt
|
||||
from ...config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class TxtDocumentProcessor(DocumentProcessor):
|
||||
def __init__(self, input_path: str, output_path: str):
|
||||
super().__init__()
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
self.ollama_client = OllamaClient(model_name=settings.OLLAMA_MODEL, base_url=settings.OLLAMA_API_URL)
|
||||
|
||||
def read_content(self) -> str:
|
||||
with open(self.input_path, 'r', encoding='utf-8') as file:
|
||||
return file.read()
|
||||
|
||||
# def process_content(self, content: str) -> str:
|
||||
|
||||
# formatted_prompt = get_masking_prompt(content)
|
||||
# response = self.ollama_client.generate(formatted_prompt)
|
||||
# logger.debug(f"Processed content: {response}")
|
||||
# return response
|
||||
|
||||
def save_content(self, content: str) -> None:
|
||||
with open(self.output_path, 'w', encoding='utf-8') as file:
|
||||
file.write(content)
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
import logging
|
||||
from ..document_handlers.document_factory import DocumentProcessorFactory
|
||||
from ..services.ollama_client import OllamaClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentService:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def process_document(self, input_path: str, output_path: str) -> bool:
|
||||
try:
|
||||
processor = DocumentProcessorFactory.create_processor(input_path, output_path)
|
||||
if not processor:
|
||||
logger.error(f"Unsupported file format: {input_path}")
|
||||
return False
|
||||
|
||||
# Read content
|
||||
content = processor.read_content()
|
||||
|
||||
# Process with Ollama
|
||||
masked_content = processor.process_content(content)
|
||||
|
||||
# Save processed content
|
||||
processor.save_content(masked_content)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document {input_path}: {str(e)}")
|
||||
return False
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
def read_file(file_path):
|
||||
with open(file_path, 'r') as file:
|
||||
return file.read()
|
||||
|
||||
def write_file(file_path, content):
|
||||
with open(file_path, 'w') as file:
|
||||
file.write(content)
|
||||
|
||||
def file_exists(file_path):
|
||||
import os
|
||||
return os.path.isfile(file_path)
|
||||
|
||||
def delete_file(file_path):
|
||||
import os
|
||||
if file_exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
def list_files_in_directory(directory_path):
|
||||
import os
|
||||
return [f for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from .core.config import settings
|
||||
from .api.endpoints import mineru
|
||||
from .core.database import engine, Base
|
||||
|
||||
# Create database tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
||||
)
|
||||
|
||||
# Set up CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, replace with specific origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
# app.include_router(
|
||||
# files.router,
|
||||
# prefix=f"{settings.API_V1_STR}/files",
|
||||
# tags=["files"]
|
||||
# )
|
||||
|
||||
app.include_router(
|
||||
mineru.router,
|
||||
prefix=f"{settings.API_V1_STR}/mineru",
|
||||
tags=["mineru"]
|
||||
)
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Welcome to Legal Document Masker API"}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
from sqlalchemy import Column, String, DateTime, Text
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from ..core.database import Base
|
||||
|
||||
class FileStatus(str):
|
||||
NOT_STARTED = "not_started"
|
||||
PROCESSING = "processing"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = "files"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
filename = Column(String(255), nullable=False)
|
||||
original_path = Column(String(255), nullable=False)
|
||||
processed_path = Column(String(255))
|
||||
status = Column(String(20), nullable=False, default=FileStatus.NOT_STARTED)
|
||||
error_message = Column(Text)
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
class FileBase(BaseModel):
|
||||
filename: str
|
||||
status: str
|
||||
error_message: Optional[str] = None
|
||||
|
||||
class FileResponse(FileBase):
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class FileList(BaseModel):
|
||||
files: list[FileResponse]
|
||||
total: int
|
||||
|
|
@ -1,27 +1,37 @@
|
|||
version: '3.8'
|
||||
|
||||
services:
|
||||
mineru-api:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
platform: linux/arm64
|
||||
api:
|
||||
build: .
|
||||
ports:
|
||||
- "8001:8000"
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- ./storage/uploads:/app/storage/uploads
|
||||
- ./storage/processed:/app/storage/processed
|
||||
- ./storage:/app/storage
|
||||
- ./legal_doc_masker.db:/app/legal_doc_masker.db
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
- MINERU_MODEL_SOURCE=local
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||
depends_on:
|
||||
- redis
|
||||
|
||||
volumes:
|
||||
uploads:
|
||||
processed:
|
||||
celery_worker:
|
||||
build: .
|
||||
command: celery -A app.services.file_service worker --loglevel=info
|
||||
volumes:
|
||||
- ./storage:/app/storage
|
||||
- ./legal_doc_masker.db:/app/legal_doc_masker.db
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||
depends_on:
|
||||
- redis
|
||||
- api
|
||||
|
||||
redis:
|
||||
image: redis:alpine
|
||||
ports:
|
||||
- "6379:6379"
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"name": "mineru",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue