Skip to content

Commit

Permalink
chore: added tests for truncate_prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Jan 8, 2025
1 parent ea966f4 commit a164f15
Show file tree
Hide file tree
Showing 7 changed files with 309 additions and 37 deletions.
46 changes: 26 additions & 20 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
TruncatePromptResult,
TruncatePromptSuccess,
)
from aidial_sdk.exceptions import ResourceNotFoundError
from aidial_sdk.exceptions import HTTPException as DialException
from typing_extensions import override

from aidial_adapter_bedrock.adapter_deployments import (
Expand All @@ -29,17 +29,16 @@
from aidial_adapter_bedrock.aws_client_config import AWSClientConfigFactory
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_model import (
ChatCompletionAdapter,
TextCompletionAdapter,
)
from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter
from aidial_adapter_bedrock.llm.consumer import ChoiceConsumer
from aidial_adapter_bedrock.llm.errors import UserError, ValidationError
from aidial_adapter_bedrock.llm.model.adapter import get_bedrock_adapter
from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages
from aidial_adapter_bedrock.server.exceptions import dial_exception_decorator
from aidial_adapter_bedrock.server.exceptions import (
dial_exception_decorator,
not_implemented_handler,
)
from aidial_adapter_bedrock.utils.log_config import app_logger as log
from aidial_adapter_bedrock.utils.not_implemented import is_implemented


class BedrockChatCompletion(ChatCompletion):
Expand Down Expand Up @@ -72,11 +71,6 @@ async def generate_response(usage: TokenUsage) -> None:
nonlocal discarded_messages

with ChoiceConsumer(response=response) as consumer:
if isinstance(model, TextCompletionAdapter):
consumer.set_tools_emulator(
model.tools_emulator(params.tool_config)
)

try:
await model.chat(consumer, params, request.messages)
except UserError as e:
Expand All @@ -101,14 +95,10 @@ async def generate_response(usage: TokenUsage) -> None:

@override
@dial_exception_decorator
@not_implemented_handler
async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse:
model = await self._get_model(request)

if not is_implemented(
model.count_completion_tokens
) or not is_implemented(model.count_prompt_tokens):
raise ResourceNotFoundError("The endpoint is not implemented")

outputs: List[TokenizeOutput] = []
for input in request.inputs:
match input:
Expand All @@ -130,6 +120,12 @@ async def _tokenize_string(
try:
tokens = await model.count_completion_tokens(value)
return TokenizeSuccess(token_count=tokens)
except NotImplementedError:
raise
except DialException as e:
# FIXME: remove when the issue is fixed:
# https://github.com/epam/ai-dial-sdk/issues/207
return TokenizeError(error=e.message)
except Exception as e:
return TokenizeError(error=str(e))

Expand All @@ -143,19 +139,23 @@ async def _tokenize_request(
params, request.messages
)
return TokenizeSuccess(token_count=token_count)
except NotImplementedError:
raise
except DialException as e:
# FIXME: remove when the issue is fixed:
# https://github.com/epam/ai-dial-sdk/issues/207
return TokenizeError(error=e.message)
except Exception as e:
return TokenizeError(error=str(e))

@override
@dial_exception_decorator
@not_implemented_handler
async def truncate_prompt(
self, request: TruncatePromptRequest
) -> TruncatePromptResponse:
model = await self._get_model(request)

if not is_implemented(model.compute_discarded_messages):
raise ResourceNotFoundError("The endpoint is not implemented")

outputs: List[TruncatePromptResult] = []
for input in request.inputs:
outputs.append(await self._truncate_prompt_request(model, input))
Expand All @@ -177,5 +177,11 @@ async def _truncate_prompt_request(
return TruncatePromptSuccess(
discarded_messages=discarded_messages or []
)
except NotImplementedError:
raise
except DialException as e:
# FIXME: remove when the issue is fixed:
# https://github.com/epam/ai-dial-sdk/issues/207
return TruncatePromptError(error=e.message)
except Exception as e:
return TruncatePromptError(error=str(e))
11 changes: 5 additions & 6 deletions aidial_adapter_bedrock/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
truncate_prompt,
)
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log
from aidial_adapter_bedrock.utils.not_implemented import not_implemented


def _is_empty_system_message(msg: Message) -> bool:
Expand All @@ -44,15 +43,14 @@ async def chat(
) -> None:
pass

@not_implemented
async def count_prompt_tokens(
self, params: ModelParameters, messages: List[Message]
) -> int: ...
) -> int:
raise NotImplementedError

@not_implemented
async def count_completion_tokens(self, string: str) -> int: ...
async def count_completion_tokens(self, string: str) -> int:
raise NotImplementedError

@not_implemented
async def compute_discarded_messages(
self, params: ModelParameters, messages: List[Message]
) -> DiscardedMessages | None:
Expand All @@ -64,6 +62,7 @@ async def compute_discarded_messages(
Otherwise, returns the indices of _discarded_ messages which should be
removed from the list to make the rest fit into the token limit.
"""
raise NotImplementedError


class TextCompletionPrompt(BaseModel):
Expand Down
17 changes: 16 additions & 1 deletion aidial_adapter_bedrock/server/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from typing import assert_never

from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InternalServerError, InvalidRequestError
from aidial_sdk.exceptions import (
InternalServerError,
InvalidRequestError,
ResourceNotFoundError,
)
from anthropic import APIStatusError
from botocore.exceptions import ClientError

Expand Down Expand Up @@ -149,3 +153,14 @@ async def wrapper(*args, **kwargs):
raise dial_exception from e

return wrapper


def not_implemented_handler(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except NotImplementedError:
raise ResourceNotFoundError("The endpoint is not implemented")

return wrapper
7 changes: 0 additions & 7 deletions aidial_adapter_bedrock/utils/not_implemented.py

This file was deleted.

202 changes: 202 additions & 0 deletions tests/integration_tests/test_truncate_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import json
import re
from dataclasses import dataclass
from typing import Callable, List

import httpx
import pytest
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel

from aidial_adapter_bedrock.deployments import ChatCompletionDeployment
from tests.integration_tests.test_chat_completion import chat_deployments
from tests.utils.json import match_objects
from tests.utils.openai import (
ai,
sanitize_test_name,
tokenize,
truncate_prompt,
user,
)


class ExpectedException(BaseModel):
type: type[Exception]
message: str
status_code: int | None = None


def expected_success(*args, **kwargs):
return True


@dataclass
class TestCase:
__test__ = False

name: str
deployment: ChatCompletionDeployment
messages: List[ChatCompletionMessageParam]
expected: dict | Callable[[dict], bool] | ExpectedException
max_prompt_tokens: int | List[ChatCompletionMessageParam] | None

def get_id(self):
if isinstance(self.max_prompt_tokens, int):
max_prompt_tokens_str = str(self.max_prompt_tokens)
elif isinstance(self.max_prompt_tokens, list):
max_prompt_tokens_str = "dynamic"
else:
max_prompt_tokens_str = "none"

return sanitize_test_name(
f"{self.deployment.value}/maxpt:{max_prompt_tokens_str}/{self.name}"
)


def get_test_cases(deployment: ChatCompletionDeployment) -> List[TestCase]:
test_cases: List[TestCase] = []

def test_case(
name: str,
messages: List[ChatCompletionMessageParam],
expected: (
dict | Callable[[dict], bool] | ExpectedException
) = expected_success,
max_prompt_tokens: int | List[ChatCompletionMessageParam] | None = None,
) -> None:
test_cases.append(
TestCase(name, deployment, messages, expected, max_prompt_tokens)
)

test_case(
name="invalid request",
messages=[
{"role": "foo", "content": "bar"}, # type: ignore
],
expected=ExpectedException(
type=httpx.HTTPStatusError,
message=json.dumps(
{
"error": {
"message": "Your request contained invalid structure on path inputs.0.messages.0.role. value is not a valid enumeration member; permitted: 'system', 'user', 'assistant', 'function', 'tool'",
"type": "invalid_request_error",
"code": "400",
}
},
separators=(",", ":"),
),
status_code=400,
),
)

test_case(
name="no max_prompt_tokens",
messages=[
user("ping"),
ai("pong"),
user("test"),
],
expected={
"outputs": [
{"status": "error", "error": "max_prompt_tokens is required"}
]
},
)

test_case(
name="keep all",
messages=[
user("ping"),
ai("pong"),
user("test"),
],
max_prompt_tokens=1000,
expected={"outputs": [{"status": "success", "discarded_messages": []}]},
)

test_case(
name="max_prompt_tokens is too small",
messages=[
user("ping"),
ai("pong"),
user("hello world"),
],
max_prompt_tokens=1,
expected={
"outputs": [
{
"status": "error",
"error": re.compile(
r"The requested maximum prompt tokens is 1. However, the system messages and the last user message resulted in \d+ tokens. Please reduce the length of the messages or increase the maximum prompt tokens."
),
}
]
},
)

test_case(
name="keep last user message",
messages=[
user("ping"),
ai("pong"),
user("hello world"),
],
max_prompt_tokens=[user("hello world")],
expected={
"outputs": [{"status": "success", "discarded_messages": [0, 1]}]
},
)

return test_cases


@pytest.mark.parametrize(
"test",
[
test
for deployment, _region in chat_deployments.items()
for test in get_test_cases(deployment)
],
ids=lambda test: test.get_id(),
)
async def test_truncate_prompt(
test_http_client: httpx.AsyncClient, test: TestCase
):
async def run_truncate_prompt() -> dict:
if test.max_prompt_tokens is None:
max_prompt_tokens = None
elif isinstance(test.max_prompt_tokens, int):
max_prompt_tokens = test.max_prompt_tokens
elif isinstance(test.max_prompt_tokens, list):
max_prompt_tokens = await tokenize(
test_http_client, test.deployment.value, test.max_prompt_tokens
)
max_prompt_tokens = max_prompt_tokens["outputs"][0]["token_count"]

return await truncate_prompt(
test_http_client,
test.deployment.value,
test.messages,
max_prompt_tokens,
)

if isinstance(test.expected, ExpectedException):
with pytest.raises(httpx.HTTPStatusError) as exc_info:
await run_truncate_prompt()

actual_exc = exc_info.value

assert isinstance(
actual_exc, test.expected.type
), f"Actual exception type ({type(actual_exc)}) doesn't match the expected one ({test.expected.type})"
assert test.expected.status_code == actual_exc.response.status_code
assert re.search(test.expected.message, actual_exc.response.text)
else:
actual_output = await run_truncate_prompt()

if isinstance(test.expected, dict):
match_objects(test.expected, actual_output)
else:
assert test.expected(
actual_output
), f"Failed output test, actual output: {actual_output}"
Loading

0 comments on commit a164f15

Please sign in to comment.