Skip to content

Commit

Permalink
refactor: check model use model_params
Browse files Browse the repository at this point in the history
  • Loading branch information
liuruibin committed Dec 25, 2024
1 parent 628cf70 commit 6412825
Show file tree
Hide file tree
Showing 66 changed files with 116 additions and 107 deletions.
4 changes: 2 additions & 2 deletions apps/setting/models_provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_model_type_list(provider):
return get_provider(provider).get_model_type_list()


def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], model_params, raise_exception=False):
"""
校验模型认证参数
@param provider: 供应商字符串
Expand All @@ -91,4 +91,4 @@ def is_valid_credential(provider, model_type, model_name, model_credential: Dict
@param raise_exception: 是否抛出错误
@return: True|False
"""
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception)
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, model_params, raise_exception)
10 changes: 7 additions & 3 deletions apps/setting/models_provider/base_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,13 @@ def get_model_credential(self, model_type, model_name):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential

def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
def get_model_params(self, model_type, model_name):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential.is_valid(model_type, model_name, model_credential, self,
return model_info.model_credential

def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], model_params: Dict[str, object], raise_exception=False):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential.is_valid(model_type, model_name, model_credential, model_params, self,
raise_exception=raise_exception)

def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
Expand Down Expand Up @@ -105,7 +109,7 @@ def filter_optional_params(model_kwargs):
class BaseModelCredential(ABC):

@abstractmethod
def is_valid(self, model_type: str, model_name, model: Dict[str, object], provider, raise_exception=True):
def is_valid(self, model_type: str, model_name, model: Dict[str, object], model_params, provider, raise_exception=True):
pass

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class QwenModelParams(BaseForm):

class QwenVLModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -49,7 +49,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
for chunk in res:
print(chunk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BaiLianLLMModelParams(BaseForm):

class BaiLianLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -41,7 +41,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
if not model_type == 'RERANKER':
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField("API Key", required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class QwenModelParams(BaseForm):

class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -73,7 +73,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
print(res)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class AliyunBaiLianTTSModelGeneralParams(BaseForm):
class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField("API Key", required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -58,7 +58,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
with open(credentials_path, 'w') as file:
file.write(content)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
with open(credentials_path, 'w') as file:
file.write(content)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
Expand All @@ -62,7 +62,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
model_credential['secret_access_key'])
model_credential['credentials_profile_name'] = 'aws-profile'
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content='你好')])
except AppApiException:
raise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
api_key = forms.PasswordInputField("API Key (api_key)", required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -46,7 +46,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
for chunk in res:
print(chunk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class AzureLLMModelParams(BaseForm):

class AzureLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -48,7 +48,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
api_key = forms.PasswordInputField("API Key (api_key)", required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
api_key = forms.PasswordInputField("API Key (api_key)", required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -64,7 +64,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
print(res)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
api_key = forms.PasswordInputField("API Key (api_key)", required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -41,7 +41,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class DeepSeekLLMModelParams(BaseForm):

class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -48,7 +48,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class GeminiEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=True):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class GeminiImageModelParams(BaseForm):
class GeminiImageModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -44,7 +44,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
for chunk in res:
print(chunk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class GeminiLLMModelParams(BaseForm):

class GeminiLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -48,7 +48,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.invoke([HumanMessage(content='你好')])
print(res)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class GeminiSTTModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand Down
Loading

0 comments on commit 6412825

Please sign in to comment.