From bad1209e250f7e1ebafaaf5ffb2f22e223fb85b2 Mon Sep 17 00:00:00 2001 From: tarepan Date: Mon, 27 Nov 2023 03:15:30 +0900 Subject: [PATCH] =?UTF-8?q?`BasePhoneme`=20=E4=B8=8D=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E3=83=A1=E3=82=BD=E3=83=83=E3=83=89=E3=81=AE=E5=89=8A=E9=99=A4?= =?UTF-8?q?=20(#782)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Refactor BasePhoneme by removing unused methods * Refactor BasePhoneme test by removing unused attr * Refactor unused imports --- test/test_acoustic_feature_extractor.py | 72 ---------------- voicevox_engine/acoustic_feature_extractor.py | 86 ------------------- 2 files changed, 158 deletions(-) diff --git a/test/test_acoustic_feature_extractor.py b/test/test_acoustic_feature_extractor.py index 5d74059ed..9caf2ec71 100644 --- a/test/test_acoustic_feature_extractor.py +++ b/test/test_acoustic_feature_extractor.py @@ -1,6 +1,3 @@ -import os -from pathlib import Path -from typing import List, Type from unittest import TestCase from voicevox_engine.acoustic_feature_extractor import BasePhoneme, OjtPhoneme @@ -13,32 +10,6 @@ def setUp(self): self.base_hello_hiho = [ BasePhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split()) ] - self.lab_str = """ - 0.00 1.00 pau - 1.00 2.00 k - 2.00 3.00 o - 3.00 4.00 N - 4.00 5.00 n - 5.00 6.00 i - 6.00 7.00 ch - 7.00 8.00 i - 8.00 9.00 w - 9.00 10.00 a - 10.00 11.00 pau - 11.00 12.00 h - 12.00 13.00 i - 13.00 14.00 h - 14.00 15.00 o - 15.00 16.00 d - 16.00 17.00 e - 17.00 18.00 s - 18.00 19.00 U - 19.00 20.00 pau - """.replace( - " ", "" - )[ - 1:-1 - ] # ダブルクオーテーションx3で囲われている部分で、空白をすべて置き換え、先頭と最後の"\n"を除外する def test_repr_(self): self.assertEqual( @@ -53,34 +24,6 @@ def test_convert(self): with self.assertRaises(NotImplementedError): BasePhoneme.convert(self.base_hello_hiho) - def test_duration(self): - self.assertEqual(self.base_hello_hiho[1].duration, 1) - - def test_parse(self): - parse_str_1 = "0 1 pau" - parse_str_2 = "32.67543 33.48933 e" - parsed_base_1 = BasePhoneme.parse(parse_str_1) - parsed_base_2 = BasePhoneme.parse(parse_str_2) - self.assertEqual(parsed_base_1.phoneme, "pau") - self.assertEqual(parsed_base_1.start, 0.0) - self.assertEqual(parsed_base_1.end, 1.0) - self.assertEqual(parsed_base_2.phoneme, "e") - self.assertEqual(parsed_base_2.start, 32.68) - self.assertEqual(parsed_base_2.end, 33.49) - - def lab_test_base( - self, - file_path: str, - phonemes: List["BasePhoneme"], - phoneme_class: Type["BasePhoneme"], - ): - phoneme_class.save_lab_list(phonemes, Path(file_path)) - with open(file_path, mode="r") as f: - self.assertEqual(f.read(), self.lab_str) - result_phoneme = phoneme_class.load_lab_list(Path(file_path)) - self.assertEqual(result_phoneme, phonemes) - os.remove(file_path) - class TestOjtPhoneme(TestBasePhoneme): def setUp(self): @@ -118,10 +61,6 @@ def test_equal(self): self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_1) self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_2) - def test_verify(self): - for phoneme in self.ojt_hello_hiho: - phoneme.verify() - def test_phoneme_id(self): ojt_str_hello_hiho = " ".join([str(p.phoneme_id) for p in self.ojt_hello_hiho]) self.assertEqual( @@ -157,14 +96,3 @@ def test_onehot(self): self.assertEqual(phoneme.onehot[j], True) else: self.assertEqual(phoneme.onehot[j], False) - - def test_parse(self): - parse_str_1 = "0 1 pau" - parse_str_2 = "32.67543 33.48933 e" - parsed_ojt_1 = OjtPhoneme.parse(parse_str_1) - parsed_ojt_2 = OjtPhoneme.parse(parse_str_2) - self.assertEqual(parsed_ojt_1.phoneme_id, 0) - self.assertEqual(parsed_ojt_2.phoneme_id, 14) - - def tes_lab_list(self): - self.lab_test_base("./ojt_lab_test", self.ojt_hello_hiho, OjtPhoneme) diff --git a/voicevox_engine/acoustic_feature_extractor.py b/voicevox_engine/acoustic_feature_extractor.py index d3c89d7c1..32f3a59e4 100644 --- a/voicevox_engine/acoustic_feature_extractor.py +++ b/voicevox_engine/acoustic_feature_extractor.py @@ -1,5 +1,4 @@ from abc import abstractmethod -from pathlib import Path from typing import List, Sequence import numpy @@ -41,12 +40,6 @@ def __eq__(self, o: object): self.phoneme == o.phoneme and self.start == o.start and self.end == o.end ) - def verify(self): - """ - 音素クラスとして、データが正しいかassertする - """ - assert self.phoneme in self.phoneme_list, f"{self.phoneme} is not defined." - @property def phoneme_id(self): """ @@ -58,17 +51,6 @@ def phoneme_id(self): """ return self.phoneme_list.index(self.phoneme) - @property - def duration(self): - """ - 音素継続期間を取得する - Returns - ------- - duration : int - 音素継続期間を返す - """ - return self.end - self.start - @property def onehot(self): """ @@ -82,79 +64,11 @@ def onehot(self): array[self.phoneme_id] = True return array - @classmethod - def parse(cls, s: str): - """ - 文字列をパースして音素クラスを作る - Parameters - ---------- - s : str - パースしたい文字列 - - Returns - ------- - phoneme : BasePhoneme - パース結果を用いた音素クラスを返す - - Examples - -------- - >>> BasePhoneme.parse('1.7425000 1.9125000 o:') - Phoneme(phoneme='o:', start=1.74, end=1.91) - """ - words = s.split() - return cls( - start=float(words[0]), - end=float(words[1]), - phoneme=words[2], - ) - @classmethod @abstractmethod def convert(cls, phonemes: List["BasePhoneme"]) -> List["BasePhoneme"]: raise NotImplementedError - @classmethod - def load_lab_list(cls, path: Path): - """ - labファイルを読み込む - Parameters - ---------- - path : Path - 読み込みたいlabファイルのパス - - Returns - ------- - phonemes : List[BasePhoneme] - パース結果を用いた音素クラスを返す - """ - phonemes = [cls.parse(s) for s in path.read_text().split("\n") if len(s) > 0] - phonemes = cls.convert(phonemes) - - for phoneme in phonemes: - phoneme.verify() - return phonemes - - @classmethod - def save_lab_list(cls, phonemes: List["BasePhoneme"], path: Path): - """ - 音素クラスのリストをlabファイル形式で保存する - Parameters - ---------- - phonemes : List[BasePhoneme] - 保存したい音素クラスのリスト - path : Path - labファイルの保存先パス - """ - text = "\n".join( - [ - f"{numpy.round(p.start, decimals=2):.2f}\t" - f"{numpy.round(p.end, decimals=2):.2f}\t" - f"{p.phoneme}" - for p in phonemes - ] - ) - path.write_text(text) - class OjtPhoneme(BasePhoneme): """