dom/ollama_client.py

76 lines
2.8 KiB
Python

import aiohttp
import json
import logging
from typing import Dict, Any, Optional
from config import Config
logger = logging.getLogger(__name__)
class OllamaClient:
"""Client for interacting with Ollama API."""
def __init__(self, config: Config):
self.config = config
self.session = None
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
async def generate_response(self, prompt: str, strip_think_tags: bool = True) -> Optional[str]:
"""Generate a response from Ollama for the given prompt.
Args:
prompt: The input prompt for the model
strip_think_tags: If True, removes <think></think> tags from the response
"""
if not self.session:
self.session = aiohttp.ClientSession()
url = f"{self.config.ollama_endpoint}/api/generate"
payload = {
"model": self.config.ollama_model,
"prompt": prompt,
"stream": False
}
try:
async with self.session.post(url, json=payload) as response:
if response.status == 200:
data = await response.json()
response_text = data.get("response", "").strip()
if strip_think_tags:
# Remove <think></think> tags and their content
import re
original_length = len(response_text)
response_text = re.sub(r'<think>.*?</think>', '', response_text, flags=re.DOTALL)
response_text = response_text.strip()
final_length = len(response_text)
if original_length != final_length:
logger.info(f"Stripped <think></think> tags from response (reduced length by {original_length - final_length} characters)")
return response_text
else:
logger.error(f"Ollama API error: {response.status}")
return None
except Exception as e:
logger.error(f"Error calling Ollama API: {e}")
return None
async def check_health(self) -> bool:
"""Check if Ollama service is available."""
if not self.session:
self.session = aiohttp.ClientSession()
try:
async with self.session.get(f"{self.config.ollama_endpoint}/api/tags") as response:
return response.status == 200
except Exception as e:
logger.error(f"Ollama health check failed: {e}")
return False