Skip to content

Commit

Permalink
fix sent_to_text
Browse files Browse the repository at this point in the history
  • Loading branch information
Taka008 committed Aug 4, 2023
1 parent 795d99c commit 7169ceb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
11 changes: 1 addition & 10 deletions src/kwja/modules/components/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
HALF_SPACE_TOKEN1,
HALF_SPACE_TOKEN2,
LEMMA_TOKEN,
NO_CANON_TOKEN,
READING_TOKEN,
SPECIAL_TO_RARE,
SURF_TOKEN,
Expand Down Expand Up @@ -64,15 +63,7 @@ def __init__(
char2tokens: Dict[str, Dict[str, int]],
) -> None:
self.tokenizer = tokenizer
self.texts: List[str] = []
for text in tokenizer.batch_decode(tokenizer.batch_encode_plus(texts).input_ids):
if text.endswith(self.tokenizer.eos_token):
text = text[: -len(self.tokenizer.eos_token)]
for token in [NO_CANON_TOKEN, FULL_SPACE_TOKEN, HALF_SPACE_TOKEN1, HALF_SPACE_TOKEN2, TRIPLE_DOT_TOKEN]:
text = text.replace(f"{token} ", token)
for token in SPECIAL_TO_RARE:
text = text.replace(f"{token} ", token)
self.texts.append(text)
self.texts: List[str] = texts
self.num_beams: int = num_beams
self.reading_candidates: Set[int] = reading_candidates
self.char2tokens: Dict[str, Dict[str, int]] = char2tokens
Expand Down
14 changes: 11 additions & 3 deletions src/kwja/utils/seq2seq_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,21 @@ def tokenize(self, texts: List[str]) -> List[str]:
return [token for token in self.tokenizer.tokenize(concat_text) if token != "▁"]

def sent_to_text(self, sentence: Sentence) -> str:
text: str = sentence.text
text: str = sentence.text.strip()
for k, v in self.word_to_token.items():
text = text.replace(k, v)
text = text.replace(HALF_SPACE_TOKEN2, HALF_SPACE_TOKEN1)
for k, v in RARE_TO_SPECIAL.items():
text = text.replace(k, v)
text = text.replace(HALF_SPACE_TOKEN2, HALF_SPACE_TOKEN1)
return text

tokenized: List[str] = [token for token in self.tokenizer.tokenize(text) if token != "▁"]
decoded: str = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(tokenized))
for token in [FULL_SPACE_TOKEN, HALF_SPACE_TOKEN1, HALF_SPACE_TOKEN2, TRIPLE_DOT_TOKEN]:
decoded = decoded.replace(f"{token} ", token)
for token in SPECIAL_TO_RARE:
decoded = decoded.replace(f"{token} ", token)
decoded = decoded.replace(" ", HALF_SPACE_TOKEN1)
return decoded

@staticmethod
def sent_to_format(sentence: Sentence) -> List[str]:
Expand Down

0 comments on commit 7169ceb

Please sign in to comment.