From b8750ff2826d739cbe8f1543b4825351a6e7ec1c Mon Sep 17 00:00:00 2001 From: Franck Stephane Ndzomga Date: Fri, 22 Nov 2024 13:51:12 +0100 Subject: [PATCH] adding nebius api client --- .../components/model_client/__init__.py | 5 + .../components/model_client/nebius_client.py | 309 ++++++++++++++++++ adalflow/tests/test_nebius_client.py | 101 ++++++ 3 files changed, 415 insertions(+) create mode 100644 adalflow/adalflow/components/model_client/nebius_client.py create mode 100644 adalflow/tests/test_nebius_client.py diff --git a/adalflow/adalflow/components/model_client/__init__.py b/adalflow/adalflow/components/model_client/__init__.py index 64db136f7..bbd1b2332 100644 --- a/adalflow/adalflow/components/model_client/__init__.py +++ b/adalflow/adalflow/components/model_client/__init__.py @@ -44,6 +44,10 @@ "adalflow.components.model_client.openai_client.OpenAIClient", OptionalPackages.OPENAI, ) +NebiusClient = LazyImport( + "adalflow.components.model_client.nebius_client.NebiusClient", + OptionalPackages.OPENAI, +) GoogleGenAIClient = LazyImport( "adalflow.components.model_client.google_client.GoogleGenAIClient", OptionalPackages.GOOGLE_GENERATIVEAI, @@ -75,6 +79,7 @@ "BedrockAPIClient", "GroqAPIClient", "OpenAIClient", + "NebiusClient", "GoogleGenAIClient", ] diff --git a/adalflow/adalflow/components/model_client/nebius_client.py b/adalflow/adalflow/components/model_client/nebius_client.py new file mode 100644 index 000000000..954c9596d --- /dev/null +++ b/adalflow/adalflow/components/model_client/nebius_client.py @@ -0,0 +1,309 @@ +"""NEBIUS AI ModelClient integration.""" + +import os +from typing import ( + Dict, + Sequence, + Optional, + List, + Any, + TypeVar, + Callable, + Generator, + Union, + Literal, +) +import re + +import logging +import backoff + +# optional import +from adalflow.utils.lazy_import import safe_import, OptionalPackages + + +openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1]) + +from openai import OpenAI, AsyncOpenAI, Stream +from openai import ( + APITimeoutError, + InternalServerError, + RateLimitError, + UnprocessableEntityError, + BadRequestError, +) +from openai.types import ( + Completion, + CreateEmbeddingResponse, +) +from openai.types.chat import ChatCompletionChunk, ChatCompletion + +from adalflow.core.model_client import ModelClient +from adalflow.core.types import ( + ModelType, + EmbedderOutput, + TokenLogProb, + CompletionUsage, + GeneratorOutput, +) +from adalflow.components.model_client.utils import parse_embedding_response + +log = logging.getLogger(__name__) +T = TypeVar("T") + + +# completion parsing functions and you can combine them into one singple chat completion parser +def get_first_message_content(completion: ChatCompletion) -> str: + r"""When we only need the content of the first message. + It is the default parser for chat completion.""" + return completion.choices[0].message.content + + +def parse_stream_response(completion: ChatCompletionChunk) -> str: + r"""Parse the response of the stream API.""" + return completion.choices[0].delta.content + + +def handle_streaming_response(generator: Stream[ChatCompletionChunk]): + r"""Handle the streaming response.""" + for completion in generator: + log.debug(f"Raw chunk completion: {completion}") + parsed_content = parse_stream_response(completion) + yield parsed_content + + +def get_all_messages_content(completion: ChatCompletion) -> List[str]: + r"""When the n > 1, get all the messages content.""" + return [c.message.content for c in completion.choices] + + +def get_probabilities(completion: ChatCompletion) -> List[List[TokenLogProb]]: + r"""Get the probabilities of each token in the completion.""" + log_probs = [] + for c in completion.choices: + content = c.logprobs.content + print(content) + log_probs_for_choice = [] + for nebius_token_logprob in content: + token = nebius_token_logprob.token + logprob = nebius_token_logprob.logprob + log_probs_for_choice.append(TokenLogProb(token=token, logprob=logprob)) + log_probs.append(log_probs_for_choice) + return log_probs + + +class NebiusClient(ModelClient): + __doc__ = r"""A component wrapper for the Nebius API client. + + Support both embedding and chat completion API. + + Args: + api_key (Optional[str], optional): NEBIUS API key. Defaults to None. + chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None. + Default is `get_first_message_content`. + + References: + - API: https://docs.nebius.com/studio/inference/api + - Models: https://docs.nebius.com/studio/inference/models + """ + + def __init__( + self, + api_key: Optional[str] = None, + chat_completion_parser: Callable[[Completion], Any] = None, + input_type: Literal["text", "messages"] = "text", + ): + r"""It is recommended to set the NEBIUS_API_KEY environment variable instead of passing it as an argument. + + Args: + api_key (Optional[str], optional): NEBIUS API key. Defaults to None. + """ + super().__init__() + self._api_key = api_key + self._base_url = "https://api.studio.nebius.ai/v1/" + self.sync_client = self.init_sync_client() + self.async_client = None # only initialize if the async call is called + self.chat_completion_parser = ( + chat_completion_parser or get_first_message_content + ) + self._input_type = input_type + + + def init_sync_client(self): + api_key = self._api_key or os.getenv("NEBIUS_API_KEY") + if not api_key: + raise ValueError("Environment variable NEBIUS_API_KEY must be set") + return OpenAI(api_key=api_key, base_url=self._base_url) + + def init_async_client(self): + api_key = self._api_key or os.getenv("NEBIUS_API_KEY") + if not api_key: + raise ValueError("Environment variable NEBIUS_API_KEY must be set") + return AsyncOpenAI(api_key=api_key, base_url=self._base_url) + + def parse_chat_completion( + self, + completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]], + ) -> "GeneratorOutput": + """Parse the completion, and put it into the raw_response.""" + log.debug(f"completion: {completion}, parser: {self.chat_completion_parser}") + try: + data = self.chat_completion_parser(completion) + usage = self.track_completion_usage(completion) + return GeneratorOutput( + data=None, error=None, raw_response=data, usage=usage + ) + except Exception as e: + log.error(f"Error parsing the completion: {e}") + return GeneratorOutput(data=None, error=str(e), raw_response=completion) + + def track_completion_usage( + self, + completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]], + ) -> CompletionUsage: + if isinstance(completion, ChatCompletion): + usage: CompletionUsage = CompletionUsage( + completion_tokens=completion.usage.completion_tokens, + prompt_tokens=completion.usage.prompt_tokens, + total_tokens=completion.usage.total_tokens, + ) + return usage + else: + raise NotImplementedError( + "streaming completion usage tracking is not implemented" + ) + + def parse_embedding_response( + self, response: CreateEmbeddingResponse + ) -> EmbedderOutput: + r"""Parse the embedding response to a structure LightRAG components can understand. + + Should be called in ``Embedder``. + """ + try: + return parse_embedding_response(response) + except Exception as e: + log.error(f"Error parsing the embedding response: {e}") + return EmbedderOutput(data=[], error=str(e), raw_response=response) + + def convert_inputs_to_api_kwargs( + self, + input: Optional[Any] = None, + model_kwargs: Dict = {}, + model_type: ModelType = ModelType.UNDEFINED, + ) -> Dict: + r""" + Specify the API input type and output api_kwargs that will be used in _call and _acall methods. + Convert the Component's standard input, and system_input(chat model) and model_kwargs into API-specific format + """ + + final_model_kwargs = model_kwargs.copy() + if model_type == ModelType.EMBEDDER: + if isinstance(input, str): + input = [input] + # convert input to input + if not isinstance(input, Sequence): + raise TypeError("input must be a sequence of text") + final_model_kwargs["input"] = input + elif model_type == ModelType.LLM: + # convert input to messages + messages: List[Dict[str, str]] = [] + + if self._input_type == "messages": + system_start_tag = "" + system_end_tag = "" + user_start_tag = "" + user_end_tag = "" + pattern = f"{system_start_tag}(.*?){system_end_tag}{user_start_tag}(.*?){user_end_tag}" + # Compile the regular expression + regex = re.compile(pattern) + # Match the pattern + match = regex.search(input) + system_prompt, input_str = None, None + + if match: + system_prompt = match.group(1) + input_str = match.group(2) + + else: + print("No match found.") + if system_prompt and input_str: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": input_str}) + if len(messages) == 0: + messages.append({"role": "system", "content": input}) + final_model_kwargs["messages"] = messages + else: + raise ValueError(f"model_type {model_type} is not supported") + return final_model_kwargs + + @backoff.on_exception( + backoff.expo, + ( + APITimeoutError, + InternalServerError, + RateLimitError, + UnprocessableEntityError, + BadRequestError, + ), + max_time=5, + ) + def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): + """ + kwargs is the combined input and model_kwargs. Support streaming call. + """ + log.info(f"api_kwargs: {api_kwargs}") + if model_type == ModelType.EMBEDDER: + return self.sync_client.embeddings.create(**api_kwargs) + elif model_type == ModelType.LLM: + if "stream" in api_kwargs and api_kwargs.get("stream", False): + log.debug("streaming call") + self.chat_completion_parser = handle_streaming_response + return self.sync_client.chat.completions.create(**api_kwargs) + return self.sync_client.chat.completions.create(**api_kwargs) + else: + raise ValueError(f"model_type {model_type} is not supported") + + @backoff.on_exception( + backoff.expo, + ( + APITimeoutError, + InternalServerError, + RateLimitError, + UnprocessableEntityError, + BadRequestError, + ), + max_time=5, + ) + async def acall( + self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED + ): + """ + kwargs is the combined input and model_kwargs + """ + if self.async_client is None: + self.async_client = self.init_async_client() + if model_type == ModelType.EMBEDDER: + return await self.async_client.embeddings.create(**api_kwargs) + elif model_type == ModelType.LLM: + return await self.async_client.chat.completions.create(**api_kwargs) + else: + raise ValueError(f"model_type {model_type} is not supported") + + @classmethod + def from_dict(cls: type[T], data: Dict[str, Any]) -> T: + obj = super().from_dict(data) + # recreate the existing clients + obj.sync_client = obj.init_sync_client() + obj.async_client = obj.init_async_client() + return obj + + def to_dict(self) -> Dict[str, Any]: + r"""Convert the component to a dictionary.""" + # TODO: not exclude but save yes or no for recreating the clients + exclude = [ + "sync_client", + "async_client", + ] # unserializable object + output = super().to_dict(exclude=exclude) + return output diff --git a/adalflow/tests/test_nebius_client.py b/adalflow/tests/test_nebius_client.py new file mode 100644 index 000000000..1a84d1232 --- /dev/null +++ b/adalflow/tests/test_nebius_client.py @@ -0,0 +1,101 @@ +import unittest +from unittest.mock import patch, AsyncMock, Mock + +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletion + +from adalflow.core.types import ModelType, GeneratorOutput +from adalflow.components.model_client.nebius_client import NebiusClient + + +def getenv_side_effect(key): + # Environment variable mapping for tests + env_vars = {"NEBIUS_API_KEY": "fake_nebius_api_key"} + return env_vars.get(key, None) # Returns None if the key is not found + + +class TestNebiusClient(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.client = NebiusClient(api_key="fake_nebius_api_key") + self.mock_response = { + "id": "cmpl-3Q8Z5J9Z1Z5z5", + "created": 1635820005, + "object": "chat.completion", + "model": "meta-llama/Meta-Llama-3.1-405B-Instruct", + "choices": [ + { + "message": { + "content": "Hello, world!", + "role": "assistant", + }, + "index": 0, + "finish_reason": "stop", + } + ], + "usage": CompletionUsage( + completion_tokens=10, prompt_tokens=20, total_tokens=30 + ), + } + self.mock_response = ChatCompletion(**self.mock_response) + self.api_kwargs = { + "messages": [{"role": "user", "content": "Hello"}], + "model": "meta-llama/Meta-Llama-3.1-405B-Instruct", + } + + @patch("adalflow.components.model_client.nebius_client.AsyncOpenAI") + async def test_acall_llm(self, MockAsyncOpenAI): + mock_async_client = AsyncMock() + MockAsyncOpenAI.return_value = mock_async_client + + # Mock the response + mock_async_client.chat.completions.create = AsyncMock( + return_value=self.mock_response + ) + + # Call the `acall` method + result = await self.client.acall( + api_kwargs=self.api_kwargs, model_type=ModelType.LLM + ) + + # Assertions + MockAsyncOpenAI.assert_called_once() + mock_async_client.chat.completions.create.assert_awaited_once_with( + **self.api_kwargs + ) + self.assertEqual(result, self.mock_response) + + @patch( + "adalflow.components.model_client.nebius_client.NebiusClient.init_sync_client" + ) + @patch("adalflow.components.model_client.nebius_client.OpenAI") + def test_call(self, MockSyncOpenAI, mock_init_sync_client): + mock_sync_client = Mock() + MockSyncOpenAI.return_value = mock_sync_client + mock_init_sync_client.return_value = mock_sync_client + + # Mock the client's API: chat.completions.create + mock_sync_client.chat.completions.create = Mock(return_value=self.mock_response) + + # Set the sync client + self.client.sync_client = mock_sync_client + + # Call the `call` method + result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM) + + # Assertions + mock_sync_client.chat.completions.create.assert_called_once_with( + **self.api_kwargs + ) + self.assertEqual(result, self.mock_response) + + # Test `parse_chat_completion` + output = self.client.parse_chat_completion(completion=self.mock_response) + self.assertTrue(isinstance(output, GeneratorOutput)) + self.assertEqual(output.raw_response, "Hello, world!") + self.assertEqual(output.usage.completion_tokens, 10) + self.assertEqual(output.usage.prompt_tokens, 20) + self.assertEqual(output.usage.total_tokens, 30) + + +if __name__ == "__main__": + unittest.main()