241 lines
10 KiB
Python
241 lines
10 KiB
Python
import requests
|
|
import logging
|
|
from typing import Dict, Any, Optional, Callable, Union
|
|
from ..utils.json_extractor import LLMJsonExtractor
|
|
from ..utils.llm_validator import LLMResponseValidator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OllamaClient:
|
|
def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3):
|
|
"""Initialize Ollama client.
|
|
|
|
Args:
|
|
model_name (str): Name of the Ollama model to use
|
|
base_url (str): Ollama server base URL
|
|
max_retries (int): Maximum number of retries for failed requests
|
|
"""
|
|
self.model_name = model_name
|
|
self.base_url = base_url
|
|
self.max_retries = max_retries
|
|
self.headers = {"Content-Type": "application/json"}
|
|
|
|
def generate(self,
|
|
prompt: str,
|
|
strip_think: bool = True,
|
|
validation_schema: Optional[Dict[str, Any]] = None,
|
|
response_type: Optional[str] = None,
|
|
return_parsed: bool = False) -> Union[str, Dict[str, Any]]:
|
|
"""Process a document using the Ollama API with optional validation and retry.
|
|
|
|
Args:
|
|
prompt (str): The prompt to send to the model
|
|
strip_think (bool): Whether to strip thinking tags from response
|
|
validation_schema (Optional[Dict]): JSON schema for validation
|
|
response_type (Optional[str]): Type of response for validation ('entity_extraction', 'entity_linkage', etc.)
|
|
return_parsed (bool): Whether to return parsed JSON instead of raw string
|
|
|
|
Returns:
|
|
Union[str, Dict[str, Any]]: Response from the model (raw string or parsed JSON)
|
|
|
|
Raises:
|
|
RequestException: If the API call fails after all retries
|
|
ValueError: If validation fails after all retries
|
|
"""
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
# Make the API call
|
|
raw_response = self._make_api_call(prompt, strip_think)
|
|
|
|
# If no validation required, return the response
|
|
if not validation_schema and not response_type and not return_parsed:
|
|
return raw_response
|
|
|
|
# Parse JSON if needed
|
|
if return_parsed or validation_schema or response_type:
|
|
parsed_response = LLMJsonExtractor.parse_raw_json_str(raw_response)
|
|
if not parsed_response:
|
|
logger.warning(f"Failed to parse JSON on attempt {attempt + 1}/{self.max_retries}")
|
|
if attempt < self.max_retries - 1:
|
|
continue
|
|
else:
|
|
raise ValueError("Failed to parse JSON response after all retries")
|
|
|
|
# Validate if schema or response type provided
|
|
if validation_schema:
|
|
if not self._validate_with_schema(parsed_response, validation_schema):
|
|
logger.warning(f"Schema validation failed on attempt {attempt + 1}/{self.max_retries}")
|
|
if attempt < self.max_retries - 1:
|
|
continue
|
|
else:
|
|
raise ValueError("Schema validation failed after all retries")
|
|
|
|
if response_type:
|
|
if not LLMResponseValidator.validate_response_by_type(parsed_response, response_type):
|
|
logger.warning(f"Response type validation failed on attempt {attempt + 1}/{self.max_retries}")
|
|
if attempt < self.max_retries - 1:
|
|
continue
|
|
else:
|
|
raise ValueError(f"Response type validation failed after all retries")
|
|
|
|
# Return parsed response if requested
|
|
if return_parsed:
|
|
return parsed_response
|
|
else:
|
|
return raw_response
|
|
|
|
return raw_response
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error(f"API call failed on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
|
|
if attempt < self.max_retries - 1:
|
|
logger.info("Retrying...")
|
|
else:
|
|
logger.error("Max retries reached, raising exception")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error on attempt {attempt + 1}/{self.max_retries}: {str(e)}")
|
|
if attempt < self.max_retries - 1:
|
|
logger.info("Retrying...")
|
|
else:
|
|
logger.error("Max retries reached, raising exception")
|
|
raise
|
|
|
|
# This should never be reached, but just in case
|
|
raise Exception("Unexpected error: max retries exceeded without proper exception handling")
|
|
|
|
def generate_with_validation(self,
|
|
prompt: str,
|
|
response_type: str,
|
|
strip_think: bool = True,
|
|
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
|
"""Generate response with automatic validation based on response type.
|
|
|
|
Args:
|
|
prompt (str): The prompt to send to the model
|
|
response_type (str): Type of response for validation
|
|
strip_think (bool): Whether to strip thinking tags from response
|
|
return_parsed (bool): Whether to return parsed JSON instead of raw string
|
|
|
|
Returns:
|
|
Union[str, Dict[str, Any]]: Validated response from the model
|
|
"""
|
|
return self.generate(
|
|
prompt=prompt,
|
|
strip_think=strip_think,
|
|
response_type=response_type,
|
|
return_parsed=return_parsed
|
|
)
|
|
|
|
def generate_with_schema(self,
|
|
prompt: str,
|
|
schema: Dict[str, Any],
|
|
strip_think: bool = True,
|
|
return_parsed: bool = True) -> Union[str, Dict[str, Any]]:
|
|
"""Generate response with custom schema validation.
|
|
|
|
Args:
|
|
prompt (str): The prompt to send to the model
|
|
schema (Dict): JSON schema for validation
|
|
strip_think (bool): Whether to strip thinking tags from response
|
|
return_parsed (bool): Whether to return parsed JSON instead of raw string
|
|
|
|
Returns:
|
|
Union[str, Dict[str, Any]]: Validated response from the model
|
|
"""
|
|
return self.generate(
|
|
prompt=prompt,
|
|
strip_think=strip_think,
|
|
validation_schema=schema,
|
|
return_parsed=return_parsed
|
|
)
|
|
|
|
def _make_api_call(self, prompt: str, strip_think: bool) -> str:
|
|
"""Make the actual API call to Ollama.
|
|
|
|
Args:
|
|
prompt (str): The prompt to send
|
|
strip_think (bool): Whether to strip thinking tags
|
|
|
|
Returns:
|
|
str: Raw response from the API
|
|
"""
|
|
url = f"{self.base_url}/api/generate"
|
|
payload = {
|
|
"model": self.model_name,
|
|
"prompt": prompt,
|
|
"stream": False
|
|
}
|
|
|
|
logger.debug(f"Sending request to Ollama API: {url}")
|
|
response = requests.post(url, json=payload, headers=self.headers)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
logger.debug(f"Received response from Ollama API: {result}")
|
|
|
|
if strip_think:
|
|
# Remove the "thinking" part from the response
|
|
# the response is expected to be <think>...</think>response_text
|
|
# Check if the response contains <think> tag
|
|
if "<think>" in result.get("response", ""):
|
|
# Split the response and take the part after </think>
|
|
response_parts = result["response"].split("</think>")
|
|
if len(response_parts) > 1:
|
|
# Return the part after </think>
|
|
return response_parts[1].strip()
|
|
else:
|
|
# If no closing tag, return the full response
|
|
return result.get("response", "").strip()
|
|
else:
|
|
# If no <think> tag, return the full response
|
|
return result.get("response", "").strip()
|
|
else:
|
|
# If strip_think is False, return the full response
|
|
return result.get("response", "")
|
|
|
|
def _validate_with_schema(self, response: Dict[str, Any], schema: Dict[str, Any]) -> bool:
|
|
"""Validate response against a JSON schema.
|
|
|
|
Args:
|
|
response (Dict): The parsed response to validate
|
|
schema (Dict): The JSON schema to validate against
|
|
|
|
Returns:
|
|
bool: True if valid, False otherwise
|
|
"""
|
|
try:
|
|
from jsonschema import validate, ValidationError
|
|
validate(instance=response, schema=schema)
|
|
logger.debug(f"Schema validation passed for response: {response}")
|
|
return True
|
|
except ValidationError as e:
|
|
logger.warning(f"Schema validation failed: {e}")
|
|
logger.warning(f"Response that failed validation: {response}")
|
|
return False
|
|
except ImportError:
|
|
logger.error("jsonschema library not available for validation")
|
|
return False
|
|
|
|
def get_model_info(self) -> Dict[str, Any]:
|
|
"""Get information about the current model.
|
|
|
|
Returns:
|
|
Dict[str, Any]: Model information
|
|
|
|
Raises:
|
|
RequestException: If the API call fails
|
|
"""
|
|
try:
|
|
url = f"{self.base_url}/api/show"
|
|
payload = {"name": self.model_name}
|
|
|
|
response = requests.post(url, json=payload, headers=self.headers)
|
|
response.raise_for_status()
|
|
|
|
return response.json()
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error(f"Error getting model info: {str(e)}")
|
|
raise |