diff --git a/tests/test_openjtalk.py b/tests/test_openjtalk.py index adfc24e..5a27466 100644 --- a/tests/test_openjtalk.py +++ b/tests/test_openjtalk.py @@ -1,3 +1,4 @@ +from concurrent.futures import ThreadPoolExecutor from pathlib import Path import pyopenjtalk @@ -68,7 +69,10 @@ def test_g2p_kana(): for text, pron in [ ("今日もこんにちは", "キョーモコンニチワ"), ("いやあん", "イヤーン"), - ("パソコンのとりあえず知っておきたい使い方", "パソコンノトリアエズシッテオキタイツカイカタ"), + ( + "パソコンのとりあえず知っておきたい使い方", + "パソコンノトリアエズシッテオキタイツカイカタ", + ), ]: p = pyopenjtalk.g2p(text, kana=True) assert p == pron @@ -108,3 +112,27 @@ def test_userdic(): ]: p = pyopenjtalk.g2p(text) assert p == expected + + +def test_multithreading(): + ojt = pyopenjtalk.openjtalk.OpenJTalk(pyopenjtalk.OPEN_JTALK_DICT_DIR) + texts = [ + "今日もいい天気ですね", + "こんにちは", + "マルチスレッドプログラミング", + "テストです", + "Pythonはプログラミング言語です", + "日本語テキストを音声合成します", + ] * 4 + + # Test consistency between single and multi-threaded runs + # make sure no corruptions happen in OJT internal + results_s = [ojt.run_frontend(text) for text in texts] + results_m = [] + with ThreadPoolExecutor() as e: + results_m = [i for i in e.map(ojt.run_frontend, texts)] + for s, m in zip(results_s, results_m): + assert len(s) == len(m) + for s_, m_ in zip(s, m): + # full context must exactly match + assert s_ == m_