Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: check model use model_params #1911

Merged
merged 1 commit into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/setting/models_provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_model_type_list(provider):
return get_provider(provider).get_model_type_list()


def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], model_params, raise_exception=False):
"""
校验模型认证参数
@param provider: 供应商字符串
Expand All @@ -91,4 +91,4 @@ def is_valid_credential(provider, model_type, model_name, model_credential: Dict
@param raise_exception: 是否抛出错误
@return: True|False
"""
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception)
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, model_params, raise_exception)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is one significant change needed:

# Original line with error
model_credential: Dict[str, object],

# Suggested corrected version with additional parameter

Explanation:
In the provided function is_valid_credential, there's an unexpected comma in the type hint of the parameter model_params. Without a leading colon after the variable list [provider, model_type, model_name, model_credential], Python will interpret this as a trailing comma for the previous argument.

The suggested correction is to replace it with the correct syntax : Dict[str, object], which properly separates parameters from their types. This prevents errors like unbalanced parentheses due to the extra comma before the type hint.

10 changes: 7 additions & 3 deletions apps/setting/models_provider/base_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,13 @@ def get_model_credential(self, model_type, model_name):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential

def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
def get_model_params(self, model_type, model_name):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential.is_valid(model_type, model_name, model_credential, self,
return model_info.model_credential

def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], model_params: Dict[str, object], raise_exception=False):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
return model_info.model_credential.is_valid(model_type, model_name, model_credential, model_params, self,
raise_exception=raise_exception)

def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
Expand Down Expand Up @@ -105,7 +109,7 @@ def filter_optional_params(model_kwargs):
class BaseModelCredential(ABC):

@abstractmethod
def is_valid(self, model_type: str, model_name, model: Dict[str, object], provider, raise_exception=True):
def is_valid(self, model_type: str, model_name, model: Dict[str, object], model_params, provider, raise_exception=True):
pass

@abstractmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some points to consider for your code:

  1. The get_model_credential method returns the model's credential. You might want to add more descriptive comments or documentation.

  2. The parameter name provider is not used anywhere within the method. Consider using a meaningful argument name that conveys its purpose.

  3. In both methods is_valid_credential, you're passing the same self argument. This could be redundant and can improve performance slightly if removed.

  4. The get_model_params method currently retrieves the entire model information, which includes the credentials. It would be better to use this method to retrieve just the parameters rather than accessing them directly from the instance variable.

  5. If the model parameters differ based on various factors (e.g., context), it might make sense to encapsulate these dependencies in a separate class. However, without more details about how they change, this isn't strictly necessary here.

Overall, your code is mostly clean with minor improvements possible related to naming conventions and readability.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class QwenModelParams(BaseForm):

class QwenVLModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -49,7 +49,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
for chunk in res:
print(chunk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BaiLianLLMModelParams(BaseForm):

class BaiLianLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -41,7 +41,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
if not model_type == 'RERANKER':
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField("API Key", required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class QwenModelParams(BaseForm):

class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -73,7 +73,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
print(res)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class AliyunBaiLianTTSModelGeneralParams(BaseForm):
class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField("API Key", required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -58,7 +58,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
with open(credentials_path, 'w') as file:
file.write(content)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
with open(credentials_path, 'w') as file:
file.write(content)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
Expand All @@ -62,7 +62,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
model_credential['secret_access_key'])
model_credential['credentials_profile_name'] = 'aws-profile'
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content='你好')])
except AppApiException:
raise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
api_key = forms.PasswordInputField("API Key (api_key)", required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -46,7 +46,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
for chunk in res:
print(chunk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class AzureLLMModelParams(BaseForm):

class AzureLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -48,7 +48,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
api_key = forms.PasswordInputField("API Key (api_key)", required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
api_key = forms.PasswordInputField("API Key (api_key)", required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -64,7 +64,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
print(res)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True)
api_key = forms.PasswordInputField("API Key (api_key)", required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -41,7 +41,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class DeepSeekLLMModelParams(BaseForm):

class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -48,7 +48,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class GeminiEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=True):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class GeminiImageModelParams(BaseForm):
class GeminiImageModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -44,7 +44,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
for chunk in res:
print(chunk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class GeminiLLMModelParams(BaseForm):

class GeminiLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand All @@ -48,7 +48,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.invoke([HumanMessage(content='你好')])
print(res)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class GeminiSTTModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
Expand Down
Loading
Loading