Skip to content

Commit

Permalink
Refactor Pydantic serialization and
Browse files Browse the repository at this point in the history
update Whisper options
  • Loading branch information
movchan74 committed Nov 17, 2023
1 parent eb664db commit df73c3c
Show file tree
Hide file tree
Showing 10 changed files with 529 additions and 516 deletions.
15 changes: 9 additions & 6 deletions aana/api/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@
from pydantic import BaseModel


def orjson_default(obj: Any) -> Any:
def json_serializer_default(obj: Any) -> Any:
"""
Default function for orjson.dumps to handle pydantic models.
Default function for json serializer to handle pydantic models.
If orjson does not know how to serialize an object, it calls the default function.
If json serializer does not know how to serialize an object, it calls the default function.
If we see that the object is a pydantic model,
we call the dict method to get the dictionary representation of the model that orjson can serialize.
we call the dict method to get the dictionary representation of the model
that json serializer can deal with.
If the object is not a pydantic model, we raise a TypeError.
Args:
obj (Any): The object to serialize.
Returns:
Any: The serialized object.
Any: The serializable object.
Raises:
TypeError: If the object is not a pydantic model.
Expand Down Expand Up @@ -50,4 +51,6 @@ def render(self, content: Any) -> bytes:
"""
Override the render method to use orjson.dumps instead of json.dumps.
"""
return orjson.dumps(content, option=self.option, default=orjson_default)
return orjson.dumps(
content, option=self.option, default=json_serializer_default
)
41 changes: 22 additions & 19 deletions aana/deployments/whisper_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@ class WhisperComputeType(str, Enum):
"""
The data type used by whisper models.
See [cTranslate2 docs on quantization](https://opennmt.net/CTranslate2/quantization.html#quantize-on-model-conversion)
for more information.
Available types:
- INT8 (int8)
- INT8_FLOAT32 (int8_float32)
- INT8_FLOAT16 (int8_float16)
- INT8_BFLOAT16 (int8_bfloat16)
- INT16 (int16)
- FLOAT16 (float16)
- BFLOAT16 (bfloat16)
- FLOAT32 (float32)
- INT8
- INT8_FLOAT32
- INT8_FLOAT16
- INT8_BFLOAT16
- INT16
- FLOAT16
- BFLOAT16
- FLOAT32
"""

INT8 = "int8"
Expand All @@ -46,17 +49,17 @@ class WhisperModelSize(str, Enum):
The whisper model.
Available models:
- TINY (tiny)
- TINY_EN (tiny.en)
- BASE (base)
- BASE_EN (base.en)
- SMALL (small)
- SMALL_EN (small.en)
- MEDIUM (medium)
- MEDIUM_EN (medium.en)
- LARGE_V1 (large-v1)
- LARGE_V2 (large-v2)
- LARGE (large)
- TINY
- TINY_EN
- BASE
- BASE_EN
- SMALL
- SMALL_EN
- MEDIUM
- MEDIUM_EN
- LARGE_V1
- LARGE_V2
- LARGE
"""

TINY = "tiny"
Expand Down
19 changes: 11 additions & 8 deletions aana/models/pydantic/asr_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
Word as WhisperWord,
TranscriptionInfo as WhisperTranscriptionInfo,
)
from aana.models.pydantic import time_interval

from aana.models.pydantic.base import BaseListModel
from aana.models.pydantic.timestamp import Timestamp
from aana.models.pydantic.time_interval import TimeInterval


class AsrWord(BaseModel):
Expand All @@ -17,12 +18,12 @@ class AsrWord(BaseModel):
Attributes:
word (str): The word text
timestamp (Timestamp): Timestamp of the word
time_interval (TimeInterval): Time interval of the word
alignment_confidence (float): Alignment confidence of the word
"""

word: str = Field(description="The word text")
timestamp: Timestamp = Field(description="Timestamp of the word")
time_interval: TimeInterval = Field(description="Time interval of the word")
alignment_confidence: float = Field(
ge=0.0, le=1.0, description="Alignment confidence of the word"
)
Expand All @@ -34,7 +35,7 @@ def from_whisper(cls, whisper_word: WhisperWord) -> "AsrWord":
"""
return cls(
word=whisper_word.word,
timestamp=Timestamp(start=whisper_word.start, end=whisper_word.end),
time_interval=TimeInterval(start=whisper_word.start, end=whisper_word.end),
alignment_confidence=whisper_word.probability,
)

Expand All @@ -50,14 +51,14 @@ class AsrSegment(BaseModel):
Attributes:
text (str): The text of the segment (transcript/translation)
timestamp (Timestamp): Timestamp of the segment
time_interval (TimeInterval): Time interval of the segment
confidence (float): Confidence of the segment
no_speech_confidence (float): Chance of being a silence segment
words (List[AsrWord]): List of words in the segment
"""

text: str = Field(description="The text of the segment (transcript/translation)")
timestamp: Timestamp = Field(description="Timestamp of the segment")
time_interval: TimeInterval = Field(description="Time interval of the segment")
confidence: float = Field(ge=0.0, le=1.0, description="Confidence of the segment")
no_speech_confidence: float = Field(
ge=0.0, le=1.0, description="Chance of being a silence segment"
Expand All @@ -71,7 +72,9 @@ def from_whisper(cls, whisper_segment: WhisperSegment) -> "AsrSegment":
"""
Convert WhisperSegment to AsrSegment.
"""
timestamp = Timestamp(start=whisper_segment.start, end=whisper_segment.end)
time_interval = TimeInterval(
start=whisper_segment.start, end=whisper_segment.end
)
confidence = np.exp(whisper_segment.avg_logprob)
if whisper_segment.words:
words = [AsrWord.from_whisper(word) for word in whisper_segment.words]
Expand All @@ -80,7 +83,7 @@ def from_whisper(cls, whisper_segment: WhisperSegment) -> "AsrSegment":

return cls(
text=whisper_segment.text,
timestamp=timestamp,
time_interval=time_interval,
confidence=confidence,
no_speech_confidence=whisper_segment.no_speech_prob,
words=words,
Expand Down
19 changes: 19 additions & 0 deletions aana/models/pydantic/time_interval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pydantic import BaseModel, Field


class TimeInterval(BaseModel):
"""
Pydantic schema for TimeInterval.
Attributes:
start (float): Start time in seconds
end (float): End time in seconds
"""

start: float = Field(ge=0.0, description="Start time in seconds")
end: float = Field(ge=0.0, description="End time in seconds")

class Config:
schema_extra = {
"description": "Time interval in seconds",
}
19 changes: 0 additions & 19 deletions aana/models/pydantic/timestamp.py

This file was deleted.

8 changes: 6 additions & 2 deletions aana/models/pydantic/whisper_params.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections.abc
from pydantic import BaseModel, Field, validator
from typing import Optional, Union, List, Tuple

Expand All @@ -12,7 +13,10 @@ class WhisperParams(BaseModel):
beam_size (int): Size of the beam for decoding.
best_of (int): Number of best candidate sentences to consider.
temperature (Union[float, List[float], Tuple[float, ...]]): Controls the sampling
randomness, with a sequence of values indicating fallback temperatures.
randomness. It can be a tuple of temperatures,
which will be successively used upon failures according to either
[compression_ratio_threshold](https://github.com/guillaumekln/faster-whisper/blob/5a0541ea7d054aa3716ac492491de30158c20057/faster_whisper/transcribe.py#L216) or
[log_prob_threshold](https://github.com/guillaumekln/faster-whisper/blob/5a0541ea7d054aa3716ac492491de30158c20057/faster_whisper/transcribe.py#L218C23-L218C23).
word_timestamps (bool): Whether to extract word-level timestamps.
vad_filter (bool): Whether to enable voice activity detection to filter non-speech.
"""
Expand All @@ -27,7 +31,7 @@ class WhisperParams(BaseModel):
best_of: int = Field(
default=5, ge=1, description="Number of best candidate sentences to consider."
)
temperature: Union[float, List[float], Tuple[float, ...]] = Field(
temperature: float | collections.abc.Sequence[float] = Field(
default=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
description=(
"Temperature for sampling. A single value or a sequence indicating fallback temperatures."
Expand Down
6 changes: 3 additions & 3 deletions aana/tests/deployments/test_whisper_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aana.models.core.video import Video
from aana.models.pydantic.whisper_params import WhisperParams
from aana.tests.utils import is_gpu_available, LevenshteinOperator
from aana.utils.general import serialize_pydantic
from aana.utils.general import pydantic_to_dict

EPSILON = 0.01

Expand Down Expand Up @@ -84,7 +84,7 @@ async def test_whisper_deployment(video_file):
output = await handle.transcribe.remote(
media=video, params=WhisperParams(word_timestamps=True)
)
output = serialize_pydantic(output)
output = pydantic_to_dict(output)

compare_transcriptions(expected_output, output)

Expand All @@ -93,7 +93,7 @@ async def test_whisper_deployment(video_file):
batch_output = await handle.transcribe_batch.remote(
media=videos, params=WhisperParams(word_timestamps=True)
)
batch_output = serialize_pydantic(batch_output)
batch_output = pydantic_to_dict(batch_output)

for i in range(len(videos)):
output = {k: v[i] for k, v in batch_output.items()}
Expand Down
Loading

0 comments on commit df73c3c

Please sign in to comment.