Skip to content

Commit

Permalink
Merge pull request #183 from ku-nlp/improve-disambiguation
Browse files Browse the repository at this point in the history
Improve disambiguation
  • Loading branch information
Taka008 authored Jul 6, 2023
2 parents 35a0480 + f484c4d commit c8a302b
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 18 deletions.
3 changes: 3 additions & 0 deletions configs/datamodule/seq2seq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@ defaults:
- seq2seq_fuman.yaml
- seq2seq_wac.yaml
- seq2seq_norm.yaml
- seq2seq_canon.yaml
- valid:
- seq2seq_kyoto.yaml
- seq2seq_kwdlc.yaml
- seq2seq_fuman.yaml
- seq2seq_wac.yaml
- seq2seq_norm.yaml
- seq2seq_canon.yaml
- test:
- seq2seq_kyoto.yaml
- seq2seq_kwdlc.yaml
- seq2seq_fuman.yaml
- seq2seq_wac.yaml
- seq2seq_norm.yaml
- seq2seq_canon.yaml
- predict:
- seq2seq_inference.yaml

Expand Down
5 changes: 5 additions & 0 deletions configs/datamodule/test/seq2seq_canon.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- /datamodule/base/seq2seq@canon

canon:
path: ${oc.env:DATA_DIR}/canon/test
5 changes: 5 additions & 0 deletions configs/datamodule/train/seq2seq_canon.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- /datamodule/base/seq2seq@canon

canon:
path: ${oc.env:DATA_DIR}/canon/train
5 changes: 5 additions & 0 deletions configs/datamodule/valid/seq2seq_canon.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- /datamodule/base/seq2seq@canon

canon:
path: ${oc.env:DATA_DIR}/canon/valid
198 changes: 198 additions & 0 deletions scripts/preprocessors/preprocess_canon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import json
import logging
import random
import re
import shutil
from argparse import ArgumentParser
from pathlib import Path
from typing import Dict, List, Set, Union

from jinf import Jinf
from rhoknp import Morpheme, Sentence
from rhoknp.utils.reader import chunk_by_sentence
from tqdm import tqdm

from kwja.utils.logging_util import filter_logs

filter_logs(environment="production")
logging.basicConfig(format="")

logger = logging.getLogger("kwja_cli")
logger.setLevel(logging.INFO)

jinf = Jinf()

STOP_SURFS: Set[str] = {'"', "\u3000"}


def is_hiragana(value):
return re.match(r"^[\u3040-\u309F\u30FC]+$", value) is not None


def get_canons(morpheme: Morpheme) -> List[str]:
canons: List[str] = []
if morpheme.canon is not None:
canons.append(morpheme.canon)
for feature in morpheme.features:
if feature[:4] == "ALT-":
canons.append(feature.split('"')[1].split(" ")[0].replace("代表表記:", ""))
return canons


def set_canon(sentence: Sentence, morpheme_index: int) -> None:
for morpheme in sentence.morphemes:
if morpheme.index == morpheme_index:
continue
if morpheme.conjtype == "*" or morpheme.pos == "特殊":
canon: str = f"{morpheme.lemma}/{morpheme.reading}"
elif is_hiragana(morpheme.lemma):
canon = f"{morpheme.lemma}/{morpheme.lemma}"
else:
canon_right: str = jinf(morpheme.reading, morpheme.conjtype, morpheme.conjform, "基本形")
canon = f"{morpheme.lemma}/{canon_right}"
morpheme.semantics["代表表記"] = canon


def get_is_excluded(sentence: Sentence) -> bool:
is_excluded: bool = False
for morpheme in sentence.morphemes:
for stop_surf in STOP_SURFS:
if stop_surf in morpheme.surf:
is_excluded = True
break
if (morpheme.pos != "特殊") and (not is_hiragana(morpheme.reading)):
is_excluded = True
break
return is_excluded


def main():
parser = ArgumentParser()
parser.add_argument("-i", "--input-dirs", nargs="*", required=True)
parser.add_argument("-o", "--output-dir", type=str, required=True)
parser.add_argument("-max", "--max-samples", type=int, default=3)
args = parser.parse_args()

random.seed(42)

sentences: List[Sentence] = []
canon2freq: Dict[str, int] = {}
excluded_nums: Dict[str, int] = {}
for input_dir in args.input_dirs:
for input_path in tqdm(Path(input_dir).glob("**/*.txt")):
with input_path.open() as f:
for knp in chunk_by_sentence(f):
try:
sentence: Sentence = Sentence.from_knp(knp)
except ValueError:
excluded_nums["sent_from_knp"] = excluded_nums.get("sent_from_knp", 0) + 1
continue
for morpheme in sentence.morphemes:
canons: List[str] = get_canons(morpheme)
if len(canons) > 1:
for canon in canons:
canon2freq[canon] = canon2freq.get(canon, 0) + 1
sentences.append(sentence)
print(f"num_sentences: {len(sentences)}")

sid2knp_str: Dict[str, str] = {}
sid2changes: Dict[str, Dict[str, Union[int, dict[str, str]]]] = {}
sampled_canon2freq: Dict[str, int] = {}
for sentence in sentences:
if get_is_excluded(sentence):
continue
for morpheme in sentence.morphemes:
canons: List[str] = get_canons(morpheme)
if len(canons) > 1 or sampled_canon2freq.get(morpheme.canon, 0) >= args.max_samples:
continue

if canon2freq.get(morpheme.canon, 0) >= 2:
if morpheme.conjtype == "*":
surf: str = morpheme.reading
lemma: str = morpheme.reading
else:
try:
lemma = jinf(morpheme.reading, morpheme.conjtype, morpheme.conjform, "基本形")
surf = jinf(lemma, morpheme.conjtype, "基本形", morpheme.conjform)
except (ValueError, NotImplementedError):
excluded_nums["jinf"] = excluded_nums.get("jinf", 0) + 1
continue
if surf == morpheme.text and lemma == morpheme.lemma:
continue
changes: Dict[str, Union[int, dict[str, str]]] = {
"morpheme_index": morpheme.index,
"before": {
"surf": morpheme.text,
"lemma": morpheme.lemma,
},
"after": {
"surf": surf,
"lemma": lemma,
},
}
morpheme.text = surf
morpheme._text_escaped = surf
morpheme.lemma = lemma
try:
set_canon(sentence, morpheme_index=morpheme.index)
except (ValueError, NotImplementedError):
excluded_nums["set_canon"] = excluded_nums.get("set_canon", 0) + 1
continue
try:
sid2knp_str[sentence.sid] = sentence.to_knp()
sampled_canon2freq[morpheme.canon] = sampled_canon2freq.get(morpheme.canon, 0) + 1
sid2changes[sentence.sid] = changes
break
except AttributeError:
excluded_nums["to_knp"] = excluded_nums.get("to_knp", 0) + 1
continue
if len(sid2knp_str) >= 5000:
break

sid2knp_str_list: List[tuple[str, str]] = list(sid2knp_str.items())
random.shuffle(sid2knp_str_list)
train_size: int = int(len(sid2knp_str_list) * 0.9)
valid_size: int = int(len(sid2knp_str_list) * 0.05)
train_list: List[tuple[str, str]] = sid2knp_str_list[:train_size]
valid_list: List[tuple[str, str]] = sid2knp_str_list[train_size : train_size + valid_size]
test_list: List[tuple[str, str]] = sid2knp_str_list[train_size + valid_size :]

output_dir: Path = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

with open(output_dir / "changes.json", "w") as f:
json.dump(sid2changes, f, indent=2, ensure_ascii=False)

train_dir: Path = output_dir / "train"
if train_dir.exists():
shutil.rmtree(str(train_dir))
train_dir.mkdir(exist_ok=True)
for name, knp_str in train_list:
with (train_dir / f"{name}.knp").open("w") as f:
f.write(f"{knp_str}\n")

valid_dir: Path = output_dir / "valid"
if valid_dir.exists():
shutil.rmtree(str(valid_dir))
valid_dir.mkdir(exist_ok=True)
for name, knp_str in valid_list:
with (valid_dir / f"{name}.knp").open("w") as f:
f.write(f"{knp_str}\n")

test_dir: Path = output_dir / "test"
if test_dir.exists():
shutil.rmtree(str(test_dir))
test_dir.mkdir(exist_ok=True)
for name, knp_str in test_list:
with (test_dir / f"{name}.knp").open("w") as f:
f.write(f"{knp_str}\n")

print(f"train: {len(train_list)}")
print(f"valid: {len(valid_list)}")
print(f"test: {len(test_list)}")
print(f"total: {len(train_list) + len(valid_list) + len(test_list)}")
print(f"excluded_nums: {excluded_nums}")


if __name__ == "__main__":
main()
65 changes: 47 additions & 18 deletions scripts/view_seq2seq_results.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import difflib
import json
import logging
import re
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Set, Tuple, cast
from typing import Dict, List, Optional, Set, Tuple, cast

import jaconv
from rhoknp import Document, Jumanpp, Morpheme, Sentence
Expand Down Expand Up @@ -80,25 +81,43 @@ def append(self, diff_type: DiffType, sys_parts, gold_parts) -> None:


class MorphologicalAnalysisScorer:
def __init__(self, sys_sentences: List[Sentence], gold_sentences: List[Sentence], eval_norm: bool = False) -> None:
def __init__(
self,
sys_sentences: List[Sentence],
gold_sentences: List[Sentence],
dataset_dir: Path,
eval_norm: bool = False,
eval_canon: bool = False,
) -> None:
self.eval_norm: bool = eval_norm
self.eval_canon: bool = eval_canon
self.tp: Dict[str, int] = dict(surf=0, reading=0, lemma=0, canon=0)
self.fp: Dict[str, int] = dict(surf=0, reading=0, lemma=0, canon=0)
self.fn: Dict[str, int] = dict(surf=0, reading=0, lemma=0, canon=0)
self.sys_sentences: List[Sentence] = sys_sentences
self.gold_sentences: List[Sentence] = gold_sentences

with (dataset_dir / Path("canon/changes.json")).open() as f:
self.canon_changes: Dict[str, Dict] = json.load(f)

self.num_diff_texts: int = 0
self.norm_types: Set[str] = set()
self.diffs: List[Diff] = self._search_diffs(sys_sentences, gold_sentences)

def _convert(self, sentence: Sentence, norm_morphemes: List[Morpheme]) -> List[str]:
def _convert(
self,
sentence: Sentence,
norm_morphemes: List[Morpheme],
canon_morpheme: Optional[Morpheme] = None,
) -> List[str]:
converteds: List[str] = []
norm_surfs: Set[str] = set(mrph.surf for mrph in norm_morphemes)
for mrph in sentence.morphemes:
surf: str = jaconv.h2z(mrph.surf.replace("<unk>", "$"), ascii=True, digit=True)
if self.eval_norm and surf not in norm_surfs:
continue
if self.eval_canon and canon_morpheme is not None and mrph.surf != canon_morpheme.surf:
continue
reading: str = jaconv.h2z(mrph.reading.replace("<unk>", "$"), ascii=True, digit=True)
lemma: str = jaconv.h2z(mrph.lemma.replace("<unk>", "$"), ascii=True, digit=True)
if mrph.canon is None or mrph.canon == "None":
Expand Down Expand Up @@ -246,8 +265,12 @@ def _search_diffs(self, sys_sentences: List[Sentence], gold_sentences: List[Sent
for gold_morpheme in gold_sentence.morphemes:
if "非標準表記" in gold_morpheme.semantics:
norm_morphemes.append(gold_morpheme)
sys_converted: List[str] = self._convert(sys_sentence, norm_morphemes)
gold_converted: List[str] = self._convert(gold_sentence, norm_morphemes)
canon_morpheme: Optional[Morpheme] = None
if self.eval_canon:
target_morpheme_index: int = self.canon_changes[gold_sentence.sid]["morpheme_index"]
canon_morpheme = gold_sentence.morphemes[target_morpheme_index]
sys_converted: List[str] = self._convert(sys_sentence, norm_morphemes, canon_morpheme)
gold_converted: List[str] = self._convert(gold_sentence, norm_morphemes, canon_morpheme)
diff: Diff = self._search_diff(sys_converted, gold_converted)
diffs.append(diff)
return diffs
Expand Down Expand Up @@ -305,7 +328,7 @@ def main():
sid_to_seq2seq_sent[sentence.sid] = sentence

output_dir: Path = Path(args.output_dir)
for corpus in ["kyoto", "kwdlc", "fuman", "wac", "norm"]:
for corpus in ["kyoto", "kwdlc", "fuman", "wac", "norm", "canon"]:
sid_to_gold_sent: Dict[str, Sentence] = dict()
sid_to_juman_sent: Dict[str, Sentence] = dict()
juman_dir: Path = output_dir / "juman" / corpus
Expand Down Expand Up @@ -340,29 +363,35 @@ def main():
golds.append(gold_sent)
assert len(jumans) == len(seq2seqs) == len(golds)

if corpus != "norm":
print(" jumanpp")
juman_scorer = MorphologicalAnalysisScorer(jumans, golds)
juman_scorer.compute_score()

if corpus == "norm":
print(" seq2seq")
system_scorer = MorphologicalAnalysisScorer(seq2seqs, golds)
system_scorer = MorphologicalAnalysisScorer(seq2seqs, golds, dataset_dir=args.dataset_dir)
system_scorer.compute_score()
print(f" # of different texts for seq2seq: {system_scorer.num_diff_texts}")
print(
f" Ratio of same texts for seq2seq = {(len(seq2seqs) - system_scorer.num_diff_texts) / len(seq2seqs) * 100:.2f}\n"
f" Ratio of same texts for seq2seq = {(len(seq2seqs) - system_scorer.num_diff_texts) / len(seq2seqs) * 100:.2f}"
)
print(" seq2seq (only target morpheme)")
system_scorer = MorphologicalAnalysisScorer(seq2seqs, golds, dataset_dir=args.dataset_dir, eval_norm=True)
system_scorer.compute_score()
print()
elif corpus == "canon":
print(" seq2seq (only target morpheme)")
system_scorer = MorphologicalAnalysisScorer(seq2seqs, golds, dataset_dir=args.dataset_dir, eval_canon=True)
system_scorer.compute_score()
print()
else:
print(" jumanpp")
juman_scorer = MorphologicalAnalysisScorer(jumans, golds, dataset_dir=args.dataset_dir)
juman_scorer.compute_score()

print(" seq2seq")
system_scorer = MorphologicalAnalysisScorer(seq2seqs, golds)
system_scorer = MorphologicalAnalysisScorer(seq2seqs, golds, dataset_dir=args.dataset_dir)
system_scorer.compute_score()
print(f" # of different texts for seq2seq: {system_scorer.num_diff_texts}")
print(
f" Ratio of same texts for seq2seq = {(len(seq2seqs) - system_scorer.num_diff_texts) / len(seq2seqs) * 100:.2f}"
f" Ratio of same texts for seq2seq = {(len(seq2seqs) - system_scorer.num_diff_texts) / len(seq2seqs) * 100:.2f}\n"
)
print(" seq2seq (only target morpheme)")
system_scorer = MorphologicalAnalysisScorer(seq2seqs, golds, eval_norm=True)
system_scorer.compute_score()


if __name__ == "__main__":
Expand Down

0 comments on commit c8a302b

Please sign in to comment.