Skip to content

Commit

Permalink
add: .get_engine() で latest version を取得できるように (#1421)
Browse files Browse the repository at this point in the history
* add: `.get_engine()` に None 入力による latest 取得を追加

* refactor: Cancellable の version を core_version へ変更して整理

* refactor: morphing のコア直接依存を削除

* refactor: None → LATEST_VERSION の変換を導入

* refactor: 定数を Final 化
  • Loading branch information
tarepan authored Jun 28, 2024
1 parent 77d8414 commit 000b363
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 34 deletions.
21 changes: 20 additions & 1 deletion test/unit/tts_pipeline/test_tts_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import HTTPException

from voicevox_engine.dev.tts_engine.mock import MockTTSEngine
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager
from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager


def test_tts_engines_register_engine() -> None:
Expand Down Expand Up @@ -48,6 +48,25 @@ def test_tts_engines_get_engine_existing() -> None:
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_latest() -> None:
"""TTSEngineManager.get_engine(LATEST_VERSION) で最新版の TTS エンジンを取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
tts_engine2 = MockTTSEngine()
tts_engine3 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")
tts_engines.register_engine(tts_engine3, "0.1.0")
# Expects
true_acquired_tts_engine = tts_engine3
# Outputs
acquired_tts_engine = tts_engines.get_engine(LATEST_VERSION)

# Test
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_missing() -> None:
"""TTSEngineManager.get_engine() で存在しない TTS エンジンを取得しようとするとエラーになる。"""
# Inputs
Expand Down
6 changes: 2 additions & 4 deletions voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,9 @@ def _get_core_characters(version: str | None) -> list[CoreCharacter]:
)

app.include_router(
generate_tts_pipeline_router(
tts_engines, core_manager, preset_manager, cancellable_engine
)
generate_tts_pipeline_router(tts_engines, preset_manager, cancellable_engine)
)
app.include_router(generate_morphing_router(tts_engines, core_manager, metas_store))
app.include_router(generate_morphing_router(tts_engines, metas_store))
app.include_router(
generate_preset_router(preset_manager, verify_mutability_allowed)
)
Expand Down
9 changes: 3 additions & 6 deletions voicevox_engine/app/routers/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from starlette.background import BackgroundTask
from starlette.responses import FileResponse

from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.metas.MetasStore import MetasStore
from voicevox_engine.model import AudioQuery
Expand All @@ -24,7 +23,7 @@
synthesis_morphing_parameter as _synthesis_morphing_parameter,
)
from voicevox_engine.morphing.morphing import synthesize_morphed_wave
from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager
from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager
from voicevox_engine.utility.file_utility import try_delete_file

# キャッシュを有効化
Expand All @@ -34,9 +33,7 @@


def generate_morphing_router(
tts_engines: TTSEngineManager,
core_manager: CoreManager,
metas_store: MetasStore,
tts_engines: TTSEngineManager, metas_store: MetasStore
) -> APIRouter:
"""モーフィング API Router を生成する"""
router = APIRouter(tags=["音声合成"])
Expand Down Expand Up @@ -89,7 +86,7 @@ def _synthesis_morphing(
指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。
モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)

# モーフィングが許可されないキャラクターペアを拒否する
Expand Down
31 changes: 15 additions & 16 deletions voicevox_engine/app/routers/tts_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
CancellableEngine,
CancellableEngineInternalError,
)
from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.model import AudioQuery
from voicevox_engine.preset.preset_manager import (
Expand All @@ -39,6 +38,7 @@
Score,
)
from voicevox_engine.tts_pipeline.tts_engine import (
LATEST_VERSION,
TalkSingInvalidInputError,
TTSEngineManager,
)
Expand All @@ -65,7 +65,6 @@ def __init__(self, err: ParseKanaError):

def generate_tts_pipeline_router(
tts_engines: TTSEngineManager,
core_manager: CoreManager,
preset_manager: PresetManager,
cancellable_engine: CancellableEngine | None,
) -> APIRouter:
Expand All @@ -85,7 +84,7 @@ def audio_query(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
accent_phrases = engine.create_accent_phrases(text, style_id)
return AudioQuery(
Expand Down Expand Up @@ -116,7 +115,7 @@ def audio_query_from_preset(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
try:
presets = preset_manager.load_presets()
Expand Down Expand Up @@ -175,7 +174,7 @@ def accent_phrases(
* アクセント位置を`'`で指定する。全てのアクセント句にはアクセント位置を1つ指定する必要がある。
* アクセント句末に`?`(全角)を入れることにより疑問文の発音ができる。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
if is_kana:
try:
Expand All @@ -197,7 +196,7 @@ def mora_data(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[AccentPhrase]:
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
return engine.update_length_and_pitch(accent_phrases, style_id)

Expand All @@ -211,7 +210,7 @@ def mora_length(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[AccentPhrase]:
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
return engine.update_length(accent_phrases, style_id)

Expand All @@ -225,7 +224,7 @@ def mora_pitch(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[AccentPhrase]:
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
return engine.update_pitch(accent_phrases, style_id)

Expand Down Expand Up @@ -253,7 +252,7 @@ def synthesis(
] = True,
core_version: str | SkipJsonSchema[None] = None,
) -> FileResponse:
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
wave = engine.synthesize_wave(
query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak
Expand Down Expand Up @@ -294,8 +293,8 @@ def cancellable_synthesis(
status_code=404,
detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。",
)
version = core_version or core_manager.latest_version()
try:
version = core_version or LATEST_VERSION
f_name = cancellable_engine._synthesis_impl(
query, style_id, request, version=version
)
Expand Down Expand Up @@ -331,7 +330,7 @@ def multi_synthesis(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> FileResponse:
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
sampling_rate = queries[0].outputSamplingRate

Expand Down Expand Up @@ -374,7 +373,7 @@ def sing_frame_audio_query(
"""
歌唱音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま歌唱音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
try:
phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume(
Expand Down Expand Up @@ -403,7 +402,7 @@ def sing_frame_volume(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[float]:
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
try:
return engine.create_sing_volume_from_phoneme_and_f0(
Expand Down Expand Up @@ -432,7 +431,7 @@ def frame_synthesis(
"""
歌唱音声合成を行います。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
try:
wave = engine.frame_synthsize_wave(query, style_id)
Expand Down Expand Up @@ -528,7 +527,7 @@ def initialize_speaker(
指定されたスタイルを初期化します。
実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
engine.initialize_synthesis(style_id, skip_reinit=skip_reinit)

Expand All @@ -540,7 +539,7 @@ def is_initialized_speaker(
"""
指定されたスタイルが初期化されているかどうかを返します。
"""
version = core_version or core_manager.latest_version()
version = core_version or LATEST_VERSION
engine = tts_engines.get_engine(version)
return engine.is_synthesis_initialized(style_id)

Expand Down
10 changes: 5 additions & 5 deletions voicevox_engine/cancellable_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .core.core_initializer import initialize_cores
from .metas.Metas import StyleId
from .model import AudioQuery
from .tts_pipeline.tts_engine import make_tts_engines_from_cores
from .tts_pipeline.tts_engine import LatestVersion, make_tts_engines_from_cores


class CancellableEngineInternalError(Exception):
Expand Down Expand Up @@ -149,7 +149,7 @@ def _synthesis_impl(
query: AudioQuery,
style_id: StyleId,
request: Request,
version: str,
version: str | LatestVersion,
) -> str:
"""
音声合成を行う関数
Expand All @@ -163,7 +163,7 @@ def _synthesis_impl(
request: fastapi.Request
接続確立時に受け取ったものをそのまま渡せばよい
https://fastapi.tiangolo.com/advanced/using-request-directly/
version: str
version
Returns
-------
Expand Down Expand Up @@ -245,9 +245,9 @@ def start_synthesis_subprocess(
while True:
try:
query, style_id, version = sub_proc_con.recv()
if tts_engines.has_engine(version):
try:
_engine = tts_engines.get_engine(version)
else:
except Exception:
# バージョンが見つからないエラー
sub_proc_con.send("")
continue
Expand Down
16 changes: 14 additions & 2 deletions voicevox_engine/tts_pipeline/tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import copy
import math
from typing import Final, Literal, TypeAlias

import numpy as np
from fastapi import HTTPException
from numpy.typing import NDArray
from soxr import resample

from voicevox_engine.utility.core_version_utility import get_latest_version

from ..core.core_adapter import CoreAdapter, DeviceSupport
from ..core.core_initializer import CoreManager
from ..core.core_wrapper import CoreWrapper
Expand Down Expand Up @@ -697,6 +700,10 @@ def frame_synthsize_wave(
return wave


LatestVersion: TypeAlias = Literal["LATEST_VERSION"]
LATEST_VERSION: Final[LatestVersion] = "LATEST_VERSION"


class TTSEngineManager:
"""TTS エンジンの集まりを一括管理するマネージャー"""

Expand All @@ -707,13 +714,18 @@ def versions(self) -> list[str]:
"""登録されたエンジンのバージョン一覧を取得する。"""
return list(self._engines.keys())

def _latest_version(self) -> str:
return get_latest_version(self.versions())

def register_engine(self, engine: TTSEngine, version: str) -> None:
"""エンジンを登録する。"""
self._engines[version] = engine

def get_engine(self, version: str) -> TTSEngine:
def get_engine(self, version: str | LatestVersion) -> TTSEngine:
"""指定バージョンのエンジンを取得する。"""
if version in self._engines:
if version == LATEST_VERSION:
return self._engines[self._latest_version()]
elif version in self._engines:
return self._engines[version]

raise HTTPException(status_code=422, detail="不明なバージョンです")
Expand Down

0 comments on commit 000b363

Please sign in to comment.