新增提问分类器

web_search, retrieve_knowledge, generate_report
This commit is contained in:
Tiger Ren 2024-10-14 19:15:27 +08:00
parent 1db093b2a8
commit b4e0395871
2 changed files with 41 additions and 49 deletions

View File

@ -104,4 +104,12 @@ def alltool_stream():
for chunk in zhipu_alltool_service.web_search_sse(message):
if chunk:
yield chunk
return Response(event_stream(), content_type='text/event-stream')
return Response(event_stream(), content_type='text/event-stream')
@zhipu_controller.route('/zhipu/alltool/classify/non-stream', methods=['POST'])
def alltool_classify_non_stream():
data = request.json
message = data.get('message', '')
response = zhipu_alltool_service.func_call_classify(message)
print(f'response: {response}')
return response

View File

@ -10,60 +10,44 @@ class ZhipuAlltoolService:
self.app_secret_key = "d54f764a1d67c17d857bd3983b772016.GRjowY0fyiMNurLc"
logger.info("ZhipuAlltoolService initialized with model: %s", self.model_name)
def func_call(self, message):
def func_call_classify(self, message):
logger.info("Starting web_search call")
start_time = time.time()
client = ZhipuAI(api_key=self.app_secret_key)
tools = [
{
"type": "function",
"function": {
"name": "classify_user_input",
"description": "根据用户输入,判断用户意图,返回意图类型",
"parameters": {
"type": "object",
"properties": {
"category": {
"type": "string",
"description": "用户的意图有以下选项:web_search,retrive_knowledge,generate_report",
}
},
"required": ["category"],
},
}
}
]
messages = [
{
"role": "user",
"content": f"判断以下用户输入的意图并将其分类返回意图类型:{message}"
}
]
try:
response = client.chat.completions.create(
model="glm-4-alltools", # 填写需要调用的模型名称
messages=[
{
"role": "user",
"content":[
{
"type":"text",
"text":"帮我查询2018年至2024年每年五一假期全国旅游出行数据并绘制成柱状图展示数据趋势。"
}
]
}
],
stream=True,
tools=[
{
"type": "function",
"function": {
"name": "get_tourist_data_by_year",
"description": "用于查询每一年的全国出行数据,输入年份范围(from_year,to_year),返回对应的出行数据,包括总出行人次、分交通方式的人次等。",
"parameters": {
"type": "object",
"properties": {
"type": {
"description": "交通方式默认为by_all火车=by_train飞机=by_plane自驾=by_car",
"type": "string"
},
"from_year": {
"description": "开始年份格式为yyyy",
"type": "string"
},
"to_year": {
"description": "结束年份格式为yyyy",
"type": "string"
}
},
"required": ["from_year","to_year"]
}
}
},
{
"type": "code_interpreter"
}
]
model="glm-4-flash", # 填写需要调用的模型名称
messages= messages,
tools= tools,
tool_choice="auto"
)
for chunk in response:
print(chunk)
return response
print(response)
return response.choices[0].message.tool_calls[0].function.arguments
except Exception as e:
logger.error("Error in web_search call: %s", str(e))
raise e