-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #120 from mobiusml/openai_compatible_api
OpenAI-compatible API
- Loading branch information
Showing
5 changed files
with
359 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# ruff: noqa: S101, S113 | ||
from collections.abc import AsyncGenerator | ||
|
||
import pytest | ||
import requests | ||
from openai import NotFoundError, OpenAI | ||
from ray import serve | ||
|
||
from aana.core.models.chat import ChatDialog, ChatMessage | ||
from aana.core.models.sampling import SamplingParams | ||
from aana.deployments.base_text_generation_deployment import ( | ||
BaseTextGenerationDeployment, | ||
ChatOutput, | ||
LLMOutput, | ||
) | ||
|
||
|
||
@serve.deployment | ||
class LowercaseLLM(BaseTextGenerationDeployment): | ||
"""Ray deployment that returns the lowercase version of a text structured as an LLM.""" | ||
|
||
async def generate_stream( | ||
self, prompt: str, sampling_params: SamplingParams | None = None | ||
) -> AsyncGenerator[LLMOutput, None]: | ||
"""Generate text stream. | ||
Args: | ||
prompt (str): The prompt. | ||
sampling_params (SamplingParams): The sampling parameters. | ||
Yields: | ||
LLMOutput: The generated text. | ||
""" | ||
for char in prompt: | ||
yield LLMOutput(text=char.lower()) | ||
|
||
async def chat( | ||
self, dialog: ChatDialog, sampling_params: SamplingParams | None = None | ||
) -> ChatOutput: | ||
"""Dummy chat method.""" | ||
text = dialog.messages[-1].content | ||
return ChatOutput(message=ChatMessage(content=text.lower(), role="assistant")) | ||
|
||
async def chat_stream( | ||
self, dialog: ChatDialog, sampling_params: SamplingParams | None = None | ||
) -> AsyncGenerator[LLMOutput, None]: | ||
"""Dummy chat stream method.""" | ||
text = dialog.messages[-1].content | ||
for char in text: | ||
yield LLMOutput(text=char.lower()) | ||
|
||
|
||
deployments = [ | ||
{ | ||
"name": "lowercase_deployment", | ||
"instance": LowercaseLLM, | ||
} | ||
] | ||
|
||
|
||
def test_chat_completion(app_setup): | ||
"""Test the chat completion endpoint for OpenAI compatible API.""" | ||
aana_app = app_setup(deployments, []) | ||
|
||
port = aana_app.port | ||
route_prefix = "" | ||
|
||
# Check that the server is ready | ||
response = requests.get(f"http://localhost:{port}{route_prefix}/api/ready") | ||
assert response.status_code == 200 | ||
assert response.json() == {"ready": True} | ||
|
||
messages = [ | ||
{"role": "user", "content": "Hello World!"}, | ||
] | ||
expected_output = messages[0]["content"].lower() | ||
|
||
client = OpenAI( | ||
api_key="token", | ||
base_url=f"http://localhost:{port}", | ||
) | ||
|
||
# Test chat completion endpoint | ||
completion = client.chat.completions.create( | ||
messages=messages, | ||
model="lowercase_deployment", | ||
) | ||
assert completion.choices[0].message.content == expected_output | ||
|
||
# Test chat completion endpoint with stream | ||
stream = client.chat.completions.create( | ||
messages=messages, | ||
model="lowercase_deployment", | ||
stream=True, | ||
) | ||
generated_text = "" | ||
for chunk in stream: | ||
generated_text += chunk.choices[0].delta.content or "" | ||
assert generated_text == expected_output | ||
|
||
# Test chat completion endpoint with non-existent model | ||
with pytest.raises(NotFoundError) as exc_info: | ||
completion = client.chat.completions.create( | ||
messages=messages, | ||
model="non_existent_model", | ||
) | ||
assert ( | ||
exc_info.value.body["message"] | ||
== "The model `non_existent_model` does not exist." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# OpenAI-compatible API | ||
|
||
Aana SDK provides an OpenAI-compatible Chat Completions API that allows you to integrate Aana with any OpenAI-compatible application. | ||
|
||
Chat Completions API is available at the `/chat/completions` endpoint. | ||
|
||
It is compatible with the OpenAI client libraries and can be used as a drop-in replacement for OpenAI API. | ||
|
||
```python | ||
from openai import OpenAI | ||
|
||
client = OpenAI( | ||
api_key="token", # Any non empty string will work, we don't require an API key | ||
base_url="http://localhost:8000", | ||
) | ||
|
||
messages = [ | ||
{"role": "user", "content": "What is the capital of France?"} | ||
] | ||
|
||
completion = client.chat.completions.create( | ||
messages=messages, | ||
model="llm_deployment", | ||
) | ||
|
||
print(completion.choices[0].message.content) | ||
``` | ||
|
||
The API also supports streaming: | ||
|
||
```python | ||
from openai import OpenAI | ||
|
||
client = OpenAI( | ||
api_key="token", # Any non empty string will work, we don't require an API key | ||
base_url="http://localhost:8000", | ||
) | ||
|
||
messages = [ | ||
{"role": "user", "content": "What is the capital of France?"} | ||
] | ||
|
||
stream = client.chat.completions.create( | ||
messages=messages, | ||
model="llm_deployment", | ||
stream=True, | ||
) | ||
for chunk in stream: | ||
print(chunk.choices[0].delta.content or "", end="") | ||
``` | ||
|
||
The API requires an LLM deployment. Aana SDK provides support for [vLLM](/docs/integrations.md#vllm) and [Hugging Face Transformers](/docs/integrations.md#hugging-face-transformers). | ||
|
||
The name of the model matches the name of the deployment. For example, if you registered a vLLM deployment with the name `llm_deployment`, you can use it with the OpenAI API as `model="llm_deployment"`. | ||
|
||
```python | ||
import os | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
|
||
from aana.core.models.sampling import SamplingParams | ||
from aana.core.models.types import Dtype | ||
from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment | ||
from aana.sdk import AanaSDK | ||
|
||
llm_deployment = VLLMDeployment.options( | ||
num_replicas=1, | ||
ray_actor_options={"num_gpus": 1}, | ||
user_config=VLLMConfig( | ||
model="TheBloke/Llama-2-7b-Chat-AWQ", | ||
dtype=Dtype.AUTO, | ||
quantization="awq", | ||
gpu_memory_reserved=13000, | ||
enforce_eager=True, | ||
default_sampling_params=SamplingParams( | ||
temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1024 | ||
), | ||
chat_template="llama2", | ||
).model_dump(mode="json"), | ||
) | ||
|
||
aana_app = AanaSDK(name="llm_app") | ||
aana_app.register_deployment(name="llm_deployment", instance=llm_deployment) | ||
|
||
if __name__ == "__main__": | ||
aana_app.connect() | ||
aana_app.migrate() | ||
aana_app.deploy() | ||
``` | ||
|
||
You can also use the example project `llama2` to deploy Llama-2-7b Chat model. | ||
|
||
```bash | ||
CUDA_VISIBLE_DEVICES=0 aana deploy aana.projects.llama2.app:aana_app | ||
``` |