From b2524eec49010f36a82d0c67bec77e5aa7f15bbc Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Tue, 22 Oct 2024 11:38:37 +0800 Subject: [PATCH] fix sequence2txt error and usage total token issue (#2961) ### What problem does this PR solve? #1363 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/conversation_app.py | 2 +- api/db/services/llm_service.py | 3 ++- api/utils/file_utils.py | 2 ++ rag/llm/chat_model.py | 18 ++++++++++-------- rag/llm/sequence2txt_model.py | 2 +- 5 files changed, 16 insertions(+), 11 deletions(-) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 525ca5b5ef..7e00dfeda7 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -26,7 +26,6 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService from api.settings import RetCode, retrievaler -from api.utils import get_uuid from api.utils.api_utils import get_json_result from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from graphrag.mind_map_extractor import MindMapExtractor @@ -187,6 +186,7 @@ def stream(): yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" ConversationService.update_by_id(conv.id, conv.to_dict()) except Exception as e: + traceback.print_exc() yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index be650a61c6..7677b700ad 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -133,7 +133,8 @@ def model_instance(cls, tenant_id, llm_type, if model_config["llm_factory"] not in Seq2txtModel: return return Seq2txtModel[model_config["llm_factory"]]( - model_config["api_key"], model_config["llm_name"], lang, + key=model_config["api_key"], model_name=model_config["llm_name"], + lang=lang, base_url=model_config["api_base"] ) if llm_type == LLMType.TTS: diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py index 32ead6cc70..06fa36e996 100644 --- a/api/utils/file_utils.py +++ b/api/utils/file_utils.py @@ -197,6 +197,7 @@ def thumbnail_img(filename, blob): pass return None + def thumbnail(filename, blob): img = thumbnail_img(filename, blob) if img is not None: @@ -205,6 +206,7 @@ def thumbnail(filename, blob): else: return '' + def traversal_files(base): for root, ds, fs in os.walk(base): for f in fs: diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index f5b584b297..58d5b0ded6 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -67,14 +67,16 @@ def chat_streamly(self, system, history, gen_conf): if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" ans += resp.choices[0].delta.content - total_tokens = ( - ( - total_tokens - + num_tokens_from_string(resp.choices[0].delta.content) - ) - if not hasattr(resp, "usage") or not resp.usage - else resp.usage.get("total_tokens", total_tokens) - ) + total_tokens += 1 + if not hasattr(resp, "usage") or not resp.usage: + total_tokens = ( + total_tokens + + num_tokens_from_string(resp.choices[0].delta.content) + ) + elif isinstance(resp.usage, dict): + total_tokens = resp.usage.get("total_tokens", total_tokens) + else: total_tokens = resp.usage.total_tokens + if resp.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py index a2d3ea0ef5..950ea10ec3 100644 --- a/rag/llm/sequence2txt_model.py +++ b/rag/llm/sequence2txt_model.py @@ -87,7 +87,7 @@ def __init__(self, key, model_name, lang="Chinese", **kwargs): class XinferenceSeq2txt(Base): - def __init__(self,key,model_name="whisper-small",**kwargs): + def __init__(self, key, model_name="whisper-small", **kwargs): self.base_url = kwargs.get('base_url', None) self.model_name = model_name self.key = key