legal-doc-masker/backend/app/core/services/ollama_client.py

241 lines
10 KiB
Python

import requests
import logging
from typing import Dict, Any, Optional, Callable, Union
from ..utils.json_extractor import LLMJsonExtractor
from ..utils.llm_validator import LLMResponseValidator
logger = logging.getLogger(__name__)
class OllamaClient:
def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3):
"""Initialize Ollama client.
Args:
model_name (str): Name of the Ollama model to use
base_url (str): Ollama server base URL
max_retries (int): Maximum number of retries for failed requests
"""
self.model_name = model_name
self.base_url = base_url
self.max_retries = max_retries
self.headers = {"Content-Type": "application/json"}
def generate(self,
prompt: str,
strip_think: bool = True,
validation_schema: Optional[Dict[str, Any]] = None,
response_type: Optional[str] = None,
return_parsed: bool = False) -> Union[str, Dict[str, Any]]:
"""Process a document using the Ollama API with optional validation and retry.
Args:
prompt (str): The prompt to send to the model
strip_think (bool): Whether to strip thinking tags from response
validation_schema (Optional[Dict]): JSON schema for validation
response_type (Optional[str]): Type of response for validation ('entity_extraction', 'entity_linkage', etc.)
return_parsed (bool): Whether to return parsed JSON instead of raw string
Returns:
Union[str, Dict[str, Any]]: Response from the model (raw string or parsed JSON)
Raises:
RequestException: If the API call fails after all retries
ValueError: If validation fails after all retries
"""
for attempt in range(self.max_retries):
try:
# Make the API call
raw_response = self._make_api_call(prompt, strip_think)
# If no validation required, return the response
if not validation_schema and not response_type and not return_parsed:
return raw_response
# Parse JSON if needed
if return_parsed or validation_schema or response_type:
parsed_response = LLMJsonExtractor.parse_raw_json_str(raw_response)
if not parsed_response:
logger.warning(f"Failed to parse JSON on attempt {attempt + 1}/{self.max_retries}")
if attempt < self.max_retries - 1:
continue
else:
raise ValueError("Failed to parse JSON response after all retries")
# Validate if schema or response type provided
if validation_schema:
if not self._validate_with_schema(parsed_response, validation_schema):
logger.warning(f"Schema validation failed on attempt {attempt + 1}/{self.max_retries}")
if attempt < self.max_retries - 1:
continue
else:
raise ValueError("Schema validation failed after all retries")
if response_type:
if not LLMResponseValidator.validate_response_by_type(parsed_response, response_type):
logger.warning(f"Response type validation failed on attempt {attempt + 1}/{self.max_retries}")
if attempt < self.max_retries - 1:
continue
else:
raise ValueError(f"Response type validation failed after all retries")
# Return parsed response if requested
if return_parsed:
return parsed_response
else:
return raw_response
return raw_response
except requests.exceptions.RequestException as e:
logger.error(f"API call failed on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
logger.error("Max retries reached, raising exception")
raise
except Exception as e:
logger.error(f"Unexpected error on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
if attempt < self.max_retries - 1:
logger.info("Retrying...")
else:
logger.error("Max retries reached, raising exception")
raise
# This should never be reached, but just in case
raise Exception("Unexpected error: max retries exceeded without proper exception handling")
def generate_with_validation(self,
prompt: str,
response_type: str,
strip_think: bool = True,
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
"""Generate response with automatic validation based on response type.
Args:
prompt (str): The prompt to send to the model
response_type (str): Type of response for validation
strip_think (bool): Whether to strip thinking tags from response
return_parsed (bool): Whether to return parsed JSON instead of raw string
Returns:
Union[str, Dict[str, Any]]: Validated response from the model
"""
return self.generate(
prompt=prompt,
strip_think=strip_think,
response_type=response_type,
return_parsed=return_parsed
)
def generate_with_schema(self,
prompt: str,
schema: Dict[str, Any],
strip_think: bool = True,
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
"""Generate response with custom schema validation.
Args:
prompt (str): The prompt to send to the model
schema (Dict): JSON schema for validation
strip_think (bool): Whether to strip thinking tags from response
return_parsed (bool): Whether to return parsed JSON instead of raw string
Returns:
Union[str, Dict[str, Any]]: Validated response from the model
"""
return self.generate(
prompt=prompt,
strip_think=strip_think,
validation_schema=schema,
return_parsed=return_parsed
)
def _make_api_call(self, prompt: str, strip_think: bool) -> str:
"""Make the actual API call to Ollama.
Args:
prompt (str): The prompt to send
strip_think (bool): Whether to strip thinking tags
Returns:
str: Raw response from the API
"""
url = f"{self.base_url}/api/generate"
payload = {
"model": self.model_name,
"prompt": prompt,
"stream": False
}
logger.debug(f"Sending request to Ollama API: {url}")
response = requests.post(url, json=payload, headers=self.headers)
response.raise_for_status()
result = response.json()
logger.debug(f"Received response from Ollama API: {result}")
if strip_think:
# Remove the "thinking" part from the response
# the response is expected to be <think>...</think>response_text
# Check if the response contains <think> tag
if "<think>" in result.get("response", ""):
# Split the response and take the part after </think>
response_parts = result["response"].split("</think>")
if len(response_parts) > 1:
# Return the part after </think>
return response_parts[1].strip()
else:
# If no closing tag, return the full response
return result.get("response", "").strip()
else:
# If no <think> tag, return the full response
return result.get("response", "").strip()
else:
# If strip_think is False, return the full response
return result.get("response", "")
def _validate_with_schema(self, response: Dict[str, Any], schema: Dict[str, Any]) -> bool:
"""Validate response against a JSON schema.
Args:
response (Dict): The parsed response to validate
schema (Dict): The JSON schema to validate against
Returns:
bool: True if valid, False otherwise
"""
try:
from jsonschema import validate, ValidationError
validate(instance=response, schema=schema)
logger.debug(f"Schema validation passed for response: {response}")
return True
except ValidationError as e:
logger.warning(f"Schema validation failed: {e}")
logger.warning(f"Response that failed validation: {response}")
return False
except ImportError:
logger.error("jsonschema library not available for validation")
return False
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the current model.
Returns:
Dict[str, Any]: Model information
Raises:
RequestException: If the API call fails
"""
try:
url = f"{self.base_url}/api/show"
payload = {"name": self.model_name}
response = requests.post(url, json=payload, headers=self.headers)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
logger.error(f"Error getting model info: {str(e)}")
raise