llm_hub/app/api/v1/zhipu_controller.py

150 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging,json
from flask import Blueprint, request, Response
from app.services.zhipu_service import ZhipuService
from app.services.zhipu_alltool_service import ZhipuAlltoolService
from app.services.zhipu_file_service import ZhipuFileService
from app.utils.prompt_repository import PromptRepository # Add this import
zhipu_controller = Blueprint('zhipu_controller', __name__)
zhipu_service = ZhipuService()
zhipu_alltool_service = ZhipuAlltoolService()
zhipu_file_service = ZhipuFileService()
logger = logging.getLogger(__name__)
@zhipu_controller.route('/zhipu/stream', methods=['POST'])
def stream_sse():
data = request.json
message = data.get('message', '')
def event_stream():
for chunk in zhipu_service.talk_to_zhipu_sse(message):
if chunk:
yield chunk
return Response(event_stream(), content_type='text/event-stream')
@zhipu_controller.route('/zhipu/non-stream', methods=['POST'])
def non_stream():
data = request.json
message = data.get('message', '')
response = zhipu_service.talk_to_zhipu(message)
print(f'response: {response}')
return response
@zhipu_controller.route('/zhipu/retrive/non-stream', methods=['POST'])
def retrive_non_stream():
data = request.json
message = data.get('message', '')
knowledge_id = data.get('knowledge_id', '')
prompt_template = data.get('prompt_template', '')
response = zhipu_service.retrive(message, knowledge_id, None)
print(f'response: {response}')
return response
@zhipu_controller.route('/zhipu/retrive/stream', methods=['POST'])
def retrive_stream():
data = request.json
message = data.get('message', '')
knowledge_id = data.get('knowledge_id', '')
prompt_template = data.get('prompt_template', '')
classification_result_str = zhipu_alltool_service.func_call_classify(message)
print(f'classification_result: {classification_result_str}')
classification_result = json.loads(classification_result_str)
if classification_result.get('category') == 'web_search':
def event_stream_websearch_sse():
for chunk in zhipu_alltool_service.web_search_sse(message):
if chunk:
yield chunk
return Response(event_stream_websearch_sse(), content_type='text/event-stream')
elif classification_result.get('category')== 'retrive_knowledge':
def event_stream_retrive():
for chunk in zhipu_service.retrive_sse(message, knowledge_id, None):
if chunk:
yield chunk
return Response(event_stream_retrive(), content_type='text/event-stream')
elif classification_result.get('category')== 'generate_report':
#do something
pass
@zhipu_controller.route('/zhipu/analysis/stream', methods=['POST'])
def analysis_stream():
data = request.json
message = data.get('message', '')
knowledge_id = data.get('knowledge_id', '')
# 获取business info
prompt_get_business_info = f"""
请根据用户提供的如下信息,查找相关的 '当前详细状态及Close节奏','Sales stage' 信息,并返回给用户:
{message}
"""
business_info = zhipu_service.retrive(prompt_get_business_info, knowledge_id, None)
print(f'business_info: {business_info}')
analysis_rule = PromptRepository().get_prompt('sales_analysis')
print(f'analysis_rule: {analysis_rule}')
# 根据当前详细状态及Close节奏以及Sales stage给出分析
prompt_analysis = f"""
请根据查询到的上述商机信息:
{business_info}
根据如下各销售阶段的销售阶段任务、销售关键动作、阶段转化标准:
{analysis_rule}
结合上述商机信息的对应阶段,分析并判断其销售动作是否完成了前一阶段的准出标准,以及是否支持将销售阶段转化到当前阶段
1. **销售阶段分析**
2. **销售动作日志分析**
3. **销售动作与销售阶段的关系**
4. **判断结果**
5. **销售阶段分析报告**
"""
def event_stream():
for chunk in zhipu_service.talk_to_zhipu_sse(prompt_analysis):
if chunk:
yield chunk
return Response(event_stream(), content_type='text/event-stream')
@zhipu_controller.route('/zhipu/alltool/websearch/stream', methods=['POST'])
def alltool_stream():
data = request.json
message = data.get('message', '')
def event_stream():
for chunk in zhipu_alltool_service.web_search_sse(message):
if chunk:
yield chunk
return Response(event_stream(), content_type='text/event-stream')
@zhipu_controller.route('/zhipu/alltool/classify/non-stream', methods=['POST'])
def alltool_classify_non_stream():
data = request.json
message = data.get('message', '')
response = zhipu_alltool_service.func_call_classify(message)
print(f'response: {response}')
return response
@zhipu_controller.route('/zhipu/file', methods=['POST'])
def submit_report():
username = "user"
data = request.json
report_text = data.get('report_text', '')
project_name = data.get('project_name', '')
prefix = data.get('prefix', '')
submit_result = zhipu_file_service.submit_file(prefix=prefix,project_name=project_name, file_content=report_text)
return submit_result
@zhipu_controller.route('/zhipu/file/list', methods=['POST'])
def get_file_list():
file_list = zhipu_file_service.get_file_list()
return file_list