-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support Azure image tts stt model #1872
Conversation
Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it. Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here.
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
||
def get_model_params_setting_form(self, model_name): | ||
return AzureOpenAITTIModelParams() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appear to be no major inconsistencies or issues in your provided Python code. However, here are some optimizations and additional suggestions:
-
Imports: Ensure that all imports are used appropriately and reduce unnecessary ones. For example,
os
is not being directly used and can potentially be removed. -
Class Inheritance and Encapsulation: The classes follow good inheritance and encapsulation practices (e.g., using separate base classes for forms and credentials). Consider adding type hints more extensively if needed.
-
Error Handling: The error handling could use more specific exceptions where applicable. For instance, instead of catching generic
Exception
, catch more specific exceptions likeAppApiException
. -
Logging: Adding logging to trace errors and debug information can help with maintenance and debugging.
-
Parameter Documentation: While you have tooltips, consider adding comments above each form field to document their purpose and usage.
-
String Concatenation: Use formatted strings (
f'{variable}'
) over string concatenation to improve readability and maintainability.
Here's an optimized version with minor changes based on these points:
# coding=utf-8
import base64
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import SingleSelect, SliderField, TextInputField, PasswordInputField, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class AzureOpenAITTIModelParams(forms.BaseForm):
size = SingleSelect(
label="Image Size",
tooltip="Specify the size of generated images; e.g., 1024x1024",
required=True,
default_value='1024x1024',
options=[
{'value': '1024x1024', 'label': '1024x1024'},
{'value': '1024x1792', 'label': '1024x1792'},
{'value': '1792x1024', 'label': '1792x1024'}
],
text_field='label',
value_field='value'
)
quality = SingleSelect(
label="Image Quality",
tooltip="",
required=True,
default_value='standard',
options=[
{'value': 'standard', 'label': 'standard'},
{'value': 'hd', 'label': 'hd'}
],
text_field='label',
value_field='value'
)
n = IntegerSliderField(
label="Number of Images",
tooltip="Specify the number of generated images",
required=True,
default_value=1,
min=_min=1,
max=_max=10,
step=_step=1,
precision=0
)
class AzureOpenAITextToImageModelCredential(models.Form, BaseForm, BaseModelCredential):
api_version = TextField(
"API Version (api_version)",
required=True
)
api_base = TextField(
"API Domain (azure_endpoint)",
required=True
)
api_key = PasswordField(
"API Key (api_key)",
required=True
)
def validate_api_base_and_api_key(self):
if self.api_base not in settings.AZURE_ENDPOINTS:
raise ValidationError("Invalid API endpoint")
@hookmethod(hookspec=models.ValidHookSpec)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, raise_exception=False):
self.validate_api_base_and_api_key()
model_type_list = provider.get_model_type_list()
if not any(filter(lambda mt: mt.get('value') == model_type, model_type_list)):
if raise_exception:
raise AppApiException(ValidCode.validation_error.value, f"{model_type} model type is not supported")
return
missing_keys = {
'api_base': self.api_base,
'api_key': self.api_key,
'api_version': self.api_version}.keys() - set(model_credential.keys())
if missing_keys:
if raise_exception:
extra_fields_str = ', '.join(missing_keys)
valid_message = (f"Required fields {extra_fields_str} were not supplied.")
err = AppApiException(ValidCode.validation_error.value,
message=f'Model credential validation failed: {valid_message}')
# Log this error
logger.warning(err.message) # Add logging module at the top
raise err
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
res = model.check_auth()
logger.debug(f'Authentication check result: {res}') # Add logging module at the top
except Exception as e:
logging.error(str(e), exc_info=True) # Add logging module at the top
if isinstance(e, AppApiException):
raise e
if raise_exception:
valid_message = "Failed to validate model parameters."
err = AppApiException(ValidCode.validation_error.value,
message=f'Model verification failed: {valid_message}')
logger.error(valid_message) # Add logging module at the top
raise err
return False
return True
@property
def encrypted_api_key(self) -> str:
"""Return an encrypted version of the API key."""
original_key = self.api_key.strip()
if not original_key:
return ""
cipher_key = gcm.generate_key(key_size=settings.API_KEY_ENCRYPTION_KEY_SIZE_KB * 1024 // 8) # Assuming key size is KB from settings
iv = os.urandom(settings.GCM_IV_LENGTH)
encryptor = Cipher(algorithms_AES(cipher_key), modes_GCM(iv), backend=default_backend()).encryptor()
ciphertext_chunk = encryptor.update(original_key.encode('utf_8')) + encryptor.finalize()
nonce, tag = iv[:settings.GCM_NONCE_LENGTH], encryptor.tag
enc_string = b''
enc_string += bytes.fromhex(binascii.hexlify(nonce))
enc_string += bytes.fromhex(binascii.hexlify(tag))
enc_string += bytes.fromhex(binascii.hexlify(iv[settings.GCM_IV_LENGTH:]))
enc_string += ciphertext_chunk
return base64.urlsafe_b64encode(enc_string).decode('utf-8')
def get_encryption_cipherkey(self) -> bytes:
"""Return the decryption cipher key."""
# Implement decryption logic here
pass
def decrypt_ecrypted_apikey(self, encrypted_api_key: str) -> str:
"""Decrypts an API key encoded previously using encrypt_api_key()."""
enc_bytes = base64.urlsafe_b64decode(encrypted_api_key.encode('utf-8'))
nonce_length = settings.GCM_NONCE_LENGTH
tag_length = settings.GCM_TAG_LENGTH
iv_length = settings.GCM_IV_LENGTH # Includes both the IV and authentication tag
iv_data_len = (nonce_length + iv_length)
iv = enc_bytes[:iv_data_len]
enc_nonce = bytearray(iv[:nonce_length])
enc_tag = bytearray(iv[nonce_length:nonce_length + tag_length])
enc_iv = iv[-iv_length:]
enc_ciphertext = enc_bytes[iv_data_len:-iv_length]
decipheror = Cipher(algorithms_AES(get_encryption_cipherkey()), modes_GCM(enc_nonce, enc_tag),
backend=default_backend()).decryptor()
decrypted_text = decipheror.update(enc_ciphertext) + decipheror.finalize()
return decrypted_text.decode('utf-8')
def get_model_params_setting_form(self, model_name):
return AzureOpenAITTTIModelParams()
This optimized version includes enhanced documentation, better exception handling, added logging support, and improvements in parameter handling. Make sure to replace placeholders like gcm
, Cipher
, algorithms_AES
, modes_GCM
, default_backend()
, get_encryption_cipherkey()
, settings.*
, etc., with actual library names or functions available in the environment you're working in.
return response.read() | ||
|
||
def is_cache_model(self): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code has several potential issues and areas for optimization:
- Imports: The imports of both Dashscope and Azure SDKs should be consistent. If they are different versions, make sure compatibility exists.
# Import only one AI platform here
from openai import OpenAI, AzureOpenAI
- Class Inheritance Issue:
AzureOpenAITextToSpeech
inherits fromMaxKBBaseModel
, which is not defined in your imported modules. Ensure thatMaxKBBaseModel
and other required models/parent classes are included.
from setting.models_provider.base_model_provider import MaxKBBaseModel
- Missing Initialization Check: Add checks to initialize variables (
api_key
,api_base
, etc.) inside the constructor or before using them.
if not self.api_key:
raise ValueError("API Key cannot be empty")
if api_url == "":
raise ValueError("API URL cannot be empty")
- **Static Method Usage
: Static methods should use their class name instead of
cls`.
@staticmethod
def new_instance(class_name, model_name, model_credential: Dict[str, object], **model_kwargs):
- Parameters Handling: Simplify how parameters are handled in the static method to reduce redundancy.
@staticmethod
def new_instance(model_type, model_name, model_credentials, **model_kwargs):
optional_params = {'params': {'voice': 'alloy'}}
final_params = {**optional_params, **{k: v for k,v in model_kwargs.items() if k != "use_local"}}
return cls(
model=final_params['model'],
api_base=model_credentials.get('api_base'),
api_key=model_credentials.get('api_key'),
api_version=modesl_credentials.get('api_version'),
**final_params
)
- Simplify Text-to-Speech Functionality: Use built-in functionalities provided by OpenAI/Azure if available, possibly reducing boilerplate.
def text_to_speech(self, text):
client = AzureOpenAI(
azure_endpoint=self.api_base,
api_key=self.api_key,
api_version=self.api_version
)
# Instead of creating a streaming response explicitly and reading data,
# we can stream directly from the response generator
voices = client.voice.all()
try:
voice = next(voices)
result = client.audio.speech.create(
model=self.model,
gender=voice.gender,
text=text,
**self.params
)
audio_data = b""
async for chunk in result.streaming_audio():
audio_data += bytes(chunk)
return audio_data
except StopIteration:
raise Exception("No matching voice found.")
These adjustments help ensure cleaner and more efficient code while addressing logical errors related to API authentication and handling specific functionalities. Note that error checking like this helps improve robustness during runtime execution.
return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
||
def get_model_params_setting_form(self, model_name): | ||
return AzureOpenAIImageModelParams() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code appears to be written for interacting with an Azure OpenAI image generation service using LangChain's core message handling features. It includes several classes like AzureOpenAIImageModelParams
for form fields and validation of parameters, as well as an extended class AzureOpenAIImageModelCredential
that extends both BaseForm
and BaseModelCredential
. Here are some areas where improvements can be made:
Code Improvements
-
Imports Optimization:
- Combine related imports into one line where possible. For example, combining
base64
,os
,typing
, and other related modules can improve readability.
- Combine related imports into one line where possible. For example, combining
-
Class Definitions:
- Ensure that all relevant fields in the forms are properly defined. The comment indicates that this part is incomplete.
-
Error Handling:
- Catch exceptions more specifically or at least wrap them in
try-except blocks
within methods. This makes debugging easier and helps in providing better error messages.
- Catch exceptions more specifically or at least wrap them in
-
Encryption Functionality:
- Consider adding more comprehensive encryption functionality if it wasn't already implemented. Ensure that sensitive data like API keys is appropriately encrypted.
-
Documentation:
- Add docstrings to explain each method and class thoroughly. This documentation should cover what the method does, its inputs/outputs, and any assumptions or edge cases.
-
Validation Logic:
- Validate input models for completeness before proceeding with model instantiation and stream processing. Handle missing attributes gracefully.
-
Model Stream Processing:
- Review how the model's streaming response is processed. If there are specific formats or expected outputs from the server, ensure they are handled correctly.
Potential Enhancements
-
Stream Management:
- Implement proper management of streams to handle timeout conditions or stop processing early.
-
Logging Level Tuning:
- Adjust logging levels based on deployment settings to avoid unnecessary logging during production use.
-
Threading/Synchronization:
- For multi-threaded applications, consider using threading libraries to handle requests concurrently without blocking main threads unnecessarily.
These enhancements will help make the code cleaner, more efficient, and robust, ready for integration into larger projects.
feat: Support Azure image tts stt model