Skip to content

Commit

Permalink
Merge pull request #186 from ku-nlp/with-cohesion-tools
Browse files Browse the repository at this point in the history
Refactor using `cohesion-tools` package
  • Loading branch information
nobu-g authored Jul 28, 2023
2 parents c8a302b + f8c5945 commit d38a24e
Show file tree
Hide file tree
Showing 23 changed files with 227 additions and 1,320 deletions.
34 changes: 33 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ jinf = "^1.0.4"
pure-cdb = "^4.0"
rich = ">=12.4"
pyyaml = "^6.0"
cohesion-tools = { git = "https://github.com/nobu-g/cohesion-tools.git" }
importlib-resources = { version = "^5.10", python = "<3.9" }

[tool.poetry.group.dev.dependencies]
Expand Down
27 changes: 14 additions & 13 deletions src/kwja/callbacks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

import numpy as np
import torch
from cohesion_tools.extractors.base import BaseExtractor
from rhoknp import BasePhrase, Document, Morpheme, Phrase, Sentence
from rhoknp.cohesion import ExophoraReferent, RelTag, RelTagList
from rhoknp.cohesion import ExophoraReferent, ExophoraReferentType, RelTag
from rhoknp.props import DepType, NamedEntity, NamedEntityCategory
from transformers import PreTrainedTokenizerBase

from kwja.datamodule.examples import SpecialTokenIndexer
from kwja.utils.cohesion_analysis import CohesionUtils
from kwja.utils.constants import (
BASE_PHRASE_FEATURES,
CONJFORM_TAGS,
Expand Down Expand Up @@ -428,40 +428,41 @@ def _resolve_dependency(base_phrase: BasePhrase, dependency_manager: DependencyM
def add_cohesion(
document: Document,
cohesion_logits: List[List[List[float]]], # (rel, seq, seq)
cohesion_task2utils: Dict[CohesionTask, CohesionUtils],
cohesion_task2extractor: Dict[CohesionTask, BaseExtractor],
cohesion_task2rels: Dict[CohesionTask, List[str]],
restrict_cohesion_target: bool,
special_token_indexer: SpecialTokenIndexer,
) -> None:
rel2logits = dict(
zip(
[r for cohesion_utils in cohesion_task2utils.values() for r in cohesion_utils.rels],
[r for cohesion_rels in cohesion_task2rels.values() for r in cohesion_rels],
cohesion_logits,
)
)
base_phrases = document.base_phrases
for base_phrase in base_phrases:
rel_tags = RelTagList()
for cohesion_utils in cohesion_task2utils.values():
if cohesion_utils.is_target(base_phrase) is False:
base_phrase.rel_tags.clear()
for cohesion_task, cohesion_extractor in cohesion_task2extractor.items():
if restrict_cohesion_target is True and cohesion_extractor.is_target(base_phrase) is False:
continue
for rel in cohesion_utils.rels:
for rel in cohesion_task2rels[cohesion_task]:
rel_tag = _to_rel_tag(
rel,
rel2logits[rel][base_phrase.head.global_index], # (seq, )
base_phrases,
special_token_indexer,
cohesion_utils.exophora_referents,
cohesion_extractor.exophora_referent_types,
)
if rel_tag is not None:
rel_tags.append(rel_tag)
base_phrase.rel_tags = rel_tags
base_phrase.rel_tags.append(rel_tag)


def _to_rel_tag(
rel: str,
rel_logits: List[float], # (seq, )
base_phrases: List[BasePhrase],
special_token_indexer: SpecialTokenIndexer,
exophora_referents: List[ExophoraReferent],
exophora_referent_types: List[ExophoraReferentType],
) -> Optional[RelTag]:
logits = [rel_logits[bp.head.global_index] for bp in base_phrases]
logits += [rel_logits[i] for i in special_token_indexer.get_morpheme_level_indices()]
Expand All @@ -480,7 +481,7 @@ def _to_rel_tag(
# exophora
special_token = special_token_indexer.special_tokens[predicted_base_phrase_global_index - len(base_phrases)]
stripped_special_token = special_token[1:-1] # strip '[' and ']'
if stripped_special_token in [str(er) for er in exophora_referents]: # exclude [NULL], [NA], and [ROOT]
if ExophoraReferent(stripped_special_token).type in exophora_referent_types: # exclude [NULL], [NA], and [ROOT]
return RelTag(
type=rel,
target=stripped_special_token,
Expand Down
4 changes: 3 additions & 1 deletion src/kwja/callbacks/word_module_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def write_on_batch_end(
add_cohesion(
predicted_document,
cohesion_logits,
dataset.cohesion_task2utils,
dataset.cohesion_task2extractor,
dataset.cohesion_task2rels,
dataset.restrict_cohesion_target,
example.special_token_indexer,
)
add_discourse(predicted_document, discourse_predictions)
Expand Down
49 changes: 30 additions & 19 deletions src/kwja/datamodule/datasets/word.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union

from cohesion_tools.extractors import BridgingExtractor, CoreferenceExtractor, PasExtractor
from cohesion_tools.extractors.base import BaseExtractor
from omegaconf import ListConfig
from rhoknp import Document, Sentence
from rhoknp.cohesion import ExophoraReferent
Expand All @@ -12,7 +14,7 @@

from kwja.datamodule.datasets.base import BaseDataset, FullAnnotatedDocumentLoaderMixin
from kwja.datamodule.examples import SpecialTokenIndexer, WordExample
from kwja.utils.cohesion_analysis import BridgingUtils, CohesionBasePhrase, CohesionUtils, CoreferenceUtils, PasUtils
from kwja.utils.cohesion_analysis import CohesionBasePhrase
from kwja.utils.constants import (
BASE_PHRASE_FEATURES,
CONJFORM_TAGS,
Expand Down Expand Up @@ -94,11 +96,24 @@ def __init__(
# ---------- cohesion analysis ----------
self.cohesion_tasks: List[CohesionTask] = [CohesionTask(ct) for ct in cohesion_tasks]
self.exophora_referents = [ExophoraReferent(er) for er in exophora_referents]
self.pas_cases: List[str] = list(pas_cases)
self.br_cases: List[str] = list(br_cases)
self.cohesion_task2utils: Dict[CohesionTask, CohesionUtils] = {
ct: self._build_cohesion_utils(ct, restrict_cohesion_target) for ct in self.cohesion_tasks
self.cohesion_task2extractor: Dict[CohesionTask, BaseExtractor] = {
CohesionTask.PAS_ANALYSIS: PasExtractor(
list(pas_cases),
[er.type for er in self.exophora_referents],
verbal_predicate=True,
nominal_predicate=True,
),
CohesionTask.BRIDGING_REFERENCE_RESOLUTION: BridgingExtractor(
list(br_cases), [er.type for er in self.exophora_referents]
),
CohesionTask.COREFERENCE_RESOLUTION: CoreferenceExtractor([er.type for er in self.exophora_referents]),
}
self.cohesion_task2rels: Dict[CohesionTask, List[str]] = {
CohesionTask.PAS_ANALYSIS: list(pas_cases),
CohesionTask.BRIDGING_REFERENCE_RESOLUTION: list(br_cases),
CohesionTask.COREFERENCE_RESOLUTION: ["="],
}
self.restrict_cohesion_target: bool = restrict_cohesion_target

# ---------- dependency parsing & cohesion analysis ----------
self.special_tokens: List[str] = list(special_tokens)
Expand Down Expand Up @@ -145,7 +160,13 @@ def _load_examples(self, doc_id2document: Dict[str, Document]) -> List[WordExamp
special_token_indexer = SpecialTokenIndexer(self.special_tokens, len(encoding.ids), len(document.morphemes))

example = WordExample(example_id, merged_encoding, special_token_indexer)
example.load_document(document, self.reading_aligner, self.cohesion_task2utils)
example.load_document(
document,
self.reading_aligner,
self.cohesion_task2extractor,
self.cohesion_task2rels,
self.restrict_cohesion_target,
)
if discourse_document := self._find_discourse_document(document):
example.load_discourse_document(discourse_document)

Expand Down Expand Up @@ -236,17 +257,17 @@ def encode(self, example: WordExample) -> WordModuleFeatures:
# ---------- cohesion analysis ----------
cohesion_labels: List[List[List[int]]] = [] # (rel, seq, seq)
cohesion_mask: List[List[List[bool]]] = [] # (rel, seq, seq)
for cohesion_task, cohesion_utils in self.cohesion_task2utils.items():
for cohesion_task, cohesion_rels in self.cohesion_task2rels.items():
cohesion_base_phrases = example.cohesion_task2base_phrases[cohesion_task]
for rel in cohesion_utils.rels:
for rel in cohesion_rels:
rel_labels = self._convert_cohesion_base_phrases_into_rel_labels(
cohesion_base_phrases, rel, example.special_token_indexer
)
cohesion_labels.append(rel_labels)
rel_mask = self._convert_cohesion_base_phrases_into_rel_mask(
cohesion_base_phrases, example.special_token_indexer
)
cohesion_mask.extend([rel_mask] * len(cohesion_utils.rels))
cohesion_mask.extend([rel_mask] * len(cohesion_rels))

# ---------- discourse parsing ----------
discourse_labels = [[IGNORE_INDEX] * self.max_seq_length for _ in range(self.max_seq_length)]
Expand Down Expand Up @@ -315,16 +336,6 @@ def _find_discourse_document(self, document: Document) -> Optional[Document]:
logger.warning(f"{discourse_path} is not a valid KNP file")
return None

def _build_cohesion_utils(self, cohesion_task: CohesionTask, restrict_cohesion_target: bool) -> CohesionUtils:
if cohesion_task == CohesionTask.PAS_ANALYSIS:
return PasUtils(self.pas_cases, "all", self.exophora_referents, restrict_cohesion_target)
elif cohesion_task == CohesionTask.BRIDGING_REFERENCE_RESOLUTION:
return BridgingUtils(self.br_cases, self.exophora_referents, restrict_cohesion_target)
elif cohesion_task == CohesionTask.COREFERENCE_RESOLUTION:
return CoreferenceUtils(self.exophora_referents, restrict_cohesion_target)
else:
raise ValueError("invalid cohesion task")

def _convert_cohesion_base_phrases_into_rel_labels(
self,
cohesion_base_phrases: List[CohesionBasePhrase],
Expand Down
43 changes: 23 additions & 20 deletions src/kwja/datamodule/datasets/word_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union

from cohesion_tools.extractors import BridgingExtractor, CoreferenceExtractor, PasExtractor
from cohesion_tools.extractors.base import BaseExtractor
from omegaconf import ListConfig
from rhoknp import Document, Sentence
from rhoknp.cohesion import ExophoraReferent
Expand All @@ -13,7 +15,6 @@
from kwja.datamodule.datasets.base import BaseDataset, FullAnnotatedDocumentLoaderMixin
from kwja.datamodule.datasets.word import WordModuleFeatures
from kwja.datamodule.examples import SpecialTokenIndexer, WordInferenceExample
from kwja.utils.cohesion_analysis import BridgingUtils, CohesionUtils, CoreferenceUtils, PasUtils
from kwja.utils.constants import SPLIT_INTO_WORDS_MODEL_NAMES, CohesionTask
from kwja.utils.logging_util import track
from kwja.utils.sub_document import extract_target_sentences
Expand Down Expand Up @@ -58,11 +59,24 @@ def __init__(
# ---------- cohesion analysis ----------
self.cohesion_tasks = [CohesionTask(ct) for ct in cohesion_tasks]
self.exophora_referents = [ExophoraReferent(er) for er in exophora_referents]
self.pas_cases: List[str] = list(pas_cases)
self.br_cases: List[str] = list(br_cases)
self.cohesion_task2utils: Dict[CohesionTask, CohesionUtils] = {
ct: self._build_cohesion_utils(ct, restrict_cohesion_target) for ct in self.cohesion_tasks
self.cohesion_task2extractor: Dict[CohesionTask, BaseExtractor] = {
CohesionTask.PAS_ANALYSIS: PasExtractor(
list(pas_cases),
[er.type for er in self.exophora_referents],
verbal_predicate=True,
nominal_predicate=True,
),
CohesionTask.BRIDGING_REFERENCE_RESOLUTION: BridgingExtractor(
list(br_cases), [er.type for er in self.exophora_referents]
),
CohesionTask.COREFERENCE_RESOLUTION: CoreferenceExtractor([er.type for er in self.exophora_referents]),
}
self.cohesion_task2rels: Dict[CohesionTask, List[str]] = {
CohesionTask.PAS_ANALYSIS: list(pas_cases),
CohesionTask.BRIDGING_REFERENCE_RESOLUTION: list(br_cases),
CohesionTask.COREFERENCE_RESOLUTION: ["="],
}
self.restrict_cohesion_target: bool = restrict_cohesion_target

# ---------- dependency parsing & cohesion analysis ----------
self.special_tokens: List[str] = list(special_tokens)
Expand Down Expand Up @@ -144,18 +158,17 @@ def encode(self, example: WordInferenceExample) -> WordModuleFeatures:
# ---------- cohesion analysis ----------
cohesion_mask: List[List[List[bool]]] = [] # (rel, seq, seq)
morphemes = document.morphemes
for cohesion_task, cohesion_utils in self.cohesion_task2utils.items():
for cohesion_task, cohesion_extractor in self.cohesion_task2extractor.items():
cohesion_rels = self.cohesion_task2rels[cohesion_task]
rel_mask: List[List[bool]] = [[False] * self.max_seq_length for _ in range(self.max_seq_length)]
for morpheme in morphemes:
for antecedent_candidate_morpheme in cohesion_utils.get_antecedent_candidate_morphemes(
morpheme, morphemes
):
for antecedent_candidate_morpheme in cohesion_extractor.get_candidates(morpheme, morphemes):
rel_mask[morpheme.global_index][antecedent_candidate_morpheme.global_index] = True
for morpheme_global_index in example.special_token_indexer.get_morpheme_level_indices(
only_cohesion=True
):
rel_mask[morpheme.global_index][morpheme_global_index] = True
cohesion_mask.extend([rel_mask] * len(cohesion_utils.rels))
cohesion_mask.extend([rel_mask] * len(cohesion_rels))
return WordModuleFeatures(
example_ids=example.example_id,
input_ids=example.encoding.ids,
Expand Down Expand Up @@ -196,13 +209,3 @@ def _generate_subword_map(
for token_index, morpheme_global_index in special_token_indexer.token_and_morpheme_level_indices:
subword_map[morpheme_global_index][token_index] = True
return subword_map

def _build_cohesion_utils(self, cohesion_task: CohesionTask, restrict_cohesion_target: bool) -> CohesionUtils:
if cohesion_task == CohesionTask.PAS_ANALYSIS:
return PasUtils(self.pas_cases, "all", self.exophora_referents, restrict_cohesion_target)
elif cohesion_task == CohesionTask.BRIDGING_REFERENCE_RESOLUTION:
return BridgingUtils(self.br_cases, self.exophora_referents, restrict_cohesion_target)
elif cohesion_task == CohesionTask.COREFERENCE_RESOLUTION:
return CoreferenceUtils(self.exophora_referents, restrict_cohesion_target)
else:
raise ValueError("invalid cohesion task")
Loading

0 comments on commit d38a24e

Please sign in to comment.