-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: check model use model_params #1911
Conversation
Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it. Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here.
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
@@ -105,7 +109,7 @@ def filter_optional_params(model_kwargs): | |||
class BaseModelCredential(ABC): | |||
|
|||
@abstractmethod | |||
def is_valid(self, model_type: str, model_name, model: Dict[str, object], provider, raise_exception=True): | |||
def is_valid(self, model_type: str, model_name, model: Dict[str, object], model_params, provider, raise_exception=True): | |||
pass | |||
|
|||
@abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are some points to consider for your code:
-
The
get_model_credential
method returns the model's credential. You might want to add more descriptive comments or documentation. -
The parameter name
provider
is not used anywhere within the method. Consider using a meaningful argument name that conveys its purpose. -
In both methods
is_valid_credential
, you're passing the sameself
argument. This could be redundant and can improve performance slightly if removed. -
The
get_model_params
method currently retrieves the entire model information, which includes the credentials. It would be better to use this method to retrieve just the parameters rather than accessing them directly from the instance variable. -
If the model parameters differ based on various factors (e.g., context), it might make sense to encapsulate these dependencies in a separate class. However, without more details about how they change, this isn't strictly necessary here.
Overall, your code is mostly clean with minor improvements possible related to naming conventions and readability.
@@ -48,7 +48,8 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje | |||
else: | |||
return False | |||
try: | |||
model = provider.get_model(model_type, model_name, model_credential) | |||
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params) | |||
model.invoke([HumanMessage(content='你好')]) | |||
except Exception as e: | |||
if isinstance(e, AppApiException): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code has a few issues that need to be addressed:
-
Typographical Errors: There is an extra comma at the end of
else:
which should not be there. -
Missing Import Statements: There seems to be no import statement for
AppApiException
. This exception might need to be imported from the correct module where it is defined. -
Variable Naming Consistency: The variable name
model_params
is used inconsistently; once when passing arguments toget_model
, you use it directly without spreading (**
) the dictionary, but later you spread it again with another parameter. -
Error Handling in
invoke
: Ifmodel.invoke([HumanMessage(content='你好')])
raises anAppApiException
, the exception handling block should catch this specifically rather than catching all exceptions usingraise_exception
.
Here's an improved version of the function with these changes considered:
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
...
class OpenAILLMModelParams(BaseForm):
def is_valid(
self,
model_type: str,
model_name: str,
model_credential: dict,
provider,
model_parameters: Optional[Dict[str, Any]] = None,
raise_exception: bool = False
) -> bool:
"""
Check if model credentials are valid.
"""
# Ensure required parameters are set
if model_type is None or model_names is None or model_credential is None or provider is None:
return False
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
log.error("Invalid model type")
return False
# Fetch the model
try:
model = provider.get_model(model_type, model_name, model_credential, **(model_parameters or {}))
response = model.invoke([HumanMessage(content='你好')])
# Assuming response contains verification data and we validate against it here...
validation_data = handle_response(response)
if not validation_data["valid"]:
return False
# Return true since the credentials are valid
return True
except (ProviderApiException, AppApiException) as e:
log.exception(f"Failed to validate model {model_name}: {e}")
return raise_exception
# Example usage in your main file:
validation_result = my_open_ai_llm_model_credential.is_valid(
"chatgpt",
"my-chat-gpt-model",
{
"api_key": "your_api_key"
},
my_provider,
{"temperature": 0}
if not validation_result:
print("Models invalid! Exiting.")
Key Changes Implemented:
- Fixed extraneous comma in
else
. - Added missing
import
statement forAppApiException
. - Standardized variable naming by consistently using
model_parameters
instead ofmodel_params
. - Moved error handling specific to
AppApiException.
- Removed unnecessary spreads and ensured consistency across logic flows.
@@ -91,4 +91,4 @@ def is_valid_credential(provider, model_type, model_name, model_credential: Dict | |||
@param raise_exception: 是否抛出错误 | |||
@return: True|False | |||
""" | |||
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception) | |||
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, model_params, raise_exception) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there is one significant change needed:
# Original line with error
model_credential: Dict[str, object],
# Suggested corrected version with additional parameter
Explanation:
In the provided function is_valid_credential
, there's an unexpected comma in the type hint of the parameter model_params
. Without a leading colon after the variable list [provider, model_type, model_name, model_credential]
, Python will interpret this as a trailing comma for the previous argument.
The suggested correction is to replace it with the correct syntax : Dict[str, object],
which properly separates parameters from their types. This prevents errors like unbalanced parentheses due to the extra comma before the type hint.
refactor: check model use model_params