-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support gemini embedding model #1877
Conversation
Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it. Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here.
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
||
api_key = forms.PasswordInputField('API Key', required=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code provided appears to be a part of an API interface for managing model credentials. It includes validation checks for Gemini Embedding models and encryption methods.
Irregularities:
-
File Encoding: The file encoding is specified at the beginning (
# coding=utf-8
), which is good practice, but it should include all supported Unicode characters (UTF-32) since UTF-8 can only represent basic multilingual plane characters. -
Variable Name Conflicts: There seems to be a conflict between the class variable
api_key
and the form field nameapi_key
. This might lead to confusion or incorrect data handling in certain scenarios. -
Exception Handling:
- If
raise_exception
is set toFalse
, there's no return value indicating whether the input was valid or not. Consider adding some feedback on this path.
- If
-
Model Type List Check: The list comprehension used in
model_type_list
filtering could be more efficient if using Python 3.10+ with structural pattern matching. -
Error Message Formatting: Some error messages use curly braces
{}
instead of regular double quotes"
, which can make debugging harder. -
Missing Import Statements: Ensure that
BaseModelCredential
is properly imported from another module within the same package. Without knowing the exact structure, this cannot be guaranteed. -
Class Inheritance Misalignment: The class
GeminiEmbeddingCredential
inherits from bothforms.BaseForm
andBaseModelCredential
. This might cause conflicts in method names or other aspects like initialization logic.
Improvements:
-
Clear Variable Names: Rename conflicting variables to avoid future confusion, such as changing
api_key
togemini_api_key
. -
Add Return Values for Non-Exceptional Cases: Provide meaningful return values when
raise_exception=False
to indicate success or failure. -
Enhance Error Messages: Use consistent quoting around strings in error messages for better readability.
# ... rest of the code ...
Here’s an example of how you might modify these points:
@@ -0,0 +1,45 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: embedding.py
+ @date:2024/7/12 16:45
+ @desc:
+"""
+from typing import Dict
+
+from common import forms
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
+
+
+class GeminiEmbeddingCredential(BaseForm, BaseModelCredential):
+ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
+ raise_exception=True) -> bool:
+ model_types = [mt['name'] for mt in provider.get_model_type_list()]
+
+ if model_type.upper() not in self.valid_models():
+ err_msg = f"{model_type} 模型类型不支持"
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, err_msg)
+ return False
missing_fields = []
required_fields = {'api_key'}
for key in required_fields:
if key not in model_credential:
missing_fields.append(key)
if missing_fields:
err_msg = "必填字段不存在" if len(missing_fields) == 1 else "以下字段缺失:" + ", ".join(missing_fields)
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, er_msg)
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query("你好")
except AppApiException as app_exc:
if raise_exception:
raise app_exc
except Exception as ex:
if raise_exception:
if isinstance(ex, AppApiException):
raise ex
elif hasattr(ex,"message"):
message=ex.message
else:
message=str(ex)
raise AppApiException(ValidCode.valid_error.value,f'校验失败,请检查参数是否正确: {message}')
return True
def valid_models(self)->list:
self.model_types=["GEMINI_EMBEDDING","OTHER_TYPE"]
return self.model_types
def encryption_dict(self,model:Dict[str,object]) -> dict:
return{**model,'secretKey':super().encryption(getattr(model,'secretKey',''))}
secretKey=forms.PasswordInputField('Secret Key',required=False)
This cleaned-up version adheres to best practices regarding clear naming conventions, proper exception handling, and clearer error messaging. Also updated the model types list checking process and added a custom method to validate available models.
return GoogleGenerativeAIEmbeddings( | ||
google_api_key=model_credential.get('api_key'), | ||
model=model_name, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code looks mostly good, but there are a few points to consider for improvement:
-
Typographical Errors: There is an unexpected space before the
GoogleGenerativeAIEmbeddings
class definition in line 16. -
Class Documentation: The docstring for the
GeminiEmbeddingModel
class lacks specific details or examples on its purpose and usage. -
Default Values: Ensure that the default values for model type and name within the
new_instance
method are defined if they're required by the base class or other parameters.
Here's a revised version of the code with these considerations addressed:
# coding=utf-8
"""
@project: MaxKB
@author:虎
@file: embedding.py
@date:2024/7/12 17:44
@desc:
"""
from typing import Dict
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class GeminiEmbeddingModel(MaxKBBaseModel, GoogleGenerativeAIEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> 'GeminiEmbeddingModel':
if not (model_credentials := model_credential.get('api_key')):
raise ValueError("API key must be provided.")
return GeminiEmbeddingModel(
google_api_key=model_credentials,
model=model_name,
**model_kwargs
)
Changes Made:
- Fixed the typo in the Google Generative AI embeddings class definition.
- Added a docstring explaining the purpose and usage of the
GeminiEmbeddingModel
. - Ensured that the API key is validated and checked early in the function for robustness.
- Explicitly returning an instance of
GeminiEmbeddingModel
.
.append_default_model_info(model_info_list[0]) | ||
.append_default_model_info(model_image_info_list[0]) | ||
.append_default_model_info(model_stt_info_list[0]) | ||
.append_default_model_info(model_embedding_info_list[0]) | ||
.build() | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's a checklist of issues and improvements:
Issues:
-
Duplicate
build()
call: There seems to be an extrabuild()
method invocation inside another block. -
Redundant credentials and models: The same credentials (
gemini_llm_model_credential
, etc.) are reused for multiple types (GeminiChatModel
,GeminiSpeechToText
), which might lead to confusion or errors if additional configurations are needed for different types. -
Missing import statements: Ensure that all necessary classes are imported correctly at the top.
-
Inconsistent indentation: Check that all indents are consistent throughout the file.
-
Logical consistency: Verify that the order of model information addition is logical (e.g., LLMs before images) as this could affect how settings load.
Optimizations and Improvements:
-
Single
Credential
Instance per Model: If each type has distinct credentials, consider merging them into a single instance with appropriate differentiation logic. -
Separate Credential Classes: If possible, use separate credential classes tailored specifically to each model type or add dynamic behavior through class methods.
-
Remove Unused Code Blocks: Review redundant blocks like the initial comment at -11, and remove unused imports.
-
Consistent Error Handling: Implement proper error handling mechanisms to manage missing dependencies or invalid data.
-
Testing Coverage: Add tests to ensure that all functionalities work as expected after making changes.
-
Refactor Method Calls: Simplify complex calls by breaking down operations where appropriate.
-
Documentation: Update comments and docstrings to describe the purpose and functionality of code components.
-
Version Control: Use version control to track changes made during reviews.
By addressing these items, you can enhance the readability, maintainability, and overall quality of your codebase.
e894dca
to
3524162
Compare
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
||
api_key = forms.PasswordInputField('API Key', required=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code seems to be an initial implementation of a class GeminiEmbeddingCredential
which validates embeddings credentials for a service named 'MaxKB'. The code appears well-structured with appropriate imports and methods. Here are some observations and suggestions:
Observations:
-
Comments: There's a comment at the top indicating the project name, author, file name, and date. This helps maintain clarity.
-
Form Validation: The method
is_valid()
checks several conditions including whether themodel_type
is supported by the provider and that all necessary credentials (like API keys) are present. -
Error Handling: Exception handling within
is_valid()
catches and raisesAppApiException
for specific cases such as unsupported model type and invalid parameters. It also includes error messages detailing what went wrong. -
Model Initialization: Inside
is_valid()
, there's a try-except block to initialize the model usingprovider.get_model
. If initialization fails due to an exception other thanAppApiException
, it logs the exception message but doesn't raise unless explicitly instructed otherwise (raise_exception=True
). -
Encryption Method: The
encryption_dict()
method returns a dictionary where the'api_key'
field has been encrypted usingsuper().encryption
.
Optimization Suggestions:
-
String Constants: Use string constants instead of concatenating strings inside functions like
f'{key} 字段为必填字段'
. This can make the code more readable and easier to maintain.from typing import Dict from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode INVALID_FIELD_MSG = "字段为必填字段" VALIDATION_FAILED_MSG = "校验失败,请检查参数是否正确" class GeminiEmbeddingCredential(BaseForm, BaseModelCredential): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, raise_exception=True): model_type_list = provider.get_model_type_list() if not any(filter(lambda mt: mt.get('value') == model_type, model_type_list)): raise AppApiException(ValidCode.valid_error.value, f"{model_type} 模型类型不支持") for key in ['api_key']: if key not in model_credential: if raise_exception: raise AppApiException(ValidCode.valid_error.value, f"{INVALID_FIELD_MSG}") else: return False try: model = provider.get_model(model_type, model_name, model_credential) model.embed_query("你好") except Exception as e: if isinstance(e, AppApiException): raise e elif raise_exception: raise AppApiException(ValidCode.valid_error.value, f"{VALIDATION_FAILED_MSG}: {str(e)}") else: return False return True def encryption_dict(self, model: Dict[str, object]): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} api_key = forms.PasswordInputField("API Key", required=True)
-
Docstrings: Add docstrings to classes, methods, and fields to describe their functionality. This enhances readability and makes the code reusable across different parts of an application.
-
Parameter Naming: Ensure consistent parameter naming conventions. In this case,
raise_exception=False
works without needing explanation, but better documentation might improve clarity. -
Initialization: Ensure that all variables are initialized correctly. While not present here, always ensure objects are properly instantiated before use to avoid runtime errors.
These suggestions focus on making the code more robust, clear, and maintainable.
) | ||
|
||
def is_cache_model(self): | ||
return False No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code appears to be mostly correct but can benefit from some additional improvements:
-
File Encoding: The file's encoding comment (
#coding=utf-8
) is at the wrong position and might cause issues with certain IDEs or editors. Place it at the top of the file. -
Whitespace and Indentation: Ensure consistent indentation (use 4 spaces per level).
-
Docstring Formatting: Add more details in the docstrings about what each class method does.
-
Optional Parameters: Use
**kwargs
if there are optional parameters that the sub-class might override.
Here’s a revised version:
# coding=utf-8
"""
@project: MaxKB
@author: Tiger
@file: embedding.py
@date: 2024/7/12 17:44
@desc:
"""
from typing import Dict, Optional
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from settings.models_provider.base_model_provider import MaxKBBaseModel
class GeminiEmbeddingModel(MaxKBBaseModel, GoogleGenerativeAIEmbeddings):
@staticmethod
def new_instance(
model_type: str,
model_name: str,
model_credential: Dict[str, object],
api_key: Optional[str] = None,
use_cache: bool = True,
timeout_seconds: int = 60,
client_max_retries: int = 5,
request_params: Optional[Dict[str, object]] = None,
model_parameters: Optional[Dict[str, object]] = None,
retry_strategy: Optional['RetryStrategy'] = None,
**model_kwargs
) -> 'GoogleGenerativeAIEmbeddings':
"""
Initialize a new instance of GeminiEmbeddingModel.
:param model_type: Type of the model
:param model_name: Name of the model
:param model_credential: Credential dictionary containing keys like 'api_key'
:param api_key: API key for authentication
:param use_cache: Whether to cache embeddings (default=True)
:param timeout_seconds: Timeout duration for requests (default=60 seconds)
:param client_max_retries: Maximum number of retries for the HTTP client (default=5 times)
:param request_params: Additional request parameters
:param model_parameters: Model-specific parameters
:param retry_strategy: Retry strategy configuration
:return: Initialized GoogleGenerativeAIEmbeddings object
"""
return super(GeminiEmbeddingModel, cls).new_instance(
google_api_key=model_credential.get('api_key', api_key),
model=model_name,
use_cache=use_cache,
timeout=timeout_seconds,
max_retries=client_max_retries,
request_params=request_params,
model_parameters=model_parameters,
retry_strategy=retry_strategy,
verbose=False
)
def initialize(self):
"""Initialize the EmbeddingModel."""
# Override this method if necessary
def delete_from_cache_by_id(self, id):
"""Delete a document from the cache by its ID."""
# Implement specific cache deletion logic here
def clear_all_embeddings_cache(self, user_id=None):
"""Clear all cached embeddings."""
# Implement cache clearing logic here
def get_embedding(self, documents) -> list:
"""Get embedings for multiple documents."""
# Override the base implementation to support batch operation
def save_embedding(self, texts_list, embeddings_list=None, ids_list=None, meta_data_list=None):
"""Save embeddings related to individual text entries."""
pass
def update_embedding(self, texts_list, embeddings_list=None, ids_list=None, meta_data_list=None):
"""Update/embeddings related to individual text entries."""
pass
def update_text_embedding(self, index_path, content="", metadata={}, overwrite_index_content=True, force_create_index=True):
"""Update / create a single piece of information in an existing dataset using custom data source."""
pass
Key Changes:
- Moved Encoding Comment: Placed the UTF-8 code comment at the top of the file.
- Consistent Indentation: Ensured proper indentation throughout the code.
- Docstring Details: Added descriptions to static methods with optional parameters.
- Use of
Optional
Types: Utilized Python'sOptional
type hinting where appropriate for better flexibility. - Override Methods: Added placeholder implementations for
initialize
,delete_from_cache_by_id
, etc., allowing subclasses to expand upon these functionalities.
These modifications enhance readability, maintainability, and usability of the provided codebase.
.append_default_model_info(model_info_list[0]) | ||
.append_default_model_info(model_image_info_list[0]) | ||
.append_default_model_info(model_stt_info_list[0]) | ||
.append_default_model_info(model_embedding_info_list[0]) | ||
.build() | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code looks generally correct, but there are a few places where optimizations or improvements can be made. Here's a review:
Optimizations and Improvements
-
Variable Naming:
- Ensure consistent naming conventions for variables (e.g., use of
_
in names that indicate "internal" data). - Consider using descriptive variable names instead of single characters.
- Ensure consistent naming conventions for variables (e.g., use of
-
Code Readability:
- Group related imports at the top.
- Add docstrings to classes and functions if they are not already present.
-
Error Handling:
- Implement error handling around file operations and model initialization if necessary.
-
Comments:
- Add comments to explain complex logic or sections of code.
Here is an improved version with these considerations:
# Import Statements
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
class GeminiModelProvider:
def __init__(self):
# Initialize credentials
self.gemini_llm_model_credential = GeminiLLMModelCredential()
self.gemini_image_model_credential = GeminiImageModelCredential()
self.gemini_stt_model_credential = GeminiSTTModelCredential()
self.gemini_embedding_model_credential = GeminiEmbeddingCredential()
# Load models
self.load_models()
def load_models(self):
"""Load supported models based on configurations."""
pass # Implementation here
def get_model_info(self):
"""Return information about available models."""
return [
ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
ModelTypeConst.GENERAL,
lambda: GeminiChatModel(self.gemini_llm_model_credential)),
ModelInfo('image-gemmivision', '利用Gemmivision的图像生成服务',
ModelTypeConst.IMAGE_GENERATION,
lambda: GeminiImage(self.gemini_image_model_credential)),
ModelInfo('stt-google', '谷歌语音识别服务', ModelTypeConst.SPEECH_TO_TEXT, lambda: GeminiSpeechToText(
self.gemini_speech_to_text_model_credential
)),
ModelInfo('embedding-model', '', ModelTypeConst.EMBEDDING, lambda: GeminiEmbeddingModel(
self.gemini_embedding_model_credential
))
]
# Example usage
if __name__ == "__main__":
provider = GeminiModelProvider()
models = provider.get_model_info()
print(models)
Key Changes Made:
- Consistent class name (
GeminiModelProvider
) and method names for better readability and maintainability. - Added
load_models()
method placeholder for future implementation. - Simplified function calls within
get_model_info()
for brevity. - Used descriptive comment blocks for clarity.
This should make the code cleaner and potentially more maintainable over time.
feat: Support gemini embedding model