From 7169ceb6c8ae549d1a5d74d73f404f3b1741668f Mon Sep 17 00:00:00 2001 From: Taka008 Date: Sat, 5 Aug 2023 01:28:52 +0900 Subject: [PATCH] fix sent_to_text --- src/kwja/modules/components/logits_processor.py | 11 +---------- src/kwja/utils/seq2seq_format.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/kwja/modules/components/logits_processor.py b/src/kwja/modules/components/logits_processor.py index 8c095cbd..fc4b1290 100644 --- a/src/kwja/modules/components/logits_processor.py +++ b/src/kwja/modules/components/logits_processor.py @@ -12,7 +12,6 @@ HALF_SPACE_TOKEN1, HALF_SPACE_TOKEN2, LEMMA_TOKEN, - NO_CANON_TOKEN, READING_TOKEN, SPECIAL_TO_RARE, SURF_TOKEN, @@ -64,15 +63,7 @@ def __init__( char2tokens: Dict[str, Dict[str, int]], ) -> None: self.tokenizer = tokenizer - self.texts: List[str] = [] - for text in tokenizer.batch_decode(tokenizer.batch_encode_plus(texts).input_ids): - if text.endswith(self.tokenizer.eos_token): - text = text[: -len(self.tokenizer.eos_token)] - for token in [NO_CANON_TOKEN, FULL_SPACE_TOKEN, HALF_SPACE_TOKEN1, HALF_SPACE_TOKEN2, TRIPLE_DOT_TOKEN]: - text = text.replace(f"{token} ", token) - for token in SPECIAL_TO_RARE: - text = text.replace(f"{token} ", token) - self.texts.append(text) + self.texts: List[str] = texts self.num_beams: int = num_beams self.reading_candidates: Set[int] = reading_candidates self.char2tokens: Dict[str, Dict[str, int]] = char2tokens diff --git a/src/kwja/utils/seq2seq_format.py b/src/kwja/utils/seq2seq_format.py index 285b5a62..236cad5c 100644 --- a/src/kwja/utils/seq2seq_format.py +++ b/src/kwja/utils/seq2seq_format.py @@ -37,13 +37,21 @@ def tokenize(self, texts: List[str]) -> List[str]: return [token for token in self.tokenizer.tokenize(concat_text) if token != "▁"] def sent_to_text(self, sentence: Sentence) -> str: - text: str = sentence.text + text: str = sentence.text.strip() for k, v in self.word_to_token.items(): text = text.replace(k, v) + text = text.replace(HALF_SPACE_TOKEN2, HALF_SPACE_TOKEN1) for k, v in RARE_TO_SPECIAL.items(): text = text.replace(k, v) - text = text.replace(HALF_SPACE_TOKEN2, HALF_SPACE_TOKEN1) - return text + + tokenized: List[str] = [token for token in self.tokenizer.tokenize(text) if token != "▁"] + decoded: str = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(tokenized)) + for token in [FULL_SPACE_TOKEN, HALF_SPACE_TOKEN1, HALF_SPACE_TOKEN2, TRIPLE_DOT_TOKEN]: + decoded = decoded.replace(f"{token} ", token) + for token in SPECIAL_TO_RARE: + decoded = decoded.replace(f"{token} ", token) + decoded = decoded.replace(" ", HALF_SPACE_TOKEN1) + return decoded @staticmethod def sent_to_format(sentence: Sentence) -> List[str]: