增加网页查询功能
This commit is contained in:
parent
2c7d706777
commit
1db093b2a8
|
|
@ -1,10 +1,12 @@
|
|||
import logging
|
||||
from flask import Blueprint, request, Response
|
||||
from app.services.zhipu_service import ZhipuService
|
||||
from app.services.zhipu_alltool_service import ZhipuAlltoolService
|
||||
from app.utils.prompt_repository import PromptRepository # Add this import
|
||||
|
||||
zhipu_controller = Blueprint('zhipu_controller', __name__)
|
||||
zhipu_service = ZhipuService()
|
||||
zhipu_alltool_service = ZhipuAlltoolService()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@zhipu_controller.route('/zhipu/stream', methods=['POST'])
|
||||
|
|
@ -91,3 +93,15 @@ def analysis_stream():
|
|||
yield chunk
|
||||
|
||||
return Response(event_stream(), content_type='text/event-stream')
|
||||
|
||||
|
||||
@zhipu_controller.route('/zhipu/alltool/websearch/stream', methods=['POST'])
|
||||
def alltool_stream():
|
||||
data = request.json
|
||||
message = data.get('message', '')
|
||||
|
||||
def event_stream():
|
||||
for chunk in zhipu_alltool_service.web_search_sse(message):
|
||||
if chunk:
|
||||
yield chunk
|
||||
return Response(event_stream(), content_type='text/event-stream')
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
import logging
|
||||
import time
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ZhipuAlltoolService:
|
||||
def __init__(self):
|
||||
self.model_name = "glm-4-alltools"
|
||||
self.app_secret_key = "d54f764a1d67c17d857bd3983b772016.GRjowY0fyiMNurLc"
|
||||
logger.info("ZhipuAlltoolService initialized with model: %s", self.model_name)
|
||||
|
||||
def func_call(self, message):
|
||||
logger.info("Starting web_search call")
|
||||
start_time = time.time()
|
||||
client = ZhipuAI(api_key=self.app_secret_key)
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="glm-4-alltools", # 填写需要调用的模型名称
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content":[
|
||||
{
|
||||
"type":"text",
|
||||
"text":"帮我查询2018年至2024年,每年五一假期全国旅游出行数据,并绘制成柱状图展示数据趋势。"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_tourist_data_by_year",
|
||||
"description": "用于查询每一年的全国出行数据,输入年份范围(from_year,to_year),返回对应的出行数据,包括总出行人次、分交通方式的人次等。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"description": "交通方式,默认为by_all,火车=by_train,飞机=by_plane,自驾=by_car",
|
||||
"type": "string"
|
||||
},
|
||||
"from_year": {
|
||||
"description": "开始年份,格式为yyyy",
|
||||
"type": "string"
|
||||
},
|
||||
"to_year": {
|
||||
"description": "结束年份,格式为yyyy",
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["from_year","to_year"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "code_interpreter"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error("Error in web_search call: %s", str(e))
|
||||
raise e
|
||||
|
||||
def web_search_sse(self, message):
|
||||
logger.info("Starting web_search_sse call")
|
||||
start_time = time.time()
|
||||
client = ZhipuAI(api_key=self.app_secret_key)
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="glm-4-alltools", # 填写需要调用的模型名称
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content":[
|
||||
{
|
||||
"type":"text",
|
||||
"text":message
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
tools=[
|
||||
{
|
||||
"type": "web_browser"
|
||||
}
|
||||
]
|
||||
)
|
||||
for chunk in response:
|
||||
# print(chunk)
|
||||
print("content: ",chunk.choices[0].delta.content)
|
||||
print("tool_calls: ",chunk.choices[0].delta.tool_calls)
|
||||
logger.info("content: %s", str(chunk))
|
||||
yield chunk.choices[0].delta.content
|
||||
except Exception as e:
|
||||
logger.error("Error in web_search_sse call: %s", str(e))
|
||||
raise e
|
||||
Loading…
Reference in New Issue