From ece8894cf1fd4b4406b03199f487b348909d049c Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 26 Dec 2024 15:28:36 +0800 Subject: [PATCH] refactor: Workflow execution logic --- apps/application/flow/common.py | 17 +++ apps/application/flow/i_step_node.py | 4 +- .../impl/base_search_dataset_node.py | 4 +- apps/application/flow/workflow_manage.py | 120 ++++++++---------- apps/setting/models_provider/tools.py | 3 + .../serializers/model_apply_serializers.py | 3 + apps/smartdoc/conf.py | 8 +- installer/config.yaml | 2 +- pyproject.toml | 3 +- 9 files changed, 94 insertions(+), 70 deletions(-) diff --git a/apps/application/flow/common.py b/apps/application/flow/common.py index 96db0011926..e0bfcdbb2f8 100644 --- a/apps/application/flow/common.py +++ b/apps/application/flow/common.py @@ -19,3 +19,20 @@ def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_no def to_dict(self): return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id, 'chat_record_id': self.chat_record_id, 'child_node': self.child_node} + + +class NodeChunk: + def __init__(self): + self.status = 0 + self.chunk_list = [] + + def add_chunk(self, chunk): + self.chunk_list.append(chunk) + + def end(self, chunk=None): + if chunk is not None: + self.add_chunk(chunk) + self.status = 200 + + def is_end(self): + return self.status == 200 diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index a9316770b47..e4279b949fa 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -17,7 +17,7 @@ from rest_framework import serializers from rest_framework.exceptions import ValidationError, ErrorDetail -from application.flow.common import Answer +from application.flow.common import Answer, NodeChunk from application.models import ChatRecord from application.models.api_key_model import ApplicationPublicAccessClient from common.constants.authentication_type import AuthenticationType @@ -175,6 +175,7 @@ def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None, if up_node_id_list is None: up_node_id_list = [] self.up_node_id_list = up_node_id_list + self.node_chunk = NodeChunk() self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS, "".join([*sorted(up_node_id_list), node.id]))), @@ -214,6 +215,7 @@ def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]: def get_write_error_context(self, e): self.status = 500 + self.answer_text = str(e) self.err_message = str(e) self.context['run_time'] = time.time() - self.context['start_time'] diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index 84af258bedf..a74e5cc8e96 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -10,7 +10,7 @@ from typing import List, Dict from django.db.models import QuerySet - +from django.db import connection from application.flow.i_step_node import NodeResult from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode from common.config.embedding_config import VectorStore @@ -77,6 +77,8 @@ def execute(self, dataset_id_list, dataset_setting, question, embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list, exclude_paragraph_id_list, True, dataset_setting.get('top_n'), dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode'))) + # 手动关闭数据库连接 + connection.close() if embedding_list is None: return get_none_result(question) paragraph_list = self.list_paragraph(embedding_list, vector) diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 27598c43add..66e67a0824d 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -6,6 +6,7 @@ @date:2024/1/9 17:40 @desc: """ +import concurrent import json import threading import traceback @@ -13,6 +14,7 @@ from functools import reduce from typing import List, Dict +from django.db import close_old_connections from django.db.models import QuerySet from langchain_core.prompts import PromptTemplate from rest_framework import status @@ -223,23 +225,6 @@ def pop(self): return None -class NodeChunk: - def __init__(self): - self.status = 0 - self.chunk_list = [] - - def add_chunk(self, chunk): - self.chunk_list.append(chunk) - - def end(self, chunk=None): - if chunk is not None: - self.add_chunk(chunk) - self.status = 200 - - def is_end(self): - return self.status == 200 - - class WorkflowManage: def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None, @@ -273,8 +258,9 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl self.status = 200 self.base_to_response = base_to_response self.chat_record = chat_record - self.await_future_map = {} self.child_node = child_node + self.future_list = [] + self.lock = threading.Lock() if start_node_id is not None: self.load_node(chat_record, start_node_id, start_node_data) else: @@ -319,6 +305,7 @@ def get_node_params(n): self.node_context.append(node) def run(self): + close_old_connections() if self.params.get('stream'): return self.run_stream(self.start_node, None) return self.run_block() @@ -328,8 +315,9 @@ def run_block(self): 非流式响应 @return: 结果 """ - result = self.run_chain_async(None, None) - result.result() + self.run_chain_async(None, None) + while self.is_run(): + pass details = self.get_runtime_details() message_tokens = sum([row.get('message_tokens') for row in details.values() if 'message_tokens' in row and row.get('message_tokens') is not None]) @@ -350,12 +338,22 @@ def run_stream(self, current_node, node_result_future): 流式响应 @return: """ - result = self.run_chain_async(current_node, node_result_future) - return tools.to_stream_response_simple(self.await_result(result)) + self.run_chain_async(current_node, node_result_future) + return tools.to_stream_response_simple(self.await_result()) - def await_result(self, result): + def is_run(self, timeout=0.1): + self.lock.acquire() try: - while await_result(result): + r = concurrent.futures.wait(self.future_list, timeout) + return len(r.not_done) > 0 + except Exception as e: + return True + finally: + self.lock.release() + + def await_result(self): + try: + while self.is_run(): while True: chunk = self.node_chunk_manage.pop() if chunk is not None: @@ -383,12 +381,16 @@ def await_result(self, result): '', True, message_tokens, answer_tokens, {}) def run_chain_async(self, current_node, node_result_future): - return executor.submit(self.run_chain_manage, current_node, node_result_future) + future = executor.submit(self.run_chain_manage, current_node, node_result_future) + self.future_list.append(future) def run_chain_manage(self, current_node, node_result_future): if current_node is None: start_node = self.get_start_node() current_node = get_node(start_node.type)(start_node, self.params, self) + self.node_chunk_manage.add_node_chunk(current_node.node_chunk) + # 添加节点 + self.append_node(current_node) result = self.run_chain(current_node, node_result_future) if result is None: return @@ -396,29 +398,22 @@ def run_chain_manage(self, current_node, node_result_future): if len(node_list) == 1: self.run_chain_manage(node_list[0], None) elif len(node_list) > 1: - + sorted_node_run_list = sorted(node_list, key=lambda n: n.node.y) # 获取到可执行的子节点 result_list = [{'node': node, 'future': executor.submit(self.run_chain_manage, node, None)} for node in - node_list] - self.set_await_map(result_list) - [r.get('future').result() for r in result_list] - - def set_await_map(self, node_run_list): - sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y) - for index in range(len(sorted_node_run_list)): - self.await_future_map[sorted_node_run_list[index].get('node').runtime_node_id] = [ - sorted_node_run_list[i].get('future') - for i in range(index)] + sorted_node_run_list] + try: + self.lock.acquire() + for r in result_list: + self.future_list.append(r.get('future')) + finally: + self.lock.release() def run_chain(self, current_node, node_result_future=None): if node_result_future is None: node_result_future = self.run_node_future(current_node) try: is_stream = self.params.get('stream', True) - # 处理节点响应 - await_future_list = self.await_future_map.get(current_node.runtime_node_id, None) - if await_future_list is not None: - [f.result() for f in await_future_list] result = self.hand_event_node_result(current_node, node_result_future) if is_stream else self.hand_node_result( current_node, node_result_future) @@ -434,16 +429,14 @@ def hand_node_result(self, current_node, node_result_future): if result is not None: # 阻塞获取结果 list(result) - # 添加节点 - self.node_context.append(current_node) return current_result except Exception as e: - # 添加节点 - self.node_context.append(current_node) traceback.print_exc() self.status = 500 current_node.get_write_error_context(e) self.answer += str(e) + finally: + current_node.node_chunk.end() def append_node(self, current_node): for index in range(len(self.node_context)): @@ -454,15 +447,14 @@ def append_node(self, current_node): self.node_context.append(current_node) def hand_event_node_result(self, current_node, node_result_future): - node_chunk = NodeChunk() real_node_id = current_node.runtime_node_id child_node = {} + view_type = current_node.view_type try: current_result = node_result_future.result() result = current_result.write_context(current_node, self) if result is not None: if self.is_result(current_node, current_result): - self.node_chunk_manage.add_node_chunk(node_chunk) for r in result: content = r child_node = {} @@ -487,26 +479,24 @@ def hand_event_node_result(self, current_node, node_result_future): 'child_node': child_node, 'node_is_end': node_is_end, 'real_node_id': real_node_id}) - node_chunk.add_chunk(chunk) - chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], - self.params['chat_record_id'], - current_node.id, - current_node.up_node_id_list, - '', False, 0, 0, {'node_is_end': True, - 'runtime_node_id': current_node.runtime_node_id, - 'node_type': current_node.type, - 'view_type': view_type, - 'child_node': child_node, - 'real_node_id': real_node_id}) - node_chunk.end(chunk) + current_node.node_chunk.add_chunk(chunk) + chunk = (self.base_to_response + .to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + current_node.id, + current_node.up_node_id_list, + '', False, 0, 0, {'node_is_end': True, + 'runtime_node_id': current_node.runtime_node_id, + 'node_type': current_node.type, + 'view_type': view_type, + 'child_node': child_node, + 'real_node_id': real_node_id})) + current_node.node_chunk.add_chunk(chunk) else: list(result) - # 添加节点 - self.append_node(current_node) return current_result except Exception as e: # 添加节点 - self.append_node(current_node) traceback.print_exc() chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'], @@ -519,12 +509,12 @@ def hand_event_node_result(self, current_node, node_result_future): 'view_type': current_node.view_type, 'child_node': {}, 'real_node_id': real_node_id}) - if not self.node_chunk_manage.contains(node_chunk): - self.node_chunk_manage.add_node_chunk(node_chunk) - node_chunk.end(chunk) + current_node.node_chunk.add_chunk(chunk) current_node.get_write_error_context(e) self.status = 500 return None + finally: + current_node.node_chunk.end() def run_node_async(self, node): future = executor.submit(self.run_node, node) @@ -636,6 +626,8 @@ def get_next_node(self): @staticmethod def dependent_node(up_node_id, node): + if not node.node_chunk.is_end(): + return False if node.id == up_node_id: if node.type == 'form-node': if node.context.get('form_data', None) is not None: diff --git a/apps/setting/models_provider/tools.py b/apps/setting/models_provider/tools.py index 66060436972..e353acc356d 100644 --- a/apps/setting/models_provider/tools.py +++ b/apps/setting/models_provider/tools.py @@ -6,6 +6,7 @@ @date:2024/7/22 11:18 @desc: """ +from django.db import connection from django.db.models import QuerySet from common.config.embedding_config import ModelManage @@ -15,6 +16,8 @@ def get_model_by_id(_id, user_id): model = QuerySet(Model).filter(id=_id).first() + # 手动关闭数据库连接 + connection.close() if model is None: raise Exception("模型不存在") if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id): diff --git a/apps/setting/serializers/model_apply_serializers.py b/apps/setting/serializers/model_apply_serializers.py index fd418698705..12b6fafd472 100644 --- a/apps/setting/serializers/model_apply_serializers.py +++ b/apps/setting/serializers/model_apply_serializers.py @@ -6,6 +6,7 @@ @date:2024/8/20 20:39 @desc: """ +from django.db import connection from django.db.models import QuerySet from langchain_core.documents import Document from rest_framework import serializers @@ -18,6 +19,8 @@ def get_embedding_model(model_id): model = QuerySet(Model).filter(id=model_id).first() + # 手动关闭数据库连接 + connection.close() embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, use_local=True)) return embedding_model diff --git a/apps/smartdoc/conf.py b/apps/smartdoc/conf.py index 0349739813f..e041c85ad85 100644 --- a/apps/smartdoc/conf.py +++ b/apps/smartdoc/conf.py @@ -80,7 +80,7 @@ class Config(dict): "DB_PORT": 5432, "DB_USER": "root", "DB_PASSWORD": "Password123@postgres", - "DB_ENGINE": "django.db.backends.postgresql_psycopg2", + "DB_ENGINE": "dj_db_conn_pool.backends.postgresql", # 向量模型 "EMBEDDING_MODEL_NAME": "shibing624/text2vec-base-chinese", "EMBEDDING_DEVICE": "cpu", @@ -108,7 +108,11 @@ def get_db_setting(self) -> dict: "PORT": self.get('DB_PORT'), "USER": self.get('DB_USER'), "PASSWORD": self.get('DB_PASSWORD'), - "ENGINE": self.get('DB_ENGINE') + "ENGINE": self.get('DB_ENGINE'), + "POOL_OPTIONS": { + "POOL_SIZE": 20, + "MAX_OVERFLOW": 5 + } } def __init__(self, *args): diff --git a/installer/config.yaml b/installer/config.yaml index c9f45db869f..8127fc9ab67 100644 --- a/installer/config.yaml +++ b/installer/config.yaml @@ -13,7 +13,7 @@ DB_HOST: 127.0.0.1 DB_PORT: 5432 DB_USER: root DB_PASSWORD: Password123@postgres -DB_ENGINE: django.db.backends.postgresql_psycopg2 +DB_ENGINE: dj_db_conn_pool.backends.postgresql EMBEDDING_MODEL_PATH: /opt/maxkb/model/embedding EMBEDDING_MODEL_NAME: /opt/maxkb/model/embedding/shibing624_text2vec-base-chinese diff --git a/pyproject.toml b/pyproject.toml index 4e6ccfcf0e3..34cd3891aa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ django-filter = "23.2" langchain = "0.2.16" langchain_community = "0.2.17" langchain-huggingface = "^0.0.3" -psycopg2-binary = "2.9.7" +psycopg2-binary = "2.9.10" jieba = "^0.42.1" diskcache = "^5.6.3" pillow = "^10.2.0" @@ -57,6 +57,7 @@ pylint = "3.1.0" pydub = "^0.25.1" cffi = "^1.17.1" pysilk = "^0.0.1" +django-db-connection-pool = "^1.2.5" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"