Skip to content

Commit

Permalink
Merge pull request #120 from mobiusml/openai_compatible_api
Browse files Browse the repository at this point in the history
OpenAI-compatible API
  • Loading branch information
movchan74 authored Jun 21, 2024
2 parents 4ebcfd9 + 25933a0 commit 0798595
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 0 deletions.
80 changes: 80 additions & 0 deletions aana/api/request_handler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import json
import time
from typing import Any
from uuid import uuid4

import ray
from fastapi.openapi.utils import get_openapi
from fastapi.responses import StreamingResponse
from ray import serve

from aana.api.api_generation import Endpoint, add_custom_schemas_to_openapi_schema
from aana.api.app import app
from aana.api.event_handlers.event_manager import EventManager
from aana.api.responses import AanaJSONResponse
from aana.configs.settings import settings as aana_settings
from aana.core.models.chat import ChatCompletetion, ChatCompletionRequest, ChatDialog
from aana.core.models.sampling import SamplingParams
from aana.core.models.task import TaskId
from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.storage.services.task import TaskInfo, delete_task, get_task_info


Expand Down Expand Up @@ -124,3 +132,75 @@ async def delete_task_endpoint(self, task_id: str) -> TaskId:
"""
task = delete_task(task_id)
return TaskId(task_id=str(task.id))

@app.post("/chat/completions", response_model=ChatCompletetion)
async def chat_completions(self, request: ChatCompletionRequest):
"""Handle chat completions requests for OpenAI compatible API."""

async def _async_chat_completions(
handle: AanaDeploymentHandle,
dialog: ChatDialog,
sampling_params: SamplingParams,
):
async for response in handle.chat_stream(
dialog=dialog, sampling_params=sampling_params
):
chunk = {
"id": f"chatcmpl-{uuid4().hex}",
"object": "chat.completion.chunk",
"model": request.model,
"created": int(time.time()),
"choices": [
{
"index": 0,
"delta": {"content": response["text"], "role": "assistant"},
}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"

# Check if the deployment exists
try:
handle = await AanaDeploymentHandle.create(request.model)
except ray.serve.exceptions.RayServeException:
return AanaJSONResponse(
content={
"error": {"message": f"The model `{request.model}` does not exist."}
},
status_code=404,
)

# Check if the deployment is a chat model
if not hasattr(handle, "chat") or not hasattr(handle, "chat_stream"):
return AanaJSONResponse(
content={
"error": {"message": f"The model `{request.model}` does not exist."}
},
status_code=404,
)

dialog = ChatDialog(
messages=request.messages,
)

sampling_params = SamplingParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
top_p=request.top_p,
)

if request.stream:
return StreamingResponse(
_async_chat_completions(handle, dialog, sampling_params),
media_type="application/x-ndjson",
)
else:
response = await handle.chat(dialog=dialog, sampling_params=sampling_params)
return {
"id": f"chatcmpl-{uuid4().hex}",
"object": "chat.completion",
"model": request.model,
"created": int(time.time()),
"choices": [{"index": 0, "message": response["message"]}],
}
70 changes: 70 additions & 0 deletions aana/core/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,73 @@ def from_list(cls, messages: list[dict[str, str]]) -> "ChatDialog":
ChatDialog: the chat dialog
"""
return ChatDialog(messages=[ChatMessage(**message) for message in messages])


class ChatCompletionRequest(BaseModel):
"""A chat completion request for OpenAI compatible API."""

model: str = Field(..., description="The model name (name of the LLM deployment).")
messages: list[ChatMessage] = Field(
..., description="A list of messages comprising the conversation so far."
)
temperature: float | None = Field(
default=None,
ge=0.0,
description=(
"Float that controls the randomness of the sampling. "
"Lower values make the model more deterministic, "
"while higher values make the model more random. "
"Zero means greedy sampling."
),
)
top_p: float | None = Field(
default=None,
gt=0.0,
le=1.0,
description=(
"Float that controls the cumulative probability of the top tokens to consider. "
"Must be in (0, 1]. Set to 1 to consider all tokens."
),
)
max_tokens: int | None = Field(
default=None, ge=1, description="The maximum number of tokens to generate."
)

stream: bool | None = Field(
default=False,
description=(
"If set, partial message deltas will be sent, like in ChatGPT. "
"Tokens will be sent as data-only server-sent events as they become available, "
"with the stream terminated by a data: [DONE] message."
),
)


class ChatCompletetionChoice(BaseModel):
"""A chat completion choice for OpenAI compatible API."""

index: int = Field(
..., description="The index of the choice in the list of choices."
)
message: ChatMessage = Field(
..., description="A chat completion message generated by the model."
)


class ChatCompletetion(BaseModel):
"""A chat completion for OpenAI compatible API."""

id: str = Field(..., description="A unique identifier for the chat completion.")
model: str = Field(..., description="The model used for the chat completion.")
created: int = Field(
...,
description="The Unix timestamp (in seconds) of when the chat completion was created.",
)
choices: list[ChatCompletetionChoice] = Field(
...,
description="A list of chat completion choices.",
)
object: Literal["chat.completion"] = Field(
"chat.completion",
description="The object type, which is always `chat.completion`.",
)
110 changes: 110 additions & 0 deletions aana/tests/units/test_chat_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# ruff: noqa: S101, S113
from collections.abc import AsyncGenerator

import pytest
import requests
from openai import NotFoundError, OpenAI
from ray import serve

from aana.core.models.chat import ChatDialog, ChatMessage
from aana.core.models.sampling import SamplingParams
from aana.deployments.base_text_generation_deployment import (
BaseTextGenerationDeployment,
ChatOutput,
LLMOutput,
)


@serve.deployment
class LowercaseLLM(BaseTextGenerationDeployment):
"""Ray deployment that returns the lowercase version of a text structured as an LLM."""

async def generate_stream(
self, prompt: str, sampling_params: SamplingParams | None = None
) -> AsyncGenerator[LLMOutput, None]:
"""Generate text stream.
Args:
prompt (str): The prompt.
sampling_params (SamplingParams): The sampling parameters.
Yields:
LLMOutput: The generated text.
"""
for char in prompt:
yield LLMOutput(text=char.lower())

async def chat(
self, dialog: ChatDialog, sampling_params: SamplingParams | None = None
) -> ChatOutput:
"""Dummy chat method."""
text = dialog.messages[-1].content
return ChatOutput(message=ChatMessage(content=text.lower(), role="assistant"))

async def chat_stream(
self, dialog: ChatDialog, sampling_params: SamplingParams | None = None
) -> AsyncGenerator[LLMOutput, None]:
"""Dummy chat stream method."""
text = dialog.messages[-1].content
for char in text:
yield LLMOutput(text=char.lower())


deployments = [
{
"name": "lowercase_deployment",
"instance": LowercaseLLM,
}
]


def test_chat_completion(app_setup):
"""Test the chat completion endpoint for OpenAI compatible API."""
aana_app = app_setup(deployments, [])

port = aana_app.port
route_prefix = ""

# Check that the server is ready
response = requests.get(f"http://localhost:{port}{route_prefix}/api/ready")
assert response.status_code == 200
assert response.json() == {"ready": True}

messages = [
{"role": "user", "content": "Hello World!"},
]
expected_output = messages[0]["content"].lower()

client = OpenAI(
api_key="token",
base_url=f"http://localhost:{port}",
)

# Test chat completion endpoint
completion = client.chat.completions.create(
messages=messages,
model="lowercase_deployment",
)
assert completion.choices[0].message.content == expected_output

# Test chat completion endpoint with stream
stream = client.chat.completions.create(
messages=messages,
model="lowercase_deployment",
stream=True,
)
generated_text = ""
for chunk in stream:
generated_text += chunk.choices[0].delta.content or ""
assert generated_text == expected_output

# Test chat completion endpoint with non-existent model
with pytest.raises(NotFoundError) as exc_info:
completion = client.chat.completions.create(
messages=messages,
model="non_existent_model",
)
assert (
exc_info.value.body["message"]
== "The model `non_existent_model` does not exist."
)
4 changes: 4 additions & 0 deletions docs/integrations.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ HfPipelineDeployment.options(
Haystack integration allows you to build Retrieval-Augmented Generation (RAG) systems with the [Deepset Haystack](https://github.com/deepset-ai/haystack).

TODO: Add example

## OpenAI-compatible Chat Completions API

The OpenAI-compatible Chat Completions API allows you to access the Aana applications with any OpenAI-compatible client. See [OpenAI-compatible API docs](/docs/openai_api.md) for more details.
95 changes: 95 additions & 0 deletions docs/openai_api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# OpenAI-compatible API

Aana SDK provides an OpenAI-compatible Chat Completions API that allows you to integrate Aana with any OpenAI-compatible application.

Chat Completions API is available at the `/chat/completions` endpoint.

It is compatible with the OpenAI client libraries and can be used as a drop-in replacement for OpenAI API.

```python
from openai import OpenAI

client = OpenAI(
api_key="token", # Any non empty string will work, we don't require an API key
base_url="http://localhost:8000",
)

messages = [
{"role": "user", "content": "What is the capital of France?"}
]

completion = client.chat.completions.create(
messages=messages,
model="llm_deployment",
)

print(completion.choices[0].message.content)
```

The API also supports streaming:

```python
from openai import OpenAI

client = OpenAI(
api_key="token", # Any non empty string will work, we don't require an API key
base_url="http://localhost:8000",
)

messages = [
{"role": "user", "content": "What is the capital of France?"}
]

stream = client.chat.completions.create(
messages=messages,
model="llm_deployment",
stream=True,
)
for chunk in stream:
print(chunk.choices[0].delta.content or "", end="")
```

The API requires an LLM deployment. Aana SDK provides support for [vLLM](/docs/integrations.md#vllm) and [Hugging Face Transformers](/docs/integrations.md#hugging-face-transformers).

The name of the model matches the name of the deployment. For example, if you registered a vLLM deployment with the name `llm_deployment`, you can use it with the OpenAI API as `model="llm_deployment"`.

```python
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from aana.core.models.sampling import SamplingParams
from aana.core.models.types import Dtype
from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment
from aana.sdk import AanaSDK

llm_deployment = VLLMDeployment.options(
num_replicas=1,
ray_actor_options={"num_gpus": 1},
user_config=VLLMConfig(
model="TheBloke/Llama-2-7b-Chat-AWQ",
dtype=Dtype.AUTO,
quantization="awq",
gpu_memory_reserved=13000,
enforce_eager=True,
default_sampling_params=SamplingParams(
temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1024
),
chat_template="llama2",
).model_dump(mode="json"),
)

aana_app = AanaSDK(name="llm_app")
aana_app.register_deployment(name="llm_deployment", instance=llm_deployment)

if __name__ == "__main__":
aana_app.connect()
aana_app.migrate()
aana_app.deploy()
```

You can also use the example project `llama2` to deploy Llama-2-7b Chat model.

```bash
CUDA_VISIBLE_DEVICES=0 aana deploy aana.projects.llama2.app:aana_app
```

0 comments on commit 0798595

Please sign in to comment.