-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support Anthropic #1974
feat: Support Anthropic #1974
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: maxkb | ||
@Author:虎 | ||
@file: __init__.py.py | ||
@date:2024/3/28 16:25 | ||
@desc: | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: maxkb | ||
@Author:虎 | ||
@file: openai_model_provider.py | ||
@date:2024/3/28 16:26 | ||
@desc: | ||
""" | ||
import os | ||
|
||
from common.util.file_util import get_file_content | ||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ | ||
ModelTypeConst, ModelInfoManage | ||
from setting.models_provider.impl.anthropic_model_provider.credential.image import AnthropicImageModelCredential | ||
from setting.models_provider.impl.anthropic_model_provider.credential.llm import AnthropicLLMModelCredential | ||
from setting.models_provider.impl.anthropic_model_provider.model.image import AnthropicImage | ||
from setting.models_provider.impl.anthropic_model_provider.model.llm import AnthropicChatModel | ||
from smartdoc.conf import PROJECT_DIR | ||
|
||
openai_llm_model_credential = AnthropicLLMModelCredential() | ||
openai_image_model_credential = AnthropicImageModelCredential() | ||
|
||
model_info_list = [ | ||
ModelInfo('claude-3-opus-20240229', '', ModelTypeConst.LLM, | ||
openai_llm_model_credential, AnthropicChatModel | ||
), | ||
ModelInfo('claude-3-sonnet-20240229', '', ModelTypeConst.LLM, openai_llm_model_credential, | ||
AnthropicChatModel), | ||
ModelInfo('claude-3-haiku-20240307', '', ModelTypeConst.LLM, openai_llm_model_credential, | ||
AnthropicChatModel), | ||
ModelInfo('claude-3-5-sonnet-20240620', '', ModelTypeConst.LLM, openai_llm_model_credential, | ||
AnthropicChatModel), | ||
ModelInfo('claude-3-5-haiku-20241022', '', ModelTypeConst.LLM, openai_llm_model_credential, | ||
AnthropicChatModel), | ||
ModelInfo('claude-3-5-sonnet-20241022', '', ModelTypeConst.LLM, openai_llm_model_credential, | ||
AnthropicChatModel), | ||
] | ||
|
||
image_model_info = [ | ||
ModelInfo('claude-3-5-sonnet-20241022', '', ModelTypeConst.IMAGE, openai_image_model_credential, | ||
AnthropicImage), | ||
] | ||
|
||
model_info_manage = ( | ||
ModelInfoManage.builder() | ||
.append_model_info_list(model_info_list) | ||
.append_default_model_info(model_info_list[0]) | ||
.append_model_info_list(image_model_info) | ||
.append_default_model_info(image_model_info[0]) | ||
.build() | ||
) | ||
|
||
|
||
class AnthropicModelProvider(IModelProvider): | ||
|
||
def get_model_info_manage(self): | ||
return model_info_manage | ||
|
||
def get_model_provide_info(self): | ||
return ModelProvideInfo(provider='model_anthropic_provider', name='Anthropic', icon=get_file_content( | ||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'anthropic_model_provider', 'icon', | ||
'anthropic_icon_svg'))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided code looks mostly correct, but there are a few minor improvements and optimizations that can be considered: Code Improvements
Here is an optimized version of the code with these improvements: # coding=utf-8
"""
@project: maxkb
@author : 虎
@file: openai_model_provider.py
@date:2024/3/28 16:26
@desc:
"""
import os
import sys
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelTypeConst, ModelInfoManage
from setting.models_provider.impl.anthropic_model_provider.credential.image import AnthropicImageModelCredential
from setting.models_provider.impl.anthropic_model_provider.credential.llm import AnthropicLLMModelCredential
from setting.models_provider.impl.anthropic_model_provider.model.image import AnthropicImage
from setting.models_provider.impl.anthropic_model_provider.model.llm import AnthropicChatModel
from smartdoc.conf import ProjectConstants
# Assuming ProjectConstants contains settings such as PROJECT_DIR
OPENAI_LLMMODELCREDENTIAL = AnthropicLLMModelCredential()
OPENAI_IMAGEMODELCREDENTIAL = AnthropicImageModelCredential()
MODEL_INFO_LIST = [
ModelInfo(name='claude-3-opus-20240229', description='', type=ModelTypeConst.LLM,
credential=OPENAI_LLMMODELCREDENTIAL, model_class AnthropicChatModel
),
# ... other ModelInfo entries ...
]
IMAGE_MODEL_INFO = [
ModelInfo(name='claude-3-5-sonnet-20241022', description='', type=ModelTypeConst.IMAGE,
credential=OPENAI_IMAGEMODELCREDENTIAL, model_class=AnthropicImage),
# ... other ImageModelInfo entries ...
]
def load_icon(path):
try:
return get_file_content(os.path.join(ProjectConstants.PROJECT_DIR, path))
except FileNotFoundError:
print(f"Icon file not found at {path}")
sys.exit(1)
ICON_PATH = 'apps' / 'setting' / 'models_provider' / 'impl' / 'anthropic_model_provider/icon' / 'anthropic_icon_svg'
DEFAULT_MODEL_INFO = MODEL_INFO_LIST[0] if MODEL_INFO_LIST else None
default_image_model_info = IMAGE_MODEL_INFO[0] if IMAGE_MODEL_INFO else None
model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(MODEL_INFO_LIST)
.append_default_model_info(DEFAULT_MODEL_INFO)
.append_model_info_list(IMAGE_MODEL_INFO)
.append_default_model_info(default_image_model_info)
.build()
)
class AnthropicModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(
provider='model_anthropic_provider',
name='Anthropic',
icon=load_icon(str(ICON_PATH)) # Convert Path object to str for compatibility
) Key Changes
This refactoring aims to improve readability and maintainability while keeping the core logic intact. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# coding=utf-8 | ||
import base64 | ||
import os | ||
from typing import Dict | ||
|
||
from langchain_core.messages import HumanMessage | ||
|
||
from common import forms | ||
from common.exception.app_exception import AppApiException | ||
from common.forms import BaseForm, TooltipLabel | ||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
||
class AnthropicImageModelParams(BaseForm): | ||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), | ||
required=True, default_value=0.7, | ||
_min=0.1, | ||
_max=1.0, | ||
_step=0.01, | ||
precision=2) | ||
|
||
max_tokens = forms.SliderField( | ||
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), | ||
required=True, default_value=800, | ||
_min=1, | ||
_max=100000, | ||
_step=1, | ||
precision=0) | ||
|
||
|
||
|
||
class AnthropicImageModelCredential(BaseForm, BaseModelCredential): | ||
api_base = forms.TextInputField('API 域名', required=True) | ||
api_key = forms.PasswordInputField('API Key', required=True) | ||
|
||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, | ||
raise_exception=False): | ||
model_type_list = provider.get_model_type_list() | ||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') | ||
|
||
for key in ['api_base', 'api_key']: | ||
if key not in model_credential: | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') | ||
else: | ||
return False | ||
try: | ||
model = provider.get_model(model_type, model_name, model_credential) | ||
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) | ||
for chunk in res: | ||
print(chunk) | ||
except Exception as e: | ||
if isinstance(e, AppApiException): | ||
raise e | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') | ||
else: | ||
return False | ||
return True | ||
|
||
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
||
def get_model_params_setting_form(self, model_name): | ||
return AnthropicImageModelParams() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code appears mostly well-written but contains some improvements and optimizations:
Here's an improved version of the code: # coding=utf-8
import base64
import os
from typing import Any, Dict
from langchain_core.messages import HumanMessage
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class AnthropicImageModelParams(BaseForm):
temperature = forms.SliderField(
TooltipLabel('Temperature', 'Higher values make outputs more random; lower values make them more focused.'),
required=True,
default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2
)
max_tokens = forms.SliderField(
TooltipLabel('Output Maximum Tokens', 'Specify how many tokens the model can generate.'),
required=True,
default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0
)
class AnthropicImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API Domain', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def is_valid(self, model_type: str, model_name: str, model_credential: Dict[str, object],
model_params: Dict[str, object], provider: Any, raise_exception: bool) -> bool:
model_type_list = provider.get_model_type_list()
if not any(mt['value'] == model_type for mt in model_type_list):
raise_app_api_exception(ValueError(f'Model type {model_type} is not supported'))
missing_keys = ['api_base', 'api_key']
if [key for key in missing_keys if key not in model_credential]:
missing_fields = ', '.join(missing_keys)
if raise_exception:
raise_app_api_exception(AppApiException(ValidCode.valid_error.value, f'{missing_fields} fields are required'))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
response = model.stream([HumanMessage(content={"type": "text", "text": "Hello"})])
for chunk in response:
print(chunk)
except Exception as e:
log_warning(f'Failed to validate API credentials: {e}')
if isinstance(e, AppApiException):
raise e
if raise_exception:
extra_msg = str(e)
err_code = ValidCode.valid_error.value if extra_msg.startswith('valid_') else None
raise_app_api_exception(
AppApiException(err_code or ExtraErrorCode.invalid_api_request.value,
msg=f'Streaming failed: {extra_msg}' if err_code != ValidCode.valid_error.value else f'Failed to authenticate API request'))
else:
return False
log_info(f'Credentials validated successfully')
return True
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, Any]:
encrypted_api_key = super().encryption(model.get('api_key', ""))
return {**model, 'api_key': encrypted_api_key}
@staticmethod
def get_model_params_setting_form(model_name) -> AnthropicImageModelParams:
params = AnthropicImageModelParams()
# Optional initialization logic could go here
return params
log_info = lambda *args: print(*args, sep=' ')
logWarning = lambda *args: print(*args, level='WARNING', sep=' ')
def raise_app_api_exception(exception):
# Implement exception raising with proper logging and formatting
pass Key Changes:
These changes should help enhance the robustness and maintainability of your codebase. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided code seems to be implementing an authentication system for a third-party service, likely using Anthropic's API. Here's a breakdown of potential issues and suggestions: Potential Issues
Optimization Suggestions
Here's a revised version snippet that includes some suggested changes: @@ -103,7 +103,7 @@
return super().get_model_params_setting_form(model_name)
def encrypt_text(text):
# Placeholder for actual password encryption function
return base64.b64encode(text.encode()).decode()
class AnthropicImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True, validators=[lambda x: len(x) > 0])
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_types_list = provider.get_model_type_list() Remember these improvements should align with your overall security and development practices. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: MaxKB | ||
@Author:虎 | ||
@file: llm.py | ||
@date:2024/7/11 18:32 | ||
@desc: | ||
""" | ||
from typing import Dict | ||
|
||
from langchain_core.messages import HumanMessage | ||
|
||
from common import forms | ||
from common.exception.app_exception import AppApiException | ||
from common.forms import BaseForm, TooltipLabel | ||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
||
|
||
class AnthropicLLMModelParams(BaseForm): | ||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), | ||
required=True, default_value=0.7, | ||
_min=0.1, | ||
_max=1.0, | ||
_step=0.01, | ||
precision=2) | ||
|
||
max_tokens = forms.SliderField( | ||
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), | ||
required=True, default_value=800, | ||
_min=1, | ||
_max=100000, | ||
_step=1, | ||
precision=0) | ||
|
||
|
||
class AnthropicLLMModelCredential(BaseForm, BaseModelCredential): | ||
|
||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, | ||
raise_exception=False): | ||
model_type_list = provider.get_model_type_list() | ||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') | ||
|
||
for key in ['api_base', 'api_key']: | ||
if key not in model_credential: | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') | ||
else: | ||
return False | ||
try: | ||
model = provider.get_model(model_type, model_name, model_credential) | ||
model.invoke([HumanMessage(content='你好')]) | ||
except Exception as e: | ||
if isinstance(e, AppApiException): | ||
raise e | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') | ||
else: | ||
return False | ||
return True | ||
|
||
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
||
api_base = forms.TextInputField('API 域名', required=True) | ||
api_key = forms.PasswordInputField('API Key', required=True) | ||
|
||
def get_model_params_setting_form(self, model_name): | ||
return AnthropicLLMModelParams() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided code looks mostly clean, but there are a few areas that could be improved:
Here’s an updated version with these considerations: # coding=utf-8
"""
@project: MaxKB
@author: Tiger
@file: llm.py
@date:2024/7/11 18:32
@desc:
"""
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class AnthropicLLMModelParams(BaseForm):
temperature = forms.SliderField(
tooltip_label="Temperature",
description='A higher value will produce more random responses while a lower value produces more focused ones.',
required=True,
default_value=0.7,
min_value=0.1,
max_value=1.0,
step=0.01,
precision=2
)
max_tokens = forms.SliderField(
tooltip_label="Output Maximum Tokens",
description='Specify the maximum number of tokens that can be generated by the model.',
required=True,
default_value=800,
min_value=1,
max_value=100000,
step=1,
precision=0
)
class AnthropicLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False) -> bool:
model_type_list = provider.get_model_type_list()
if not any(mt['value'] == model_type for mt in model_type_list):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
mandatory_fields = ['api_base', 'api_key']
missing_fields = [field for field in mandatory_fields if field not in model_credential]
if missing_fields:
message = ', '.join(f"'{field}' 字段" for field in missing_fields)
error_code = ValidCode.valid_error.value
if raise_exception:
raise AppApiException(error_code, f'{message} 必须填写')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
elif raise_exception:
error_msg = str(e).strip() or '未知错误'
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {error_msg}')
else:
return False
return True
def encryption_dict(self, model: dict) -> dict:
return {**model, 'api_key': super().encrypt(model.get('api_key', ''))}
api_base = forms.TextInputField("API 域名", required=True)
api_key = forms.PasswordInputField("API Key", required=True)
def get_model_params_setting_form(self, model_name) -> AnthropicLLMModelParams:
return AnthropicLLMModelParams() Key Changes:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code you've shared appears to be a Python file that defines a form for configuring LLM (Language Model) parameters using LangChain's
Here’s a refined version of the code with these suggestions incorporated: @@ -2,6 +2,7 @@
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
logger = logging.getLogger(__name__)
class AnthropicLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)
class AnthropicLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
logger.debug("Validating Anthropic LLm credentials.")
logger.info(f"Checking supported model types.")
supported_models = [mt['value'] for mt in provider.get_model_type_list()]
if model_type not in supported_models:
return False
logger.info("Verifying API base and API key.")
keys_to_check = ['api_base', 'api_key']
missing_fields = [key for key in keys_to_check if key not in model_credential]
if missing_fields and raise_exception:
logger.error("Missing required fields: %s", ', '.join(missing_fields))
raise AppApiException(ValidCode.valid_error.value, f'Missing required fields: {", ".join(missing_fields)}')
try:
logger.info("Attempting to authenticate with model")
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
return True
except Exception as e:
logger.error("An error occurred during validation: %s", str(e))
if isinstance(e, AppApiException):
return False
elif raise_exception:
logger.error("Raising valid error: %s", str(e))
raise AppApiException(ValidCode.valid_error.value, f'Check failed: {str(e)}')
else:
return False
finally:
logger.debug("Validation completed.")
def encryption_dict(self, model: Dict[str, object]):
logger.info("Encrypting API key.")
encrypted_api_key = super().encryption(model.get('api_key', ''))
return {**model, 'api_key': encrypted_api_key}
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def get_model_params_setting_form(self, model_name):
logger.info("Fetching parameter settings for model '%s'", model_name)
params_class = getattr(locals(), f'AnthropicLLMModelParams_{model_name}', None)
if params_class:
logger.debug("Using existing param settings form for model '%s'", model_name)
return params_class()
# Placeholder for new model-specific parameters
return AnthropicLLMModelParams() Key Changes Made:
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
<svg xmlns="http://www.w3.org/2000/svg" shape-rendering="geometricPrecision" text-rendering="geometricPrecision" image-rendering="optimizeQuality" fill-rule="evenodd" clip-rule="evenodd" viewBox="0 0 512 512"><rect fill="#CC9B7A" width="512" height="512" rx="104.187" ry="105.042"/><path fill="#1F1F1E" fill-rule="nonzero" d="M318.663 149.787h-43.368l78.952 212.423 43.368.004-78.952-212.427zm-125.326 0l-78.952 212.427h44.255l15.932-44.608 82.846-.004 16.107 44.612h44.255l-79.126-212.427h-45.317zm-4.251 128.341l26.91-74.701 27.083 74.701h-53.993z"/></svg> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import Dict | ||
|
||
from langchain_anthropic import ChatAnthropic | ||
|
||
from common.config.tokenizer_manage_config import TokenizerManage | ||
from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
|
||
|
||
def custom_get_token_ids(text: str): | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return tokenizer.encode(text) | ||
|
||
|
||
class AnthropicImage(MaxKBBaseModel, ChatAnthropic): | ||
|
||
@staticmethod | ||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) | ||
return AnthropicImage( | ||
model=model_name, | ||
anthropic_api_url=model_credential.get('api_base'), | ||
anthropic_api_key=model_credential.get('api_key'), | ||
# stream_options={"include_usage": True}, | ||
streaming=True, | ||
**optional_params, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: maxkb | ||
@Author:虎 | ||
@file: llm.py | ||
@date:2024/4/18 15:28 | ||
@desc: | ||
""" | ||
from typing import List, Dict | ||
|
||
from langchain_anthropic import ChatAnthropic | ||
from langchain_core.messages import BaseMessage, get_buffer_string | ||
|
||
from common.config.tokenizer_manage_config import TokenizerManage | ||
from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
|
||
|
||
def custom_get_token_ids(text: str): | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return tokenizer.encode(text) | ||
|
||
|
||
class AnthropicChatModel(MaxKBBaseModel, ChatAnthropic): | ||
|
||
@staticmethod | ||
def is_cache_model(): | ||
return False | ||
|
||
@staticmethod | ||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) | ||
azure_chat_open_ai = AnthropicChatModel( | ||
model=model_name, | ||
anthropic_api_url=model_credential.get('api_base'), | ||
anthropic_api_key=model_credential.get('api_key'), | ||
**optional_params, | ||
custom_get_token_ids=custom_get_token_ids | ||
) | ||
return azure_chat_open_ai | ||
|
||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: | ||
try: | ||
return super().get_num_tokens_from_messages(messages) | ||
except Exception as e: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) | ||
|
||
def get_num_tokens(self, text: str) -> int: | ||
try: | ||
return super().get_num_tokens(text) | ||
except Exception as e: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return len(tokenizer.encode(text)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code appears to be incomplete and contains several issues that need attention:
Imports and Dependencies: The code imports
os
,get_file_content
from another module (common.util.file_util
), but does not include the actual implementation of these functions or where they should be imported from.File Content Retrieval: The
get_file_content
function is used to retrieve an icon file, but it doesn't specify how this content will be handled. Ensure that this function returns a valid string containing the image data.Missing Models Class Definitions: The
AnthropicChatModel
andAnthropicImage
classes are assumed to exist and implement specific interfaces within other modules. You may need to define these classes if they do not already exist.Default Settings Configuration: While most details seem correct regarding default models and providers, ensure that all necessary configuration settings are properly specified and accessible at runtime.
Code Consistency and Readability: Some lines have extra spaces which can make them harder to read. Consider aligning variable names and using consistent indentation for better readability.
Here's a general outline of what you might want to address with additional context assuming you have access to those missing components:
This structure provides a foundation that requires further clarification on class implementations. Let me know if you need more specific guidance!