From d37921099d1ac30f0ebc8121448cb21dc675c3fb Mon Sep 17 00:00:00 2001 From: Tiger Ren Date: Sat, 26 Oct 2024 22:29:09 +0800 Subject: [PATCH] =?UTF-8?q?=E5=89=8D=E7=AB=AF=E4=B8=8D=E5=8F=91=E9=80=81co?= =?UTF-8?q?okie=EF=BC=8C=E4=BD=BF=E7=94=A8=E6=9C=AC=E5=9C=B0=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E6=96=87=E4=BB=B6=E4=BF=9D=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/v2/zhipu_controller_v2.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/app/api/v2/zhipu_controller_v2.py b/app/api/v2/zhipu_controller_v2.py index fc3e7b4..0ce141e 100644 --- a/app/api/v2/zhipu_controller_v2.py +++ b/app/api/v2/zhipu_controller_v2.py @@ -8,6 +8,21 @@ from app.services.zhipu_kb_service import ZhipuKbService from app.services.openai_service import OpenaiService from app.utils.prompt_repository import PromptRepository # Add this import from app.utils.sessions import init_session +import os + +CONFIG_FILE = 'llm_service_config.json' + +def get_current_service(): + if os.path.exists(CONFIG_FILE): + with open(CONFIG_FILE, 'r') as f: + config = json.load(f) + return config.get('llm_service', 'zhipu') + return 'zhipu' # Default to zhipu if file doesn't exist + +def set_current_service(service): + config = {'llm_service': service} + with open(CONFIG_FILE, 'w') as f: + json.dump(config, f) zhipu_controller_v2 = Blueprint('zhipu_controller_v2', __name__) @@ -157,16 +172,19 @@ def analysis_stream(): if 'zhipu' in message.lower() or '智谱' in message: logger.info(f'switch to zhipu service, save to session') session['llm_service'] = 'zhipu' + set_current_service('zhipu') return format_chunk("切换到智谱AI服务", None, None) if 'openai' in message.lower() or 'openai' in message: logger.info(f'switch to openai service, save to session') session['llm_service'] = 'openai' + set_current_service('openai') return format_chunk("切换到openai服务", None, None) # 默认使用智谱AI服务 llm_service = zhipu_service logger.info(f'llm_service: {session["llm_service"]}') - current_service = session.get('llm_service', 'zhipu') # Default to 'zhipu' if not set - + # current_service = session.get('llm_service', 'zhipu') # Default to 'zhipu' if not set + current_service = get_current_service() + if current_service == 'openai': logger.info('Using OpenAI service') llm_service = openai_service @@ -263,3 +281,4 @@ def analysis_stream(): return Response(event_stream(), mimetype='text/event-stream', headers=response_headers) +