Skip to content

Commit

Permalink
Merge pull request #1 from igorbenav/unified-error-handling
Browse files Browse the repository at this point in the history
unified error handling
  • Loading branch information
igorbenav authored Oct 25, 2024
2 parents 411f8d7 + c3b3aa9 commit 3676ebf
Show file tree
Hide file tree
Showing 16 changed files with 995 additions and 152 deletions.
6 changes: 2 additions & 4 deletions clientai/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def generate_text(
return_full_response: bool = False,
stream: bool = False,
**kwargs: Any,
) -> GenericResponse[R, T, S]:
...
) -> GenericResponse[R, T, S]: ...

def chat(
self,
Expand All @@ -36,8 +35,7 @@ def chat(
return_full_response: bool = False,
stream: bool = False,
**kwargs: Any,
) -> GenericResponse[R, T, S]:
...
) -> GenericResponse[R, T, S]: ...


P = TypeVar("P", bound=AIProviderProtocol)
Expand Down
104 changes: 104 additions & 0 deletions clientai/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import Optional, Type


class ClientAIError(Exception):
"""Base exception class for ClientAI errors."""

def __init__(
self,
message: str,
status_code: Optional[int] = None,
original_error: Optional[Exception] = None,
):
super().__init__(message)
self.status_code = status_code
self.original_error = original_error

def __str__(self):
error_msg = super().__str__()
if self.status_code:
error_msg = f"[{self.status_code}] {error_msg}"
return error_msg

@property
def original_exception(self) -> Optional[Exception]:
"""Returns the original exception object if available."""
return self.original_error


class AuthenticationError(ClientAIError):
"""Raised when there's an authentication problem with the AI provider."""


class APIError(ClientAIError):
"""Raised when there's an API-related error from the AI provider."""


class RateLimitError(ClientAIError):
"""Raised when the AI provider's rate limit is exceeded."""


class InvalidRequestError(ClientAIError):
"""Raised when the request to the AI provider is invalid."""


class ModelError(ClientAIError):
"""Raised when there's an issue with the specified model."""


class ProviderNotInstalledError(ClientAIError):
"""Raised when the required provider package is not installed."""


class TimeoutError(ClientAIError):
"""Raised when a request to the AI provider times out."""


def map_status_code_to_exception(
status_code: int, message: str, original_error: Optional[Exception] = None
) -> Type[ClientAIError]:
"""
Maps an HTTP status code to the appropriate ClientAI exception class.
Args:
status_code (int): The HTTP status code.
message (str): The error message.
original_error (Exception, optional): The original exception caught.
Returns:
Type[ClientAIError]: The appropriate ClientAI exception class.
"""
if status_code == 401:
return AuthenticationError
elif status_code == 429:
return RateLimitError
elif status_code == 400:
return InvalidRequestError
elif status_code == 404:
return ModelError
elif status_code == 408:
return TimeoutError
elif status_code >= 500:
return APIError
else:
return APIError


def raise_clientai_error(
status_code: int, message: str, original_error: Optional[Exception] = None
) -> None:
"""
Raises the appropriate ClientAI exception based on the status code.
Args:
status_code (int): The HTTP status code.
message (str): The error message.
original_error (Exception, optional): The original exception caught.
Raises:
ClientAIError: The appropriate ClientAI exception.
"""
exception_class = map_status_code_to_exception(
status_code, message, original_error
)
raise exception_class(message, status_code, original_error)
6 changes: 2 additions & 4 deletions clientai/ollama/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,15 @@ class OllamaChatResponse(TypedDict):
class OllamaClientProtocol(Protocol):
def generate(
self, model: str, prompt: str, stream: bool = False, **kwargs: Any
) -> Union[OllamaResponse, Iterator[OllamaStreamResponse]]:
...
) -> Union[OllamaResponse, Iterator[OllamaStreamResponse]]: ...

def chat(
self,
model: str,
messages: List[Message],
stream: bool = False,
**kwargs: Any,
) -> Union[OllamaChatResponse, Iterator[OllamaStreamResponse]]:
...
) -> Union[OllamaChatResponse, Iterator[OllamaStreamResponse]]: ...


Client = "ollama.Client"
110 changes: 78 additions & 32 deletions clientai/ollama/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
from typing import Any, List, Optional, Union, cast

from ..ai_provider import AIProvider
from ..exceptions import (
APIError,
AuthenticationError,
ClientAIError,
InvalidRequestError,
ModelError,
RateLimitError,
TimeoutError,
)
from . import OLLAMA_INSTALLED
from ._typing import (
Message,
Expand Down Expand Up @@ -98,6 +107,35 @@ def _stream_chat_response(
else:
yield chunk["message"]["content"]

def _map_exception_to_clientai_error(self, e: Exception) -> ClientAIError:
"""
Maps an Ollama exception to the appropriate ClientAI exception.
Args:
e (Exception): The exception caught during the API call.
Returns:
ClientAIError: An instance of the appropriate ClientAI exception.
"""
message = str(e)

if isinstance(e, ollama.RequestError):
if "authentication" in message.lower():
return AuthenticationError(message, original_error=e)
elif "rate limit" in message.lower():
return RateLimitError(message, original_error=e)
elif "not found" in message.lower():
return ModelError(message, original_error=e)
else:
return InvalidRequestError(message, original_error=e)
elif isinstance(e, ollama.ResponseError):
if "timeout" in message.lower() or "timed out" in message.lower():
return TimeoutError(message, original_error=e)
else:
return APIError(message, original_error=e)
else:
return ClientAIError(message, original_error=e)

def generate_text(
self,
prompt: str,
Expand Down Expand Up @@ -152,24 +190,28 @@ def generate_text(
print(chunk, end="", flush=True)
```
"""
response = self.client.generate(
model=model, prompt=prompt, stream=stream, **kwargs
)

if stream:
return cast(
OllamaGenericResponse,
self._stream_generate_response(
cast(Iterator[OllamaStreamResponse], response),
return_full_response,
),
try:
response = self.client.generate(
model=model, prompt=prompt, stream=stream, **kwargs
)
else:
response = cast(OllamaResponse, response)
if return_full_response:
return response

if stream:
return cast(
OllamaGenericResponse,
self._stream_generate_response(
cast(Iterator[OllamaStreamResponse], response),
return_full_response,
),
)
else:
return response["response"]
response = cast(OllamaResponse, response)
if return_full_response:
return response
else:
return response["response"]

except Exception as e:
raise self._map_exception_to_clientai_error(e)

def chat(
self,
Expand Down Expand Up @@ -231,21 +273,25 @@ def chat(
print(chunk, end="", flush=True)
```
"""
response = self.client.chat(
model=model, messages=messages, stream=stream, **kwargs
)

if stream:
return cast(
OllamaGenericResponse,
self._stream_chat_response(
cast(Iterator[OllamaChatResponse], response),
return_full_response,
),
try:
response = self.client.chat(
model=model, messages=messages, stream=stream, **kwargs
)
else:
response = cast(OllamaChatResponse, response)
if return_full_response:
return response

if stream:
return cast(
OllamaGenericResponse,
self._stream_chat_response(
cast(Iterator[OllamaChatResponse], response),
return_full_response,
),
)
else:
return response["message"]["content"]
response = cast(OllamaChatResponse, response)
if return_full_response:
return response
else:
return response["message"]["content"]

except Exception as e:
raise self._map_exception_to_clientai_error(e)
6 changes: 2 additions & 4 deletions clientai/openai/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ class OpenAIStreamResponse:
class OpenAIChatCompletionProtocol(Protocol):
def create(
self, **kwargs: Any
) -> Union[OpenAIResponse, Iterator[OpenAIStreamResponse]]:
...
) -> Union[OpenAIResponse, Iterator[OpenAIStreamResponse]]: ...


class OpenAIChatProtocol(Protocol):
Expand All @@ -90,8 +89,7 @@ def create(
messages: List[Message],
stream: bool = False,
**kwargs: Any,
) -> Union[OpenAIResponse, OpenAIStreamResponse]:
...
) -> Union[OpenAIResponse, OpenAIStreamResponse]: ...


OpenAIProvider = Any
Expand Down
Loading

0 comments on commit 3676ebf

Please sign in to comment.