62 lines
2.3 KiB
Python
62 lines
2.3 KiB
Python
import logging
|
|
import time
|
|
from openai import OpenAI
|
|
from app.services.ai_service_interface import AIServiceInterface
|
|
# from app.utils.prompt_repository import PromptRepository
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class OpenaiService(AIServiceInterface):
|
|
def __init__(self):
|
|
self.client = OpenAI(base_url="https://ai.xorbit.link:8443/e5b2a5e5-b41d-4715-9d50-d4a3b0c1a85f/v1", api_key="sk-proj-e5b2a5e5b41d47159d50d4a3b0c1a85f")
|
|
|
|
def generate_response(self, prompt):
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model="gpt-4o-mini",
|
|
messages=[{"role": "system", "content": prompt}],
|
|
temperature=0.9
|
|
)
|
|
return response.choices[0].message.content
|
|
except Exception as e:
|
|
logger.error(f"Error generating response: {e}")
|
|
return "An error occurred while generating the response."
|
|
|
|
def generate_response_sse(self, prompt):
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model="gpt-4o-mini",
|
|
messages=[{"role": "system", "content": prompt}],
|
|
temperature=0.7,
|
|
stream=True
|
|
)
|
|
for chunk in response:
|
|
if chunk.choices[0].delta.content is not None:
|
|
yield chunk.choices[0].delta.content
|
|
yield "answer provided by openai"
|
|
except Exception as e:
|
|
logger.error(f"Error generating SSE response: {e}")
|
|
yield "An error occurred while generating the SSE response."
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Set up logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
# Create an instance of OpenaiService
|
|
openai_service = OpenaiService()
|
|
|
|
# Test the generate_response method
|
|
test_prompt = "What is the capital of France?"
|
|
response = openai_service.generate_response(test_prompt)
|
|
print(f"Response to '{test_prompt}': {response}")
|
|
|
|
# Test the generate_response_sse method
|
|
print("\nTesting generate_response_sse:")
|
|
sse_prompt = "Count from 1 to 5 slowly."
|
|
for chunk in openai_service.generate_response_sse(sse_prompt):
|
|
print(chunk, end='', flush=True)
|
|
time.sleep(0.1) # Add a small delay to simulate streaming
|
|
print("\nSSE response complete.")
|
|
|