From 21db91ee711af9a7d630bc386c92790c6f120786 Mon Sep 17 00:00:00 2001 From: Tiger Ren Date: Sat, 26 Oct 2024 22:13:01 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0openai=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E6=94=AF=E6=8C=81=EF=BC=8C=E9=94=80=E5=94=AE=E5=88=86=E6=9E=90?= =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E4=BD=BF=E7=94=A8openai=E6=88=96zhipu?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apitest/api_test_local.sh | 37 +++++++++++++ app/api/v1/zhipu_controller.py | 6 +- app/api/v2/zhipu_controller_v2.py | 42 ++++++++++++-- app/services/ai_service_interface.py | 11 ++++ app/services/openai_service.py | 61 ++++++++++++++++++++ app/services/zhipu_kb_service.py | 83 ++++++++++++++++++++++++++++ app/services/zhipu_service.py | 45 +++++++-------- app/utils/sessions.py | 2 + requirements.txt | 3 +- 9 files changed, 257 insertions(+), 33 deletions(-) create mode 100644 app/services/ai_service_interface.py create mode 100644 app/services/zhipu_kb_service.py diff --git a/apitest/api_test_local.sh b/apitest/api_test_local.sh index 63ff30f..377b84b 100644 --- a/apitest/api_test_local.sh +++ b/apitest/api_test_local.sh @@ -48,4 +48,41 @@ curl -N -X POST http://127.0.0.1:5002/api/v2/zhipu/retrive/stream \ -d '{ "message":"视睿电子教学课件系统续费项目更新日志:**项目进展描述**:了解到客户需求和降本要求后,与项目经理杨建线下沟通。客户同意在剩余2套系统上增加模块,但要求降价,具体数 量待内部讨论。", "knowledge_id":"1843318172036575232" +}' + + +curl -N -X POST http://127.0.0.1:5002/api/v2/zhipu/analysis/stream \ +-H "Content-Type: application/json" \ +-d '{ +"message":"分析商机广汽汇理汽车金融有限公司的商机建议", +"knowledge_id":"1843318172036575232" +}' + +curl -N -X POST http://127.0.0.1:5002/api/v2/zhipu/analysis/stream \ +-H "Content-Type: application/json" \ +-d '{ +"message":"2. 更新我负责的广汽汇理汽车金融有限公司项目的最新动作:1,今日和客户有做了一次技术交流,他们最近和Ocean Base和Gold DB也做了交流,以及内部也做了沟通,接下来他们希望能够拿出一个业务场景做测试,已确定哪个产品更适合他们。", +"knowledge_id":"1843318172036575232" +}' + + +curl -N -X POST http://127.0.0.1:5002/api/v2/zhipu/analysis/stream \ +-H "Content-Type: application/json" \ +-d '{ +"message":"更新我负责的广汽汇理汽车金融有限公司项目的最新动作:已经完成了POC,客户对POC效果表示满意", +"knowledge_id":"1843318172036575232" +}' + +curl -N -X POST http://127.0.0.1:5002/api/v2/zhipu/analysis/stream \ +-H "Content-Type: application/json" \ +-d '{ +"message":"openai", +"knowledge_id":"1843318172036575232" +}' + +curl -N -X POST http://127.0.0.1:5002/api/v2/zhipu/analysis/stream \ +-H "Content-Type: application/json" \ +-d '{ +"message":"zhipu", +"knowledge_id":"1843318172036575232" }' \ No newline at end of file diff --git a/app/api/v1/zhipu_controller.py b/app/api/v1/zhipu_controller.py index eb08037..f687573 100644 --- a/app/api/v1/zhipu_controller.py +++ b/app/api/v1/zhipu_controller.py @@ -19,7 +19,7 @@ def stream_sse(): message = data.get('message', '') def event_stream(): - for chunk in zhipu_service.talk_to_zhipu_sse(message): + for chunk in zhipu_service.generate_response_sse(message): if chunk: yield chunk @@ -30,7 +30,7 @@ def non_stream(): data = request.json message = data.get('message', '') - response = zhipu_service.talk_to_zhipu(message) + response = zhipu_service.generate_response(message) print(f'response: {response}') return response @@ -188,7 +188,7 @@ def analysis_stream(): """ def event_stream(): - for chunk in zhipu_service.talk_to_zhipu_sse(prompt_analysis): + for chunk in zhipu_service.generate_response_sse(prompt_analysis): if chunk: yield chunk diff --git a/app/api/v2/zhipu_controller_v2.py b/app/api/v2/zhipu_controller_v2.py index 4558478..fc3e7b4 100644 --- a/app/api/v2/zhipu_controller_v2.py +++ b/app/api/v2/zhipu_controller_v2.py @@ -4,12 +4,18 @@ from flask import Blueprint, request, Response,session from app.services.zhipu_service import ZhipuService from app.services.zhipu_alltool_service import ZhipuAlltoolService from app.services.zhipu_file_service import ZhipuFileService +from app.services.zhipu_kb_service import ZhipuKbService +from app.services.openai_service import OpenaiService from app.utils.prompt_repository import PromptRepository # Add this import +from app.utils.sessions import init_session zhipu_controller_v2 = Blueprint('zhipu_controller_v2', __name__) + zhipu_service = ZhipuService() +openai_service = OpenaiService() zhipu_alltool_service = ZhipuAlltoolService() zhipu_file_service = ZhipuFileService() +zhipu_kb_service = ZhipuKbService() response_headers = {'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', @@ -66,7 +72,7 @@ def retrive_stream(): def event_stream_retrive(): accumulated_result = "" - for chunk in zhipu_service.retrive_sse(message, knowledge_id, system_prompt="你是一个销售助理,语言对话请以第一人称你我进行"): + for chunk in zhipu_kb_service.retrive_sse(message, knowledge_id, system_prompt="你是一个销售助理,语言对话请以第一人称你我进行"): if chunk: accumulated_result += chunk chunk_out = format_chunk(chunk, None, None) @@ -92,7 +98,7 @@ def retrive_stream(): prompt_report_template = PromptRepository().get_prompt("report_template") prompt_report_title = f"根据用户提问中\"\"\" {message} \"\"\" 中提到的项目信息 在知识库中查找该项目的销售日志。如果销售日志中缺乏模板中的要点(时间,参与人,事件,获得信息,信息来源,项目进展描述)信息,则该要点内容留空,不要填充信息 日报模板: \"\"\" {prompt_report_template} \"\"\"。输出: 日志报告" generated_report = "" - for chunk in zhipu_service.retrive_sse(prompt_report_title + message, knowledge_id, None): + for chunk in zhipu_kb_service.retrive_sse(prompt_report_title + message, knowledge_id, None): if chunk: print(chunk) generated_report += chunk @@ -123,7 +129,7 @@ def retrive_stream(): prompt_report_template = PromptRepository().get_prompt("report_template") prompt_report_title = f"根据用户提问中\"\"\" {message} \"\"\" 中提到的项目信息 在知识库中查找该项目的销售日志并结合用户提供的新的日志信息 \"\"\"{message} \"\"\"生成日报。如果销售日志中缺乏模板中的要点(时间,参与人,事件,获得信息,信息来源,项目进展描述)信息,则该要点内容留空,不要填充信息 日报模板: \"\"\" {prompt_report_template} \"\"\"。输出: 日志报告" generated_report = "" - for chunk in zhipu_service.retrive_sse(prompt_report_title + message, knowledge_id, None): + for chunk in zhipu_kb_service.retrive_sse(prompt_report_title + message, knowledge_id, None): if chunk: print(chunk) generated_report += chunk @@ -142,11 +148,32 @@ def retrive_stream(): @zhipu_controller_v2.route('/zhipu/analysis/stream', methods=['POST']) def analysis_stream(): + init_session() data = request.json message = data.get('message', '') knowledge_id = data.get('knowledge_id', '') message = message.replace("我", "我(徐春峰)") + if 'zhipu' in message.lower() or '智谱' in message: + logger.info(f'switch to zhipu service, save to session') + session['llm_service'] = 'zhipu' + return format_chunk("切换到智谱AI服务", None, None) + if 'openai' in message.lower() or 'openai' in message: + logger.info(f'switch to openai service, save to session') + session['llm_service'] = 'openai' + return format_chunk("切换到openai服务", None, None) + # 默认使用智谱AI服务 + llm_service = zhipu_service + logger.info(f'llm_service: {session["llm_service"]}') + current_service = session.get('llm_service', 'zhipu') # Default to 'zhipu' if not set + + if current_service == 'openai': + logger.info('Using OpenAI service') + llm_service = openai_service + else: + logger.info('Using Zhipu service') + llm_service = zhipu_service + logger.info(f'/zhipu/analysis/stream v2: {message}') intent_categories =["analyze_sales","provide_sales_update_info"] @@ -158,6 +185,11 @@ def analysis_stream(): if '更新' in message and '我负责' in message and '项目' in message and '最新' in message: classification_result = {"category":"provide_sales_update_info"} + # if 'openai' in message.lower(): + # logger.info(f'switch to openai service') + # llm_service = openai_service + # message = message.replace("openai", "") + additional_business_info = "" if classification_result.get('category') == 'analyze_sales': @@ -182,7 +214,7 @@ def analysis_stream(): 请根据用户提供的如下信息,查找相关的 '当前详细状态及已完成工作','Sales stage' 信息,并返回给用户: {message} """ - business_info = zhipu_service.retrive(prompt_get_business_info, knowledge_id, None) + business_info = zhipu_kb_service.retrive(prompt_get_business_info, knowledge_id, None) logger.info(f'business_info: {business_info}') analysis_rule = PromptRepository().get_prompt('sales_analysis') @@ -220,7 +252,7 @@ def analysis_stream(): def event_stream(): accumulated_result = "" - for chunk in zhipu_service.talk_to_zhipu_sse(prompt_analysis): + for chunk in llm_service.generate_response_sse(prompt_analysis): if chunk: accumulated_result += chunk chunk_out = format_chunk(chunk, None, None) diff --git a/app/services/ai_service_interface.py b/app/services/ai_service_interface.py new file mode 100644 index 0000000..3d710b1 --- /dev/null +++ b/app/services/ai_service_interface.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod + +class AIServiceInterface(ABC): + @abstractmethod + def generate_response(self, prompt): + pass + + @abstractmethod + def generate_response_sse(self, prompt): + pass + diff --git a/app/services/openai_service.py b/app/services/openai_service.py index e69de29..493e524 100644 --- a/app/services/openai_service.py +++ b/app/services/openai_service.py @@ -0,0 +1,61 @@ +import logging +import time +from openai import OpenAI +from app.services.ai_service_interface import AIServiceInterface +# from app.utils.prompt_repository import PromptRepository + +logger = logging.getLogger(__name__) + +class OpenaiService(AIServiceInterface): + def __init__(self): + self.client = OpenAI(base_url="https://ai.xorbit.link:8443/e5b2a5e5-b41d-4715-9d50-d4a3b0c1a85f/v1", api_key="sk-proj-e5b2a5e5b41d47159d50d4a3b0c1a85f") + + def generate_response(self, prompt): + try: + response = self.client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "system", "content": prompt}], + temperature=0.7, + ) + return response.choices[0].message.content + except Exception as e: + logger.error(f"Error generating response: {e}") + return "An error occurred while generating the response." + + def generate_response_sse(self, prompt): + try: + response = self.client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "system", "content": prompt}], + temperature=0.7, + stream=True + ) + for chunk in response: + if chunk.choices[0].delta.content is not None: + yield chunk.choices[0].delta.content + yield "answer provided by openai" + except Exception as e: + logger.error(f"Error generating SSE response: {e}") + yield "An error occurred while generating the SSE response." + + +if __name__ == "__main__": + # Set up logging + logging.basicConfig(level=logging.INFO) + + # Create an instance of OpenaiService + openai_service = OpenaiService() + + # Test the generate_response method + test_prompt = "What is the capital of France?" + response = openai_service.generate_response(test_prompt) + print(f"Response to '{test_prompt}': {response}") + + # Test the generate_response_sse method + print("\nTesting generate_response_sse:") + sse_prompt = "Count from 1 to 5 slowly." + for chunk in openai_service.generate_response_sse(sse_prompt): + print(chunk, end='', flush=True) + time.sleep(0.1) # Add a small delay to simulate streaming + print("\nSSE response complete.") + diff --git a/app/services/zhipu_kb_service.py b/app/services/zhipu_kb_service.py new file mode 100644 index 0000000..b0cba2c --- /dev/null +++ b/app/services/zhipu_kb_service.py @@ -0,0 +1,83 @@ +import logging +import time +from zhipuai import ZhipuAI +from app.utils.prompt_repository import PromptRepository + +logger = logging.getLogger(__name__) + +class ZhipuKbService: + def __init__(self): + self.model_name = "glm-4" + self.app_secret_key = "d54f764a1d67c17d857bd3983b772016.GRjowY0fyiMNurLc" + logger.info("ZhipuKbService initialized with model: %s", self.model_name) + + + def retrive(self, message, knowledge_id, prompt_template): + logger.info("Starting retrive call with knowledge_id: %s", knowledge_id) + start_time = time.time() + client = ZhipuAI(api_key=self.app_secret_key) + default_prompt = "从文档\n\"\"\"\n{{knowledge}}\n\"\"\"\n中找问题\n\"\"\"\n{{question}}\n\"\"\"\n的答案,找到答案就仅使用文档语句回答问题,找不到答案就用自身知识回答并且告诉用户该信息不是来自文档。\n不要复述问题,直接开始回答。" + + if prompt_template is None or prompt_template == "": + prompt_template = default_prompt + try: + response = client.chat.completions.create( + model="glm-4", + messages=[ + {"role": "user", "content": message}, + ], + tools=[ + { + "type": "retrieval", + "retrieval": { + "knowledge_id": knowledge_id, + "prompt_template": prompt_template + } + } + ], + stream=False, + temperature=0.01, + top_p=0.1, + ) + result = response.choices[0].message.content + end_time = time.time() + logger.info("retrive call completed in %.2f seconds", end_time - start_time) + return result + except Exception as e: + logger.error("Error in retrive: %s", str(e)) + raise + + def retrive_sse(self, message, knowledge_id, prompt_template=None,system_prompt=None): + logger.info("Starting retrive_sse call with knowledge_id: %s, message:%s", knowledge_id, message) + start_time = time.time() + client = ZhipuAI(api_key=self.app_secret_key) + default_prompt = "从文档\n\"\"\"\n{{knowledge}}\n\"\"\"\n中找问题\n\"\"\"\n{{question}}\n\"\"\"\n的答案,找到答案就仅使用文档语句回答问题,找不到答案就告诉用户知识库中没有该信息。\n不要复述问题,直接开始回答。" + messages = [{"role": "user", "content": message}] + # if system_prompt != None: + # messages.append({"role": "system", "content": system_prompt}) + if prompt_template is None or prompt_template == "": + prompt_template = default_prompt + try: + response = client.chat.completions.create( + model="glm-4", + messages=messages, + tools=[ + { + "type": "retrieval", + "retrieval": { + "knowledge_id": knowledge_id, + "prompt_template": prompt_template + } + } + ], + stream=True, + temperature=0.01, + top_p=0.1, + ) + for chunk in response: + yield chunk.choices[0].delta.content + end_time = time.time() + logger.info("retrive_sse call completed in %.2f seconds", end_time - start_time) + except Exception as e: + logger.error("Error in retrive_sse: %s", str(e)) + raise diff --git a/app/services/zhipu_service.py b/app/services/zhipu_service.py index 7872262..6d0875d 100644 --- a/app/services/zhipu_service.py +++ b/app/services/zhipu_service.py @@ -1,47 +1,47 @@ import logging import time from zhipuai import ZhipuAI +from app.services.ai_service_interface import AIServiceInterface from app.utils.prompt_repository import PromptRepository logger = logging.getLogger(__name__) -class ZhipuService: +class ZhipuService(AIServiceInterface): def __init__(self): self.model_name = "glm-4" self.app_secret_key = "d54f764a1d67c17d857bd3983b772016.GRjowY0fyiMNurLc" + self.client = ZhipuAI(api_key=self.app_secret_key) logger.info("ZhipuService initialized with model: %s", self.model_name) - def talk_to_zhipu(self, message): - logger.info("Starting talk_to_zhipu call") + def generate_response(self, prompt): + logger.info("Starting generate_response call") start_time = time.time() - client = ZhipuAI(api_key=self.app_secret_key) try: - response = client.chat.completions.create( + response = self.client.chat.completions.create( model=self.model_name, messages=[ - {"role": "user", "content": message}, + {"role": "user", "content": prompt}, ], stream=False, temperature=0.01, top_p=0.1, ) - accum_resp = response.choices[0].message.content + result = response.choices[0].message.content end_time = time.time() - logger.info("talk_to_zhipu call completed in %.2f seconds", end_time - start_time) - return accum_resp + logger.info("generate_response call completed in %.2f seconds", end_time - start_time) + return result except Exception as e: - logger.error("Error in talk_to_zhipu: %s", str(e)) + logger.error("Error in generate_response: %s", str(e)) raise - def talk_to_zhipu_sse(self, message): - logger.info("Starting talk_to_zhipu_sse call") + def generate_response_sse(self, prompt): + logger.info("Starting generate_response_sse call") start_time = time.time() - client = ZhipuAI(api_key=self.app_secret_key) try: - response = client.chat.completions.create( + response = self.client.chat.completions.create( model=self.model_name, messages=[ - {"role": "user", "content": message}, + {"role": "user", "content": prompt}, ], stream=True, temperature=0.01, @@ -50,21 +50,20 @@ class ZhipuService: for chunk in response: yield chunk.choices[0].delta.content end_time = time.time() - logger.info("talk_to_zhipu_sse call completed in %.2f seconds", end_time - start_time) + logger.info("generate_response_sse call completed in %.2f seconds", end_time - start_time) except Exception as e: - logger.error("Error in talk_to_zhipu_sse: %s", str(e)) + logger.error("Error in generate_response_sse: %s", str(e)) raise def retrive(self, message, knowledge_id, prompt_template): logger.info("Starting retrive call with knowledge_id: %s", knowledge_id) start_time = time.time() - client = ZhipuAI(api_key=self.app_secret_key) default_prompt = "从文档\n\"\"\"\n{{knowledge}}\n\"\"\"\n中找问题\n\"\"\"\n{{question}}\n\"\"\"\n的答案,找到答案就仅使用文档语句回答问题,找不到答案就用自身知识回答并且告诉用户该信息不是来自文档。\n不要复述问题,直接开始回答。" if prompt_template is None or prompt_template == "": prompt_template = default_prompt try: - response = client.chat.completions.create( + response = self.client.chat.completions.create( model="glm-4", messages=[ {"role": "user", "content": message}, @@ -93,7 +92,6 @@ class ZhipuService: def retrive_sse(self, message, knowledge_id, prompt_template=None,system_prompt=None): logger.info("Starting retrive_sse call with knowledge_id: %s, message:%s", knowledge_id, message) start_time = time.time() - client = ZhipuAI(api_key=self.app_secret_key) default_prompt = "从文档\n\"\"\"\n{{knowledge}}\n\"\"\"\n中找问题\n\"\"\"\n{{question}}\n\"\"\"\n的答案,找到答案就仅使用文档语句回答问题,找不到答案就告诉用户知识库中没有该信息。\n不要复述问题,直接开始回答。" messages = [{"role": "user", "content": message}] # if system_prompt != None: @@ -101,7 +99,7 @@ class ZhipuService: if prompt_template is None or prompt_template == "": prompt_template = default_prompt try: - response = client.chat.completions.create( + response = self.client.chat.completions.create( model="glm-4", messages=messages, tools=[ @@ -132,11 +130,10 @@ class ZhipuService: prompt_report_template = PromptRepository().get_prompt("report_template") prompt_report_missing_check = f"""{prompt_report_template} - 请检查以下日志信息是否完整,如果信息缺失则提示要求用户需要补充的信息要点,如果信息完整请直接返回“上述日志信息为全部信息”。日志信息如下:\n\"\"\"\n{message}\n\"\"\"\n""" + 请检查以下日志信息是否完整,如果信息缺失则提示要求用户需要补充的信息要点,如果信息完整请直接返回"上述日志信息为全部信息"。日志信息如下:\n\"\"\"\n{message}\n\"\"\"\n""" - client = ZhipuAI(api_key=self.app_secret_key) try: - response = client.chat.completions.create( + response = self.client.chat.completions.create( model="glm-4-flash", messages=[ {"role": "user", "content": prompt_report_missing_check}, diff --git a/app/utils/sessions.py b/app/utils/sessions.py index ebd94fb..2636b81 100644 --- a/app/utils/sessions.py +++ b/app/utils/sessions.py @@ -9,3 +9,5 @@ def init_session(): if 'history' not in session: session['history'] = [] # 初始化会话历史 session['session_info'] = {} + if 'llm_service' not in session: + session['llm_service'] = 'zhipu' diff --git a/requirements.txt b/requirements.txt index 798083d..1ac4363 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,5 @@ Werkzeug==3.0.4 zipp==3.20.2 zhipuai==2.1.5.20230904 pytz==2024.2 -flask-debug==0.4.3 \ No newline at end of file +flask-debug==0.4.3 +openai==1.52.2 \ No newline at end of file