Skip to content

Commit

Permalink
tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
Taka008 committed Jul 29, 2023
1 parent effc4b4 commit 152ed8f
Showing 1 changed file with 0 additions and 8 deletions.
8 changes: 0 additions & 8 deletions src/kwja/modules/components/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,45 +118,37 @@ def get_batch_banned_token_ids(self, prev_input_ids: torch.Tensor) -> List[List[
total_permitted_token_ids |= self.get_permitted_token_ids(remaining_surf)
total_permitted_token_ids |= self.underscore_token_ids
elif (input_ids[-2], input_ids[-1]) == (self.new_line_token_id, self.underscore_token_id):
# print("a")
# 「改行 "▁"」の次は,まだ生成していない文字列の先頭からマッチするサブワードを許容.ただしアンダースコア始まりは許容しない
total_permitted_token_ids |= self.get_permitted_token_ids(remaining_surf)
elif input_ids[-2] == self.new_line_token_id:
last_token: str = self.tokenizer.convert_ids_to_tokens(input_ids[-1])
if last_token.startswith("▁"):
# print("b")
# 「改行 "▁xxx"」の次は,まだ生成していない文字列の先頭からマッチするサブワードを許容.また,全てのアンダースコア始まりのサブワードも許容
total_permitted_token_ids |= self.get_permitted_token_ids(remaining_surf)
total_permitted_token_ids |= self.underscore_token_ids
else:
# print("c")
# 「改行 "xxx"」の次は,任意のサブワードを許容
pass
elif input_ids[-1] == self.underscore_token_id:
# print("d")
self.is_decoding_surf[hypo_idx] = False
# 「"xxx" "▁"」 の次は,任意のサブワードを許容.ただしアンダースコア始まりは許容しない
total_permitted_token_ids |= self.all_token_ids - self.underscore_token_ids
elif input_ids[-1] in self.underscore_token_ids:
# print("e")
self.is_decoding_surf[hypo_idx] = False
# 「"xxx" "▁yyy"」 の次は,任意のサブワードを許容
pass
else:
# print("f")
# 「"xxx" "yyy"」 の次は,まだ生成していない文字列の先頭からマッチするサブワードを許容.また,全てのアンダースコア始まりのサブワードも許容
total_permitted_token_ids |= self.get_permitted_token_ids(remaining_surf)
total_permitted_token_ids |= self.underscore_token_ids
else:
if input_ids[-1] == self.new_line_token_id:
# print("g")
self.is_decoding_surf[hypo_idx] = True
# 「改行」の次は,まだ生成していない文字列の先頭からマッチするサブワードを許容
total_permitted_token_ids |= self.get_permitted_token_ids(remaining_surf)
total_permitted_token_ids |= self.get_permitted_underscore_token_ids(remaining_surf)
total_permitted_token_ids.add(self.underscore_token_id)
else:
# print("h")
# surf のデコーディング時以外は,任意のサブワードを許容
pass

Expand Down

0 comments on commit 152ed8f

Please sign in to comment.