Skip to content

Commit

Permalink
Added streaming endpoints for video processing
Browse files Browse the repository at this point in the history
  • Loading branch information
movchan74 committed Nov 23, 2023
1 parent f148012 commit 1a045c3
Show file tree
Hide file tree
Showing 13 changed files with 436 additions and 79 deletions.
116 changes: 97 additions & 19 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from aana.exceptions.general import MultipleFileUploadNotAllowed
from aana.models.pydantic.exception_response import ExceptionResponseModel
from pydantic import ValidationError


async def run_pipeline(
Expand Down Expand Up @@ -93,6 +94,20 @@ class FileUploadField:
description: str


@dataclass
class EndpointOutput:
"""
Class used to represent an endpoint output.
Attributes:
name (str): Name of the output that should be returned by the endpoint.
output (str): Output of the pipeline that should be returned by the endpoint.
"""

name: str
output: str


@dataclass
class Endpoint:
"""
Expand All @@ -102,8 +117,7 @@ class Endpoint:
name (str): Name of the endpoint.
path (str): Path of the endpoint.
summary (str): Description of the endpoint that will be shown in the API documentation.
outputs (List[str]): List of required outputs from the pipeline that should be returned
by the endpoint.
outputs (List[EndpointOutput]): List of outputs that should be returned by the endpoint.
output_filter (Optional[OutputFilter]): The parameter will be added to the request and
will allow to choose subset of `outputs` to return.
streaming (bool): Whether the endpoint outputs a stream of data.
Expand All @@ -112,10 +126,19 @@ class Endpoint:
name: str
path: str
summary: str
outputs: List[str]
outputs: list[EndpointOutput]
output_filter: Optional[OutputFilter] = None
streaming: bool = False

def __post_init__(self):
"""
Post init method.
Creates dictionaries for fast lookup of outputs.
"""
self.name_to_output = {output.name: output.output for output in self.outputs}
self.output_to_name = {output.output: output.name for output in self.outputs}

def generate_model_name(self, suffix: str) -> str:
"""
Generate a Pydantic model name based on a given suffix.
Expand Down Expand Up @@ -145,28 +168,49 @@ def socket_to_field(self, socket: Socket) -> Tuple[Any, Any]:
data_model = Any
return (data_model, Field(None))

# check if any of the fields are required
if any(field.required for field in data_model.__fields__.values()):
# try to instantiate the data model
# to see if any of the fields are required
try:
data_model_instance = data_model()
return (data_model, data_model_instance)
except ValidationError:
# if we can't instantiate the data model
# it means that it has required fields
return (data_model, ...)

return (data_model, data_model())

def get_fields(self, sockets: List[Socket]) -> Dict[str, Tuple[Any, Any]]:
def get_input_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]:
"""
Generate fields for the Pydantic model based on the provided sockets.
Generate fields for the request Pydantic model based on the provided sockets.
Parameters:
sockets (List[Socket]): List of sockets.
sockets (list[Socket]): List of sockets.
Returns:
Dict[str, Tuple[Any, Field]]: Dictionary of fields for the Pydantic model.
dict[str, tuple[Any, Field]]: Dictionary of fields for the request Pydantic model.
"""
fields = {}
for socket in sockets:
field = self.socket_to_field(socket)
fields[socket.name] = field
return fields

def get_output_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]:
"""
Generate fields for the response Pydantic model based on the provided sockets.
Parameters:
sockets (list[Socket]): List of sockets.
Returns:
dict[str, tuple[Any, Field]]: Dictionary of fields for the response Pydantic model.
"""
fields = {}
for socket in sockets:
field = self.socket_to_field(socket)
name = self.output_to_name[socket.name]
fields[name] = field
return fields

def get_file_upload_field(
self, input_sockets: List[Socket]
) -> Optional[FileUploadField]:
Expand Down Expand Up @@ -219,7 +263,9 @@ def get_output_filter_field(self) -> Optional[Tuple[Any, Any]]:
description = self.output_filter.description
outputs_enum_name = self.generate_model_name("Outputs")
outputs_enum = Enum( # type: ignore
outputs_enum_name, [(output, output) for output in self.outputs], type=str
outputs_enum_name,
[(output.name, output.name) for output in self.outputs],
type=str,
)
field = (Optional[List[outputs_enum]], Field(None, description=description))
return field
Expand All @@ -235,7 +281,7 @@ def get_request_model(self, input_sockets: List[Socket]) -> Type[BaseModel]:
Type[BaseModel]: Pydantic model for the request.
"""
model_name = self.generate_model_name("Request")
input_fields = self.get_fields(input_sockets)
input_fields = self.get_input_fields(input_sockets)
output_filter_field = self.get_output_filter_field()
if output_filter_field and self.output_filter:
input_fields[self.output_filter.name] = output_filter_field
Expand All @@ -253,10 +299,30 @@ def get_response_model(self, output_sockets: List[Socket]) -> Type[BaseModel]:
Type[BaseModel]: Pydantic model for the response.
"""
model_name = self.generate_model_name("Response")
output_fields = self.get_fields(output_sockets)
output_fields = self.get_output_fields(output_sockets)
ResponseModel = create_model(model_name, **output_fields)
return ResponseModel

def process_output(self, output: dict[str, Any]) -> dict[str, Any]:
"""
Process the output of the pipeline.
Maps the output names of the pipeline to the names defined in the endpoint outputs.
For example, maps videos_captions_hf_blip2_opt_2_7b to captions.
Args:
output (dict): The output of the pipeline.
Returns:
dict: The processed output.
"""
output = {
self.output_to_name.get(output_name, output_name): output_value
for output_name, output_value in output.items()
}
return output

def create_endpoint_func(
self,
pipeline: Pipeline,
Expand Down Expand Up @@ -293,26 +359,36 @@ async def route_func_body(body: str, files: Optional[List[UploadFile]] = None):
if requested_outputs:
# get values for requested outputs because it's a list of enums
requested_outputs = [output.value for output in requested_outputs]
outputs = requested_outputs
# map the requested outputs to the actual outputs
# for example, videos_captions_hf_blip2_opt_2_7b to captions
outputs = [self.name_to_output[output] for output in requested_outputs]
# otherwise use the required outputs from the config (all outputs endpoint provides)
else:
outputs = self.outputs
outputs = [output.output for output in self.outputs]

# remove the output filter parameter from the data
if self.output_filter and self.output_filter.name in data_dict:
del data_dict[self.output_filter.name]

# run the pipeline
if self.streaming:

async def generator_wrapper():
"""
Serializes the output of the generator using ORJSONResponseCustom
"""
async for output in run_pipeline_streaming(pipeline, data_dict, outputs):
async for output in run_pipeline_streaming(
pipeline, data_dict, outputs
):
output = self.process_output(output)
yield AanaJSONResponse(content=output).body
return StreamingResponse(generator_wrapper(), media_type="application/json")

return StreamingResponse(
generator_wrapper(), media_type="application/json"
)
else:
output = await run_pipeline(pipeline, data_dict, outputs)
output = self.process_output(output)
return AanaJSONResponse(content=output)

if file_upload_field:
Expand All @@ -336,7 +412,9 @@ def register(
pipeline (Pipeline): Pipeline to register the endpoint to.
custom_schemas (Dict[str, Dict]): Dictionary of custom schemas.
"""
input_sockets, output_sockets = pipeline.get_sockets(self.outputs)
input_sockets, output_sockets = pipeline.get_sockets(
[output.output for output in self.outputs]
)
RequestModel = self.get_request_model(input_sockets)
ResponseModel = self.get_response_model(output_sockets)
file_upload_field = self.get_file_upload_field(input_sockets)
Expand Down
2 changes: 1 addition & 1 deletion aana/configs/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_configuration(target: str, endpoints, nodes, deployments) -> Dict:
# Target endpoints require the following outputs
endpoint_outputs = []
for endpoint in target_endpoints:
endpoint_outputs += endpoint.outputs
endpoint_outputs += [output.output for output in endpoint.outputs]

# Build the output graph for the whole pipeline
node_definitions = [NodeDefinition.from_dict(node_dict) for node_dict in nodes]
Expand Down
96 changes: 86 additions & 10 deletions aana/configs/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from aana.api.api_generation import Endpoint
from aana.api.api_generation import Endpoint, EndpointOutput


endpoints = {
Expand All @@ -7,13 +7,19 @@
name="llm_generate",
path="/llm/generate",
summary="Generate text using LLaMa2 7B Chat",
outputs=["vllm_llama2_7b_chat_output"],
outputs=[
EndpointOutput(name="completion", output="vllm_llama2_7b_chat_output")
],
),
Endpoint(
name="llm_generate_stream",
path="/llm/generate_stream",
summary="Generate text using LLaMa2 7B Chat (streaming)",
outputs=["vllm_llama2_7b_chat_output_stream"],
outputs=[
EndpointOutput(
name="completion", output="vllm_llama2_7b_chat_output_stream"
)
],
streaming=True,
),
],
Expand All @@ -22,29 +28,41 @@
name="zephyr_generate",
path="/llm/generate",
summary="Generate text using Zephyr 7B Beta",
outputs=["vllm_zephyr_7b_beta_output"],
outputs=[
EndpointOutput(name="completion", output="vllm_zephyr_7b_beta_output")
],
)
],
"blip2": [
Endpoint(
name="blip2_generate",
path="/image/generate_captions",
summary="Generate captions for images using BLIP2 OPT-2.7B",
outputs=["captions_hf_blip2_opt_2_7b"],
outputs=[
EndpointOutput(name="captions", output="captions_hf_blip2_opt_2_7b")
],
),
Endpoint(
name="blip2_video_generate",
path="/video/generate_captions",
summary="Generate captions for videos using BLIP2 OPT-2.7B",
outputs=["video_captions_hf_blip2_opt_2_7b", "timestamps"],
outputs=[
EndpointOutput(
name="captions", output="videos_captions_hf_blip2_opt_2_7b"
),
EndpointOutput(name="timestamps", output="timestamps"),
],
),
],
"video": [
Endpoint(
name="video_extract_frames",
path="/video/extract_frames",
summary="Extract frames from a video",
outputs=["timestamps", "duration"],
outputs=[
EndpointOutput(name="timestamps", output="timestamps"),
EndpointOutput(name="duration", output="duration"),
],
)
],
"whisper": [
Expand All @@ -53,10 +71,68 @@
path="/video/transcribe",
summary="Transcribe a video using Whisper Medium",
outputs=[
"video_transcriptions_whisper_medium",
"video_transcriptions_segments_whisper_medium",
"video_transcriptions_info_whisper_medium",
EndpointOutput(
name="transcription", output="videos_transcriptions_whisper_medium"
),
EndpointOutput(
name="segments",
output="videos_transcriptions_segments_whisper_medium",
),
EndpointOutput(
name="info", output="videos_transcriptions_info_whisper_medium"
),
],
)
],
"chat_with_video": [
Endpoint(
name="blip2_video_generate",
path="/video/generate_captions",
summary="Generate captions for videos using BLIP2 OPT-2.7B",
outputs=[
EndpointOutput(
name="captions", output="video_captions_hf_blip2_opt_2_7b"
),
EndpointOutput(name="timestamps", output="video_timestamps"),
],
streaming=True,
),
Endpoint(
name="whisper_transcribe",
path="/video/transcribe",
summary="Transcribe a video using Whisper Medium",
outputs=[
EndpointOutput(
name="transcription", output="video_transcriptions_whisper_medium"
),
EndpointOutput(
name="segments",
output="video_transcriptions_segments_whisper_medium",
),
EndpointOutput(
name="info", output="video_transcriptions_info_whisper_medium"
),
],
streaming=True,
),
Endpoint(
name="llm_generate",
path="/llm/generate",
summary="Generate text using LLaMa2 7B Chat",
outputs=[
EndpointOutput(name="completion", output="vllm_llama2_7b_chat_output")
],
),
Endpoint(
name="llm_generate_stream",
path="/llm/generate_stream",
summary="Generate text using LLaMa2 7B Chat (streaming)",
outputs=[
EndpointOutput(
name="completion", output="vllm_llama2_7b_chat_output_stream"
)
],
streaming=True,
),
],
}
Loading

0 comments on commit 1a045c3

Please sign in to comment.