Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Azure image tts stt model #1872

Merged
merged 1 commit into from
Dec 19, 2024
Merged

Conversation

shaohuzhang1
Copy link
Contributor

feat: Support Azure image tts stt model

Copy link

f2c-ci-robot bot commented Dec 19, 2024

Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it.

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository.

Copy link

f2c-ci-robot bot commented Dec 19, 2024

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:

The full list of commands accepted by this bot can be found here.

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

def get_model_params_setting_form(self, model_name):
return AzureOpenAITTIModelParams()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There appear to be no major inconsistencies or issues in your provided Python code. However, here are some optimizations and additional suggestions:

  1. Imports: Ensure that all imports are used appropriately and reduce unnecessary ones. For example, os is not being directly used and can potentially be removed.

  2. Class Inheritance and Encapsulation: The classes follow good inheritance and encapsulation practices (e.g., using separate base classes for forms and credentials). Consider adding type hints more extensively if needed.

  3. Error Handling: The error handling could use more specific exceptions where applicable. For instance, instead of catching generic Exception, catch more specific exceptions like AppApiException.

  4. Logging: Adding logging to trace errors and debug information can help with maintenance and debugging.

  5. Parameter Documentation: While you have tooltips, consider adding comments above each form field to document their purpose and usage.

  6. String Concatenation: Use formatted strings (f'{variable}') over string concatenation to improve readability and maintainability.

Here's an optimized version with minor changes based on these points:

# coding=utf-8

import base64
from typing import Dict

from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import SingleSelect, SliderField, TextInputField, PasswordInputField, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode


class AzureOpenAITTIModelParams(forms.BaseForm):
    size = SingleSelect(
        label="Image Size",
        tooltip="Specify the size of generated images; e.g., 1024x1024",
        required=True,
        default_value='1024x1024',
        options=[
            {'value': '1024x1024', 'label': '1024x1024'},
            {'value': '1024x1792', 'label': '1024x1792'},
            {'value': '1792x1024', 'label': '1792x1024'}
        ],
        text_field='label',
        value_field='value'
    )

    quality = SingleSelect(
        label="Image Quality",
        tooltip="",
        required=True,
        default_value='standard',
        options=[
            {'value': 'standard', 'label': 'standard'},
            {'value': 'hd', 'label': 'hd'}
        ],
        text_field='label',
        value_field='value'
    )

    n = IntegerSliderField(
        label="Number of Images",
        tooltip="Specify the number of generated images",
        required=True,
        default_value=1,
        min=_min=1,
        max=_max=10,
        step=_step=1,
        precision=0
    )


class AzureOpenAITextToImageModelCredential(models.Form, BaseForm, BaseModelCredential):
    api_version = TextField(
        "API Version (api_version)",
        required=True
    )
    api_base = TextField(
        "API Domain (azure_endpoint)",
        required=True
    )
    api_key = PasswordField(
        "API Key (api_key)",
        required=True
    )

    def validate_api_base_and_api_key(self):
        if self.api_base not in settings.AZURE_ENDPOINTS:
            raise ValidationError("Invalid API endpoint")

    @hookmethod(hookspec=models.ValidHookSpec)
    def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, raise_exception=False):
        self.validate_api_base_and_api_key()

        model_type_list = provider.get_model_type_list()
        if not any(filter(lambda mt: mt.get('value') == model_type, model_type_list)):
            if raise_exception:
                raise AppApiException(ValidCode.validation_error.value, f"{model_type} model type is not supported")
            return
        
        missing_keys = {
            'api_base': self.api_base,
            'api_key': self.api_key,
            'api_version': self.api_version}.keys() - set(model_credential.keys())
        
        if missing_keys:
            if raise_exception:
                extra_fields_str = ', '.join(missing_keys)
                valid_message = (f"Required fields {extra_fields_str} were not supplied.")
                err = AppApiException(ValidCode.validation_error.value, 
                                    message=f'Model credential validation failed: {valid_message}')
                # Log this error
                logger.warning(err.message)  # Add logging module at the top
                raise err
            
            return False
        
        try:
            model = provider.get_model(model_type, model_name, model_credential)
            res = model.check_auth()
            logger.debug(f'Authentication check result: {res}')  # Add logging module at the top
        except Exception as e:
            logging.error(str(e), exc_info=True)  # Add logging module at the top
            if isinstance(e, AppApiException):
                raise e
            if raise_exception:
                valid_message = "Failed to validate model parameters."
                err = AppApiException(ValidCode.validation_error.value, 
                                        message=f'Model verification failed: {valid_message}')
                logger.error(valid_message)  # Add logging module at the top
                raise err
            
            return False
        
        return True
    
    @property
    def encrypted_api_key(self) -> str:
        """Return an encrypted version of the API key."""
        original_key = self.api_key.strip()
        if not original_key:
            return ""

        cipher_key = gcm.generate_key(key_size=settings.API_KEY_ENCRYPTION_KEY_SIZE_KB * 1024 // 8)  # Assuming key size is KB from settings
        iv = os.urandom(settings.GCM_IV_LENGTH)

        encryptor = Cipher(algorithms_AES(cipher_key), modes_GCM(iv), backend=default_backend()).encryptor()
        ciphertext_chunk = encryptor.update(original_key.encode('utf_8')) + encryptor.finalize()

        nonce, tag = iv[:settings.GCM_NONCE_LENGTH], encryptor.tag

        enc_string = b''
        enc_string += bytes.fromhex(binascii.hexlify(nonce))
        enc_string += bytes.fromhex(binascii.hexlify(tag))
        enc_string += bytes.fromhex(binascii.hexlify(iv[settings.GCM_IV_LENGTH:]))
        enc_string += ciphertext_chunk

        return base64.urlsafe_b64encode(enc_string).decode('utf-8')

    def get_encryption_cipherkey(self) -> bytes:
        """Return the decryption cipher key."""
        # Implement decryption logic here
        pass

    def decrypt_ecrypted_apikey(self, encrypted_api_key: str) -> str:
        """Decrypts an API key encoded previously using encrypt_api_key()."""
        enc_bytes = base64.urlsafe_b64decode(encrypted_api_key.encode('utf-8'))
        nonce_length = settings.GCM_NONCE_LENGTH
        tag_length = settings.GCM_TAG_LENGTH
        iv_length = settings.GCM_IV_LENGTH  # Includes both the IV and authentication tag

        iv_data_len = (nonce_length + iv_length)
        iv = enc_bytes[:iv_data_len]

        enc_nonce = bytearray(iv[:nonce_length])
        enc_tag = bytearray(iv[nonce_length:nonce_length + tag_length])
        enc_iv = iv[-iv_length:]

        enc_ciphertext = enc_bytes[iv_data_len:-iv_length]
        decipheror = Cipher(algorithms_AES(get_encryption_cipherkey()), modes_GCM(enc_nonce, enc_tag),
                            backend=default_backend()).decryptor()
        decrypted_text = decipheror.update(enc_ciphertext) + decipheror.finalize()

        return decrypted_text.decode('utf-8')

    def get_model_params_setting_form(self, model_name):
        return AzureOpenAITTTIModelParams()

This optimized version includes enhanced documentation, better exception handling, added logging support, and improvements in parameter handling. Make sure to replace placeholders like gcm, Cipher, algorithms_AES, modes_GCM, default_backend(), get_encryption_cipherkey(), settings.*, etc., with actual library names or functions available in the environment you're working in.

return response.read()

def is_cache_model(self):
return False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code has several potential issues and areas for optimization:

  • Imports: The imports of both Dashscope and Azure SDKs should be consistent. If they are different versions, make sure compatibility exists.
# Import only one AI platform here
from openai import OpenAI, AzureOpenAI
  • Class Inheritance Issue: AzureOpenAITextToSpeech inherits from MaxKBBaseModel, which is not defined in your imported modules. Ensure that MaxKBBaseModel and other required models/parent classes are included.
from setting.models_provider.base_model_provider import MaxKBBaseModel
  • Missing Initialization Check: Add checks to initialize variables (api_key, api_base, etc.) inside the constructor or before using them.
if not self.api_key:
    raise ValueError("API Key cannot be empty")

if api_url == "":
    raise ValueError("API URL cannot be empty")
  • **Static Method Usage: Static methods should use their class name instead of cls`.
@staticmethod
def new_instance(class_name, model_name, model_credential: Dict[str, object], **model_kwargs):
  • Parameters Handling: Simplify how parameters are handled in the static method to reduce redundancy.
@staticmethod
def new_instance(model_type, model_name, model_credentials, **model_kwargs):
    optional_params = {'params': {'voice': 'alloy'}}
    final_params = {**optional_params, **{k: v for k,v in model_kwargs.items() if k != "use_local"}}
    return cls(
        model=final_params['model'],
        api_base=model_credentials.get('api_base'),
        api_key=model_credentials.get('api_key'),
        api_version=modesl_credentials.get('api_version'),
        **final_params
    )
  • Simplify Text-to-Speech Functionality: Use built-in functionalities provided by OpenAI/Azure if available, possibly reducing boilerplate.
def text_to_speech(self, text):
    client = AzureOpenAI(
        azure_endpoint=self.api_base,
        api_key=self.api_key,
        api_version=self.api_version
    )

    # Instead of creating a streaming response explicitly and reading data, 
    # we can stream directly from the response generator
    voices = client.voice.all()

    try:
        voice = next(voices)
        result = client.audio.speech.create(
            model=self.model,
            gender=voice.gender,
            text=text,
            **self.params
        )
        
        audio_data = b""
        async for chunk in result.streaming_audio():
            audio_data += bytes(chunk)

        return audio_data

    except StopIteration:
        raise Exception("No matching voice found.")

These adjustments help ensure cleaner and more efficient code while addressing logical errors related to API authentication and handling specific functionalities. Note that error checking like this helps improve robustness during runtime execution.

return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

def get_model_params_setting_form(self, model_name):
return AzureOpenAIImageModelParams()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code appears to be written for interacting with an Azure OpenAI image generation service using LangChain's core message handling features. It includes several classes like AzureOpenAIImageModelParams for form fields and validation of parameters, as well as an extended class AzureOpenAIImageModelCredential that extends both BaseForm and BaseModelCredential. Here are some areas where improvements can be made:

Code Improvements

  1. Imports Optimization:

    • Combine related imports into one line where possible. For example, combining base64, os, typing, and other related modules can improve readability.
  2. Class Definitions:

    • Ensure that all relevant fields in the forms are properly defined. The comment indicates that this part is incomplete.
  3. Error Handling:

    • Catch exceptions more specifically or at least wrap them in try-except blocks within methods. This makes debugging easier and helps in providing better error messages.
  4. Encryption Functionality:

    • Consider adding more comprehensive encryption functionality if it wasn't already implemented. Ensure that sensitive data like API keys is appropriately encrypted.
  5. Documentation:

    • Add docstrings to explain each method and class thoroughly. This documentation should cover what the method does, its inputs/outputs, and any assumptions or edge cases.
  6. Validation Logic:

    • Validate input models for completeness before proceeding with model instantiation and stream processing. Handle missing attributes gracefully.
  7. Model Stream Processing:

    • Review how the model's streaming response is processed. If there are specific formats or expected outputs from the server, ensure they are handled correctly.

Potential Enhancements

  • Stream Management:

    • Implement proper management of streams to handle timeout conditions or stop processing early.
  • Logging Level Tuning:

    • Adjust logging levels based on deployment settings to avoid unnecessary logging during production use.
  • Threading/Synchronization:

    • For multi-threaded applications, consider using threading libraries to handle requests concurrently without blocking main threads unnecessarily.

These enhancements will help make the code cleaner, more efficient, and robust, ready for integration into larger projects.

@liuruibin liuruibin merged commit c98874e into main Dec 19, 2024
4 of 5 checks passed
@liuruibin liuruibin deleted the pr@main@feat_azure_models branch December 19, 2024 04:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants