189 lines
7.2 KiB
Python
189 lines
7.2 KiB
Python
import logging
|
||
import time
|
||
import json
|
||
from zhipuai import ZhipuAI
|
||
from collections import Counter
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class ZhipuAlltoolService:
|
||
def __init__(self):
|
||
self.model_name = "glm-4-alltools"
|
||
self.app_secret_key = "d54f764a1d67c17d857bd3983b772016.GRjowY0fyiMNurLc"
|
||
logger.info("ZhipuAlltoolService initialized with model: %s", self.model_name)
|
||
|
||
def func_call_classify_with_voting(self, message, categories:list=None, additional_desc=None):
|
||
logger.info("Starting func_call_classify_with_voting call")
|
||
models = ["glm-4-plus", "glm-4-air", "glm-4-plus"]
|
||
votes = []
|
||
|
||
for model in models:
|
||
try:
|
||
result = self.func_call_classify(message, categories, additional_desc, model_name=model)
|
||
category = json.loads(result).get('category')
|
||
# category = result.get('category')
|
||
if category:
|
||
votes.append(category)
|
||
except Exception as e:
|
||
logger.error(f"Error in func_call_classify with model {model}: {str(e)}")
|
||
|
||
if not votes:
|
||
raise Exception("All classification attempts failed")
|
||
|
||
most_common = Counter(votes).most_common(1)
|
||
final_category = most_common[0][0]
|
||
|
||
logger.info(f"Voting results: {votes}, Final category: {final_category}")
|
||
return {"category": final_category}
|
||
|
||
def func_call_classify(self, message, categories:list=None, additional_desc = None, model_name=None):
|
||
logger.info("Starting func_call_classify call")
|
||
start_time = time.time()
|
||
default_model_name = "glm-4-flash"
|
||
if model_name is None:
|
||
model_name = default_model_name
|
||
default_categories = ["web_search", "retrive_knowledge", "generate_report", "update_report", "clear_report"]
|
||
categories = categories if categories else default_categories
|
||
client = ZhipuAI(api_key=self.app_secret_key)
|
||
tools = [
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "classify_user_input",
|
||
"description": f"根据用户输入,判断用户意图,返回意图类型。{additional_desc}",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"category": {
|
||
"type": "string",
|
||
# "description": "用户的意图有以下选项:web_search,retrive_knowledge,generate_report,update_report,clear_report",
|
||
"description": "用户的意图有以下选项:" + ",".join(categories)
|
||
}
|
||
},
|
||
"required": ["category"],
|
||
},
|
||
}
|
||
}
|
||
]
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": f"判断以下用户输入的意图并将其分类返回意图类型:{message}"
|
||
}
|
||
]
|
||
try:
|
||
response = client.chat.completions.create(
|
||
model=model_name, # 填写需要调用的模型名称
|
||
messages= messages,
|
||
tools= tools,
|
||
tool_choice="auto"
|
||
)
|
||
print(response)
|
||
return response.choices[0].message.tool_calls[0].function.arguments
|
||
except Exception as e:
|
||
logger.error("Error in web_search call: %s", str(e))
|
||
raise e
|
||
|
||
def func_call_yes_or_no_with_voting(self, message, question):
|
||
logger.info("Starting func_call_yes_or_no_with_voting call")
|
||
models = ["glm-4-plus", "glm-4-air", "glm-4-plus"]
|
||
votes = []
|
||
|
||
for model in models:
|
||
try:
|
||
result = self.func_call_yes_or_no(message, question, model_name=model)
|
||
answer =json.loads(result).get('answer')
|
||
if answer:
|
||
votes.append(answer.lower())
|
||
except Exception as e:
|
||
logger.error(f"Error in func_call_yes_or_no with model {model}: {str(e)}")
|
||
|
||
if not votes:
|
||
raise Exception("All yes/no classification attempts failed")
|
||
|
||
most_common = Counter(votes).most_common(1)
|
||
final_answer = most_common[0][0]
|
||
|
||
logger.info(f"Voting results: {votes}, Final answer: {final_answer}")
|
||
return {"answer": final_answer}
|
||
|
||
def func_call_yes_or_no(self, message, question, model_name=None):
|
||
logger.info("Starting func_call_yes_or_no call")
|
||
default_model_name = "glm-4-flash"
|
||
if model_name is None:
|
||
model_name = default_model_name
|
||
client = ZhipuAI(api_key=self.app_secret_key)
|
||
tools = [
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "classify_user_input",
|
||
"description": "根据用户输入的信息和问题,回答yes或者no",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"answer": {
|
||
"type": "string",
|
||
"description": "判断结果有以下选项:yes,no"
|
||
}
|
||
},
|
||
"required": ["answer"],
|
||
},
|
||
}
|
||
}
|
||
]
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": f"{question}:{message}"
|
||
}
|
||
]
|
||
try:
|
||
response = client.chat.completions.create(
|
||
model=model_name, # 填写需要调用的模型名称
|
||
messages= messages,
|
||
tools= tools,
|
||
tool_choice="auto"
|
||
)
|
||
print(response)
|
||
return response.choices[0].message.tool_calls[0].function.arguments
|
||
except Exception as e:
|
||
logger.error("Error in web_search call: %s", str(e))
|
||
raise e
|
||
|
||
|
||
def web_search_sse(self, message):
|
||
logger.info("Starting web_search_sse call")
|
||
start_time = time.time()
|
||
client = ZhipuAI(api_key=self.app_secret_key)
|
||
try:
|
||
response = client.chat.completions.create(
|
||
model="glm-4-alltools", # 填写需要调用的模型名称
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content":[
|
||
{
|
||
"type":"text",
|
||
"text":message
|
||
}
|
||
]
|
||
}
|
||
],
|
||
stream=True,
|
||
tools=[
|
||
{
|
||
"type": "web_browser"
|
||
}
|
||
]
|
||
)
|
||
for chunk in response:
|
||
# print(chunk)
|
||
print("content: ",chunk.choices[0].delta.content)
|
||
print("tool_calls: ",chunk.choices[0].delta.tool_calls)
|
||
logger.info("content: %s", str(chunk))
|
||
yield chunk.choices[0].delta.content
|
||
except Exception as e:
|
||
logger.error("Error in web_search_sse call: %s", str(e))
|
||
raise e
|