diff --git a/aana/api/api_generation.py b/aana/api/api_generation.py index 2ed87ef1..74a0ae60 100644 --- a/aana/api/api_generation.py +++ b/aana/api/api_generation.py @@ -256,8 +256,8 @@ async def route_func_body(body: str, files: Optional[List[UploadFile]] = None): field_value = getattr(data, field_name) # check if it has a method convert_to_entities # if it does, call it to convert the model to an entity - if hasattr(field_value, "convert_to_entity"): - field_value = field_value.convert_to_entity() + if hasattr(field_value, "convert_input_to_object"): + field_value = field_value.convert_input_to_object() data_dict[field_name] = field_value if self.output_filter: diff --git a/aana/configs/deployments.py b/aana/configs/deployments.py index 7c66724e..45f32f0e 100644 --- a/aana/configs/deployments.py +++ b/aana/configs/deployments.py @@ -1,6 +1,9 @@ +from aana.deployments.hf_blip2_deployment import HFBlip2Config, HFBlip2Deployment from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment +from aana.models.core.dtype import Dtype from aana.models.pydantic.sampling_params import SamplingParams + deployments = { "vllm_deployment_llama2_7b_chat": VLLMDeployment.options( num_replicas=1, @@ -31,4 +34,15 @@ ), ).dict(), ), + "hf_blip2_deployment_opt_2_7b": HFBlip2Deployment.options( + num_replicas=1, + max_concurrent_queries=1000, + ray_actor_options={"num_gpus": 0.5}, + user_config=HFBlip2Config( + model="Salesforce/blip2-opt-2.7b", + dtype=Dtype.FLOAT16, + batch_size=2, + num_processing_threads=2, + ).dict(), + ), } diff --git a/aana/configs/endpoints.py b/aana/configs/endpoints.py index 147544ec..16577728 100644 --- a/aana/configs/endpoints.py +++ b/aana/configs/endpoints.py @@ -18,4 +18,12 @@ outputs=["vllm_zephyr_7b_beta_output"], ) ], + "blip2": [ + Endpoint( + name="blip2_generate", + path="/blip2/generate_captions", + summary="Generate captions using BLIP2 OPT-2.7B", + outputs=["captions_hf_blip2_opt_2_7b"], + ) + ], } diff --git a/aana/configs/pipeline.py b/aana/configs/pipeline.py index 57bc31ae..81ca9857 100644 --- a/aana/configs/pipeline.py +++ b/aana/configs/pipeline.py @@ -3,6 +3,7 @@ It is used to generate the pipeline and the API endpoints. """ +from aana.models.pydantic.image_input import ImageListInput from aana.models.pydantic.prompt import Prompt from aana.models.pydantic.sampling_params import SamplingParams @@ -15,6 +16,16 @@ # vllm_llama2_7b_chat_output: str # vllm_zephyr_7b_beta_output_stream: str # vllm_zephyr_7b_beta_output: str +# image_batch: ImageBatch +# +# class ImageBatch: +# images: list[Image] +# +# class Image: +# image: ImageInput +# caption_hf_blip2_opt_2_7b: str + +# pipeline configuration nodes = [ @@ -83,7 +94,7 @@ } ], }, - { + { "name": "vllm_stream_zephyr_7b_beta", "type": "ray_deployment", "deployment_name": "vllm_deployment_zephyr_7b_beta", @@ -127,4 +138,38 @@ } ], }, + { + "name": "images", + "type": "input", + "inputs": [], + "outputs": [ + { + "name": "images", + "key": "images", + "path": "image_batch.images.[*].image", + "data_model": ImageListInput, + } + ], + }, + { + "name": "hf_blip2_opt_2_7b", + "type": "ray_deployment", + "deployment_name": "hf_blip2_deployment_opt_2_7b", + "method": "generate_batch", + "inputs": [ + { + "name": "images", + "key": "images", + "path": "image_batch.images.[*].image", + "data_model": ImageListInput, + } + ], + "outputs": [ + { + "name": "captions_hf_blip2_opt_2_7b", + "key": "captions", + "path": "image_batch.images.[*].caption_hf_blip2_opt_2_7b", + } + ], + }, ] diff --git a/aana/configs/settings.py b/aana/configs/settings.py new file mode 100644 index 00000000..44a661dd --- /dev/null +++ b/aana/configs/settings.py @@ -0,0 +1,14 @@ +from pathlib import Path +from pydantic import BaseSettings + + +class Settings(BaseSettings): + """ + A pydantic model for SDK settings. + + """ + + tmp_data_dir: Path = Path("/tmp/aana_data") + + +settings = Settings() diff --git a/aana/deployments/hf_blip2_deployment.py b/aana/deployments/hf_blip2_deployment.py new file mode 100644 index 00000000..f5b5c677 --- /dev/null +++ b/aana/deployments/hf_blip2_deployment.py @@ -0,0 +1,187 @@ +from typing import Any, Dict, List, TypedDict +from pydantic import BaseModel, Field, validator +from ray import serve +import torch +from transformers import Blip2Processor, Blip2ForConditionalGeneration +from aana.deployments.base_deployment import BaseDeployment + +from aana.exceptions.general import InferenceException +from aana.models.core.dtype import Dtype +from aana.models.core.image import Image +from aana.utils.batch_processor import BatchProcessor + + +class HFBlip2Config(BaseModel): + """ + The configuration for the BLIP2 deployment with HuggingFace models. + + Attributes: + model (str): the model ID on HuggingFace + dtype (str): the data type (optional, default: "auto"), one of "auto", "float32", "float16" + batch_size (int): the batch size (optional, default: 1) + num_processing_threads (int): the number of processing threads (optional, default: 1) + """ + + model: str + dtype: Dtype = Field(default=Dtype.AUTO) + batch_size: int = Field(default=1) + num_processing_threads: int = Field(default=1) + + @validator("dtype", pre=True, always=True) + def validate_dtype(cls, value: Dtype) -> Dtype: + """ + Validate the data type. For BLIP2 only "float32" and "float16" are supported. + + Args: + value (Dtype): the data type + + Returns: + Dtype: the validated data type + + Raises: + ValueError: if the data type is not supported + """ + if value not in {Dtype.AUTO, Dtype.FLOAT32, Dtype.FLOAT16}: + raise ValueError( + f"Invalid dtype: {value}. BLIP2 only supports 'auto', 'float32', and 'float16'." + ) + return value + + +class CaptioningOutput(TypedDict): + """ + The output of the captioning model. + + Attributes: + caption (str): the caption + """ + + caption: str + + +class CaptioningBatchOutput(TypedDict): + """ + The output of the captioning model. + + Attributes: + captions (List[str]): the list of captions + """ + + captions: List[str] + + +@serve.deployment +class HFBlip2Deployment(BaseDeployment): + """ + Deployment to serve BLIP2 models using HuggingFace. + """ + + async def apply_config(self, config: Dict[str, Any]): + """ + Apply the configuration. + + The method is called when the deployment is created or updated. + + It loads the model and processor from HuggingFace. + + The configuration should conform to the HFBlip2Config schema. + """ + + config_obj = HFBlip2Config(**config) + + # Create the batch processor to split the requests into batches + # and process them in parallel + self.batch_size = config_obj.batch_size + self.num_processing_threads = config_obj.num_processing_threads + # The actual inference is done in _generate() + # We use lambda because BatchProcessor expects dict as input + # and we use **kwargs to unpack the dict into named arguments for _generate() + self.batch_processor = BatchProcessor( + process_batch=lambda request: self._generate(**request), + batch_size=self.batch_size, + num_threads=self.num_processing_threads, + ) + + # Load the model and processor for BLIP2 from HuggingFace + self.model_id = config_obj.model + self.dtype = config_obj.dtype + self.torch_dtype = self.dtype.to_torch() + self.model = Blip2ForConditionalGeneration.from_pretrained( + self.model_id, torch_dtype=self.torch_dtype + ) + self.processor = Blip2Processor.from_pretrained(self.model_id) + + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model.to(self.device) + + async def generate(self, image: Image) -> CaptioningOutput: + """ + Generate captions for the given image. + + Args: + image (Image): the image + + Returns: + CaptioningOutput: the dictionary with one key "captions" + and the list of captions for the image as value + + Raises: + InferenceException: if the inference fails + """ + captions: CaptioningBatchOutput = await self.batch_processor.process( + {"images": [image]} + ) + return CaptioningOutput(caption=captions["captions"][0]) + + async def generate_batch(self, **kwargs) -> CaptioningBatchOutput: + """ + Generate captions for the given images. + + Args: + images (List[Image]): the images + + Returns: + CaptioningBatchOutput: the dictionary with one key "captions" + and the list of captions for the images as value + + Raises: + InferenceException: if the inference fails + """ + # Call the batch processor to process the requests + # The actual inference is done in _generate() + return await self.batch_processor.process(kwargs) + + def _generate(self, images: List[Image]) -> CaptioningBatchOutput: + """ + Generate captions for the given images. + + This method is called by the batch processor. + + Args: + images (List[Image]): the images + + Returns: + CaptioningBatchOutput: the dictionary with one key "captions" + and the list of captions for the images as value + + Raises: + InferenceException: if the inference fails + """ + numpy_images = [] + for image in images: + numpy_images.append(image.get_numpy()) # loading images + inputs = self.processor(numpy_images, return_tensors="pt").to( + self.device, self.torch_dtype + ) + + try: + generated_ids = self.model.generate(**inputs) + generated_texts = self.processor.batch_decode( + generated_ids, skip_special_tokens=True + ) + generated_texts = [ + generated_text.strip() for generated_text in generated_texts + ] + return CaptioningBatchOutput(captions=generated_texts) + except Exception as e: + raise InferenceException(self.model_id) from e diff --git a/aana/deployments/vllm_deployment.py b/aana/deployments/vllm_deployment.py index 1175c783..a58a9d69 100644 --- a/aana/deployments/vllm_deployment.py +++ b/aana/deployments/vllm_deployment.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional, TypedDict from pydantic import BaseModel, Field from ray import serve from vllm.engine.arg_utils import AsyncEngineArgs @@ -34,6 +34,28 @@ class VLLMConfig(BaseModel): max_model_len: Optional[int] = Field(default=None) +class LLMOutput(TypedDict): + """ + The output of the LLM model. + + Attributes: + text (str): the generated text + """ + + text: str + + +class LLMBatchOutput(TypedDict): + """ + The output of the LLM model for a batch of inputs. + + Attributes: + texts (List[str]): the list of generated texts + """ + + texts: List[str] + + @serve.deployment class VLLMDeployment(BaseDeployment): """ @@ -78,7 +100,9 @@ async def apply_config(self, config: Dict[str, Any]): # create the engine self.engine = AsyncLLMEngine.from_engine_args(args) - async def generate_stream(self, prompt: str, sampling_params: SamplingParams): + async def generate_stream( + self, prompt: str, sampling_params: SamplingParams + ) -> AsyncGenerator[LLMOutput, None]: """ Generate completion for the given prompt and stream the results. @@ -87,7 +111,7 @@ async def generate_stream(self, prompt: str, sampling_params: SamplingParams): sampling_params (SamplingParams): the sampling parameters Yields: - dict: the generated text + LLMOutput: the dictionary with the key "text" and the generated text as the value """ prompt = str(prompt) sampling_params = merged_options(self.default_sampling_params, sampling_params) @@ -108,7 +132,7 @@ async def generate_stream(self, prompt: str, sampling_params: SamplingParams): num_returned = 0 async for request_output in results_generator: text_output = request_output.outputs[0].text[num_returned:] - yield {"text": text_output} + yield LLMOutput(text=text_output) num_returned += len(text_output) except GeneratorExit as e: # If the generator is cancelled, we need to cancel the request @@ -118,7 +142,7 @@ async def generate_stream(self, prompt: str, sampling_params: SamplingParams): except Exception as e: raise InferenceException(model_name=self.model) from e - async def generate(self, prompt: str, sampling_params: SamplingParams): + async def generate(self, prompt: str, sampling_params: SamplingParams) -> LLMOutput: """ Generate completion for the given prompt. @@ -127,14 +151,16 @@ async def generate(self, prompt: str, sampling_params: SamplingParams): sampling_params (SamplingParams): the sampling parameters Returns: - dict: the generated text + LLMOutput: the dictionary with the key "text" and the generated text as the value """ generated_text = "" async for chunk in self.generate_stream(prompt, sampling_params): generated_text += chunk["text"] - return {"text": generated_text} + return LLMOutput(text=generated_text) - async def generate_batch(self, prompts: List[str], sampling_params: SamplingParams): + async def generate_batch( + self, prompts: List[str], sampling_params: SamplingParams + ) -> LLMBatchOutput: """ Generate completion for the batch of prompts. @@ -143,11 +169,12 @@ async def generate_batch(self, prompts: List[str], sampling_params: SamplingPara sampling_params (SamplingParams): the sampling parameters Returns: - dict: the generated texts + LLMBatchOutput: the dictionary with the key "texts" + and the list of generated texts as the value """ texts = [] for prompt in prompts: text = await self.generate(prompt, sampling_params) texts.append(text["text"]) - return {"texts": texts} + return LLMBatchOutput(texts=texts) diff --git a/aana/exceptions/general.py b/aana/exceptions/general.py index ba0e4a3a..fa19ea51 100644 --- a/aana/exceptions/general.py +++ b/aana/exceptions/general.py @@ -1,5 +1,8 @@ -from typing import Any, Dict from mobius_pipeline.exceptions import BaseException +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from aana.models.core.image import Image class InferenceException(BaseException): @@ -16,9 +19,8 @@ def __init__(self, model_name): Args: model_name (str): name of the model that caused the exception """ - super().__init__() + super().__init__(model_name=model_name) self.model_name = model_name - self.extra["model_name"] = model_name def __reduce__(self): # This method is called when the exception is pickled @@ -43,8 +45,52 @@ def __init__(self, input_name: str): Args: input_name (str): name of the input that caused the exception """ + super().__init__(input_name=input_name) self.input_name = input_name - super().__init__() def __reduce__(self): return (self.__class__, (self.input_name,)) + + +class ImageReadingException(BaseException): + """ + Exception raised when there is an error reading an image. + + Attributes: + image (Image): the image that caused the exception + """ + + def __init__(self, image: "Image"): + """ + Initialize the exception. + + Args: + image (Image): the image that caused the exception + """ + super().__init__(image=image) + self.image = image + + def __reduce__(self): + return (self.__class__, (self.image,)) + + +class DownloadException(BaseException): + """ + Exception raised when there is an error downloading a file. + + Attributes: + url (str): the URL of the file that caused the exception + """ + + def __init__(self, url: str): + """ + Initialize the exception. + + Args: + url (str): the URL of the file that caused the exception + """ + super().__init__(url=url) + self.url = url + + def __reduce__(self): + return (self.__class__, (self.url,)) diff --git a/aana/models/core/dtype.py b/aana/models/core/dtype.py new file mode 100644 index 00000000..c656c6c1 --- /dev/null +++ b/aana/models/core/dtype.py @@ -0,0 +1,45 @@ +from enum import Enum +from typing import Union + +import torch + + +class Dtype(str, Enum): + """ + Data types. + + Possible values are "auto", "float32", "float16", and "int8". + + Attributes: + AUTO (str): auto + FLOAT32 (str): float32 + FLOAT16 (str): float16 + INT8 (str): int8 + """ + + AUTO = "auto" + FLOAT32 = "float32" + FLOAT16 = "float16" + INT8 = "int8" + + def to_torch(self) -> Union[torch.dtype, str]: + """ + Convert the instance's dtype to a torch dtype. + + Returns: + Union[torch.dtype, str]: the torch dtype or "auto" + + Raises: + ValueError: if the dtype is unknown + """ + match self.value: + case self.AUTO: + return "auto" + case self.FLOAT32: + return torch.float32 + case self.FLOAT16: + return torch.float16 + case self.INT8: + return torch.int8 + case _: + raise ValueError(f"Unknown dtype: {self}") diff --git a/aana/models/core/image.py b/aana/models/core/image.py new file mode 100644 index 00000000..22ae8dbe --- /dev/null +++ b/aana/models/core/image.py @@ -0,0 +1,422 @@ +from dataclasses import dataclass, field +import hashlib +import os +from pathlib import Path +from typing import Optional, Type +import uuid +import cv2 + +import numpy as np +from aana.configs.settings import settings + +from aana.exceptions.general import ImageReadingException +from aana.utils.general import download_file + + +class AbstractImageLibrary: + """ + Abstract class for image libraries. + """ + + @classmethod + def read_file(cls, path: Path) -> np.ndarray: + """ + Read a file using the image library. + + Args: + path (Path): The path of the file to read. + + Returns: + np.ndarray: The file as a numpy array. + """ + raise NotImplementedError + + @classmethod + def read_from_bytes(cls, content: bytes) -> np.ndarray: + """ + Read bytes using the image library. + + Args: + content (bytes): The content of the file to read. + + Returns: + np.ndarray: The file as a numpy array. + """ + raise NotImplementedError + + @classmethod + def write_file(cls, path: Path, img: np.ndarray): + """ + Write a file using the image library. + + Args: + path (Path): The path of the file to write. + img (np.ndarray): The image to write. + """ + raise NotImplementedError + + @classmethod + def write_to_bytes(cls, img: np.ndarray) -> bytes: + """ + Write bytes using the image library. + + Args: + img (np.ndarray): The image to write. + + Returns: + bytes: The image as bytes. + """ + raise NotImplementedError + + +class OpenCVWrapper(AbstractImageLibrary): + """ + Wrapper class for OpenCV functions. + """ + + @classmethod + def read_file(cls, path: Path) -> np.ndarray: + """ + Read a file using OpenCV. + + Args: + path (Path): The path of the file to read. + + Returns: + np.ndarray: The file as a numpy array in RGB format. + """ + img = cv2.imread(str(path)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + @classmethod + def read_from_bytes(cls, content: bytes) -> np.ndarray: + """ + Read bytes using OpenCV. + + Args: + content (bytes): The content of the file to read. + + Returns: + np.ndarray: The file as a numpy array in RGB format. + """ + img = cv2.imdecode(np.frombuffer(content, np.uint8), cv2.IMREAD_COLOR) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + @classmethod + def write_file(cls, path: Path, img: np.ndarray): + """ + Write a file using OpenCV. + + Args: + path (Path): The path of the file to write. + img (np.ndarray): The image to write. + """ + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + cv2.imwrite(str(path), img) + + @classmethod + def write_to_bytes(cls, img: np.ndarray) -> bytes: + """ + Write bytes using OpenCV. + + Args: + img (np.ndarray): The image to write. + + Returns: + bytes: The image as bytes. + """ + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + _, buffer = cv2.imencode(".bmp", img) + return buffer.tobytes() + + +@dataclass +class Image: + path: Optional[Path] = None # The file path of the image. + url: Optional[str] = None # The URL of the image. + content: Optional[ + bytes + ] = None # The content of the image in bytes (image file as bytes). + numpy: Optional[np.ndarray] = None # The image as a numpy array. + image_id: Optional[str] = field( + default_factory=lambda: str(uuid.uuid4()) + ) # The ID of the image, generated automatically + save_on_disk: bool = True # Whether to save the image on disk or not + image_lib: Type[ + AbstractImageLibrary + ] = OpenCVWrapper # The image library to use, TODO: add support for PIL and allow to choose the library + is_saved: bool = False # Whether the image is saved on disk by the class or not (used for cleanup) + + def __post_init__(self): + """ + Post-initialization method. + + Performs checks: + - Checks that path is a Path object. + - Checks that at least one of 'path', 'url', 'content' or 'numpy' is provided. + - Checks if path exists if provided. + + Saves the image on disk if needed. + """ + # check that path is a Path object + if self.path: + if not isinstance(self.path, Path): + raise ValueError("Path must be a Path object.") + + # check that at least one of 'path', 'url', 'content' or 'numpy' is provided + if not any( + [ + self.path is not None, + self.url is not None, + self.content is not None, + self.numpy is not None, + ] + ): + raise ValueError( + "At least one of 'path', 'url', 'content' or 'numpy' must be provided." + ) + + # check if path exists if provided + if self.path and not self.path.exists(): + raise FileNotFoundError(f"Image file not found: {self.path}") + + if self.save_on_disk: + self.save() + + def save(self): + """ + Save the image on disk. + If the image is already available on disk, do nothing. + If the image represented as a byte string, save it on disk. + If the image is represented as a URL, download it and save it on disk. + If the image is represented as a numpy array, convert it to BMP and save it on disk. + + First check if the image is already available on disk, then content, then url, then numpy + to avoid unnecessary operations (e.g. downloading the image or converting it to BMP). + + Raises: + ValueError: If none of 'path', 'url', 'content' or 'numpy' is provided. + """ + if self.path: + return + + file_dir = settings.tmp_data_dir / "images" + file_dir.mkdir(parents=True, exist_ok=True) + file_path = file_dir / (self.image_id + ".bmp") + + if self.content: + self.save_from_content(file_path) + elif self.numpy is not None: + self.save_from_numpy(file_path) + elif self.url: + self.save_from_url(file_path) + else: + raise ValueError( + "At least one of 'path', 'url', 'content' or 'numpy' must be provided." + ) + self.path = file_path + self.is_saved = True + + def save_from_content(self, file_path: Path): + """ + Save the image from content on disk. + + Args: + file_path (Path): The path of the file to write. + """ + assert self.content is not None + with open(file_path, "wb") as f: + f.write(self.content) + + def save_from_numpy(self, file_path: Path): + """ + Save the image from numpy on disk. + + Args: + file_path (Path): The path of the file to write. + """ + assert self.numpy is not None + self.image_lib.write_file(file_path, self.numpy) + + def save_from_url(self, file_path: Path): + """ + Save the image from URL on disk. + + Args: + file_path (Path): The path of the file to write. + """ + assert self.url is not None + content = download_file(self.url) + file_path.write_bytes(content) + + def get_numpy(self) -> np.ndarray: + """ + Load the image as a numpy array. + + Returns: + np.ndarray: The image as a numpy array. + + Raises: + ValueError: If none of 'path', 'url', 'content' or 'numpy' is provided. + ImageReadingException: If there is an error reading the image. + """ + if self.numpy is not None: + return self.numpy + elif self.path: + self.load_numpy_from_path() + elif self.url: + self.load_numpy_from_url() + elif self.content: + self.load_numpy_from_content() + else: + raise ValueError( + "At least one of 'path', 'url', 'content' or 'numpy' must be provided." + ) + assert self.numpy is not None + return self.numpy + + def load_numpy_from_path(self): + """ + Load the image as a numpy array from a path. + + Raises: + ImageReadingException: If there is an error reading the image. + """ + assert self.path is not None + try: + self.numpy = self.image_lib.read_file(self.path) + except Exception as e: + raise ImageReadingException(self) from e + + def load_numpy_from_image_bytes(self, img_bytes: bytes): + """ + Load the image as a numpy array from image bytes (downloaded from URL or read from file). + + Raises: + ImageReadingException: If there is an error reading the image. + """ + try: + self.numpy = self.image_lib.read_from_bytes(img_bytes) + except Exception as e: + raise ImageReadingException(self) from e + + def load_numpy_from_url(self): + """ + Load the image as a numpy array from a URL. + + Raises: + ImageReadingException: If there is an error reading the image. + """ + assert self.url is not None + content: bytes = download_file(self.url) + self.load_numpy_from_image_bytes(content) + + def load_numpy_from_content(self): + """ + Load the image as a numpy array from content. + + Raises: + ImageReadingException: If there is an error reading the image. + """ + assert self.content is not None + self.load_numpy_from_image_bytes(self.content) + + def get_content(self) -> bytes: + """ + Get the content of the image as bytes. + + Returns: + bytes: The content of the image as bytes. + + Raises: + ValueError: If none of 'path', 'url', 'content' or 'numpy' is provided. + """ + if self.content: + return self.content + elif self.path: + self.load_content_from_path() + elif self.url: + self.load_content_from_url() + elif self.numpy is not None: + self.load_content_from_numpy() + else: + raise ValueError( + "At least one of 'path', 'url', 'content' or 'numpy' must be provided." + ) + assert self.content is not None + return self.content + + def load_content_from_numpy(self): + """ + Load the content of the image from numpy. + """ + assert self.numpy is not None + self.content = self.image_lib.write_to_bytes(self.numpy) + + def load_content_from_path(self): + """ + Load the content of the image from the path. + + Raises: + FileNotFoundError: If the image file does not exist. + """ + assert self.path is not None + with open(self.path, "rb") as f: + self.content = f.read() + + def load_content_from_url(self): + """ + Load the content of the image from the URL using requests. + + Raises: + DownloadException: If there is an error downloading the image. + """ + assert self.url is not None + self.content = download_file(self.url) + + def __repr__(self) -> str: + """ + Get the representation of the image. + + Use md5 hash for the content of the image if it is available. + + For numpy array, use the shape of the array with the md5 hash of the array if it is available. + + Returns: + str: The representation of the image. + """ + content_hash = hashlib.md5(self.content).hexdigest() if self.content else None + if self.numpy is not None: + assert self.numpy is not None + numpy_hash = hashlib.md5(self.numpy.tobytes()).hexdigest() + numpy_repr = f"ndarray(shape={self.numpy.shape}, dtype={self.numpy.dtype}, md5={numpy_hash})" + else: + numpy_repr = None + return ( + f"Image(path={self.path}, " + f"url={self.url}, " + f"content={content_hash}, " + f"numpy={numpy_repr}, " + f"image_id={self.image_id})" + ) + + def __str__(self) -> str: + """ + Get the string representation of the image. + + Returns: + str: The string representation of the image. + """ + return self.__repr__() + + def cleanup(self): + """ + Cleanup the image. + + Remove the image from disk if it was saved by the class. + """ + # Remove the image from disk if it was saved by the class + if self.is_saved and self.path: + self.path.unlink(missing_ok=True) diff --git a/aana/models/pydantic/base.py b/aana/models/pydantic/base.py new file mode 100644 index 00000000..f2d79910 --- /dev/null +++ b/aana/models/pydantic/base.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel + + +class BaseListModel(BaseModel): + """ + The base model for pydantic models with list as root. + + It makes pydantic models with list as root behave like normal lists. + """ + + def __iter__(self): + return iter(self.__root__) + + def __len__(self): + return len(self.__root__) + + def __getitem__(self, index): + return self.__root__[index] + + def __setitem__(self, index, value): + self.__root__[index] = value + + def __delitem__(self, index): + del self.__root__[index] + + def __contains__(self, item): + return item in self.__root__ diff --git a/aana/models/pydantic/exception_response.py b/aana/models/pydantic/exception_response.py index dadcacd6..0ced696b 100644 --- a/aana/models/pydantic/exception_response.py +++ b/aana/models/pydantic/exception_response.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Any, Optional from pydantic import BaseModel, Extra @@ -9,13 +9,13 @@ class ExceptionResponseModel(BaseModel): Attributes: error (str): The error that occurred. message (str): The message of the error. - data (Optional[Dict]): The extra data that helps to debug the error. + data (Optional[Any]): The extra data that helps to debug the error. stacktrace (Optional[str]): The stacktrace of the error. """ error: str message: str - data: Optional[Dict] = None + data: Optional[Any] = None stacktrace: Optional[str] = None class Config: diff --git a/aana/models/pydantic/image_input.py b/aana/models/pydantic/image_input.py new file mode 100644 index 00000000..9019abd3 --- /dev/null +++ b/aana/models/pydantic/image_input.py @@ -0,0 +1,225 @@ +import io +from pathlib import Path +import numpy as np +from typing import Dict, List, Optional +from pydantic import BaseModel, Field, ValidationError, root_validator, validator +from pydantic.error_wrappers import ErrorWrapper + +from aana.models.core.image import Image +from aana.models.pydantic.base import BaseListModel + + +class ImageInput(BaseModel): + """ + An image input. + + Exactly one of 'path', 'url', or 'content' must be provided. + + If 'content' or 'numpy' is set to 'file', + the image will be loaded from the files uploaded to the endpoint. + + + Attributes: + path (str): the file path of the image + url (str): the URL of the image + content (bytes): the content of the image in bytes + numpy (bytes): the image as a numpy array + """ + + path: Optional[str] = Field(None, description="The file path of the image.") + url: Optional[str] = Field( + None, description="The URL of the image." + ) # TODO: validate url + content: Optional[bytes] = Field( + None, + description=( + "The content of the image in bytes. " + "Set this field to 'file' to upload files to the endpoint." + ), + ) + numpy: Optional[bytes] = Field( + None, + description=( + "The image as a numpy array. " + "Set this field to 'file' to upload files to the endpoint." + ), + ) + + def set_file(self, file: bytes): + """ + If 'content' or 'numpy' is set to 'file', + the image will be loaded from the file uploaded to the endpoint. + + set_file() should be called after the files are uploaded to the endpoint. + + Args: + file (bytes): the file uploaded to the endpoint + + Raises: + ValueError: if the content or numpy isn't set to 'file' + """ + if self.content == b"file": + self.content = file + elif self.numpy == b"file": + self.numpy = file + else: + raise ValueError( + "The content or numpy of the image must be 'file' to set files." + ) + + def set_files(self, files: List[bytes]): + """ + Set the files for the image. + + Args: + files (List[bytes]): the files uploaded to the endpoint + + Raises: + ValidationError: if the number of images and files aren't the same + """ + if len(files) != 1: + error = ErrorWrapper( + ValueError("The number of images and files must be the same."), + loc=("images",), + ) + raise ValidationError([error], self.__class__) + self.set_file(files[0]) + + @root_validator + def check_only_one_field(cls, values: Dict) -> Dict: + """ + Check that exactly one of 'path', 'url', 'content' or 'numpy' is provided. + + Args: + values (Dict): the values of the fields + + Returns: + Dict: the values of the fields + + Raises: + ValueError: if not exactly one of 'path', 'url', 'content' or 'numpy' is provided + """ + count = sum( + value is not None + for key, value in values.items() + if key in ["path", "url", "content", "numpy"] + ) + if count != 1: + raise ValueError( + "Exactly one of 'path', 'url', 'content' or 'numpy' must be provided." + ) + return values + + def convert_input_to_object(self) -> Image: + """ + Convert the image input to an image object. + + Returns: + Image: the image object corresponding to the image input + + Raises: + ValueError: if the numpy file isn't set + """ + if self.numpy and self.numpy != b"file": + try: + numpy = np.load(io.BytesIO(self.numpy), allow_pickle=False) + except ValueError: + raise ValueError("The numpy file isn't valid.") + elif self.numpy == b"file": + raise ValueError("The numpy file isn't set. Call set_files() to set it.") + else: + numpy = None + + return Image( + path=Path(self.path) if self.path else None, + url=self.url, + content=self.content, + numpy=numpy, + ) + + class Config: + schema_extra = { + "description": ( + "An image. \n" + "Exactly one of 'path', 'url', or 'content' must be provided. \n" + "If 'path' is provided, the image will be loaded from the path. \n" + "If 'url' is provided, the image will be downloaded from the url. \n" + "The 'content' will be loaded automatically " + "if files are uploaded to the endpoint (should be set to 'file' for that)." + ) + } + validate_assignment = True + file_upload = True + file_upload_description = "Upload image file." + + +class ImageListInput(BaseListModel): + """ + A pydantic model for a list of images to be used as input. + + Only used for the requests, DO NOT use it for anything else. + Convert it to a list of image objects with convert_input_to_object(). + """ + + __root__: List[ImageInput] + + @validator("__root__", pre=True) + def check_non_empty(cls, v: List[ImageInput]) -> List[ImageInput]: + """ + Check that the list of images isn't empty. + + Args: + v (List[ImageInput]): the list of images + + Returns: + List[ImageInput]: the list of images + + Raises: + ValueError: if the list of images is empty + """ + if len(v) == 0: + raise ValueError("The list of images must not be empty.") + return v + + def set_files(self, files: List[bytes]): + """ + Set the files for the images. + + Args: + files (List[bytes]): the files uploaded to the endpoint + + Raises: + ValidationError: if the number of images and files aren't the same + """ + + if len(self.__root__) != len(files): + error = ErrorWrapper( + ValueError("The number of images and files must be the same."), + loc=("images",), + ) + raise ValidationError([error], self.__class__) + for image, file in zip(self.__root__, files): + image.set_file(file) + + def convert_input_to_object(self) -> List[Image]: + """ + Convert the list of image inputs to a list of image objects. + + Returns: + List[Image]: the list of image objects corresponding to the image inputs + """ + return [image.convert_input_to_object() for image in self.__root__] + + class Config: + schema_extra = { + "description": ( + "A list of images. \n" + "Exactly one of 'path', 'url', or 'content' must be provided for each image. \n" + "If 'path' is provided, the image will be loaded from the path. \n" + "If 'url' is provided, the image will be downloaded from the url. \n" + "The 'content' will be loaded automatically " + "if files are uploaded to the endpoint (should be set to 'file' for that)." + ) + } + file_upload = True + file_upload_description = "Upload image files." diff --git a/aana/tests/const.py b/aana/tests/const.py new file mode 100644 index 00000000..37f66b3a --- /dev/null +++ b/aana/tests/const.py @@ -0,0 +1 @@ +ALLOWED_LEVENSTEIN_ERROR_RATE = 0.1 diff --git a/aana/tests/deployments/test_hf_blip2_deployment.py b/aana/tests/deployments/test_hf_blip2_deployment.py new file mode 100644 index 00000000..5fc94d11 --- /dev/null +++ b/aana/tests/deployments/test_hf_blip2_deployment.py @@ -0,0 +1,50 @@ +from importlib import resources +import random +import pytest +import ray +from ray import serve + +from aana.configs.deployments import deployments +from aana.models.core.image import Image +from aana.tests.utils import compare_texts, is_gpu_available + + +def ray_setup(deployment): + # Setup ray environment and serve + ray.init(ignore_reinit_error=True) + app = deployment.bind() + # random port from 30000 to 40000 + port = random.randint(30000, 40000) + handle = serve.run(app, port=port) + return handle + + +@pytest.mark.skipif(not is_gpu_available(), reason="GPU is not available") +@pytest.mark.asyncio +@pytest.mark.parametrize( + "image_name, expected_text", + [("Starry_Night.jpeg", "the starry night by vincent van gogh")], +) +async def test_hf_blip2_deployments(image_name, expected_text): + for name, deployment in deployments.items(): + # skip if not a VLLM deployment + if deployment.name != "HFBlip2Deployment": + continue + + handle = ray_setup(deployment) + + path = resources.path("aana.tests.files.images", image_name) + image = Image(path=path, save_on_disk=False) + + output = await handle.generate.remote(image=image) + caption = output["caption"] + compare_texts(expected_text, caption) + + images = [image] * 8 + + output = await handle.generate_batch.remote(images=images) + captions = output["captions"] + + assert len(captions) == 8 + for caption in captions: + compare_texts(expected_text, caption) diff --git a/aana/tests/deployments/test_vllm_deployment.py b/aana/tests/deployments/test_vllm_deployment.py index 231de644..8601c969 100644 --- a/aana/tests/deployments/test_vllm_deployment.py +++ b/aana/tests/deployments/test_vllm_deployment.py @@ -1,14 +1,11 @@ import random import pytest -import rapidfuzz import ray from ray import serve from aana.configs.deployments import deployments from aana.models.pydantic.sampling_params import SamplingParams -from aana.tests.utils import is_gpu_available - -ALLOWED_LEVENSTEIN_ERROR_RATE = 0.1 +from aana.tests.utils import compare_texts, is_gpu_available def expected_output(name): @@ -36,29 +33,14 @@ def ray_setup(deployment): return handle -def compare_texts(expected_text: str, text: str): - """ - Compare two texts using Levenshtein distance. - The error rate is allowed to be less than ALLOWED_LEVENSTEIN_ERROR_RATE. - - Args: - expected_text (str): the expected text - text (str): the actual text - - Raises: - AssertionError: if the error rate is too high - """ - dist = rapidfuzz.distance.Levenshtein.distance(text, expected_text) - assert dist < len(expected_text) * ALLOWED_LEVENSTEIN_ERROR_RATE, ( - expected_text, - text, - ) - - @pytest.mark.skipif(not is_gpu_available(), reason="GPU is not available") @pytest.mark.asyncio async def test_vllm_deployments(): for name, deployment in deployments.items(): + # skip if not a VLLM deployment + if deployment.name != "VLLMDeployment": + continue + handle = ray_setup(deployment) # test generate method diff --git a/aana/tests/files/images/Starry_Night.jpeg b/aana/tests/files/images/Starry_Night.jpeg new file mode 100644 index 00000000..5e748716 Binary files /dev/null and b/aana/tests/files/images/Starry_Night.jpeg differ diff --git a/aana/tests/test_batch_processor.py b/aana/tests/test_batch_processor.py new file mode 100644 index 00000000..c2ff6923 --- /dev/null +++ b/aana/tests/test_batch_processor.py @@ -0,0 +1,108 @@ +import pytest +import numpy as np + +from aana.utils.batch_processor import BatchProcessor + +NUM_IMAGES = 5 +IMAGE_SIZE = 100 +FEATURE_SIZE = 10 + + +@pytest.fixture +def images(): + return np.array([np.random.rand(IMAGE_SIZE, IMAGE_SIZE) for _ in range(NUM_IMAGES)]) + + +@pytest.fixture +def texts(): + return [f"text{i}" for i in range(NUM_IMAGES)] + + +@pytest.fixture +def features(): + return np.random.rand(NUM_IMAGES, FEATURE_SIZE) + + +@pytest.fixture +def request_batch(images, texts, features): + return {"images": images, "texts": texts, "features": features} + + +@pytest.fixture +def process_batch(): + # Dummy processing function that just returns the batch + return lambda batch: batch + + +def test_batch_iterator(request_batch, process_batch): + """ + Test batch iterator. + """ + batch_size = 2 + processor = BatchProcessor( + process_batch=process_batch, batch_size=batch_size, num_threads=2 + ) + + batches = list(processor.batch_iterator(request_batch)) + # We expect 3 batches with a batch_size of 2 for 5 items + assert len(batches) == NUM_IMAGES // batch_size + 1 + assert all( + len(batch["texts"]) == batch_size for batch in batches[:-1] + ) # All but the last should have batch_size items + assert len(batches[-1]["texts"]) == 1 # Last batch should have the remaining item + + +@pytest.mark.asyncio +async def test_process_batches(request_batch, process_batch): + """ + Test processing of batches. + """ + batch_size = 2 + processor = BatchProcessor( + process_batch=process_batch, batch_size=batch_size, num_threads=2 + ) + + result = await processor.process(request_batch) + # Ensure all texts are processed + assert len(result["texts"]) == NUM_IMAGES + # Check if images are concatenated properly + assert result["images"].shape == (NUM_IMAGES, IMAGE_SIZE, IMAGE_SIZE) + # Check if features are concatenated properly + assert result["features"].shape == (NUM_IMAGES, FEATURE_SIZE) + + +def test_merge_outputs(request_batch, process_batch): + """ + Test merging of outputs from multiple batches. + + images and features should be concatenated into numpy arrays because they are numpy arrays. + texts should be concatenated because they are lists. + """ + batch_size = 2 + processor = BatchProcessor( + process_batch=process_batch, batch_size=batch_size, num_threads=2 + ) + + # Assume the processor has already batched and processed the data + processed_batches = [ + { + "images": request_batch["images"][:2], + "texts": request_batch["texts"][:2], + "features": request_batch["features"][:2], + }, + { + "images": request_batch["images"][2:4], + "texts": request_batch["texts"][2:4], + "features": request_batch["features"][2:4], + }, + { + "images": request_batch["images"][4:], + "texts": request_batch["texts"][4:], + "features": request_batch["features"][4:], + }, + ] + + merged_output = processor.merge_outputs(processed_batches) + assert merged_output["texts"] == request_batch["texts"] + assert np.array_equal(merged_output["images"], request_batch["images"]) + assert np.array_equal(merged_output["features"], request_batch["features"]) diff --git a/aana/tests/test_image.py b/aana/tests/test_image.py new file mode 100644 index 00000000..e1cfc767 --- /dev/null +++ b/aana/tests/test_image.py @@ -0,0 +1,182 @@ +from importlib import resources +from pathlib import Path +import numpy as np +import pytest +from aana.models.core.image import Image + + +def load_numpy_from_image_bytes(content: bytes) -> np.ndarray: + """ + Load a numpy array from image bytes. + """ + image = Image(content=content, save_on_disk=False) + return image.get_numpy() + + +@pytest.fixture +def mock_download_file(mocker): + """ + Mock download_file. + """ + mock = mocker.patch("aana.models.core.image.download_file", autospec=True) + path = resources.path("aana.tests.files.images", "Starry_Night.jpeg") + content = path.read_bytes() + mock.return_value = content + return mock + + +def test_image(mock_download_file): + """ + Test that the image can be created from path, url, content, or numpy. + """ + + try: + path = resources.path("aana.tests.files.images", "Starry_Night.jpeg") + image = Image(path=path, save_on_disk=False) + assert image.path == path + assert image.content is None + assert image.numpy is None + assert image.url is None + + numpy = image.get_numpy() + assert numpy.shape == (720, 909, 3) + + content = image.get_content() + numpy = load_numpy_from_image_bytes(content) + assert numpy.shape == (720, 909, 3) + finally: + image.cleanup() + + try: + url = "http://example.com/Starry_Night.jpeg" + image = Image(url=url, save_on_disk=False) + assert image.path is None + assert image.content is None + assert image.numpy is None + assert image.url == url + + numpy = image.get_numpy() + assert numpy.shape == (720, 909, 3) + + content = image.get_content() + numpy = load_numpy_from_image_bytes(content) + assert numpy.shape == (720, 909, 3) + finally: + image.cleanup() + + try: + path = resources.path("aana.tests.files.images", "Starry_Night.jpeg") + content = path.read_bytes() + image = Image(content=content, save_on_disk=False) + assert image.path is None + assert image.content == content + assert image.numpy is None + assert image.url is None + + numpy = image.get_numpy() + assert numpy.shape == (720, 909, 3) + + content = image.get_content() + numpy = load_numpy_from_image_bytes(content) + assert numpy.shape == (720, 909, 3) + finally: + image.cleanup() + + try: + numpy = np.random.rand(100, 100, 3).astype(np.uint8) + image = Image(numpy=numpy, save_on_disk=False) + assert image.path is None + assert image.content is None + assert np.array_equal(image.numpy, numpy) + assert image.url is None + + numpy = image.get_numpy() + assert np.array_equal(numpy, numpy) + + content = image.get_content() + numpy = load_numpy_from_image_bytes(content) + assert np.array_equal(numpy, numpy) + finally: + image.cleanup() + + +def test_image_path_not_exist(): + """ + Test that the image can't be created from path if the path doesn't exist. + """ + path = Path("path/to/image_that_does_not_exist.jpeg") + with pytest.raises(FileNotFoundError): + Image(path=path) + + +def test_save_image(mock_download_file): + """ + Test that save_on_disk works. + """ + + try: + path = resources.path("aana.tests.files.images", "Starry_Night.jpeg") + image = Image(path=path, save_on_disk=True) + assert image.path == path + assert image.content is None + assert image.numpy is None + assert image.url is None + assert image.path.exists() + finally: + image.cleanup() + + try: + url = "http://example.com/Starry_Night.jpeg" + image = Image(url=url, save_on_disk=True) + assert image.content is None + assert image.numpy is None + assert image.url == url + assert image.path.exists() + finally: + image.cleanup() + + try: + path = resources.path("aana.tests.files.images", "Starry_Night.jpeg") + content = path.read_bytes() + image = Image(content=content, save_on_disk=True) + assert image.content == content + assert image.numpy is None + assert image.url is None + assert image.path.exists() + finally: + image.cleanup() + + try: + numpy = np.random.rand(100, 100, 3).astype(np.uint8) + image = Image(numpy=numpy, save_on_disk=True) + assert image.content is None + assert np.array_equal(image.numpy, numpy) + assert image.url is None + assert image.path.exists() + finally: + image.cleanup() + + +def test_cleanup(mock_download_file): + """ + Test that cleanup works. + """ + + try: + url = "http://example.com/Starry_Night.jpeg" + image = Image(url=url, save_on_disk=True) + assert image.path.exists() + finally: + image.cleanup() + assert not image.path.exists() + + +def test_at_least_one_input(): + """ + Test that at least one input is provided. + """ + with pytest.raises(ValueError): + Image(save_on_disk=False) + + with pytest.raises(ValueError): + Image(save_on_disk=True) diff --git a/aana/tests/test_image_input.py b/aana/tests/test_image_input.py new file mode 100644 index 00000000..1591c187 --- /dev/null +++ b/aana/tests/test_image_input.py @@ -0,0 +1,252 @@ +from importlib import resources +import io +from pathlib import Path +import pytest +import numpy as np +from pydantic import ValidationError +from aana.models.pydantic.image_input import ImageInput, ImageListInput + + +@pytest.fixture +def mock_download_file(mocker): + """ + Mock download_file. + """ + mock = mocker.patch("aana.models.core.image.download_file", autospec=True) + path = resources.path("aana.tests.files.images", "Starry_Night.jpeg") + content = path.read_bytes() + mock.return_value = content + return mock + + +def test_new_imageinput_success(): + """ + Test that ImageInput can be created successfully. + """ + image_input = ImageInput(path="image.png") + assert image_input.path == "image.png" + + image_input = ImageInput(url="http://image.png") + assert image_input.url == "http://image.png" + + image_input = ImageInput(content=b"file") + assert image_input.content == b"file" + + image_input = ImageInput(numpy=b"file") + assert image_input.numpy == b"file" + + +def test_imageinput_check_only_one_field(): + """ + Test that exactly one of 'path', 'url', 'content', or 'numpy' is provided. + """ + fields = { + "path": "image.png", + "url": "http://image.png", + "content": b"file", + "numpy": b"file", + } + + # check all combinations of two fields + for field1 in fields: + for field2 in fields: + if field1 != field2: + with pytest.raises(ValidationError): + ImageInput(**{field1: fields[field1], field2: fields[field2]}) + + # check all combinations of three fields + for field1 in fields: + for field2 in fields: + for field3 in fields: + if field1 != field2 and field1 != field3 and field2 != field3: + with pytest.raises(ValidationError): + ImageInput( + **{ + field1: fields[field1], + field2: fields[field2], + field3: fields[field3], + } + ) + + # check all combinations of four fields + with pytest.raises(ValidationError): + ImageInput(**fields) + + # check that no fields is also invalid + with pytest.raises(ValidationError): + ImageInput() + + +def test_imageinput_set_file(): + """ + Test that the file can be set for the image. + """ + file_content = b"image data" + + # If 'content' is set to 'file', + # the image can be set from the file uploaded to the endpoint. + image_input = ImageInput(content=b"file") + image_input.set_file(file_content) + assert image_input.content == file_content + + # If 'numpy' is set to 'file', + # the image can be set from the file uploaded to the endpoint. + image_input = ImageInput(numpy=b"file") + image_input.set_file(file_content) + assert image_input.numpy == file_content + + # If neither 'content' nor 'numpy' is set to 'file', + # an error should be raised. + image_input = ImageInput(path="image.png") + with pytest.raises(ValueError): + image_input.set_file(file_content) + + +def test_imageinput_set_files(): + """ + Test that the files can be set for the image. + """ + files = [b"image data"] + + # If 'content' is set to 'file', + # the image can be set from the file uploaded to the endpoint. + image_input = ImageInput(content=b"file") + image_input.set_files(files) + assert image_input.content == files[0] + + # If 'numpy' is set to 'file', + # the image can be set from the file uploaded to the endpoint. + image_input = ImageInput(numpy=b"file") + image_input.set_files(files) + assert image_input.numpy == files[0] + + # If neither 'content' nor 'numpy' is set to 'file', + # an error should be raised. + image_input = ImageInput(path="image.png") + with pytest.raises(ValueError): + image_input.set_files(files) + + # If the number of images and files aren't the same, + # an error should be raised. + files = [b"image data", b"another image data"] + image_input = ImageInput(content=b"file") + with pytest.raises(ValidationError): + image_input.set_files(files) + + files = [] + image_input = ImageInput(content=b"file") + with pytest.raises(ValidationError): + image_input.set_files(files) + + +def test_imageinput_convert_input_to_object(mock_download_file): + """ + Test that ImageInput can be converted to Image. + """ + path = resources.path("aana.tests.files.images", "Starry_Night.jpeg") + image_input = ImageInput(path=str(path)) + try: + image_object = image_input.convert_input_to_object() + assert image_object.path == path + finally: + image_object.cleanup() + + url = "http://example.com/Starry_Night.jpeg" + image_input = ImageInput(url=url) + try: + image_object = image_input.convert_input_to_object() + assert image_object.url == url + finally: + image_object.cleanup() + + content = Path(path).read_bytes() + image_input = ImageInput(content=content) + try: + image_object = image_input.convert_input_to_object() + assert image_object.content == content + finally: + image_object.cleanup() + + numpy = np.random.rand(100, 100, 3).astype(np.uint8) + # convert numpy array to bytes + buffer = io.BytesIO() + np.save(buffer, numpy) + numpy_bytes = buffer.getvalue() + image_input = ImageInput(numpy=numpy_bytes) + try: + image_object = image_input.convert_input_to_object() + assert np.array_equal(image_object.numpy, numpy) + finally: + image_object.cleanup() + + +def test_imageinput_convert_input_to_object_invalid_numpy(): + """ + Test that ImageInput can't be converted to Image if numpy is invalid. + """ + numpy = np.random.rand(100, 100, 3).astype(np.uint8) + # convert numpy array to bytes + buffer = io.BytesIO() + np.save(buffer, numpy) + numpy_bytes = buffer.getvalue() + # remove the last byte + numpy_bytes = numpy_bytes[:-1] + image_input = ImageInput(numpy=numpy_bytes) + with pytest.raises(ValueError): + image_input.convert_input_to_object() + + +def test_imageinput_convert_input_to_object_numpy_not_set(): + """ + Test that ImageInput can't be converted to Image if numpy file isn't set with set_file(). + """ + image_input = ImageInput(numpy=b"file") + with pytest.raises(ValueError): + image_input.convert_input_to_object() + + +def test_imagelistinput(): + """ + Test that ImageListInput can be created successfully. + """ + + images = [ + ImageInput(path="image.png"), + ImageInput(url="http://image.png"), + ImageInput(content=b"file"), + ImageInput(numpy=b"file"), + ] + + image_list_input = ImageListInput(__root__=images) + assert image_list_input.__root__ == images + assert len(image_list_input) == len(images) + assert image_list_input[0] == images[0] + assert image_list_input[1] == images[1] + assert image_list_input[2] == images[2] + assert image_list_input[3] == images[3] + + +def test_imagelistinput_set_files(): + """ + Test that the files can be set for the images. + """ + files = [b"image data 1", b"image data 2"] + + images = [ + ImageInput(content=b"file"), + ImageInput(numpy=b"file"), + ] + + image_list_input = ImageListInput(__root__=images) + image_list_input.set_files(files) + + assert image_list_input[0].content == files[0] + assert image_list_input[1].numpy == files[1] + + +def test_imagelistinput_non_empty(): + """ + Test that ImageListInput must not be empty. + """ + with pytest.raises(ValidationError): + ImageListInput(__root__=[]) diff --git a/aana/tests/test_settings.py b/aana/tests/test_settings.py new file mode 100644 index 00000000..cbd16f3c --- /dev/null +++ b/aana/tests/test_settings.py @@ -0,0 +1,16 @@ +from pathlib import Path +from aana.configs.settings import Settings + + +def test_default_tmp_data_dir(): + """Test that the default temporary data directory is set correctly.""" + settings = Settings() + assert settings.tmp_data_dir == Path("/tmp/aana_data") + + +def test_custom_tmp_data_dir(monkeypatch): + """Test that the custom temporary data directory with environment variable is set correctly.""" + test_path = "/tmp/override/path" + monkeypatch.setenv("TMP_DATA_DIR", test_path) + settings = Settings() + assert settings.tmp_data_dir == Path(test_path) diff --git a/aana/tests/utils.py b/aana/tests/utils.py index 159e620e..a3e66489 100644 --- a/aana/tests/utils.py +++ b/aana/tests/utils.py @@ -1,3 +1,8 @@ +import rapidfuzz + +from aana.tests.const import ALLOWED_LEVENSTEIN_ERROR_RATE + + def is_gpu_available() -> bool: """ Check if a GPU is available. @@ -9,3 +14,22 @@ def is_gpu_available() -> bool: # TODO: find the way to check if GPU is available without importing torch return torch.cuda.is_available() + + +def compare_texts(expected_text: str, text: str): + """ + Compare two texts using Levenshtein distance. + The error rate is allowed to be less than ALLOWED_LEVENSTEIN_ERROR_RATE. + + Args: + expected_text (str): the expected text + text (str): the actual text + + Raises: + AssertionError: if the error rate is too high + """ + dist = rapidfuzz.distance.Levenshtein.distance(text, expected_text) + assert dist < len(expected_text) * ALLOWED_LEVENSTEIN_ERROR_RATE, ( + expected_text, + text, + ) diff --git a/aana/utils/batch_processor.py b/aana/utils/batch_processor.py new file mode 100644 index 00000000..385526ba --- /dev/null +++ b/aana/utils/batch_processor.py @@ -0,0 +1,146 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Dict, List, Any, Iterator +import asyncio +import numpy as np + + +class BatchProcessor: + """ + The BatchProcessor class encapsulates the logic required to take a large collection of data, + split it into manageable batches, process these batches in parallel, and then combine the results + into a single cohesive output. + + Batching works by iterating through the input request, which is a dictionary where each key maps + to a list-like collection of data. The class splits each collection into sublists of length up + to `batch_size`, ensuring that corresponding elements across the collections remain grouped + together in their respective batches. + + Merging takes the output from each processed batch, which is also a dictionary structure, and + combines these into a single dictionary. Lists are extended, numpy arrays are concatenated, and + dictionaries are updated. If a new data type is encountered, an error is raised prompting the + implementer to specify how these types should be merged. + + This class is particularly useful for batching of requests to a machine learning model. + + The thread pool for parallel processing is managed internally and is shut down automatically when + the BatchProcessor instance is garbage collected. + + Attributes: + process_batch (Callable): A function to process each batch. + batch_size (int): The size of each batch to be processed. + num_threads (int): The number of threads in the thread pool for parallel processing. + """ + + def __init__(self, process_batch: Callable, batch_size: int, num_threads: int): + """ + Args: + process_batch (Callable): Function that processes each batch. + batch_size (int): Size of the batches. + num_threads (int): Number of threads in the pool. + """ + self.process_batch = process_batch + self.batch_size = batch_size + self.pool = ThreadPoolExecutor(num_threads) + + def __del__(self): + """ + Destructor that shuts down the thread pool when the instance is destroyed. + """ + self.pool.shutdown() + + def batch_iterator(self, request: Dict[str, Any]) -> Iterator[Dict[str, List[Any]]]: + """ + Iterates over the input request, breaking it into smaller batches for processing. + Each batch is a dictionary with the same keys as the input request, but the values + are sublists containing only the elements for that batch. + + Example: + request = { + 'images': [img1, img2, img3, img4, img5], + 'texts': ['text1', 'text2', 'text3', 'text4', 'text5'] + } + # Assuming a batch size of 2, this iterator would yield: + # 1st iteration: {'images': [img1, img2], 'texts': ['text1', 'text2']} + # 2nd iteration: {'images': [img3, img4], 'texts': ['text3', 'text4']} + # 3rd iteration: {'images': [img5], 'texts': ['text5']} + + Args: + request (Dict[str, List[Any]]): The request data to split into batches. + + Yields: + Iterator[Dict[str, List[Any]]]: An iterator over the batched requests. + """ + lengths = [len(value) for value in request.values()] + if len(set(lengths)) > 1: + raise ValueError("All inputs must have the same length") + + total_batches = (max(lengths) + self.batch_size - 1) // self.batch_size + for i in range(total_batches): + start = i * self.batch_size + end = start + self.batch_size + yield {key: value[start:end] for key, value in request.items()} + + async def process(self, request: Dict[str, Any]) -> Dict[str, Any]: + """ + Splits the input request into batches, processes each batch in parallel, and then merges + the results into a single dictionary. + + Args: + request (Dict[str, Any]): The request data to process. + + Returns: + Dict[str, Any]: The merged results from processing all batches. + """ + loop = asyncio.get_running_loop() + futures = [ + loop.run_in_executor(self.pool, self.process_batch, batch) + for batch in self.batch_iterator(request) + ] + outputs = await asyncio.gather(*futures) + return self.merge_outputs(outputs) + + def merge_outputs(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Merges processed batch outputs into a single dictionary. It handles various data types + by extending lists, updating dictionaries, and concatenating numpy arrays. + + Example: + outputs = [ + {'images': [processed_img1, processed_img2], 'labels': ['cat', 'dog']}, + {'images': [processed_img3, processed_img4], 'labels': ['bird', 'mouse']}, + {'images': [processed_img5], 'labels': ['fish']} + ] + # The merged result would be: + # { + # 'images': [processed_img1, processed_img2, processed_img3, processed_img4, processed_img5], + # 'labels': ['cat', 'dog', 'bird', 'mouse', 'fish'] + # } + + Args: + outputs (List[Dict[str, Any]]): List of outputs from the processed batches. + + Returns: + Dict[str, Any]: The merged result. + """ + merged_output = {} + for output in outputs: + for key, value in output.items(): + if key not in merged_output: + merged_output[key] = value + else: + if isinstance(value, list): + merged_output[key].extend(value) + elif isinstance(value, dict): + merged_output[key].update(value) + elif isinstance(value, np.ndarray): + if key in merged_output: + merged_output[key] = np.concatenate( + (merged_output[key], value) + ) + else: + merged_output[key] = value + else: + raise NotImplementedError( + "Merging of this data type is not implemented" + ) + return merged_output diff --git a/aana/utils/general.py b/aana/utils/general.py index 7b0c4d28..14f5730e 100644 --- a/aana/utils/general.py +++ b/aana/utils/general.py @@ -1,6 +1,9 @@ from typing import TypeVar from pydantic import BaseModel +import requests + +from aana.exceptions.general import DownloadException OptionType = TypeVar("OptionType", bound=BaseModel) @@ -27,3 +30,24 @@ def merged_options(default_options: OptionType, options: OptionType) -> OptionTy if v is not None: default_options_dict[k] = v return options.__class__.parse_obj(default_options_dict) + + +def download_file(url: str) -> bytes: + """ + Download a file from a URL. + + Args: + url (str): the URL of the file to download + + Returns: + bytes: the file content + + Raises: + + """ + # TODO: add retries, check status code, etc. + try: + response = requests.get(url) + except Exception as e: + raise DownloadException(url) from e + return response.content diff --git a/mobius-pipeline b/mobius-pipeline index 8bdf633a..f455c656 160000 --- a/mobius-pipeline +++ b/mobius-pipeline @@ -1 +1 @@ -Subproject commit 8bdf633aaa9227b732b56a096ae04c6ebe4e8060 +Subproject commit f455c656da566866f031d74d8676bef0f558b4f3 diff --git a/pyproject.toml b/pyproject.toml index b4d6a4a3..610f0437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,9 @@ torchvision = { url = "https://download.pytorch.org/whl/cu118/torchvision-0.15.2 vllm = "^0.2.1.post1" scipy = "^1.11.3" rapidfuzz = "^3.4.0" +transformers = "^4.34.1" +opencv-python = "^4.8.1.78" +pytest-mock = "^3.12.0" [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.2"