From 89775d11f771b11d1d9739ba6ac5ad5824b0d328 Mon Sep 17 00:00:00 2001 From: Sinan Date: Wed, 26 Jul 2023 21:42:41 +0200 Subject: [PATCH] BOS and EOS (#195) * Added pad bos eos functionality (and .gitignore) * Deleted debug lines * Changed 'pad' to 'add' * fix * fix inconsistent indentation --- .gitignore | 2 ++ tokenizer.py | 19 +++++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..ef1f3072 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +# ignore __pycache__ folder +__pycache__/ \ No newline at end of file diff --git a/tokenizer.py b/tokenizer.py index 8a64c905..c2befc26 100644 --- a/tokenizer.py +++ b/tokenizer.py @@ -18,16 +18,23 @@ def __init__(self, tokenizer_model_path): self.pad_token_id = 0 # self.tokenizer.pad_id() self.newline_token_id = 13 - # Encode string - def encode(self, text, return_mask = False, max_seq_len = 2048): + def encode(self, text, return_mask = False, max_seq_len = 2048, add_bos = False, add_eos = False): if isinstance(text, list): # text is a list of strings list_ids = self.tokenizer.EncodeAsIds(text) + + # pad bos and eos + + if add_bos: + for ids in list_ids: ids.insert(0, self.bos_token_id) + if add_eos: + for ids in list_ids: ids.append(self.eos_token_id) + max_length = max([len(ids) for ids in list_ids]) needs_mask = False @@ -56,6 +63,14 @@ def encode(self, text, return_mask = False, max_seq_len = 2048): # text is a single string ids = self.tokenizer.EncodeAsIds(text) + + # pad bos and eos + + if add_bos: + ids = [self.bos_token_id] + ids + if add_eos: + ids = ids + [self.eos_token_id] + stacked_ids = torch.tensor(ids).unsqueeze(0) if return_mask: