From 7b53875ebc5aba3afcf45f4d3188ee6318157232 Mon Sep 17 00:00:00 2001 From: codertimo Date: Sun, 21 Oct 2018 00:36:09 +0900 Subject: [PATCH 01/12] Change elif to else to fix #13 issue --- bert_pytorch/dataset/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bert_pytorch/dataset/dataset.py b/bert_pytorch/dataset/dataset.py index 1cf7f38..4a63603 100644 --- a/bert_pytorch/dataset/dataset.py +++ b/bert_pytorch/dataset/dataset.py @@ -58,7 +58,7 @@ def random_word(self, sentence): tokens[i] = random.randrange(len(self.vocab)) # 10% randomly change token to current token - elif prob >= prob * 0.9: + else: tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index) output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index)) From faa7a29bf92fe3a146775c7c36a97263aa39b6d8 Mon Sep 17 00:00:00 2001 From: codertimo Date: Sun, 21 Oct 2018 00:53:24 +0900 Subject: [PATCH 02/12] Adding multi-gpu backward ops using pytorch-encoding --- bert_pytorch/trainer/pretrain.py | 6 ++++-- requirements.txt | 1 + setup.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/bert_pytorch/trainer/pretrain.py b/bert_pytorch/trainer/pretrain.py index 2b99fa5..8520dcb 100644 --- a/bert_pytorch/trainer/pretrain.py +++ b/bert_pytorch/trainer/pretrain.py @@ -3,6 +3,8 @@ from torch.optim import Adam from torch.utils.data import DataLoader +from encoding.parallel import DataParallelModel, DataParallelCriterion + from ..model import BERTLM, BERT import tqdm @@ -47,7 +49,7 @@ def __init__(self, bert: BERT, vocab_size: int, # Distributed GPU training if CUDA can detect more than 1 GPU if torch.cuda.device_count() > 1: print("Using %d GPUS for BERT" % torch.cuda.device_count()) - self.model = nn.DataParallel(self.model) + self.model = DataParallelModel(self.model) # Setting the train and test data loader self.train_data = train_dataloader @@ -57,7 +59,7 @@ def __init__(self, bert: BERT, vocab_size: int, self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) # Using Negative Log Likelihood Loss function for predicting the masked_token - self.criterion = nn.NLLLoss(ignore_index=0) + self.criterion = DataParallelCriterion(nn.NLLLoss(ignore_index=0)) self.log_freq = log_freq diff --git a/requirements.txt b/requirements.txt index ebffc70..8506741 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ tqdm numpy torch>=0.4.0 +torch-encoding \ No newline at end of file diff --git a/setup.py b/setup.py index 15888d3..2ea38ae 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import os import sys -__version__ = "0.0.1a3" +__version__ = "0.0.1a4" with open("requirements.txt") as f: require_packages = [line[:-1] for line in f] From 00303af54e3e3e6303191b6c32b07eb95aee9cd1 Mon Sep 17 00:00:00 2001 From: codertimo Date: Sun, 21 Oct 2018 00:57:49 +0900 Subject: [PATCH 03/12] Fixing requirements.txt parsing --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2ea38ae..4e721cf 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ __version__ = "0.0.1a4" with open("requirements.txt") as f: - require_packages = [line[:-1] for line in f] + require_packages = [line[:-1] if line[-1] == "\n" else line for line in f] with open("README.md", "r", encoding="utf-8") as f: long_description = f.read() From aaa4425618070316778fbc227b2eacd1040c822e Mon Sep 17 00:00:00 2001 From: codertimo Date: Sun, 21 Oct 2018 01:04:18 +0900 Subject: [PATCH 04/12] Adding selecting the cuda device_ids --- bert_pytorch/__main__.py | 4 +++- bert_pytorch/trainer/pretrain.py | 10 ++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/bert_pytorch/__main__.py b/bert_pytorch/__main__.py index ed5b95d..c208ecf 100644 --- a/bert_pytorch/__main__.py +++ b/bert_pytorch/__main__.py @@ -23,9 +23,11 @@ def train(): parser.add_argument("-b", "--batch_size", type=int, default=64) parser.add_argument("-e", "--epochs", type=int, default=10) parser.add_argument("-w", "--num_workers", type=int, default=5) + parser.add_argument("--with_cuda", type=bool, default=True) parser.add_argument("--log_freq", type=int, default=10) parser.add_argument("--corpus_lines", type=int, default=None) + parser.add_argument("--cuda_devices", type=int, nargs='+', default=None) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--adam_weight_decay", type=float, default=0.01) @@ -56,7 +58,7 @@ def train(): print("Creating BERT Trainer") trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, - with_cuda=args.with_cuda, log_freq=args.log_freq) + with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq) print("Training Start") for epoch in range(args.epochs): diff --git a/bert_pytorch/trainer/pretrain.py b/bert_pytorch/trainer/pretrain.py index 8520dcb..ef4f617 100644 --- a/bert_pytorch/trainer/pretrain.py +++ b/bert_pytorch/trainer/pretrain.py @@ -24,7 +24,7 @@ class BERTTrainer: def __init__(self, bert: BERT, vocab_size: int, train_dataloader: DataLoader, test_dataloader: DataLoader = None, lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, - with_cuda: bool = True, log_freq: int = 10): + with_cuda: bool = True, cuda_devices=None, log_freq: int = 10): """ :param bert: BERT model which you want to train :param vocab_size: total word vocab size @@ -47,9 +47,9 @@ def __init__(self, bert: BERT, vocab_size: int, self.model = BERTLM(bert, vocab_size).to(self.device) # Distributed GPU training if CUDA can detect more than 1 GPU - if torch.cuda.device_count() > 1: + if with_cuda and torch.cuda.device_count() > 1: print("Using %d GPUS for BERT" % torch.cuda.device_count()) - self.model = DataParallelModel(self.model) + self.model = DataParallelModel(self.model, device_ids=cuda_devices) # Setting the train and test data loader self.train_data = train_dataloader @@ -59,7 +59,9 @@ def __init__(self, bert: BERT, vocab_size: int, self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) # Using Negative Log Likelihood Loss function for predicting the masked_token - self.criterion = DataParallelCriterion(nn.NLLLoss(ignore_index=0)) + self.criterion = nn.NLLLoss(ignore_index=0) + if with_cuda and torch.cuda.device_count() > 0: + self.criterion = DataParallelCriterion(self.criterion, device_ids=cuda_devices) self.log_freq = log_freq From a453ab80d287a066c1011b867d1027325c4eb05b Mon Sep 17 00:00:00 2001 From: codertimo Date: Sun, 21 Oct 2018 10:56:55 +0900 Subject: [PATCH 05/12] Fix license mis-spell the scatter lab --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index 5a9d33c..240374c 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2018 Junseong Kim, Scatter Labs, BERT contributors + Copyright 2018 Junseong Kim, Scatter Lab, BERT contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From 913c43a3e7c0057d9b6bcc7dd054d50f553f7ddf Mon Sep 17 00:00:00 2001 From: codertimo Date: Mon, 22 Oct 2018 18:22:08 +0900 Subject: [PATCH 06/12] Fixing Percentage Issue --- bert_pytorch/__main__.py | 54 +++++++++++++++++---------------- bert_pytorch/dataset/dataset.py | 24 ++++++++++----- requirements.txt | 2 +- 3 files changed, 45 insertions(+), 35 deletions(-) diff --git a/bert_pytorch/__main__.py b/bert_pytorch/__main__.py index c208ecf..d4193f2 100644 --- a/bert_pytorch/__main__.py +++ b/bert_pytorch/__main__.py @@ -10,29 +10,30 @@ def train(): parser = argparse.ArgumentParser() - parser.add_argument("-c", "--train_dataset", required=True, type=str) - parser.add_argument("-t", "--test_dataset", type=str, default=None) - parser.add_argument("-v", "--vocab_path", required=True, type=str) - parser.add_argument("-o", "--output_path", required=True, type=str) - - parser.add_argument("-hs", "--hidden", type=int, default=256) - parser.add_argument("-l", "--layers", type=int, default=8) - parser.add_argument("-a", "--attn_heads", type=int, default=8) - parser.add_argument("-s", "--seq_len", type=int, default=20) - - parser.add_argument("-b", "--batch_size", type=int, default=64) - parser.add_argument("-e", "--epochs", type=int, default=10) - parser.add_argument("-w", "--num_workers", type=int, default=5) - - parser.add_argument("--with_cuda", type=bool, default=True) - parser.add_argument("--log_freq", type=int, default=10) - parser.add_argument("--corpus_lines", type=int, default=None) - parser.add_argument("--cuda_devices", type=int, nargs='+', default=None) - - parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--adam_weight_decay", type=float, default=0.01) - parser.add_argument("--adam_beta1", type=float, default=0.9) - parser.add_argument("--adam_beta2", type=float, default=0.999) + parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset for train bert") + parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set for evaluate train set") + parser.add_argument("-v", "--vocab_path", required=True, type=str, help="built vocab model path with bert-vocab") + parser.add_argument("-o", "--output_path", required=True, type=str, help="ex)output/bert.model") + + parser.add_argument("-hs", "--hidden", type=int, default=256, help="hidden size of transformer model") + parser.add_argument("-l", "--layers", type=int, default=8, help="number of layers") + parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads") + parser.add_argument("-s", "--seq_len", type=int, default=20, help="maximum sequence len") + + parser.add_argument("-b", "--batch_size", type=int, default=64, help="number of batch_size") + parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs") + parser.add_argument("-w", "--num_workers", type=int, default=5, help="dataloader worker size") + + parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false") + parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n") + parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus") + parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids") + parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false") + + parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam") + parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam") + parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value") args = parser.parse_args() @@ -41,11 +42,12 @@ def train(): print("Vocab Size: ", len(vocab)) print("Loading Train Dataset", args.train_dataset) - train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, corpus_lines=args.corpus_lines) + train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, + corpus_lines=args.corpus_lines, on_memory=args.on_memory) print("Loading Test Dataset", args.test_dataset) - test_dataset = BERTDataset(args.test_dataset, vocab, - seq_len=args.seq_len) if args.test_dataset is not None else None + test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \ + if args.test_dataset is not None else None print("Creating Dataloader") train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) diff --git a/bert_pytorch/dataset/dataset.py b/bert_pytorch/dataset/dataset.py index 4a63603..d377149 100644 --- a/bert_pytorch/dataset/dataset.py +++ b/bert_pytorch/dataset/dataset.py @@ -5,19 +5,27 @@ class BERTDataset(Dataset): - def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None): + def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True): self.vocab = vocab self.seq_len = seq_len + self.on_memory = on_memory + self.corpus_lines = corpus_lines with open(corpus_path, "r", encoding=encoding) as f: - self.datas = [line[:-1].split("\t") - for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)] + if self.corpus_lines is None and not on_memory: + for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines): + self.corpus_lines += 1 + + if on_memory: + self.lines = [line[:-1].split("\t") + for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)] + self.corpus_lines = len(self.lines) def __len__(self): - return len(self.datas) + return self.corpus_lines def __getitem__(self, item): - t1, (t2, is_next_label) = self.datas[item][0], self.random_sent(item) + t1, t2, is_next_label = self.random_sent(item) t1_random, t1_label = self.random_word(t1) t2_random, t2_label = self.random_word(t2) @@ -54,7 +62,7 @@ def random_word(self, sentence): tokens[i] = self.vocab.mask_index # 10% randomly change token to random token - elif prob * 0.8 <= prob < prob * 0.9: + elif 0.15 * 0.8 <= prob < 0.15 * 0.9: tokens[i] = random.randrange(len(self.vocab)) # 10% randomly change token to current token @@ -72,6 +80,6 @@ def random_word(self, sentence): def random_sent(self, index): # output_text, label(isNotNext:0, isNext:1) if random.random() > 0.5: - return self.datas[index][1], 1 + return self.datas[index][0], self.datas[index][1], 1 else: - return self.datas[random.randrange(len(self.datas))][1], 0 + return self.datas[index][0], self.datas[random.randrange(len(self.datas))][1], 0 diff --git a/requirements.txt b/requirements.txt index 8506741..661d08f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ tqdm numpy torch>=0.4.0 -torch-encoding \ No newline at end of file +torch-encodin \ No newline at end of file From 7b145dc5eea4a992fefcad8b492ce9b9a8cc9b87 Mon Sep 17 00:00:00 2001 From: codertimo Date: Mon, 22 Oct 2018 18:55:43 +0900 Subject: [PATCH 07/12] Adding none memory loading --- bert_pytorch/dataset/dataset.py | 42 ++++++++++++++++++++++++++++++-- bert_pytorch/trainer/pretrain.py | 6 +---- requirements.txt | 3 +-- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/bert_pytorch/dataset/dataset.py b/bert_pytorch/dataset/dataset.py index d377149..f607f33 100644 --- a/bert_pytorch/dataset/dataset.py +++ b/bert_pytorch/dataset/dataset.py @@ -8,8 +8,11 @@ class BERTDataset(Dataset): def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True): self.vocab = vocab self.seq_len = seq_len + self.on_memory = on_memory self.corpus_lines = corpus_lines + self.corpus_path = corpus_path + self.encoding = encoding with open(corpus_path, "r", encoding=encoding) as f: if self.corpus_lines is None and not on_memory: @@ -21,6 +24,13 @@ def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=N for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)] self.corpus_lines = len(self.lines) + if not on_memory: + self.file = open(corpus_path, "r", encoding=encoding) + self.random_file = open(corpus_path, "r", encoding=encoding) + + for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): + self.random_file.__next__() + def __len__(self): return self.corpus_lines @@ -78,8 +88,36 @@ def random_word(self, sentence): return tokens, output_label def random_sent(self, index): + t1, t2 = self.get_corpus_line(index) + # output_text, label(isNotNext:0, isNext:1) if random.random() > 0.5: - return self.datas[index][0], self.datas[index][1], 1 + return t1, t2, 1 + else: + return t1, self.get_random_line(), 0 + + def get_corpus_line(self, item): + if self.on_memory: + return self.lines[item][0], self.lines[item][1] else: - return self.datas[index][0], self.datas[random.randrange(len(self.datas))][1], 0 + line = self.file.__next__() + if line is None: + self.file.close() + self.file = open(self.corpus_path, "r", encoding=self.encoding) + line = self.file.__next__() + + t1, t2 = line[:-1].split("\t") + return t1, t2 + + def get_random_line(self): + if self.on_memory: + return self.lines[random.randrange(len(self.lines))][1] + + line = self.file.__next__() + if line is None: + self.file.close() + self.file = open(self.corpus_path, "r", encoding=self.encoding) + for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): + self.random_file.__next__() + line = self.random_file.__next__() + return line[:-1].split("\t")[1] diff --git a/bert_pytorch/trainer/pretrain.py b/bert_pytorch/trainer/pretrain.py index ef4f617..4d9c45c 100644 --- a/bert_pytorch/trainer/pretrain.py +++ b/bert_pytorch/trainer/pretrain.py @@ -3,8 +3,6 @@ from torch.optim import Adam from torch.utils.data import DataLoader -from encoding.parallel import DataParallelModel, DataParallelCriterion - from ..model import BERTLM, BERT import tqdm @@ -49,7 +47,7 @@ def __init__(self, bert: BERT, vocab_size: int, # Distributed GPU training if CUDA can detect more than 1 GPU if with_cuda and torch.cuda.device_count() > 1: print("Using %d GPUS for BERT" % torch.cuda.device_count()) - self.model = DataParallelModel(self.model, device_ids=cuda_devices) + self.model = nn.DataParallel(self.model, device_ids=cuda_devices) # Setting the train and test data loader self.train_data = train_dataloader @@ -60,8 +58,6 @@ def __init__(self, bert: BERT, vocab_size: int, # Using Negative Log Likelihood Loss function for predicting the masked_token self.criterion = nn.NLLLoss(ignore_index=0) - if with_cuda and torch.cuda.device_count() > 0: - self.criterion = DataParallelCriterion(self.criterion, device_ids=cuda_devices) self.log_freq = log_freq diff --git a/requirements.txt b/requirements.txt index 661d08f..3689708 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ tqdm numpy -torch>=0.4.0 -torch-encodin \ No newline at end of file +torch>=0.4.0 \ No newline at end of file From ed68f5a413254d9846ac250858e5d4cc6e7defbf Mon Sep 17 00:00:00 2001 From: Mathis Chenuet Date: Tue, 23 Oct 2018 00:57:31 +0200 Subject: [PATCH 08/12] really fix conditions #13 --- bert_pytorch/dataset/dataset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bert_pytorch/dataset/dataset.py b/bert_pytorch/dataset/dataset.py index f607f33..7d787f3 100644 --- a/bert_pytorch/dataset/dataset.py +++ b/bert_pytorch/dataset/dataset.py @@ -67,12 +67,14 @@ def random_word(self, sentence): for i, token in enumerate(tokens): prob = random.random() if prob < 0.15: - # 80% randomly change token to make token - if prob < prob * 0.8: + prob /= 0.15 + + # 80% randomly change token to mask token + if prob < 0.8: tokens[i] = self.vocab.mask_index # 10% randomly change token to random token - elif 0.15 * 0.8 <= prob < 0.15 * 0.9: + elif prob < 0.9: tokens[i] = random.randrange(len(self.vocab)) # 10% randomly change token to current token From a4d886f923e0e042ba9df30d5342ac2eca530365 Mon Sep 17 00:00:00 2001 From: jeonsworld <37530102+jeonsworld@users.noreply.github.com> Date: Tue, 23 Oct 2018 10:05:42 +0900 Subject: [PATCH 09/12] update dataset.py if condition prob < prob * 0.8 can't be true --- bert_pytorch/dataset/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bert_pytorch/dataset/dataset.py b/bert_pytorch/dataset/dataset.py index 1cf7f38..69aebb4 100644 --- a/bert_pytorch/dataset/dataset.py +++ b/bert_pytorch/dataset/dataset.py @@ -50,7 +50,7 @@ def random_word(self, sentence): prob = random.random() if prob < 0.15: # 80% randomly change token to make token - if prob < prob * 0.8: + if prob < 0.15 * 0.8: tokens[i] = self.vocab.mask_index # 10% randomly change token to random token From 9e35c63c852572c7ce95a379abc766358d3721b2 Mon Sep 17 00:00:00 2001 From: codertimo Date: Tue, 23 Oct 2018 10:07:48 +0900 Subject: [PATCH 10/12] Fixing #20 issue --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cc0e597..16d3df9 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ bert-vocab -c data/corpus.small -o data/vocab.small ### 2. Train your own BERT model ```shell -bert -c data/dataset.small -v data/vocab.small -o output/bert.model +bert -c data/corpus.small -v data/vocab.small -o output/bert.model ``` ## Language Model Pre-training From 6521dfeaddd3ab171e8cbf946a4aab13e1985127 Mon Sep 17 00:00:00 2001 From: Pengjia Zhu Date: Tue, 23 Oct 2018 15:01:30 +1300 Subject: [PATCH 11/12] fixed a bug in position.py --- bert_pytorch/model/embedding/position.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bert_pytorch/model/embedding/position.py b/bert_pytorch/model/embedding/position.py index d1d7e81..d55c224 100644 --- a/bert_pytorch/model/embedding/position.py +++ b/bert_pytorch/model/embedding/position.py @@ -13,7 +13,7 @@ def __init__(self, d_model, max_len=512): pe.require_grad = False position = torch.arange(0, max_len).float().unsqueeze(1) - div_term = (torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)).float().exp() + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) From e31ff4c8c5d34691edf073133f3894b22dcee47a Mon Sep 17 00:00:00 2001 From: codertimo Date: Tue, 23 Oct 2018 13:26:30 +0900 Subject: [PATCH 12/12] Adding optim schedule feature for #17 --- bert_pytorch/trainer/optim_schedule.py | 35 ++++++++++++++++++++++++++ bert_pytorch/trainer/pretrain.py | 8 +++--- 2 files changed, 40 insertions(+), 3 deletions(-) create mode 100644 bert_pytorch/trainer/optim_schedule.py diff --git a/bert_pytorch/trainer/optim_schedule.py b/bert_pytorch/trainer/optim_schedule.py new file mode 100644 index 0000000..5ccd222 --- /dev/null +++ b/bert_pytorch/trainer/optim_schedule.py @@ -0,0 +1,35 @@ +'''A wrapper class for optimizer ''' +import numpy as np + + +class ScheduledOptim(): + '''A simple wrapper class for learning rate scheduling''' + + def __init__(self, optimizer, d_model, n_warmup_steps): + self._optimizer = optimizer + self.n_warmup_steps = n_warmup_steps + self.n_current_steps = 0 + self.init_lr = np.power(d_model, -0.5) + + def step_and_update_lr(self): + "Step with the inner optimizer" + self._update_learning_rate() + self._optimizer.step() + + def zero_grad(self): + "Zero out the gradients by the inner optimizer" + self._optimizer.zero_grad() + + def _get_lr_scale(self): + return np.min([ + np.power(self.n_current_steps, -0.5), + np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) + + def _update_learning_rate(self): + ''' Learning rate scheduling per step ''' + + self.n_current_steps += 1 + lr = self.init_lr * self._get_lr_scale() + + for param_group in self._optimizer.param_groups: + param_group['lr'] = lr diff --git a/bert_pytorch/trainer/pretrain.py b/bert_pytorch/trainer/pretrain.py index 4d9c45c..0b882dd 100644 --- a/bert_pytorch/trainer/pretrain.py +++ b/bert_pytorch/trainer/pretrain.py @@ -4,6 +4,7 @@ from torch.utils.data import DataLoader from ..model import BERTLM, BERT +from .optim_schedule import ScheduledOptim import tqdm @@ -21,7 +22,7 @@ class BERTTrainer: def __init__(self, bert: BERT, vocab_size: int, train_dataloader: DataLoader, test_dataloader: DataLoader = None, - lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, + lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000, with_cuda: bool = True, cuda_devices=None, log_freq: int = 10): """ :param bert: BERT model which you want to train @@ -55,6 +56,7 @@ def __init__(self, bert: BERT, vocab_size: int, # Setting the Adam optimizer with hyper-param self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) + self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps) # Using Negative Log Likelihood Loss function for predicting the masked_token self.criterion = nn.NLLLoss(ignore_index=0) @@ -110,9 +112,9 @@ def iteration(self, epoch, data_loader, train=True): # 3. backward and optimization only in train if train: - self.optim.zero_grad() + self.optim_schedule.zero_grad() loss.backward() - self.optim.step() + self.optim_schedule.step_and_update_lr() # next sentence prediction accuracy correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()