Skip to content

Commit

Permalink
ScliteJob, precision_ndigit option (#442)
Browse files Browse the repository at this point in the history
Co-authored-by: vieting <[email protected]>
  • Loading branch information
albertz and vieting authored Aug 20, 2023
1 parent 6bb893f commit f9a9f39
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 27 deletions.
95 changes: 69 additions & 26 deletions recognition/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tempfile
import collections
import re
from typing import List, Optional
from typing import List, Optional, Dict, Tuple

from sisyphus import *
from i6_core.lib.corpus import *
Expand Down Expand Up @@ -51,7 +51,7 @@ class ScliteJob(Job):
- out_*: the job also outputs many variables, please look in the init code for a list
"""

__sis_hash_exclude__ = {"sctk_binary_path": None}
__sis_hash_exclude__ = {"sctk_binary_path": None, "precision_ndigit": 1}

def __init__(
self,
Expand All @@ -61,6 +61,7 @@ def __init__(
sort_files: bool = False,
additional_args: Optional[List[str]] = None,
sctk_binary_path: Optional[tk.Path] = None,
precision_ndigit: Optional[int] = 1,
):
"""
:param ref: reference stm text file
Expand All @@ -69,6 +70,12 @@ def __init__(
:param sort_files: sort ctm and stm before scoring
:param additional_args: additional command line arguments passed to the Sclite binary call
:param sctk_binary_path: set an explicit binary path.
:param precision_ndigit: number of digits after decimal point for the precision
of the percentages in the output variables.
If None, no rounding is done.
In sclite, the precision was always one digit after the decimal point
(https://github.com/usnistgov/SCTK/blob/f48376a203ab17f/src/sclite/sc_dtl.c#L343),
thus we recalculate the percentages here.
"""
self.set_vis_name("Sclite - %s" % ("CER" if cer else "WER"))

Expand All @@ -78,6 +85,7 @@ def __init__(
self.sort_files = sort_files
self.additional_args = additional_args
self.sctk_binary_path = sctk_binary_path
self.precision_ndigit = precision_ndigit

self.out_report_dir = self.output_path("reports", True)

Expand Down Expand Up @@ -149,31 +157,66 @@ def run(self, output_to_report_dir=True):

if output_to_report_dir: # run as real job
with open(f"{output_dir}/sclite.dtl", "rt", errors="ignore") as f:
# Example:
"""
Percent Total Error = 5.3% (2709)
...
Percent Word Accuracy = 94.7%
...
Ref. words = (50948)
"""

# key -> percentage, absolute
output_variables: Dict[str, Tuple[Optional[tk.Variable], Optional[tk.Variable]]] = {
"Percent Total Error": (self.out_wer, self.out_num_errors),
"Percent Correct": (self.out_percent_correct, self.out_num_correct),
"Percent Substitution": (self.out_percent_substitution, self.out_num_substitution),
"Percent Deletions": (self.out_percent_deletions, self.out_num_deletions),
"Percent Insertions": (self.out_percent_insertions, self.out_num_insertions),
"Percent Word Accuracy": (self.out_percent_word_accuracy, None),
"Ref. words": (None, self.out_ref_words),
"Hyp. words": (None, self.out_hyp_words),
"Aligned words": (None, self.out_aligned_words),
}

outputs_absolute: Dict[str, int] = {}
for line in f:
s = line.split()
if line.startswith("Percent Total Error"):
self.out_wer.set(float(s[4][:-1]))
self.out_num_errors.set(int("".join(s[5:])[1:-1]))
elif line.startswith("Percent Correct"):
self.out_percent_correct.set(float(s[3][:-1]))
self.out_num_correct.set(int("".join(s[4:])[1:-1]))
elif line.startswith("Percent Substitution"):
self.out_percent_substitution.set(float(s[3][:-1]))
self.out_num_substitution.set(int("".join(s[4:])[1:-1]))
elif line.startswith("Percent Deletions"):
self.out_percent_deletions.set(float(s[3][:-1]))
self.out_num_deletions.set(int("".join(s[4:])[1:-1]))
elif line.startswith("Percent Insertions"):
self.out_percent_insertions.set(float(s[3][:-1]))
self.out_num_insertions.set(int("".join(s[4:])[1:-1]))
elif line.startswith("Percent Word Accuracy"):
self.out_percent_word_accuracy.set(float(s[4][:-1]))
elif line.startswith("Ref. words"):
self.out_ref_words.set(int("".join(s[3:])[1:-1]))
elif line.startswith("Hyp. words"):
self.out_hyp_words.set(int("".join(s[3:])[1:-1]))
elif line.startswith("Aligned words"):
self.out_aligned_words.set(int("".join(s[3:])[1:-1]))
key: Optional[str] = ([key for key in output_variables if line.startswith(key)] or [None])[0]
if not key:
continue
pattern = rf"^{re.escape(key)}\s*=\s*((\S+)%)?\s*(\(\s*(\d+)\))?$"
m = re.match(pattern, line)
assert m, f"Could not parse line: {line!r}, does not match to pattern r'{pattern}'"
absolute_s = m.group(4)
if not absolute_s:
assert not output_variables[key][1], f"Expected absolute value for {key}"
continue
outputs_absolute[key] = int(absolute_s)
if key == "Aligned words":
break # that should be the last key, can stop now

assert "Ref. words" in outputs_absolute, "Expected absolute numbers for Ref. words"
num_ref_words = outputs_absolute["Ref. words"]
assert "Percent Total Error" in outputs_absolute, "Expected absolute numbers for Percent Total Error"
outputs_absolute["Percent Word Accuracy"] = num_ref_words - outputs_absolute["Percent Total Error"]

outputs_percentage: Dict[str, float] = {}
for key, absolute in outputs_absolute.items():
if num_ref_words > 0:
percentage = 100.0 * absolute / num_ref_words
else:
percentage = float("nan")
outputs_percentage[key] = (
round(percentage, self.precision_ndigit) if self.precision_ndigit is not None else percentage
)

for key, (percentage_var, absolute_var) in output_variables.items():
if percentage_var is not None:
assert key in outputs_percentage, f"Expected percentage value for {key}"
percentage_var.set(outputs_percentage[key])
if absolute_var is not None:
assert key in outputs_absolute, f"Expected absolute value for {key}"
absolute_var.set(outputs_absolute[key])

def calc_wer(self):
wer = None
Expand Down
27 changes: 26 additions & 1 deletion tests/job_tests/recognition/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_sclite_job():
assert sclite_job.out_num_deletions.get() == 2, "Wrong num deletions, %s instead of 2" % str(
sclite_job.out_num_deletions.get()
)
assert sclite_job.out_percent_insertions.get() == 5.9, "Wrong percent insertions, %s instead of 4.5" % str(
assert sclite_job.out_percent_insertions.get() == 5.9, "Wrong percent insertions, %s instead of 5.9" % str(
sclite_job.out_percent_insertions.get()
)
assert sclite_job.out_num_insertions.get() == 1, "Wrong num insertions, %s instead of 1" % str(
Expand All @@ -88,3 +88,28 @@ def test_sclite_job():
assert sclite_job.out_aligned_words.get() == 18, "Wrong num aligned words, %s instead of 18" % str(
sclite_job.out_aligned_words.get()
)

# Now test custom precision.

sclite_job = ScliteJob(ref=ref, hyp=hyp, sctk_binary_path=sctk_binary, precision_ndigit=2)
sclite_job._sis_setup_directory()
sclite_job.run()

assert sclite_job.out_wer.get() == 58.82, "Wrong WER, %s instead of 58.82" % str(sclite_job.out_wer.get())

assert sclite_job.out_percent_correct.get() == 47.06, "Wrong percent correct, %s instead of 47.06" % str(
sclite_job.out_percent_correct.get()
)
assert (
sclite_job.out_percent_substitution.get() == 41.18
), "Wrong percent substitution, %s instead of 41.18" % str(sclite_job.out_percent_substitution.get())

assert sclite_job.out_percent_deletions.get() == 11.76, "Wrong percent deletions, %s instead of 11.76" % str(
sclite_job.out_percent_deletions.get()
)
assert sclite_job.out_percent_insertions.get() == 5.88, "Wrong percent insertions, %s instead of 5.88" % str(
sclite_job.out_percent_insertions.get()
)
assert (
sclite_job.out_percent_word_accuracy.get() == 41.18
), "Wrong percent word accuracy, %s instead of 41.18" % str(sclite_job.out_percent_word_accuracy.get())

0 comments on commit f9a9f39

Please sign in to comment.