diff --git a/src/kwja/utils/seq2seq_format.py b/src/kwja/utils/seq2seq_format.py index 54749662..0374a977 100644 --- a/src/kwja/utils/seq2seq_format.py +++ b/src/kwja/utils/seq2seq_format.py @@ -51,11 +51,8 @@ def tokenize(self, mrph_lines: List[List[str]], tgt_mrphs: Dict[str, Dict[str, s tokenized: List[str] = [x for x in self.tokenizer.tokenize(mrph) if x != "▁"] + [ special_tokens[idx_in_mrph] ] - if is_partial: - if partial_anno_type == "canon" or (partial_anno_type == "norm" and idx_in_mrph in {0, 2}): - output.extend(tokenized) - else: - output.extend([self.pad_token] * len(tokenized)) + if is_partial and partial_anno_type == "": + output.extend([self.pad_token] * len(tokenized)) else: output.extend(tokenized) return output