152 lines
6.4 KiB
Python
152 lines
6.4 KiB
Python
import logging
|
|
import time
|
|
from zhipuai import ZhipuAI
|
|
from app.services.ai_service_interface import AIServiceInterface
|
|
from app.utils.prompt_repository import PromptRepository
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ZhipuService(AIServiceInterface):
|
|
def __init__(self):
|
|
self.model_name = "glm-4"
|
|
self.app_secret_key = "d54f764a1d67c17d857bd3983b772016.GRjowY0fyiMNurLc"
|
|
self.client = ZhipuAI(api_key=self.app_secret_key)
|
|
logger.info("ZhipuService initialized with model: %s", self.model_name)
|
|
|
|
def generate_response(self, prompt):
|
|
logger.info("Starting generate_response call")
|
|
start_time = time.time()
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=[
|
|
{"role": "user", "content": prompt},
|
|
],
|
|
stream=False,
|
|
temperature=0.01,
|
|
top_p=0.1,
|
|
)
|
|
result = response.choices[0].message.content
|
|
end_time = time.time()
|
|
logger.info("generate_response call completed in %.2f seconds", end_time - start_time)
|
|
return result
|
|
except Exception as e:
|
|
logger.error("Error in generate_response: %s", str(e))
|
|
raise
|
|
|
|
def generate_response_sse(self, prompt):
|
|
logger.info("Starting generate_response_sse call")
|
|
start_time = time.time()
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=[
|
|
{"role": "user", "content": prompt},
|
|
],
|
|
stream=True,
|
|
temperature=0.01,
|
|
top_p=0.1,
|
|
)
|
|
for chunk in response:
|
|
yield chunk.choices[0].delta.content
|
|
end_time = time.time()
|
|
logger.info("generate_response_sse call completed in %.2f seconds", end_time - start_time)
|
|
except Exception as e:
|
|
logger.error("Error in generate_response_sse: %s", str(e))
|
|
raise
|
|
|
|
def retrive(self, message, knowledge_id, prompt_template):
|
|
logger.info("Starting retrive call with knowledge_id: %s", knowledge_id)
|
|
start_time = time.time()
|
|
default_prompt = "从文档\n\"\"\"\n{{knowledge}}\n\"\"\"\n中找问题\n\"\"\"\n{{question}}\n\"\"\"\n的答案,找到答案就仅使用文档语句回答问题,找不到答案就用自身知识回答并且告诉用户该信息不是来自文档。\n不要复述问题,直接开始回答。"
|
|
|
|
if prompt_template is None or prompt_template == "":
|
|
prompt_template = default_prompt
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model="glm-4",
|
|
messages=[
|
|
{"role": "user", "content": message},
|
|
],
|
|
tools=[
|
|
{
|
|
"type": "retrieval",
|
|
"retrieval": {
|
|
"knowledge_id": knowledge_id,
|
|
"prompt_template": prompt_template
|
|
}
|
|
}
|
|
],
|
|
stream=False,
|
|
temperature=0.01,
|
|
top_p=0.1,
|
|
)
|
|
result = response.choices[0].message.content
|
|
end_time = time.time()
|
|
logger.info("retrive call completed in %.2f seconds", end_time - start_time)
|
|
return result
|
|
except Exception as e:
|
|
logger.error("Error in retrive: %s", str(e))
|
|
raise
|
|
|
|
def retrive_sse(self, message, knowledge_id, prompt_template=None,system_prompt=None):
|
|
logger.info("Starting retrive_sse call with knowledge_id: %s, message:%s", knowledge_id, message)
|
|
start_time = time.time()
|
|
default_prompt = "从文档\n\"\"\"\n{{knowledge}}\n\"\"\"\n中找问题\n\"\"\"\n{{question}}\n\"\"\"\n的答案,找到答案就仅使用文档语句回答问题,找不到答案就告诉用户知识库中没有该信息。\n不要复述问题,直接开始回答。"
|
|
messages = [{"role": "user", "content": message}]
|
|
# if system_prompt != None:
|
|
# messages.append({"role": "system", "content": system_prompt})
|
|
if prompt_template is None or prompt_template == "":
|
|
prompt_template = default_prompt
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model="glm-4",
|
|
messages=messages,
|
|
tools=[
|
|
{
|
|
"type": "retrieval",
|
|
"retrieval": {
|
|
"knowledge_id": knowledge_id,
|
|
"prompt_template": prompt_template
|
|
}
|
|
}
|
|
],
|
|
stream=True,
|
|
temperature=0.01,
|
|
top_p=0.1,
|
|
)
|
|
for chunk in response:
|
|
yield chunk.choices[0].delta.content
|
|
end_time = time.time()
|
|
logger.info("retrive_sse call completed in %.2f seconds", end_time - start_time)
|
|
except Exception as e:
|
|
logger.error("Error in retrive_sse: %s", str(e))
|
|
raise
|
|
|
|
def check_report_missing_info(self, message):
|
|
logger.info("Starting check_report_missing_info call")
|
|
|
|
#1. 日志模版
|
|
prompt_report_template = PromptRepository().get_prompt("report_template")
|
|
|
|
prompt_report_missing_check = f"""{prompt_report_template}
|
|
请检查以下日志信息是否完整,如果信息缺失则提示要求用户需要补充的信息要点,如果信息完整请直接返回"上述日志信息为全部信息"。日志信息如下:\n\"\"\"\n{message}\n\"\"\"\n"""
|
|
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model="glm-4-flash",
|
|
messages=[
|
|
{"role": "user", "content": prompt_report_missing_check},
|
|
],
|
|
stream=True,
|
|
temperature=0.01,
|
|
top_p=0.1,
|
|
)
|
|
for chunk in response:
|
|
yield chunk.choices[0].delta.content
|
|
end_time = time.time()
|
|
logger.info("check_report_missing_info call completed in %.2f seconds", end_time - start_time)
|
|
except Exception as e:
|
|
logger.error("Error in check_report_missing_info: %s", str(e))
|
|
raise
|