Skip to content

Commit

Permalink
refactor: image model get_num_tokens override
Browse files Browse the repository at this point in the history
  • Loading branch information
liuruibin committed Jan 8, 2025
1 parent 8db35c4 commit 1310c8a
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# coding=utf-8

from typing import Dict
from typing import Dict, List

from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel


def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)


class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):

@staticmethod
Expand All @@ -21,3 +28,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
**optional_params,
)
return chat_tong_yi

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict
from typing import Dict, List

from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai import AzureChatOpenAI
from langchain_openai.chat_models import ChatOpenAI

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
Expand All @@ -26,3 +26,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
streaming=True,
**optional_params,
)

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict
from typing import Dict, List
from urllib.parse import urlparse, ParseResult

from langchain_core.messages import get_buffer_string, BaseMessage
from langchain_openai.chat_models import ChatOpenAI

from common.config.tokenizer_manage_config import TokenizerManage
Expand Down Expand Up @@ -36,3 +37,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
streaming=True,
**optional_params,
)

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict
from typing import Dict, List

from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai.chat_models import ChatOpenAI

from common.config.tokenizer_manage_config import TokenizerManage
Expand All @@ -24,3 +25,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
streaming=True,
**optional_params,
)

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# coding=utf-8

from typing import Dict
from typing import Dict, List

from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel


def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)


class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):

@staticmethod
Expand All @@ -21,3 +28,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
**optional_params,
)
return chat_tong_yi

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict
from typing import Dict, List

from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai.chat_models import ChatOpenAI

from common.config.tokenizer_manage_config import TokenizerManage
Expand All @@ -24,3 +25,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
streaming=True,
**optional_params,
)

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict
from typing import Dict, List

from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai.chat_models import ChatOpenAI

from common.config.tokenizer_manage_config import TokenizerManage
Expand All @@ -24,3 +25,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
streaming=True,
**optional_params,
)

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict
from typing import Dict, List

from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai.chat_models import ChatOpenAI

from common.config.tokenizer_manage_config import TokenizerManage
Expand All @@ -24,3 +25,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
streaming=True,
**optional_params,
)

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict
from typing import Dict, List

from langchain_core.messages import get_buffer_string, BaseMessage
from langchain_openai.chat_models import ChatOpenAI

from common.config.tokenizer_manage_config import TokenizerManage
Expand All @@ -24,3 +25,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
streaming=True,
**optional_params,
)

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

0 comments on commit 1310c8a

Please sign in to comment.