-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support gemini embedding model #1877
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: MaxKB | ||
@Author:虎 | ||
@file: embedding.py | ||
@date:2024/7/12 16:45 | ||
@desc: | ||
""" | ||
from typing import Dict | ||
|
||
from common import forms | ||
from common.exception.app_exception import AppApiException | ||
from common.forms import BaseForm | ||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
||
|
||
class GeminiEmbeddingCredential(BaseForm, BaseModelCredential): | ||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, | ||
raise_exception=True): | ||
model_type_list = provider.get_model_type_list() | ||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') | ||
|
||
for key in ['api_key']: | ||
if key not in model_credential: | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') | ||
else: | ||
return False | ||
try: | ||
model = provider.get_model(model_type, model_name, model_credential) | ||
model.embed_query('你好') | ||
except Exception as e: | ||
if isinstance(e, AppApiException): | ||
raise e | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') | ||
else: | ||
return False | ||
return True | ||
|
||
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
||
api_key = forms.PasswordInputField('API Key', required=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided code seems to be an initial implementation of a class Observations:
Optimization Suggestions:
These suggestions focus on making the code more robust, clear, and maintainable. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,9 +11,11 @@ | |
from common.util.file_util import get_file_content | ||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ | ||
ModelInfoManage | ||
from setting.models_provider.impl.gemini_model_provider.credential.embedding import GeminiEmbeddingCredential | ||
from setting.models_provider.impl.gemini_model_provider.credential.image import GeminiImageModelCredential | ||
from setting.models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential | ||
from setting.models_provider.impl.gemini_model_provider.credential.stt import GeminiSTTModelCredential | ||
from setting.models_provider.impl.gemini_model_provider.model.embedding import GeminiEmbeddingModel | ||
from setting.models_provider.impl.gemini_model_provider.model.image import GeminiImage | ||
from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel | ||
from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText | ||
|
@@ -22,6 +24,7 @@ | |
gemini_llm_model_credential = GeminiLLMModelCredential() | ||
gemini_image_model_credential = GeminiImageModelCredential() | ||
gemini_stt_model_credential = GeminiSTTModelCredential() | ||
gemini_embedding_model_credential = GeminiEmbeddingCredential() | ||
|
||
model_info_list = [ | ||
ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新', | ||
|
@@ -56,14 +59,23 @@ | |
GeminiSpeechToText), | ||
] | ||
|
||
model_embedding_info_list = [ | ||
ModelInfo('models/embedding-001', '', | ||
ModelTypeConst.EMBEDDING, | ||
gemini_embedding_model_credential, | ||
GeminiEmbeddingModel), | ||
] | ||
|
||
model_info_manage = ( | ||
ModelInfoManage.builder() | ||
.append_model_info_list(model_info_list) | ||
.append_model_info_list(model_image_info_list) | ||
.append_model_info_list(model_stt_info_list) | ||
.append_model_info_list(model_embedding_info_list) | ||
.append_default_model_info(model_info_list[0]) | ||
.append_default_model_info(model_image_info_list[0]) | ||
.append_default_model_info(model_stt_info_list[0]) | ||
.append_default_model_info(model_embedding_info_list[0]) | ||
.build() | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's a checklist of issues and improvements: Issues:
Optimizations and Improvements:
By addressing these items, you can enhance the readability, maintainability, and overall quality of your codebase. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided code looks generally correct, but there are a few places where optimizations or improvements can be made. Here's a review: Optimizations and Improvements
Here is an improved version with these considerations: # Import Statements
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
class GeminiModelProvider:
def __init__(self):
# Initialize credentials
self.gemini_llm_model_credential = GeminiLLMModelCredential()
self.gemini_image_model_credential = GeminiImageModelCredential()
self.gemini_stt_model_credential = GeminiSTTModelCredential()
self.gemini_embedding_model_credential = GeminiEmbeddingCredential()
# Load models
self.load_models()
def load_models(self):
"""Load supported models based on configurations."""
pass # Implementation here
def get_model_info(self):
"""Return information about available models."""
return [
ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
ModelTypeConst.GENERAL,
lambda: GeminiChatModel(self.gemini_llm_model_credential)),
ModelInfo('image-gemmivision', '利用Gemmivision的图像生成服务',
ModelTypeConst.IMAGE_GENERATION,
lambda: GeminiImage(self.gemini_image_model_credential)),
ModelInfo('stt-google', '谷歌语音识别服务', ModelTypeConst.SPEECH_TO_TEXT, lambda: GeminiSpeechToText(
self.gemini_speech_to_text_model_credential
)),
ModelInfo('embedding-model', '', ModelTypeConst.EMBEDDING, lambda: GeminiEmbeddingModel(
self.gemini_embedding_model_credential
))
]
# Example usage
if __name__ == "__main__":
provider = GeminiModelProvider()
models = provider.get_model_info()
print(models) Key Changes Made:
This should make the code cleaner and potentially more maintainable over time. |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: MaxKB | ||
@Author:虎 | ||
@file: embedding.py | ||
@date:2024/7/12 17:44 | ||
@desc: | ||
""" | ||
from typing import Dict | ||
|
||
from langchain_google_genai import GoogleGenerativeAIEmbeddings | ||
|
||
from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
|
||
|
||
class GeminiEmbeddingModel(MaxKBBaseModel, GoogleGenerativeAIEmbeddings): | ||
@staticmethod | ||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
return GoogleGenerativeAIEmbeddings( | ||
google_api_key=model_credential.get('api_key'), | ||
model=model_name, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided code looks mostly good, but there are a few points to consider for improvement:
Here's a revised version of the code with these considerations addressed: # coding=utf-8
"""
@project: MaxKB
@author:虎
@file: embedding.py
@date:2024/7/12 17:44
@desc:
"""
from typing import Dict
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class GeminiEmbeddingModel(MaxKBBaseModel, GoogleGenerativeAIEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> 'GeminiEmbeddingModel':
if not (model_credentials := model_credential.get('api_key')):
raise ValueError("API key must be provided.")
return GeminiEmbeddingModel(
google_api_key=model_credentials,
model=model_name,
**model_kwargs
) Changes Made:
|
||
|
||
def is_cache_model(self): | ||
return False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code appears to be mostly correct but can benefit from some additional improvements:
Here’s a revised version: # coding=utf-8
"""
@project: MaxKB
@author: Tiger
@file: embedding.py
@date: 2024/7/12 17:44
@desc:
"""
from typing import Dict, Optional
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from settings.models_provider.base_model_provider import MaxKBBaseModel
class GeminiEmbeddingModel(MaxKBBaseModel, GoogleGenerativeAIEmbeddings):
@staticmethod
def new_instance(
model_type: str,
model_name: str,
model_credential: Dict[str, object],
api_key: Optional[str] = None,
use_cache: bool = True,
timeout_seconds: int = 60,
client_max_retries: int = 5,
request_params: Optional[Dict[str, object]] = None,
model_parameters: Optional[Dict[str, object]] = None,
retry_strategy: Optional['RetryStrategy'] = None,
**model_kwargs
) -> 'GoogleGenerativeAIEmbeddings':
"""
Initialize a new instance of GeminiEmbeddingModel.
:param model_type: Type of the model
:param model_name: Name of the model
:param model_credential: Credential dictionary containing keys like 'api_key'
:param api_key: API key for authentication
:param use_cache: Whether to cache embeddings (default=True)
:param timeout_seconds: Timeout duration for requests (default=60 seconds)
:param client_max_retries: Maximum number of retries for the HTTP client (default=5 times)
:param request_params: Additional request parameters
:param model_parameters: Model-specific parameters
:param retry_strategy: Retry strategy configuration
:return: Initialized GoogleGenerativeAIEmbeddings object
"""
return super(GeminiEmbeddingModel, cls).new_instance(
google_api_key=model_credential.get('api_key', api_key),
model=model_name,
use_cache=use_cache,
timeout=timeout_seconds,
max_retries=client_max_retries,
request_params=request_params,
model_parameters=model_parameters,
retry_strategy=retry_strategy,
verbose=False
)
def initialize(self):
"""Initialize the EmbeddingModel."""
# Override this method if necessary
def delete_from_cache_by_id(self, id):
"""Delete a document from the cache by its ID."""
# Implement specific cache deletion logic here
def clear_all_embeddings_cache(self, user_id=None):
"""Clear all cached embeddings."""
# Implement cache clearing logic here
def get_embedding(self, documents) -> list:
"""Get embedings for multiple documents."""
# Override the base implementation to support batch operation
def save_embedding(self, texts_list, embeddings_list=None, ids_list=None, meta_data_list=None):
"""Save embeddings related to individual text entries."""
pass
def update_embedding(self, texts_list, embeddings_list=None, ids_list=None, meta_data_list=None):
"""Update/embeddings related to individual text entries."""
pass
def update_text_embedding(self, index_path, content="", metadata={}, overwrite_index_content=True, force_create_index=True):
"""Update / create a single piece of information in an existing dataset using custom data source."""
pass Key Changes:
These modifications enhance readability, maintainability, and usability of the provided codebase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code provided appears to be a part of an API interface for managing model credentials. It includes validation checks for Gemini Embedding models and encryption methods.
Irregularities:
File Encoding: The file encoding is specified at the beginning (
# coding=utf-8
), which is good practice, but it should include all supported Unicode characters (UTF-32) since UTF-8 can only represent basic multilingual plane characters.Variable Name Conflicts: There seems to be a conflict between the class variable
api_key
and the form field nameapi_key
. This might lead to confusion or incorrect data handling in certain scenarios.Exception Handling:
raise_exception
is set toFalse
, there's no return value indicating whether the input was valid or not. Consider adding some feedback on this path.Model Type List Check: The list comprehension used in
model_type_list
filtering could be more efficient if using Python 3.10+ with structural pattern matching.Error Message Formatting: Some error messages use curly braces
{}
instead of regular double quotes"
, which can make debugging harder.Missing Import Statements: Ensure that
BaseModelCredential
is properly imported from another module within the same package. Without knowing the exact structure, this cannot be guaranteed.Class Inheritance Misalignment: The class
GeminiEmbeddingCredential
inherits from bothforms.BaseForm
andBaseModelCredential
. This might cause conflicts in method names or other aspects like initialization logic.Improvements:
Clear Variable Names: Rename conflicting variables to avoid future confusion, such as changing
api_key
togemini_api_key
.Add Return Values for Non-Exceptional Cases: Provide meaningful return values when
raise_exception=False
to indicate success or failure.Enhance Error Messages: Use consistent quoting around strings in error messages for better readability.
# ... rest of the code ...
Here’s an example of how you might modify these points:
This cleaned-up version adheres to best practices regarding clear naming conventions, proper exception handling, and clearer error messaging. Also updated the model types list checking process and added a custom method to validate available models.