-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b4acd70
commit 06d4d36
Showing
79 changed files
with
144,067 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
data/tydiqa-v1.0-dev.jsonl.gz filter=lfs diff=lfs merge=lfs -text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
.DS_Store | ||
data/tydiqa-v1.0-dev.jsonl.gz | ||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
feedback data/Dev-400.jsonl.gz | ||
feedback data/full-test-long-term.jsonl.gz | ||
feedback data/static-test.jsonl.gz | ||
feedback data/full-test-parallel.jsonl.gz | ||
tydi data/tydiqa-v1.0-dev.jsonl.gz | ||
squad2 - | ||
squad - |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+150 KB
...lel/round1/sensitivity_all800/train-data-parallel-round1-200-wprob-sensitivity-1.jsonl.gz
Binary file not shown.
Binary file added
BIN
+145 KB
...el/round1/sensitivity_all800/train-data-parallel-round1-200-wprob-sensitivity-10.jsonl.gz
Binary file not shown.
Binary file added
BIN
+141 KB
...lel/round1/sensitivity_all800/train-data-parallel-round1-200-wprob-sensitivity-2.jsonl.gz
Binary file not shown.
Binary file added
BIN
+144 KB
...lel/round1/sensitivity_all800/train-data-parallel-round1-200-wprob-sensitivity-3.jsonl.gz
Binary file not shown.
Binary file added
BIN
+147 KB
...lel/round1/sensitivity_all800/train-data-parallel-round1-200-wprob-sensitivity-4.jsonl.gz
Binary file not shown.
Binary file added
BIN
+147 KB
...lel/round1/sensitivity_all800/train-data-parallel-round1-200-wprob-sensitivity-5.jsonl.gz
Binary file not shown.
Binary file added
BIN
+149 KB
...lel/round1/sensitivity_all800/train-data-parallel-round1-200-wprob-sensitivity-6.jsonl.gz
Binary file not shown.
Binary file added
BIN
+145 KB
...lel/round1/sensitivity_all800/train-data-parallel-round1-200-wprob-sensitivity-7.jsonl.gz
Binary file not shown.
Binary file added
BIN
+149 KB
...lel/round1/sensitivity_all800/train-data-parallel-round1-200-wprob-sensitivity-8.jsonl.gz
Binary file not shown.
Binary file added
BIN
+145 KB
...lel/round1/sensitivity_all800/train-data-parallel-round1-200-wprob-sensitivity-9.jsonl.gz
Binary file not shown.
Binary file added
BIN
+174 KB
data/train_parallel/round1/train-data-parallel-round1-200-wprob-128.jsonl.gz
Binary file not shown.
Binary file added
BIN
+576 KB
data/train_parallel/round1/train-data-parallel-round1-200-wprob-512-all.jsonl.gz
Binary file not shown.
Binary file added
BIN
+146 KB
data/train_parallel/round1/train-data-parallel-round1-200-wprob-512-woclass.jsonl.gz
Binary file not shown.
Binary file added
BIN
+72.8 KB
data/train_parallel/round1/train-data-parallel-round1-200-wprob-fewer.jsonl.gz
Binary file not shown.
Binary file added
BIN
+131 KB
data/train_parallel/round1/train-data-parallel-round1-200-wprob-news.jsonl.gz
Binary file not shown.
Binary file added
BIN
+143 KB
data/train_parallel/round1/train-data-parallel-round1-200-wprob-short.jsonl.gz
Binary file not shown.
Binary file added
BIN
+70.3 KB
data/train_parallel/round10/train-data-parallel-round10-200-wprob-fewer.jsonl.gz
Binary file not shown.
Binary file added
BIN
+143 KB
data/train_parallel/round2/train-data-parallel-round2-200-wprob-128.jsonl.gz
Binary file not shown.
Binary file added
BIN
+144 KB
data/train_parallel/round2/train-data-parallel-round2-200-wprob-512-woclass.jsonl.gz
Binary file not shown.
Binary file added
BIN
+70.1 KB
data/train_parallel/round2/train-data-parallel-round2-200-wprob-fewer.jsonl.gz
Binary file not shown.
Binary file added
BIN
+151 KB
data/train_parallel/round2/train-data-parallel-round2-200-wprob-news-force10.jsonl.gz
Binary file not shown.
Binary file added
BIN
+138 KB
data/train_parallel/round2/train-data-parallel-round2-200-wprob-news.jsonl.gz
Binary file not shown.
Binary file added
BIN
+160 KB
data/train_parallel/round2/train-data-parallel-round2-200-wprob-short.jsonl.gz
Binary file not shown.
Binary file added
BIN
+148 KB
data/train_parallel/round3/train-data-parallel-round3-200-wprob-128.jsonl.gz
Binary file not shown.
Binary file added
BIN
+136 KB
data/train_parallel/round3/train-data-parallel-round3-200-wprob-512-woclass.jsonl.gz
Binary file not shown.
Binary file added
BIN
+68.4 KB
data/train_parallel/round3/train-data-parallel-round3-200-wprob-fewer.jsonl.gz
Binary file not shown.
Binary file added
BIN
+135 KB
data/train_parallel/round3/train-data-parallel-round3-200-wprob-news.jsonl.gz
Binary file not shown.
Binary file added
BIN
+154 KB
data/train_parallel/round3/train-data-parallel-round3-200-wprob-short.jsonl.gz
Binary file not shown.
Binary file added
BIN
+163 KB
data/train_parallel/round4/train-data-parallel-round4-200-wprob-128.jsonl.gz
Binary file not shown.
Binary file added
BIN
+153 KB
data/train_parallel/round4/train-data-parallel-round4-200-wprob-512-woclass.jsonl.gz
Binary file not shown.
Binary file added
BIN
+85.2 KB
data/train_parallel/round4/train-data-parallel-round4-200-wprob-fewer.jsonl.gz
Binary file not shown.
Binary file added
BIN
+129 KB
data/train_parallel/round4/train-data-parallel-round4-200-wprob-news.jsonl.gz
Binary file not shown.
Binary file added
BIN
+140 KB
data/train_parallel/round4/train-data-parallel-round4-200-wprob-short.jsonl.gz
Binary file not shown.
Binary file added
BIN
+161 KB
data/train_parallel/round5/train-data-parallel-round5-200-wprob-128.jsonl.gz
Binary file not shown.
Binary file added
BIN
+168 KB
data/train_parallel/round5/train-data-parallel-round5-200-wprob-512-woclass.jsonl.gz
Binary file not shown.
Binary file added
BIN
+60.3 KB
data/train_parallel/round5/train-data-parallel-round5-200-wprob-fewer.jsonl.gz
Binary file not shown.
Binary file added
BIN
+139 KB
data/train_parallel/round5/train-data-parallel-round5-200-wprob-news.jsonl.gz
Binary file not shown.
Binary file added
BIN
+135 KB
data/train_parallel/round5/train-data-parallel-round5-200-wprob-short.jsonl.gz
Binary file not shown.
Binary file added
BIN
+75.8 KB
data/train_parallel/round6/train-data-parallel-round6-200-wprob-fewer.jsonl.gz
Binary file not shown.
Binary file added
BIN
+73 KB
data/train_parallel/round7/train-data-parallel-round7-200-wprob-fewer.jsonl.gz
Binary file not shown.
Binary file added
BIN
+78.5 KB
data/train_parallel/round8/train-data-parallel-round8-200-wprob-fewer.jsonl.gz
Binary file not shown.
Binary file added
BIN
+78.1 KB
data/train_parallel/round9/train-data-parallel-round9-200-wprob-fewer.jsonl.gz
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import json | ||
import csv | ||
import string | ||
import re | ||
import gzip | ||
import collections | ||
import math | ||
import torch | ||
from torch.utils.data import DataLoader, TensorDataset | ||
import numpy as np | ||
from transformers import BertTokenizer,DebertaTokenizerFast, DebertaV2TokenizerFast, AutoTokenizer | ||
|
||
from src.data import read_feedback_examples_and_features, get_feedback_data | ||
from model import BertForQuestionAnswering, DebertaSQuAD2, BertForQuestionAnsweringSequence | ||
from src_utils.merge_data import merge | ||
from src.eval import RawResult, normalize_answer | ||
import argparse | ||
|
||
|
||
def get_batch_log_prob(start_probs, end_probs, start_samples, end_samples): | ||
bs = start_samples.shape[0] | ||
ignored_index = start_probs.size(1) | ||
start_samples.clamp_(0, ignored_index) | ||
end_samples.clamp_(0, ignored_index) | ||
log_prob = start_probs[torch.arange(bs), start_samples].log() + end_probs[torch.arange(bs), | ||
end_samples].log() | ||
return log_prob | ||
|
||
|
||
def load_initialization(model, ckpt_name): | ||
ckpt = torch.load(ckpt_name) | ||
|
||
model.load_state_dict(ckpt['model_state_dict']) | ||
print("Loaded the model state from a saved checkpoint {}".format(ckpt_name)) | ||
return model | ||
|
||
def main(train_batches, model, device, add_classifier): | ||
total = 0 | ||
log_probs = [] | ||
class_log_probs = [] | ||
for step, batch in enumerate(train_batches): | ||
batch = tuple(t.to(device) for t in batch) | ||
|
||
input_ids, input_mask, segment_ids, start_samples, end_samples, rewards = batch | ||
with torch.no_grad(): | ||
start_probs, end_probs, class_prob = model(batch=batch[:3], return_prob=True, classifier=add_classifier) | ||
if args.add_classifier: | ||
class_log_prob = class_prob.log() | ||
|
||
class_sample = class_log_prob.argmax(dim=-1).item() | ||
|
||
log_prob = get_batch_log_prob( | ||
start_probs, end_probs, start_samples, end_samples) | ||
|
||
log_probs.append(log_prob) | ||
if args.add_classifier: | ||
class_log_probs.append(class_log_prob[:, class_sample]) | ||
total += input_ids.size(0) | ||
|
||
print('='*50) | ||
print('[logging] Total: %d'%(total)) | ||
print('='*50) | ||
|
||
if add_classifier: | ||
return torch.cat(log_probs, dim=0), torch.cat(class_log_probs, dim=0) | ||
else: | ||
return torch.cat(log_probs, dim=0), None | ||
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model", default='deepset/deberta-v3-base-squad2', type=str) | ||
parser.add_argument("--data_file", default=None, type=str, required=True, help='data you wish to generate prob from') | ||
parser.add_argument("--checkpoint", default=None, type=str, required=True) | ||
parser.add_argument("--add_classifier", action='store_true') | ||
parser.add_argument( | ||
"--outfile", | ||
default=None, | ||
type=str, | ||
required=True, | ||
help="The output directory where the model checkpoints and predictions will be written.") | ||
args = parser.parse_args() | ||
|
||
|
||
#### initialization #### | ||
model_type = args.model | ||
data_file = args.data_file | ||
outfile = args.outfile | ||
checkpoint = args.checkpoint | ||
add_classifier = args.add_classifier | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
batch_size = 1 | ||
|
||
|
||
# tokenization and dataset | ||
if model_type == 'deepset/deberta-v3-base-squad2': | ||
tokenizer = AutoTokenizer.from_pretrained(model_type, return_offsets_mapping=True) | ||
elif 'v3' in model_type: | ||
tokenizer = DebertaV2TokenizerFast.from_pretrained(model_type, return_offsets_mapping=True) | ||
else: | ||
tokenizer = DebertaTokenizerFast.from_pretrained(model_type, return_offsets_mapping=True) | ||
|
||
train_dataset = get_feedback_data(data_file) # original train data | ||
|
||
# load model | ||
if model_type == "deepset/deberta-v3-base-squad2": | ||
model = DebertaSQuAD2(model_type=model_type) | ||
elif model_type == 'microsoft/deberta-v3-base': | ||
if args.add_classifier: | ||
model = BertForQuestionAnsweringSequence(model_type=model_type) | ||
else: | ||
model = BertForQuestionAnswering(model_type=model_type) | ||
if checkpoint: | ||
model = load_initialization(model, checkpoint) | ||
model = model.to(device) | ||
|
||
# processing examples | ||
train_examples, train_features = read_feedback_examples_and_features(input_data=train_dataset, | ||
negative_reward=-0.1, | ||
partial_reward=0.5, | ||
reward_wrong_unans=-1, | ||
reward_correct_span=1, | ||
reward_correct_unans=1, | ||
reward_class_wrong=0, | ||
reward_class_correct_ans=1, | ||
tokenizer=tokenizer, | ||
max_seq_length=512, | ||
prepend_title=True, | ||
load_log_prob=False) | ||
|
||
# read_feedback_examples_and_features(train_dataset, -0.1, 0.5, -1, 1, 1, tokenizer, 512, True) | ||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) | ||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) | ||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) | ||
|
||
all_start_samples = torch.tensor([f.start_sample for f in train_features], dtype=torch.long) | ||
all_end_samples = torch.tensor([f.end_sample for f in train_features], dtype=torch.long) | ||
all_rewards = torch.tensor([f.reward for f in train_features], dtype=torch.float) | ||
|
||
data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_samples, all_end_samples, all_rewards) | ||
print("***** Train *****") | ||
print(" Num examples = %d"%len(train_features)) | ||
print(" Batch size = %d"%batch_size) | ||
|
||
train_dataloader = DataLoader(data, batch_size=batch_size) | ||
train_batches = [batch for batch in train_dataloader] | ||
|
||
|
||
# main function | ||
log_probs, class_log_probs = main(train_batches, model, device, add_classifier) | ||
print(log_probs.size(), len(train_dataset)) | ||
assert log_probs.size(0) == len(train_dataset) | ||
if args.add_classifier: | ||
assert class_log_probs.size(0) == len(train_dataset) | ||
print(class_log_probs.size()) | ||
|
||
for i, inst in enumerate(train_dataset): | ||
inst['log_prob'] = log_probs[i].item() | ||
# print(class_log_probs[i]) | ||
if args.add_classifier: | ||
inst['class_log_prob'] = class_log_probs[i].item() | ||
else: | ||
inst['class_log_prob'] = 0 | ||
|
||
print(train_dataset[0]) | ||
|
||
# write data | ||
fw = open(outfile, 'w') | ||
for l in train_dataset: | ||
fw.write(json.dumps(l) + '\n') | ||
fw.close() | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import torch.nn as nn | ||
from torch.nn import CrossEntropyLoss | ||
from transformers import BertModel, AutoModel, AutoModelForQuestionAnswering | ||
import torch | ||
|
||
|
||
class BertForQuestionAnswering(nn.Module): | ||
def __init__(self, model_type: str): | ||
super(BertForQuestionAnswering, self).__init__() | ||
if 'deberta' in model_type: | ||
self.bert = AutoModel.from_pretrained(model_type) | ||
elif 'bert-' in model_type: | ||
self.bert = BertModel.from_pretrained(model_type) | ||
else: | ||
raise ValueError('Model type!') | ||
|
||
self.qa_outputs = nn.Linear(self.bert.config.hidden_size, 2) # [N, L, H] => [N, L, 2] | ||
self.classifier_coeff = 10 | ||
self.entropy_penalty = 0 | ||
# print(self.classifier_coeff) | ||
self.softmax = nn.Softmax(dim=-1) | ||
|
||
def forward(self, batch, classifier=False, return_prob=False, **kwargs): | ||
''' | ||
each batch is a list of 5 items (training) or 3 items (inference) | ||
- input_ids: token id of the input sequence | ||
- attention_mask: mask of the sequence (1 for present, 0 for blank) | ||
- token_type_ids: indicator of type of sequence. | ||
- e.g. in QA, whether it is question or document | ||
- (training) start_positions: list of start positions of the span | ||
- (training) end_positions: list of end positions of the span | ||
''' | ||
|
||
input_ids, attention_masks, token_type_ids = batch[:3] | ||
# pooler_output, last_hidden_state | ||
output = self.bert(input_ids=input_ids, | ||
token_type_ids=token_type_ids, | ||
attention_mask=attention_masks) | ||
sequence_output = output.last_hidden_state | ||
logits = self.qa_outputs(sequence_output) # (bs, max_input_len, 2) | ||
start_logits, end_logits = logits.split(1, dim=-1) | ||
start_logits = start_logits.squeeze(-1) # (bs, max_input_len) | ||
end_logits = end_logits.squeeze(-1) # (bs, max_input_len) | ||
|
||
if len(batch) == 5: | ||
start_positions, end_positions = batch[3:] | ||
ignored_index = start_logits.size(1) | ||
start_positions.clamp_(0, ignored_index) | ||
end_positions.clamp_(0, ignored_index) | ||
|
||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | ||
start_loss = loss_fct(start_logits, start_positions) | ||
end_loss = loss_fct(end_logits, end_positions) | ||
total_loss = (start_loss + end_loss) / 2 | ||
|
||
|
||
if classifier: | ||
answerable_mask = (start_positions != 0) | (end_positions != 0) | ||
loss_fct = CrossEntropyLoss() | ||
|
||
answerable_logits = self.classification(sequence_output)[:, 0] | ||
classifier_loss = loss_fct(answerable_logits, answerable_mask.long()) | ||
total_loss += self.classifier_coeff * classifier_loss | ||
answerable_prob = torch.softmax(answerable_logits, dim=-1) | ||
|
||
total_loss += self.entropy_penalty * (-torch.mean(torch.sum(-answerable_prob * torch.log(answerable_prob), dim=-1))) | ||
return total_loss, torch.softmax(self.classification(sequence_output[:, 0]), dim=-1) | ||
return total_loss, None | ||
|
||
elif len(batch) == 3 and not classifier: | ||
if not return_prob: | ||
return start_logits, end_logits, None | ||
else: | ||
return self.softmax(start_logits), self.softmax(end_logits), None | ||
elif len(batch) == 3 and classifier: | ||
if return_prob: | ||
return self.softmax(start_logits), self.softmax( | ||
end_logits), self.softmax(self.classification(sequence_output[:, 0])) | ||
else: | ||
return start_logits, end_logits, self.classification(sequence_output[:, 0]) | ||
else: | ||
raise NotImplementedError() | ||
|
||
|
||
class BertForQuestionAnsweringSequence(BertForQuestionAnswering): | ||
def __init__(self, model_type: str): | ||
super(BertForQuestionAnsweringSequence, self).__init__(model_type=model_type) | ||
self.classification = nn.Linear(self.bert.config.hidden_size, 2) # [N, L, H] => [N, L, 2] | ||
|
||
|
||
|
||
|
||
|
||
class DebertaSQuAD2(nn.Module): | ||
def __init__(self, model_type: str): | ||
super(DebertaSQuAD2, self).__init__() | ||
if model_type == 'deepset/deberta-v3-base-squad2': | ||
self.bert = AutoModelForQuestionAnswering.from_pretrained(model_type) | ||
else: | ||
raise ValueError('Model type!') | ||
|
||
def forward(self, batch, return_prob=False, **kwargs): | ||
''' | ||
each batch is a list of 5 items (training) or 3 items (inference) | ||
- input_ids: token id of the input sequence | ||
- attention_mask: mask of the sequence (1 for present, 0 for blank) | ||
- token_type_ids: indicator of type of sequence. | ||
- e.g. in QA, whether it is question or document | ||
- (training) start_positions: list of start positions of the span | ||
- (training) end_positions: list of end positions of the span | ||
''' | ||
|
||
input_ids, attention_masks, token_type_ids = batch[:3] | ||
# pooler_output, last_hidden_state | ||
output = self.bert(input_ids=input_ids, | ||
token_type_ids=token_type_ids, | ||
attention_mask=attention_masks) | ||
|
||
start_logits, end_logits = output.start_logits, output.end_logits | ||
|
||
if len(batch) == 3: | ||
if not return_prob: | ||
return start_logits, end_logits | ||
else: | ||
return torch.softmax(start_logits, dim=-1), torch.softmax(end_logits, dim=-1) | ||
|
||
else: | ||
raise NotImplementedError() |
Oops, something went wrong.