100 lines
3.6 KiB
Python
100 lines
3.6 KiB
Python
import logging
|
|
from typing import Optional
|
|
from openai import OpenAI
|
|
from config import Config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class GrokClient:
|
|
"""Client for interacting with Grok-3 model via OpenRouter API."""
|
|
|
|
def __init__(self, config: Config):
|
|
self.config = config
|
|
self.client = None
|
|
self._initialize_client()
|
|
|
|
def _initialize_client(self):
|
|
"""Initialize the OpenAI client for OpenRouter."""
|
|
if not self.config.openrouter_api_key:
|
|
logger.warning("OpenRouter API key not configured")
|
|
return
|
|
|
|
self.client = OpenAI(
|
|
base_url=self.config.openrouter_base_url,
|
|
api_key=self.config.openrouter_api_key,
|
|
)
|
|
|
|
async def generate_response(self, prompt: str, strip_think_tags: bool = True) -> Optional[str]:
|
|
"""Generate a response from Grok-3 for the given prompt.
|
|
|
|
Args:
|
|
prompt: The input prompt for the model
|
|
strip_think_tags: If True, removes <think></think> tags from the response
|
|
"""
|
|
if not self.client:
|
|
logger.error("Grok client not initialized. Please check your OpenRouter API key configuration.")
|
|
return None
|
|
|
|
try:
|
|
# Prepare headers for OpenRouter
|
|
extra_headers = {}
|
|
if self.config.openrouter_site_url:
|
|
extra_headers["HTTP-Referer"] = self.config.openrouter_site_url
|
|
if self.config.openrouter_site_name:
|
|
extra_headers["X-Title"] = self.config.openrouter_site_name
|
|
|
|
completion = self.client.chat.completions.create(
|
|
extra_headers=extra_headers,
|
|
extra_body={},
|
|
model=self.config.openrouter_model,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": prompt
|
|
}
|
|
]
|
|
)
|
|
|
|
response_text = completion.choices[0].message.content.strip()
|
|
|
|
if strip_think_tags:
|
|
# Remove <think></think> tags and their content
|
|
import re
|
|
original_length = len(response_text)
|
|
response_text = re.sub(r'<think>.*?</think>', '', response_text, flags=re.DOTALL)
|
|
response_text = response_text.strip()
|
|
final_length = len(response_text)
|
|
|
|
if original_length != final_length:
|
|
logger.info(f"Stripped <think></think> tags from response (reduced length by {original_length - final_length} characters)")
|
|
|
|
return response_text
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calling Grok-3 via OpenRouter API: {e}")
|
|
return None
|
|
|
|
async def check_health(self) -> bool:
|
|
"""Check if OpenRouter service is available."""
|
|
if not self.client:
|
|
return False
|
|
|
|
try:
|
|
# Try a simple completion to test connectivity
|
|
completion = self.client.chat.completions.create(
|
|
model=self.config.openrouter_model,
|
|
messages=[{"role": "user", "content": "test"}],
|
|
max_tokens=1
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Grok health check failed: {e}")
|
|
return False
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
# OpenAI client doesn't need explicit cleanup
|
|
pass
|