diff --git a/src/kwja/modules/components/logits_processor.py b/src/kwja/modules/components/logits_processor.py index 3afd7143..26bb0943 100644 --- a/src/kwja/modules/components/logits_processor.py +++ b/src/kwja/modules/components/logits_processor.py @@ -118,45 +118,37 @@ def get_batch_banned_token_ids(self, prev_input_ids: torch.Tensor) -> List[List[ total_permitted_token_ids |= self.get_permitted_token_ids(remaining_surf) total_permitted_token_ids |= self.underscore_token_ids elif (input_ids[-2], input_ids[-1]) == (self.new_line_token_id, self.underscore_token_id): - # print("a") # 「改行 "▁"」の次は,まだ生成していない文字列の先頭からマッチするサブワードを許容.ただしアンダースコア始まりは許容しない total_permitted_token_ids |= self.get_permitted_token_ids(remaining_surf) elif input_ids[-2] == self.new_line_token_id: last_token: str = self.tokenizer.convert_ids_to_tokens(input_ids[-1]) if last_token.startswith("▁"): - # print("b") # 「改行 "▁xxx"」の次は,まだ生成していない文字列の先頭からマッチするサブワードを許容.また,全てのアンダースコア始まりのサブワードも許容 total_permitted_token_ids |= self.get_permitted_token_ids(remaining_surf) total_permitted_token_ids |= self.underscore_token_ids else: - # print("c") # 「改行 "xxx"」の次は,任意のサブワードを許容 pass elif input_ids[-1] == self.underscore_token_id: - # print("d") self.is_decoding_surf[hypo_idx] = False # 「"xxx" "▁"」 の次は,任意のサブワードを許容.ただしアンダースコア始まりは許容しない total_permitted_token_ids |= self.all_token_ids - self.underscore_token_ids elif input_ids[-1] in self.underscore_token_ids: - # print("e") self.is_decoding_surf[hypo_idx] = False # 「"xxx" "▁yyy"」 の次は,任意のサブワードを許容 pass else: - # print("f") # 「"xxx" "yyy"」 の次は,まだ生成していない文字列の先頭からマッチするサブワードを許容.また,全てのアンダースコア始まりのサブワードも許容 total_permitted_token_ids |= self.get_permitted_token_ids(remaining_surf) total_permitted_token_ids |= self.underscore_token_ids else: if input_ids[-1] == self.new_line_token_id: - # print("g") self.is_decoding_surf[hypo_idx] = True # 「改行」の次は,まだ生成していない文字列の先頭からマッチするサブワードを許容 total_permitted_token_ids |= self.get_permitted_token_ids(remaining_surf) total_permitted_token_ids |= self.get_permitted_underscore_token_ids(remaining_surf) total_permitted_token_ids.add(self.underscore_token_id) else: - # print("h") # surf のデコーディング時以外は,任意のサブワードを許容 pass