76 lines
2.8 KiB
Python
76 lines
2.8 KiB
Python
import aiohttp
|
|
import json
|
|
import logging
|
|
from typing import Dict, Any, Optional
|
|
from config import Config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class OllamaClient:
|
|
"""Client for interacting with Ollama API."""
|
|
|
|
def __init__(self, config: Config):
|
|
self.config = config
|
|
self.session = None
|
|
|
|
async def __aenter__(self):
|
|
self.session = aiohttp.ClientSession()
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
if self.session:
|
|
await self.session.close()
|
|
|
|
async def generate_response(self, prompt: str, strip_think_tags: bool = True) -> Optional[str]:
|
|
"""Generate a response from Ollama for the given prompt.
|
|
|
|
Args:
|
|
prompt: The input prompt for the model
|
|
strip_think_tags: If True, removes <think></think> tags from the response
|
|
"""
|
|
if not self.session:
|
|
self.session = aiohttp.ClientSession()
|
|
|
|
url = f"{self.config.ollama_endpoint}/api/generate"
|
|
payload = {
|
|
"model": self.config.ollama_model,
|
|
"prompt": prompt,
|
|
"stream": False
|
|
}
|
|
|
|
try:
|
|
async with self.session.post(url, json=payload) as response:
|
|
if response.status == 200:
|
|
data = await response.json()
|
|
response_text = data.get("response", "").strip()
|
|
|
|
if strip_think_tags:
|
|
# Remove <think></think> tags and their content
|
|
import re
|
|
original_length = len(response_text)
|
|
response_text = re.sub(r'<think>.*?</think>', '', response_text, flags=re.DOTALL)
|
|
response_text = response_text.strip()
|
|
final_length = len(response_text)
|
|
|
|
if original_length != final_length:
|
|
logger.info(f"Stripped <think></think> tags from response (reduced length by {original_length - final_length} characters)")
|
|
|
|
return response_text
|
|
else:
|
|
logger.error(f"Ollama API error: {response.status}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error calling Ollama API: {e}")
|
|
return None
|
|
|
|
async def check_health(self) -> bool:
|
|
"""Check if Ollama service is available."""
|
|
if not self.session:
|
|
self.session = aiohttp.ClientSession()
|
|
|
|
try:
|
|
async with self.session.get(f"{self.config.ollama_endpoint}/api/tags") as response:
|
|
return response.status == 200
|
|
except Exception as e:
|
|
logger.error(f"Ollama health check failed: {e}")
|
|
return False |