diff --git a/src/kwja/modules/components/logits_processor.py b/src/kwja/modules/components/logits_processor.py index 8c095cbd..fc4b1290 100644 --- a/src/kwja/modules/components/logits_processor.py +++ b/src/kwja/modules/components/logits_processor.py @@ -12,7 +12,6 @@ HALF_SPACE_TOKEN1, HALF_SPACE_TOKEN2, LEMMA_TOKEN, - NO_CANON_TOKEN, READING_TOKEN, SPECIAL_TO_RARE, SURF_TOKEN, @@ -64,15 +63,7 @@ def __init__( char2tokens: Dict[str, Dict[str, int]], ) -> None: self.tokenizer = tokenizer - self.texts: List[str] = [] - for text in tokenizer.batch_decode(tokenizer.batch_encode_plus(texts).input_ids): - if text.endswith(self.tokenizer.eos_token): - text = text[: -len(self.tokenizer.eos_token)] - for token in [NO_CANON_TOKEN, FULL_SPACE_TOKEN, HALF_SPACE_TOKEN1, HALF_SPACE_TOKEN2, TRIPLE_DOT_TOKEN]: - text = text.replace(f"{token} ", token) - for token in SPECIAL_TO_RARE: - text = text.replace(f"{token} ", token) - self.texts.append(text) + self.texts: List[str] = texts self.num_beams: int = num_beams self.reading_candidates: Set[int] = reading_candidates self.char2tokens: Dict[str, Dict[str, int]] = char2tokens diff --git a/src/kwja/utils/seq2seq_format.py b/src/kwja/utils/seq2seq_format.py index 285b5a62..236cad5c 100644 --- a/src/kwja/utils/seq2seq_format.py +++ b/src/kwja/utils/seq2seq_format.py @@ -37,13 +37,21 @@ def tokenize(self, texts: List[str]) -> List[str]: return [token for token in self.tokenizer.tokenize(concat_text) if token != "▁"] def sent_to_text(self, sentence: Sentence) -> str: - text: str = sentence.text + text: str = sentence.text.strip() for k, v in self.word_to_token.items(): text = text.replace(k, v) + text = text.replace(HALF_SPACE_TOKEN2, HALF_SPACE_TOKEN1) for k, v in RARE_TO_SPECIAL.items(): text = text.replace(k, v) - text = text.replace(HALF_SPACE_TOKEN2, HALF_SPACE_TOKEN1) - return text + + tokenized: List[str] = [token for token in self.tokenizer.tokenize(text) if token != "▁"] + decoded: str = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(tokenized)) + for token in [FULL_SPACE_TOKEN, HALF_SPACE_TOKEN1, HALF_SPACE_TOKEN2, TRIPLE_DOT_TOKEN]: + decoded = decoded.replace(f"{token} ", token) + for token in SPECIAL_TO_RARE: + decoded = decoded.replace(f"{token} ", token) + decoded = decoded.replace(" ", HALF_SPACE_TOKEN1) + return decoded @staticmethod def sent_to_format(sentence: Sentence) -> List[str]: