Skip to content

Commit

Permalink
feat: refactor init to load custom models
Browse files Browse the repository at this point in the history
  • Loading branch information
dblencowe committed Jan 22, 2024
1 parent 6b9ea42 commit bf8d0ca
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 65 deletions.
130 changes: 69 additions & 61 deletions ovos_tts_plugin_piper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ class PiperTTSPlugin(TTS):

def __init__(self, lang="en-us", config=None):
super(PiperTTSPlugin, self).__init__(lang, config)
self.voice = self.config.get("model", self.voice)

if self.voice == "default":
if self.lang.startswith("en"):
# alan pope is the default english voice of mycroft/OVOS
Expand All @@ -262,19 +264,14 @@ def __init__(self, lang="en-us", config=None):
self.length_scale = self.config.get("length-scale") # Phoneme length
self.noise_w = self.config.get("noise-w") # Phoneme width noise

# pre-load models
preload_voices = self.config.get("preload_voices") or []
preload_langs = self.config.get("preload_langs") or [self.lang]

for lang in preload_langs:
if lang not in self.lang2voices:
lang = lang.split("-")[0]
voice = self.lang2voices.get(lang)
if voice and voice not in preload_voices:
preload_voices.append(voice)
self.get_model(voice=self.voice)

for voice in preload_voices:
self.get_model(voice=voice)
def get_model_name(self, src: str) -> str:
if src in self.voice2url:
return self.voice2url[src].rsplit('/', 1)[-1].split('.')[0]
if src.startswith("http"):
return src.rsplit('/', 1)[-1].split('.')[0]
raise ValueError("Must be predefined voice or url to onnx / gz model")

def get_model(self, lang=None, voice=None, speaker=None):

Expand Down Expand Up @@ -310,57 +307,67 @@ def get_model(self, lang=None, voice=None, speaker=None):
if voice in PiperTTSPlugin.engines:
return PiperTTSPlugin.engines[voice], speaker

# find requested voice
if voice in self.voice2url:
xdg_p = f"{xdg_data_home()}/piper_tts/{voice}"
if not os.path.isdir(xdg_p):
xdg_p = f"{xdg_data_home()}/piper_tts/{self.get_model_name(voice)}"
if not os.path.isdir(xdg_p):
url = voice
if voice in self.voice2url:
url = self.voice2url[voice]

m = url.split("/")[-1]
xdg_p = f"{xdg_data_home()}/piper_tts/{m.split('.')[0]}"

model_file = f"{xdg_p}/{m}"
if not os.path.isfile(model_file):
LOG.info(f"downloading piper model: {url}")
os.makedirs(xdg_p, exist_ok=True)
# TODO - streaming download
data = requests.get(url)
with open(model_file, "wb") as f:
f.write(data.content)

if url.endswith(".onnx"):
json_data = requests.get(url + '.json')
with open(model_file + '.json', "wb") as f:
f.write(json_data.content)
else:
with tarfile.open(model_file) as file:
file.extractall(xdg_p)

for f in os.listdir(xdg_p):
if f.endswith(".onnx"):
model = f"{xdg_p}/{f}"

with open(model + ".json", "r", encoding="utf-8") as config_file:
config_dict = json.load(config_file)

engine = PiperVoice(
config=PiperConfig.from_dict(config_dict),
session=onnxruntime.InferenceSession(
str(model),
sess_options=onnxruntime.SessionOptions(),
providers=["CPUExecutionProvider"]
if not self.use_cuda
else ["CUDAExecutionProvider"],
),
)

LOG.debug(f"loaded model: {model}")
PiperTTSPlugin.engines[voice] = engine
return engine, speaker
else:
raise FileNotFoundError("onnx model not found")
self.download_model(xdg_p, url)

engine = self.load_model_directory(xdg_p)
LOG.debug(f"loaded model: {xdg_p}")
PiperTTSPlugin.engines[voice] = engine
return engine, speaker

def download_model(self, model_basepath: str, url: str):
"""Download the file for a model into a model directory
Model name will be determined from the file name
Arguments:
model_basepath (str): the directory to store models in
url (str): the url to the .onnx or .tar.gz
"""
if not url.startswith("http"):
raise ValueError("model url must start with http")

file_name = url.rsplit("/", -1)
model_name = self.get_model_name(url)
model_directory = f"{model_basepath}/{model_name}"
os.makedirs(model_directory, exist_ok=True)
LOG.info(f"downloading piper model: {url}")
data = requests.get(url, timeout=120)
with open(f"{model_directory}/{file_name}", "wb") as f:
f.write(data.content)

if url.endswith(".onnx"):
json_data = requests.get(url + '.json', timeout=120)
with open(f"{model_directory}/{file_name}.json", "wb") as f:
f.write(json_data.content)
else:
raise ValueError(f"invalid voice: {voice}")
with tarfile.open(f"{model_directory}/{file_name}") as file:
file.extractall(f"{model_basepath}/{model_name}")

def load_model_directory(self, model_dir: str) -> PiperVoice:
"""Create an instance of a PiperVoice from a directory containing an .onnx file and its .json definition"""
for f in os.listdir(model_dir):
if f.endswith("onnx"):
model = f"{model_dir}/{f}"

with open(model + ".json", "r", encoding="utf-8") as config_file:
config_dict = json.load(config_file)

return PiperVoice(
config = PiperConfig.from_dict(config_dict),
session=onnxruntime.InferenceSession(
str(model),
sess_options=onnxruntime.SessionOptions(),
providers=["CPUExecutionProvider"]
if not self.use_cuda
else ["CUDAExecutionProvider"],
),
)
raise FileNotFoundError("onnx model not found")


def get_tts(self, sentence, wav_file, lang=None, voice=None, speaker=None):
"""Generate WAV and phonemes.
Expand All @@ -376,6 +383,7 @@ def get_tts(self, sentence, wav_file, lang=None, voice=None, speaker=None):
tuple ((str) file location, (str) generated phonemes)
"""
lang = lang or self.lang
voice = voice or self.voice
# HACK: bug in some neon-core versions - neon_audio.tts.neon:_get_tts:198 - INFO - Legacy Neon TTS signature found
if isinstance(speaker, dict):
LOG.warning("Legacy Neon TTS signature found, pass speaker as a str")
Expand Down
17 changes: 13 additions & 4 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,20 @@ OVOS TTS plugin for [piper](https://github.com/rhasspy/piper)

## Configuration

download models from https://github.com/rhasspy/piper/releases/tag/v0.0.2
Models can be loaded from the ones built-in to Piper, a list of pre-defined downloadable voices, or from the devices local storage.
Models are stored in `$XDG_HOME/piper_tts/$model_name"` and consist of a .onnx and .json file, ex:
```bash
ls -1 $XDG_HOME/piper_tts/example/
example.onnx
example.onnx.json
```

you can also pass an url for a .tar.gz model, and it will be auto downloaded
Available aliases can be found in [this list](https://github.com/OpenVoiceOS/ovos-tts-plugin-piper/blob/dev/ovos_tts_plugin_piper/__init__.py#L154)
A list of downloadable models can be found [here](https://github.com/rhasspy/piper/releases/tag/v0.0.2) or [here](https://huggingface.co/rhasspy/piper-voices/tree/main), to use one just link to the .onnx file in the `voice` parameter of the configuration

if no model is passed it will be auto selected based on language
Passed URLs can be to a .onnx file which contains an appropriately named .json definition file in the same location, or to a .tar.gz archive containing the files

you can pass a model name alias, eg "alan-low"
if no model is passed it will be auto selected based on language

```json
"tts": {
Expand All @@ -24,3 +31,5 @@ you can pass a model name alias, eg "alan-low"
}
}
```


0 comments on commit bf8d0ca

Please sign in to comment.