Skip to content

Commit

Permalink
Merge pull request #3 from mobiusml/blip2
Browse files Browse the repository at this point in the history
BLIP2 Integration and Image Capabilities
  • Loading branch information
movchan74 authored Nov 10, 2023
2 parents 26142e9 + 8e93fde commit 7d7ad12
Show file tree
Hide file tree
Showing 26 changed files with 1,892 additions and 44 deletions.
4 changes: 2 additions & 2 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ async def route_func_body(body: str, files: Optional[List[UploadFile]] = None):
field_value = getattr(data, field_name)
# check if it has a method convert_to_entities
# if it does, call it to convert the model to an entity
if hasattr(field_value, "convert_to_entity"):
field_value = field_value.convert_to_entity()
if hasattr(field_value, "convert_input_to_object"):
field_value = field_value.convert_input_to_object()
data_dict[field_name] = field_value

if self.output_filter:
Expand Down
14 changes: 14 additions & 0 deletions aana/configs/deployments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from aana.deployments.hf_blip2_deployment import HFBlip2Config, HFBlip2Deployment
from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment
from aana.models.core.dtype import Dtype
from aana.models.pydantic.sampling_params import SamplingParams


deployments = {
"vllm_deployment_llama2_7b_chat": VLLMDeployment.options(
num_replicas=1,
Expand Down Expand Up @@ -31,4 +34,15 @@
),
).dict(),
),
"hf_blip2_deployment_opt_2_7b": HFBlip2Deployment.options(
num_replicas=1,
max_concurrent_queries=1000,
ray_actor_options={"num_gpus": 0.5},
user_config=HFBlip2Config(
model="Salesforce/blip2-opt-2.7b",
dtype=Dtype.FLOAT16,
batch_size=2,
num_processing_threads=2,
).dict(),
),
}
8 changes: 8 additions & 0 deletions aana/configs/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,12 @@
outputs=["vllm_zephyr_7b_beta_output"],
)
],
"blip2": [
Endpoint(
name="blip2_generate",
path="/blip2/generate_captions",
summary="Generate captions using BLIP2 OPT-2.7B",
outputs=["captions_hf_blip2_opt_2_7b"],
)
],
}
47 changes: 46 additions & 1 deletion aana/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
It is used to generate the pipeline and the API endpoints.
"""

from aana.models.pydantic.image_input import ImageListInput
from aana.models.pydantic.prompt import Prompt
from aana.models.pydantic.sampling_params import SamplingParams

Expand All @@ -15,6 +16,16 @@
# vllm_llama2_7b_chat_output: str
# vllm_zephyr_7b_beta_output_stream: str
# vllm_zephyr_7b_beta_output: str
# image_batch: ImageBatch
#
# class ImageBatch:
# images: list[Image]
#
# class Image:
# image: ImageInput
# caption_hf_blip2_opt_2_7b: str

# pipeline configuration


nodes = [
Expand Down Expand Up @@ -83,7 +94,7 @@
}
],
},
{
{
"name": "vllm_stream_zephyr_7b_beta",
"type": "ray_deployment",
"deployment_name": "vllm_deployment_zephyr_7b_beta",
Expand Down Expand Up @@ -127,4 +138,38 @@
}
],
},
{
"name": "images",
"type": "input",
"inputs": [],
"outputs": [
{
"name": "images",
"key": "images",
"path": "image_batch.images.[*].image",
"data_model": ImageListInput,
}
],
},
{
"name": "hf_blip2_opt_2_7b",
"type": "ray_deployment",
"deployment_name": "hf_blip2_deployment_opt_2_7b",
"method": "generate_batch",
"inputs": [
{
"name": "images",
"key": "images",
"path": "image_batch.images.[*].image",
"data_model": ImageListInput,
}
],
"outputs": [
{
"name": "captions_hf_blip2_opt_2_7b",
"key": "captions",
"path": "image_batch.images.[*].caption_hf_blip2_opt_2_7b",
}
],
},
]
14 changes: 14 additions & 0 deletions aana/configs/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pathlib import Path
from pydantic import BaseSettings


class Settings(BaseSettings):
"""
A pydantic model for SDK settings.
"""

tmp_data_dir: Path = Path("/tmp/aana_data")


settings = Settings()
187 changes: 187 additions & 0 deletions aana/deployments/hf_blip2_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from typing import Any, Dict, List, TypedDict
from pydantic import BaseModel, Field, validator
from ray import serve
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from aana.deployments.base_deployment import BaseDeployment

from aana.exceptions.general import InferenceException
from aana.models.core.dtype import Dtype
from aana.models.core.image import Image
from aana.utils.batch_processor import BatchProcessor


class HFBlip2Config(BaseModel):
"""
The configuration for the BLIP2 deployment with HuggingFace models.
Attributes:
model (str): the model ID on HuggingFace
dtype (str): the data type (optional, default: "auto"), one of "auto", "float32", "float16"
batch_size (int): the batch size (optional, default: 1)
num_processing_threads (int): the number of processing threads (optional, default: 1)
"""

model: str
dtype: Dtype = Field(default=Dtype.AUTO)
batch_size: int = Field(default=1)
num_processing_threads: int = Field(default=1)

@validator("dtype", pre=True, always=True)
def validate_dtype(cls, value: Dtype) -> Dtype:
"""
Validate the data type. For BLIP2 only "float32" and "float16" are supported.
Args:
value (Dtype): the data type
Returns:
Dtype: the validated data type
Raises:
ValueError: if the data type is not supported
"""
if value not in {Dtype.AUTO, Dtype.FLOAT32, Dtype.FLOAT16}:
raise ValueError(
f"Invalid dtype: {value}. BLIP2 only supports 'auto', 'float32', and 'float16'."
)
return value


class CaptioningOutput(TypedDict):
"""
The output of the captioning model.
Attributes:
caption (str): the caption
"""

caption: str


class CaptioningBatchOutput(TypedDict):
"""
The output of the captioning model.
Attributes:
captions (List[str]): the list of captions
"""

captions: List[str]


@serve.deployment
class HFBlip2Deployment(BaseDeployment):
"""
Deployment to serve BLIP2 models using HuggingFace.
"""

async def apply_config(self, config: Dict[str, Any]):
"""
Apply the configuration.
The method is called when the deployment is created or updated.
It loads the model and processor from HuggingFace.
The configuration should conform to the HFBlip2Config schema.
"""

config_obj = HFBlip2Config(**config)

# Create the batch processor to split the requests into batches
# and process them in parallel
self.batch_size = config_obj.batch_size
self.num_processing_threads = config_obj.num_processing_threads
# The actual inference is done in _generate()
# We use lambda because BatchProcessor expects dict as input
# and we use **kwargs to unpack the dict into named arguments for _generate()
self.batch_processor = BatchProcessor(
process_batch=lambda request: self._generate(**request),
batch_size=self.batch_size,
num_threads=self.num_processing_threads,
)

# Load the model and processor for BLIP2 from HuggingFace
self.model_id = config_obj.model
self.dtype = config_obj.dtype
self.torch_dtype = self.dtype.to_torch()
self.model = Blip2ForConditionalGeneration.from_pretrained(
self.model_id, torch_dtype=self.torch_dtype
)
self.processor = Blip2Processor.from_pretrained(self.model_id)

self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)

async def generate(self, image: Image) -> CaptioningOutput:
"""
Generate captions for the given image.
Args:
image (Image): the image
Returns:
CaptioningOutput: the dictionary with one key "captions"
and the list of captions for the image as value
Raises:
InferenceException: if the inference fails
"""
captions: CaptioningBatchOutput = await self.batch_processor.process(
{"images": [image]}
)
return CaptioningOutput(caption=captions["captions"][0])

async def generate_batch(self, **kwargs) -> CaptioningBatchOutput:
"""
Generate captions for the given images.
Args:
images (List[Image]): the images
Returns:
CaptioningBatchOutput: the dictionary with one key "captions"
and the list of captions for the images as value
Raises:
InferenceException: if the inference fails
"""
# Call the batch processor to process the requests
# The actual inference is done in _generate()
return await self.batch_processor.process(kwargs)

def _generate(self, images: List[Image]) -> CaptioningBatchOutput:
"""
Generate captions for the given images.
This method is called by the batch processor.
Args:
images (List[Image]): the images
Returns:
CaptioningBatchOutput: the dictionary with one key "captions"
and the list of captions for the images as value
Raises:
InferenceException: if the inference fails
"""
numpy_images = []
for image in images:
numpy_images.append(image.get_numpy()) # loading images
inputs = self.processor(numpy_images, return_tensors="pt").to(
self.device, self.torch_dtype
)

try:
generated_ids = self.model.generate(**inputs)
generated_texts = self.processor.batch_decode(
generated_ids, skip_special_tokens=True
)
generated_texts = [
generated_text.strip() for generated_text in generated_texts
]
return CaptioningBatchOutput(captions=generated_texts)
except Exception as e:
raise InferenceException(self.model_id) from e
Loading

0 comments on commit 7d7ad12

Please sign in to comment.