增加分类器和是非器的voting机制
This commit is contained in:
parent
467ca12898
commit
50ff1de1c6
|
|
@ -33,9 +33,10 @@ def retrive_stream():
|
|||
logger.info(f'/zhipu/retrive/stream v2: {message}')
|
||||
|
||||
classify_rule = "只有当输入明确表示需要搜索互联网时,才返回web_search"
|
||||
classification_result_str = zhipu_alltool_service.func_call_classify(message, additional_desc=classify_rule)
|
||||
classification_result_str = zhipu_alltool_service.func_call_classify_with_voting(message, additional_desc=classify_rule)
|
||||
print(f'classification_result: {classification_result_str}')
|
||||
classification_result = json.loads(classification_result_str)
|
||||
# classification_result = json.loads(classification_result_str)
|
||||
classification_result = classification_result_str
|
||||
|
||||
if classification_result.get('category') == 'web_search':
|
||||
logger.info(f'question classify: web_search')
|
||||
|
|
@ -52,15 +53,15 @@ def retrive_stream():
|
|||
logger.info(f'question classify: retrive_knowledge')
|
||||
message = message.replace("我", "我(徐春峰)")
|
||||
|
||||
ask_for_project = zhipu_alltool_service.func_call_yes_or_no(message, "输入问题是否在查询负责的项目信息")
|
||||
ask_for_project = json.loads(ask_for_project)
|
||||
ask_for_project = zhipu_alltool_service.func_call_yes_or_no_with_voting(message, "输入问题是否在查询负责的项目信息")
|
||||
# ask_for_project = json.loads(ask_for_project)
|
||||
if ask_for_project.get('answer') == 'yes':
|
||||
message += " 用markdown表格形式输出,输出字段:客户名称, 商机名称, Sales stage, 预估 ACV, 预计签单时间"
|
||||
message += " ,用markdown表格形式输出,输出字段:客户名称, 商机名称, Sales stage, 预估 ACV, 预计签单时间"
|
||||
else:
|
||||
ask_for_project = zhipu_alltool_service.func_call_yes_or_no(message, "输入问题是否在询问项目的销售阶段或销售信息")
|
||||
ask_for_project = json.loads(ask_for_project)
|
||||
ask_for_project = zhipu_alltool_service.func_call_yes_or_no_with_voting(message, "输入问题是否在询问项目的销售阶段或销售信息")
|
||||
# ask_for_project = json.loads(ask_for_project)
|
||||
if ask_for_project.get('answer') == 'yes':
|
||||
message += " 用markdown表格形式输出,输出字段:客户名称, 商机名称, Sales stage, 预测类型,Timing 风险, 预估 ACV, 预计签单时间"
|
||||
message += " ,用markdown表格形式输出,输出字段:客户名称, 商机名称, Sales stage, 预测类型,Timing 风险, 预估 ACV, 预计签单时间"
|
||||
|
||||
|
||||
def event_stream_retrive():
|
||||
|
|
@ -148,9 +149,10 @@ def analysis_stream():
|
|||
|
||||
intent_categories =["analyze_sales","provide_sales_update_info"]
|
||||
|
||||
classification_result_str = zhipu_alltool_service.func_call_classify(message, intent_categories)
|
||||
classification_result_str = zhipu_alltool_service.func_call_classify_with_voting(message, intent_categories)
|
||||
print(f'classification_result: {classification_result_str}')
|
||||
classification_result = json.loads(classification_result_str)
|
||||
classification_result = classification_result_str
|
||||
# classification_result = json.loads(classification_result_str)
|
||||
|
||||
|
||||
additional_business_info = ""
|
||||
|
|
@ -160,9 +162,9 @@ def analysis_stream():
|
|||
pass
|
||||
elif classification_result.get('category') == 'provide_sales_update_info':
|
||||
logger.info(f'question classify: provide_sales_update_info')
|
||||
contain_project_info = zhipu_alltool_service.func_call_yes_or_no(message, "是否包含项目信息")
|
||||
contain_project_info = zhipu_alltool_service.func_call_yes_or_no_with_voting(message, "是否包含项目信息")
|
||||
logger.info(f'contain_project_info: {contain_project_info}')
|
||||
contain_project_info = json.loads(contain_project_info)
|
||||
# contain_project_info = json.loads(contain_project_info)
|
||||
if contain_project_info.get('answer') == 'yes':
|
||||
additional_business_info = message
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
import time
|
||||
import json
|
||||
from zhipuai import ZhipuAI
|
||||
from collections import Counter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -10,9 +12,36 @@ class ZhipuAlltoolService:
|
|||
self.app_secret_key = "d54f764a1d67c17d857bd3983b772016.GRjowY0fyiMNurLc"
|
||||
logger.info("ZhipuAlltoolService initialized with model: %s", self.model_name)
|
||||
|
||||
def func_call_classify(self, message, categories:list=None, additional_desc = None):
|
||||
def func_call_classify_with_voting(self, message, categories:list=None, additional_desc=None):
|
||||
logger.info("Starting func_call_classify_with_voting call")
|
||||
models = ["glm-4-plus", "glm-4-air", "glm-4-plus"]
|
||||
votes = []
|
||||
|
||||
for model in models:
|
||||
try:
|
||||
result = self.func_call_classify(message, categories, additional_desc, model_name=model)
|
||||
category = json.loads(result).get('category')
|
||||
# category = result.get('category')
|
||||
if category:
|
||||
votes.append(category)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in func_call_classify with model {model}: {str(e)}")
|
||||
|
||||
if not votes:
|
||||
raise Exception("All classification attempts failed")
|
||||
|
||||
most_common = Counter(votes).most_common(1)
|
||||
final_category = most_common[0][0]
|
||||
|
||||
logger.info(f"Voting results: {votes}, Final category: {final_category}")
|
||||
return {"category": final_category}
|
||||
|
||||
def func_call_classify(self, message, categories:list=None, additional_desc = None, model_name=None):
|
||||
logger.info("Starting func_call_classify call")
|
||||
start_time = time.time()
|
||||
default_model_name = "glm-4-flash"
|
||||
if model_name is None:
|
||||
model_name = default_model_name
|
||||
default_categories = ["web_search", "retrive_knowledge", "generate_report", "update_report", "clear_report"]
|
||||
categories = categories if categories else default_categories
|
||||
client = ZhipuAI(api_key=self.app_secret_key)
|
||||
|
|
@ -44,7 +73,7 @@ class ZhipuAlltoolService:
|
|||
]
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="glm-4-flash", # 填写需要调用的模型名称
|
||||
model=model_name, # 填写需要调用的模型名称
|
||||
messages= messages,
|
||||
tools= tools,
|
||||
tool_choice="auto"
|
||||
|
|
@ -55,9 +84,34 @@ class ZhipuAlltoolService:
|
|||
logger.error("Error in web_search call: %s", str(e))
|
||||
raise e
|
||||
|
||||
def func_call_yes_or_no(self, message, question):
|
||||
logger.info("Starting func_call_yes_or_no call")
|
||||
def func_call_yes_or_no_with_voting(self, message, question):
|
||||
logger.info("Starting func_call_yes_or_no_with_voting call")
|
||||
models = ["glm-4-plus", "glm-4-air", "glm-4-plus"]
|
||||
votes = []
|
||||
|
||||
for model in models:
|
||||
try:
|
||||
result = self.func_call_yes_or_no(message, question, model_name=model)
|
||||
answer =json.loads(result).get('answer')
|
||||
if answer:
|
||||
votes.append(answer.lower())
|
||||
except Exception as e:
|
||||
logger.error(f"Error in func_call_yes_or_no with model {model}: {str(e)}")
|
||||
|
||||
if not votes:
|
||||
raise Exception("All yes/no classification attempts failed")
|
||||
|
||||
most_common = Counter(votes).most_common(1)
|
||||
final_answer = most_common[0][0]
|
||||
|
||||
logger.info(f"Voting results: {votes}, Final answer: {final_answer}")
|
||||
return {"answer": final_answer}
|
||||
|
||||
def func_call_yes_or_no(self, message, question, model_name=None):
|
||||
logger.info("Starting func_call_yes_or_no call")
|
||||
default_model_name = "glm-4-flash"
|
||||
if model_name is None:
|
||||
model_name = default_model_name
|
||||
client = ZhipuAI(api_key=self.app_secret_key)
|
||||
tools = [
|
||||
{
|
||||
|
|
@ -86,7 +140,7 @@ class ZhipuAlltoolService:
|
|||
]
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="glm-4-flash", # 填写需要调用的模型名称
|
||||
model=model_name, # 填写需要调用的模型名称
|
||||
messages= messages,
|
||||
tools= tools,
|
||||
tool_choice="auto"
|
||||
|
|
@ -131,4 +185,4 @@ class ZhipuAlltoolService:
|
|||
yield chunk.choices[0].delta.content
|
||||
except Exception as e:
|
||||
logger.error("Error in web_search_sse call: %s", str(e))
|
||||
raise e
|
||||
raise e
|
||||
|
|
|
|||
Loading…
Reference in New Issue