import requests import logging from typing import Dict, Any, Optional, Callable, Union from ..utils.json_extractor import LLMJsonExtractor from ..utils.llm_validator import LLMResponseValidator logger = logging.getLogger(__name__) class OllamaClient: def __init__(self, model_name: str, base_url: str = "http://localhost:11434", max_retries: int = 3): """Initialize Ollama client. Args: model_name (str): Name of the Ollama model to use base_url (str): Ollama server base URL max_retries (int): Maximum number of retries for failed requests """ self.model_name = model_name self.base_url = base_url self.max_retries = max_retries self.headers = {"Content-Type": "application/json"} def generate(self, prompt: str, strip_think: bool = True, validation_schema: Optional[Dict[str, Any]] = None, response_type: Optional[str] = None, return_parsed: bool = False) -> Union[str, Dict[str, Any]]: """Process a document using the Ollama API with optional validation and retry. Args: prompt (str): The prompt to send to the model strip_think (bool): Whether to strip thinking tags from response validation_schema (Optional[Dict]): JSON schema for validation response_type (Optional[str]): Type of response for validation ('entity_extraction', 'entity_linkage', etc.) return_parsed (bool): Whether to return parsed JSON instead of raw string Returns: Union[str, Dict[str, Any]]: Response from the model (raw string or parsed JSON) Raises: RequestException: If the API call fails after all retries ValueError: If validation fails after all retries """ for attempt in range(self.max_retries): try: # Make the API call raw_response = self._make_api_call(prompt, strip_think) # If no validation required, return the response if not validation_schema and not response_type and not return_parsed: return raw_response # Parse JSON if needed if return_parsed or validation_schema or response_type: parsed_response = LLMJsonExtractor.parse_raw_json_str(raw_response) if not parsed_response: logger.warning(f"Failed to parse JSON on attempt {attempt + 1}/{self.max_retries}") if attempt < self.max_retries - 1: continue else: raise ValueError("Failed to parse JSON response after all retries") # Validate if schema or response type provided if validation_schema: if not self._validate_with_schema(parsed_response, validation_schema): logger.warning(f"Schema validation failed on attempt {attempt + 1}/{self.max_retries}") if attempt < self.max_retries - 1: continue else: raise ValueError("Schema validation failed after all retries") if response_type: if not LLMResponseValidator.validate_response_by_type(parsed_response, response_type): logger.warning(f"Response type validation failed on attempt {attempt + 1}/{self.max_retries}") if attempt < self.max_retries - 1: continue else: raise ValueError(f"Response type validation failed after all retries") # Return parsed response if requested if return_parsed: return parsed_response else: return raw_response return raw_response except requests.exceptions.RequestException as e: logger.error(f"API call failed on attempt {attempt + 1}/{self.max_retries}: {str(e)}") if attempt < self.max_retries - 1: logger.info("Retrying...") else: logger.error("Max retries reached, raising exception") raise except Exception as e: logger.error(f"Unexpected error on attempt {attempt + 1}/{self.max_retries}: {str(e)}") if attempt < self.max_retries - 1: logger.info("Retrying...") else: logger.error("Max retries reached, raising exception") raise # This should never be reached, but just in case raise Exception("Unexpected error: max retries exceeded without proper exception handling") def generate_with_validation(self, prompt: str, response_type: str, strip_think: bool = True, return_parsed: bool = True) -> Union[str, Dict[str, Any]]: """Generate response with automatic validation based on response type. Args: prompt (str): The prompt to send to the model response_type (str): Type of response for validation strip_think (bool): Whether to strip thinking tags from response return_parsed (bool): Whether to return parsed JSON instead of raw string Returns: Union[str, Dict[str, Any]]: Validated response from the model """ return self.generate( prompt=prompt, strip_think=strip_think, response_type=response_type, return_parsed=return_parsed ) def generate_with_schema(self, prompt: str, schema: Dict[str, Any], strip_think: bool = True, return_parsed: bool = True) -> Union[str, Dict[str, Any]]: """Generate response with custom schema validation. Args: prompt (str): The prompt to send to the model schema (Dict): JSON schema for validation strip_think (bool): Whether to strip thinking tags from response return_parsed (bool): Whether to return parsed JSON instead of raw string Returns: Union[str, Dict[str, Any]]: Validated response from the model """ return self.generate( prompt=prompt, strip_think=strip_think, validation_schema=schema, return_parsed=return_parsed ) def _make_api_call(self, prompt: str, strip_think: bool) -> str: """Make the actual API call to Ollama. Args: prompt (str): The prompt to send strip_think (bool): Whether to strip thinking tags Returns: str: Raw response from the API """ url = f"{self.base_url}/api/generate" payload = { "model": self.model_name, "prompt": prompt, "stream": False } logger.debug(f"Sending request to Ollama API: {url}") response = requests.post(url, json=payload, headers=self.headers) response.raise_for_status() result = response.json() logger.debug(f"Received response from Ollama API: {result}") if strip_think: # Remove the "thinking" part from the response # the response is expected to be ...response_text # Check if the response contains tag if "" in result.get("response", ""): # Split the response and take the part after response_parts = result["response"].split("") if len(response_parts) > 1: # Return the part after return response_parts[1].strip() else: # If no closing tag, return the full response return result.get("response", "").strip() else: # If no tag, return the full response return result.get("response", "").strip() else: # If strip_think is False, return the full response return result.get("response", "") def _validate_with_schema(self, response: Dict[str, Any], schema: Dict[str, Any]) -> bool: """Validate response against a JSON schema. Args: response (Dict): The parsed response to validate schema (Dict): The JSON schema to validate against Returns: bool: True if valid, False otherwise """ try: from jsonschema import validate, ValidationError validate(instance=response, schema=schema) logger.debug(f"Schema validation passed for response: {response}") return True except ValidationError as e: logger.warning(f"Schema validation failed: {e}") logger.warning(f"Response that failed validation: {response}") return False except ImportError: logger.error("jsonschema library not available for validation") return False def get_model_info(self) -> Dict[str, Any]: """Get information about the current model. Returns: Dict[str, Any]: Model information Raises: RequestException: If the API call fails """ try: url = f"{self.base_url}/api/show" payload = {"name": self.model_name} response = requests.post(url, json=payload, headers=self.headers) response.raise_for_status() return response.json() except requests.exceptions.RequestException as e: logger.error(f"Error getting model info: {str(e)}") raise