Skip to content

Commit

Permalink
fix: fixed bug in truncate prompt (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Oct 24, 2024
1 parent c4eb9ed commit 3cce291
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 29 deletions.
6 changes: 3 additions & 3 deletions aidial_adapter_vertexai/chat/bison/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import AsyncIterator, List, Tuple
from typing import AsyncIterator, List

from aidial_sdk.chat_completion import FinishReason, Message
from typing_extensions import override
Expand All @@ -15,7 +15,7 @@
)
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import DiscardedMessages
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import ModelParameters
from aidial_adapter_vertexai.dial_api.token_usage import TokenUsage
from aidial_adapter_vertexai.utils.log_config import vertex_ai_logger as log
Expand Down Expand Up @@ -44,7 +44,7 @@ async def parse_prompt(
@override
async def truncate_prompt(
self, prompt: BisonPrompt, max_prompt_tokens: int
) -> Tuple[DiscardedMessages, BisonPrompt]:
) -> TruncatedPrompt[BisonPrompt]:
return await prompt.truncate(
tokenizer=self.count_prompt_tokens, user_limit=max_prompt_tokens
)
Expand Down
6 changes: 3 additions & 3 deletions aidial_adapter_vertexai/chat/chat_completion_adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from abc import ABC, abstractmethod
from typing import Generic, List, Tuple, TypeVar
from typing import Generic, List, TypeVar

from aidial_sdk.chat_completion import Message

from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import UserError
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import DiscardedMessages
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import ModelParameters
from aidial_adapter_vertexai.utils.not_implemented import not_implemented

Expand All @@ -29,7 +29,7 @@ async def chat(
@not_implemented
async def truncate_prompt(
self, prompt: P, max_prompt_tokens: int
) -> Tuple[DiscardedMessages, P]: ...
) -> TruncatedPrompt: ...

@not_implemented
async def count_prompt_tokens(self, prompt: P) -> int: ...
Expand Down
5 changes: 2 additions & 3 deletions aidial_adapter_vertexai/chat/gemini/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Dict,
List,
Optional,
Tuple,
TypeVar,
assert_never,
cast,
Expand Down Expand Up @@ -43,7 +42,7 @@
Gemini_1_5_Prompt,
)
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import DiscardedMessages
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.deployments import (
ChatCompletionDeployment,
GeminiDeployment,
Expand Down Expand Up @@ -228,7 +227,7 @@ async def chat(
@override
async def truncate_prompt(
self, prompt: GeminiPrompt, max_prompt_tokens: int
) -> Tuple[DiscardedMessages, GeminiPrompt]:
) -> TruncatedPrompt[GeminiPrompt]:
return await prompt.truncate(
tokenizer=self.count_prompt_tokens, user_limit=max_prompt_tokens
)
Expand Down
8 changes: 4 additions & 4 deletions aidial_adapter_vertexai/chat/imagen/adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional

from aidial_sdk.chat_completion import Attachment, Message
from PIL import Image as PIL_Image
Expand All @@ -15,7 +15,7 @@
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import ValidationError
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import DiscardedMessages
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import (
ModelParameters,
collect_text_content,
Expand Down Expand Up @@ -59,8 +59,8 @@ async def parse_prompt(
@override
async def truncate_prompt(
self, prompt: ImagenPrompt, max_prompt_tokens: int
) -> Tuple[DiscardedMessages, ImagenPrompt]:
return [], prompt
) -> TruncatedPrompt[ImagenPrompt]:
return TruncatedPrompt(discarded_messages=[], prompt=prompt)

@staticmethod
def get_image_type(image: PIL_Image.Image) -> str:
Expand Down
31 changes: 25 additions & 6 deletions aidial_adapter_vertexai/chat/truncate_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from abc import ABC, abstractmethod
from typing import Awaitable, Callable, List, Optional, Self, Set, Sized, Tuple
from typing import (
Awaitable,
Callable,
Generic,
List,
Optional,
Self,
Set,
Sized,
TypeVar,
)

from aidial_sdk.exceptions import ContextLengthExceededError
from aidial_sdk.exceptions import HTTPException as DialException
Expand All @@ -9,6 +19,15 @@
)
from pydantic import BaseModel

DiscardedMessages = List[int]

_P = TypeVar("_P")


class TruncatedPrompt(BaseModel, Generic[_P]):
prompt: _P
discarded_messages: DiscardedMessages


class TruncatePromptError(ABC, BaseModel):
@abstractmethod
Expand Down Expand Up @@ -64,9 +83,6 @@ def _partition_indexer(chunks: List[int]) -> Callable[[int], List[int]]:
return mapping.__getitem__


DiscardedMessages = List[int]


class TruncatablePrompt(ABC, Sized):

@abstractmethod
Expand Down Expand Up @@ -114,7 +130,7 @@ async def truncate(
tokenizer: Callable[[Self], Awaitable[int]],
model_limit: Optional[int] = None,
user_limit: Optional[int] = None,
) -> Tuple[DiscardedMessages, Self]:
) -> TruncatedPrompt[Self]:
"""
Returns a list of indices of discarded messages and
the truncated prompt that doesn't include the discarded messages and fits into the given user limit.
Expand All @@ -139,7 +155,10 @@ async def truncate(
if isinstance(result, TruncatePromptError):
raise result.to_dial_exception()

return (list(result), self.omit(set(result)))
return TruncatedPrompt(
discarded_messages=list(result),
prompt=self.omit(set(result)),
)

async def compute_discarded_messages(
self,
Expand Down
26 changes: 17 additions & 9 deletions aidial_adapter_vertexai/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from aidial_adapter_vertexai.adapters import get_chat_completion_model
from aidial_adapter_vertexai.chat.chat_completion_adapter import (
ChatCompletionAdapter,
TruncatedPrompt,
)
from aidial_adapter_vertexai.chat.consumer import ChoiceConsumer
from aidial_adapter_vertexai.chat.errors import UserError, ValidationError
Expand Down Expand Up @@ -69,14 +70,16 @@ async def chat_completion(self, request: Request, response: Response):
if n > 1 and params.stream:
raise ValidationError("n>1 is not supported in streaming mode")

discarded_messages: List[int] = []
if params.max_prompt_tokens is not None:
if params.max_prompt_tokens is None:
truncated_prompt = TruncatedPrompt(
prompt=prompt, discarded_messages=[]
)
else:
if not is_implemented(model.truncate_prompt):
raise ValidationError(
"max_prompt_tokens request parameter is not supported"
)

prompt, discarded_messages = await model.truncate_prompt(
truncated_prompt = await model.truncate_prompt(
prompt, params.max_prompt_tokens
)

Expand All @@ -85,7 +88,7 @@ async def generate_response(usage: TokenUsage, choice_idx: int) -> None:
choice.open()

consumer = ChoiceConsumer(choice)
await model.chat(params, consumer, prompt)
await model.chat(params, consumer, truncated_prompt.prompt)
usage.accumulate(consumer.usage)

finish_reason = consumer.finish_reason
Expand All @@ -102,7 +105,7 @@ async def generate_response(usage: TokenUsage, choice_idx: int) -> None:
response.set_usage(usage.prompt_tokens, usage.completion_tokens)

if params.max_prompt_tokens is not None:
response.set_discarded_messages(discarded_messages)
response.set_discarded_messages(truncated_prompt.discarded_messages)

@override
@dial_exception_decorator
Expand Down Expand Up @@ -174,9 +177,14 @@ async def _truncate_prompt_request(
if request.max_prompt_tokens is None:
raise ValidationError("max_prompt_tokens is required")

discarded_messages, _prompt = await model.truncate_prompt(
request.messages, request.max_prompt_tokens
tools = ToolsConfig.from_request(request)
prompt = await model.parse_prompt(tools, request.messages)

truncated_prompt = await model.truncate_prompt(
prompt, request.max_prompt_tokens
)
return TruncatePromptSuccess(
discarded_messages=truncated_prompt.discarded_messages
)
return TruncatePromptSuccess(discarded_messages=discarded_messages)
except Exception as e:
return TruncatePromptError(error=str(e))
2 changes: 1 addition & 1 deletion tests/unit_tests/prompt_truncation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ async def get_discarded_messages(
) -> DiscardedMessages:
return (
await prompt.truncate(tokenizer=tokenizer, user_limit=max_prompt_tokens)
)[0]
).discarded_messages

0 comments on commit 3cce291

Please sign in to comment.