import aiohttp import json import logging from typing import Dict, Any, Optional from config import Config logger = logging.getLogger(__name__) class OllamaClient: """Client for interacting with Ollama API.""" def __init__(self, config: Config): self.config = config self.session = None async def __aenter__(self): self.session = aiohttp.ClientSession() return self async def __aexit__(self, exc_type, exc_val, exc_tb): if self.session: await self.session.close() async def generate_response(self, prompt: str, strip_think_tags: bool = True) -> Optional[str]: """Generate a response from Ollama for the given prompt. Args: prompt: The input prompt for the model strip_think_tags: If True, removes tags from the response """ if not self.session: self.session = aiohttp.ClientSession() url = f"{self.config.ollama_endpoint}/api/generate" payload = { "model": self.config.ollama_model, "prompt": prompt, "stream": False } try: async with self.session.post(url, json=payload) as response: if response.status == 200: data = await response.json() response_text = data.get("response", "").strip() if strip_think_tags: # Remove tags and their content import re original_length = len(response_text) response_text = re.sub(r'.*?', '', response_text, flags=re.DOTALL) response_text = response_text.strip() final_length = len(response_text) if original_length != final_length: logger.info(f"Stripped tags from response (reduced length by {original_length - final_length} characters)") return response_text else: logger.error(f"Ollama API error: {response.status}") return None except Exception as e: logger.error(f"Error calling Ollama API: {e}") return None async def check_health(self) -> bool: """Check if Ollama service is available.""" if not self.session: self.session = aiohttp.ClientSession() try: async with self.session.get(f"{self.config.ollama_endpoint}/api/tags") as response: return response.status == 200 except Exception as e: logger.error(f"Ollama health check failed: {e}") return False