llm_hub/app/services/zhipu_alltool_service.py

189 lines
7.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import time
import json
from zhipuai import ZhipuAI
from collections import Counter
logger = logging.getLogger(__name__)
class ZhipuAlltoolService:
def __init__(self):
self.model_name = "glm-4-alltools"
self.app_secret_key = "d54f764a1d67c17d857bd3983b772016.GRjowY0fyiMNurLc"
logger.info("ZhipuAlltoolService initialized with model: %s", self.model_name)
def func_call_classify_with_voting(self, message, categories:list=None, additional_desc=None):
logger.info("Starting func_call_classify_with_voting call")
models = ["glm-4-plus", "glm-4-air", "glm-4-plus"]
votes = []
for model in models:
try:
result = self.func_call_classify(message, categories, additional_desc, model_name=model)
category = json.loads(result).get('category')
# category = result.get('category')
if category:
votes.append(category)
except Exception as e:
logger.error(f"Error in func_call_classify with model {model}: {str(e)}")
if not votes:
raise Exception("All classification attempts failed")
most_common = Counter(votes).most_common(1)
final_category = most_common[0][0]
logger.info(f"Voting results: {votes}, Final category: {final_category}")
return {"category": final_category}
def func_call_classify(self, message, categories:list=None, additional_desc = None, model_name=None):
logger.info("Starting func_call_classify call")
start_time = time.time()
default_model_name = "glm-4-flash"
if model_name is None:
model_name = default_model_name
default_categories = ["web_search", "retrive_knowledge", "generate_report", "update_report", "clear_report"]
categories = categories if categories else default_categories
client = ZhipuAI(api_key=self.app_secret_key)
tools = [
{
"type": "function",
"function": {
"name": "classify_user_input",
"description": f"根据用户输入,判断用户意图,返回意图类型。{additional_desc}",
"parameters": {
"type": "object",
"properties": {
"category": {
"type": "string",
# "description": "用户的意图有以下选项:web_search,retrive_knowledge,generate_report,update_report,clear_report",
"description": "用户的意图有以下选项:" + ",".join(categories)
}
},
"required": ["category"],
},
}
}
]
messages = [
{
"role": "user",
"content": f"判断以下用户输入的意图并将其分类返回意图类型:{message}"
}
]
try:
response = client.chat.completions.create(
model=model_name, # 填写需要调用的模型名称
messages= messages,
tools= tools,
tool_choice="auto"
)
print(response)
return response.choices[0].message.tool_calls[0].function.arguments
except Exception as e:
logger.error("Error in web_search call: %s", str(e))
raise e
def func_call_yes_or_no_with_voting(self, message, question):
logger.info("Starting func_call_yes_or_no_with_voting call")
models = ["glm-4-plus", "glm-4-air", "glm-4-plus"]
votes = []
for model in models:
try:
result = self.func_call_yes_or_no(message, question, model_name=model)
answer =json.loads(result).get('answer')
if answer:
votes.append(answer.lower())
except Exception as e:
logger.error(f"Error in func_call_yes_or_no with model {model}: {str(e)}")
if not votes:
raise Exception("All yes/no classification attempts failed")
most_common = Counter(votes).most_common(1)
final_answer = most_common[0][0]
logger.info(f"Voting results: {votes}, Final answer: {final_answer}")
return {"answer": final_answer}
def func_call_yes_or_no(self, message, question, model_name=None):
logger.info("Starting func_call_yes_or_no call")
default_model_name = "glm-4-flash"
if model_name is None:
model_name = default_model_name
client = ZhipuAI(api_key=self.app_secret_key)
tools = [
{
"type": "function",
"function": {
"name": "classify_user_input",
"description": "根据用户输入的信息和问题回答yes或者no",
"parameters": {
"type": "object",
"properties": {
"answer": {
"type": "string",
"description": "判断结果有以下选项:yes,no"
}
},
"required": ["answer"],
},
}
}
]
messages = [
{
"role": "user",
"content": f"{question}{message}"
}
]
try:
response = client.chat.completions.create(
model=model_name, # 填写需要调用的模型名称
messages= messages,
tools= tools,
tool_choice="auto"
)
print(response)
return response.choices[0].message.tool_calls[0].function.arguments
except Exception as e:
logger.error("Error in web_search call: %s", str(e))
raise e
def web_search_sse(self, message):
logger.info("Starting web_search_sse call")
start_time = time.time()
client = ZhipuAI(api_key=self.app_secret_key)
try:
response = client.chat.completions.create(
model="glm-4-alltools", # 填写需要调用的模型名称
messages=[
{
"role": "user",
"content":[
{
"type":"text",
"text":message
}
]
}
],
stream=True,
tools=[
{
"type": "web_browser"
}
]
)
for chunk in response:
# print(chunk)
print("content: ",chunk.choices[0].delta.content)
print("tool_calls: ",chunk.choices[0].delta.tool_calls)
logger.info("content: %s", str(chunk))
yield chunk.choices[0].delta.content
except Exception as e:
logger.error("Error in web_search_sse call: %s", str(e))
raise e