Skip to content

Commit

Permalink
feat: Add option to generate LM image and GC via two separate jobs
Browse files Browse the repository at this point in the history
Closes #430
  • Loading branch information
NeoLegends committed Aug 21, 2023
1 parent f9a9f39 commit cb5565e
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions recognition/advanced_tree_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"BuildGlobalCacheJob",
]

import copy
from sisyphus import *

Path = setup_path(__package__)
Expand Down Expand Up @@ -167,6 +168,7 @@ def __init__(
lmgc_mem: float = 12.0,
lmgc_alias: Optional[str] = None,
lmgc_scorer: Optional[rasr.FeatureScorer] = None,
separate_lmi_gc_generation: bool = False,
model_combination_config: Optional[rasr.RasrConfig] = None,
model_combination_post_config: Optional[rasr.RasrConfig] = None,
extra_config: Optional[rasr.RasrConfig] = None,
Expand Down Expand Up @@ -286,18 +288,40 @@ def create_config(
lmgc_mem: float,
lmgc_alias: Optional[str],
lmgc_scorer: Optional[rasr.FeatureScorer],
separate_lmi_gc_generation: bool,
model_combination_config: Optional[rasr.RasrConfig],
model_combination_post_config: Optional[rasr.RasrConfig],
extra_config: Optional[rasr.RasrConfig],
extra_post_config: Optional[rasr.RasrConfig],
**kwargs,
):
lm_gc = AdvancedTreeSearchLmImageAndGlobalCacheJob(
crp, lmgc_scorer if lmgc_scorer is not None else feature_scorer, extra_config, extra_post_config
)
if lmgc_alias is not None:
lm_gc.add_alias(lmgc_alias)
lm_gc.rqmt["mem"] = lmgc_mem
def specialize_lm_config(crp, lm_config):
crp = copy.deepcopy(crp)
crp.language_model = lm_config
return crp

if separate_lmi_gc_generation:
gc = BuildGlobalCacheJob(crp, extra_config, extra_post_config).out_global_cache

arpa_lms = AdvancedTreeSearchLmImageAndGlobalCacheJob.find_arpa_lms(
crp.language_model, post_config.lm if post_config is not None else None
)
lm_images = {
(i + 1): lm.CreateLmImageJob(
specialize_lm_config(crp, lm), extra_config=extra_config, extra_post_config=extra_post_config
).out_lm
for i, lm in enumerate(arpa_lms)
}
else:
lm_gc = AdvancedTreeSearchLmImageAndGlobalCacheJob(
crp, lmgc_scorer if lmgc_scorer is not None else feature_scorer, extra_config, extra_post_config
)
if lmgc_alias is not None:
lm_gc.add_alias(lmgc_alias)
lm_gc.rqmt["mem"] = lmgc_mem

gc = lm_gc.out_global_cache
lm_images = lm_gc.out_lm_images

search_parameters = cls.update_search_parameters(search_parameters)

Expand Down Expand Up @@ -397,14 +421,14 @@ def create_config(
]

post_config.flf_lattice_tool.global_cache.read_only = True
post_config.flf_lattice_tool.global_cache.file = lm_gc.out_global_cache
post_config.flf_lattice_tool.global_cache.file = gc

arpa_lms = AdvancedTreeSearchLmImageAndGlobalCacheJob.find_arpa_lms(
config.flf_lattice_tool.network.recognizer.lm,
post_config.flf_lattice_tool.network.recognizer.lm,
)
for i, lm_config in enumerate(arpa_lms):
lm_config[1].image = lm_gc.out_lm_images[i + 1]
lm_config[1].image = lm_images[i + 1]

# Remaining Flf-network

Expand Down Expand Up @@ -438,11 +462,11 @@ def create_config(
config._update(extra_config)
post_config._update(extra_post_config)

return config, post_config, lm_gc
return config, post_config

@classmethod
def hash(cls, kwargs):
config, post_config, lm_gc = cls.create_config(**kwargs)
config, post_config = cls.create_config(**kwargs)
return super().hash(
{
"config": config,
Expand Down

0 comments on commit cb5565e

Please sign in to comment.