Skip to content

Commit

Permalink
upload source code
Browse files Browse the repository at this point in the history
  • Loading branch information
timchen0618 committed May 15, 2023
1 parent b4acd70 commit 06d4d36
Show file tree
Hide file tree
Showing 79 changed files with 144,067 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
data/tydiqa-v1.0-dev.jsonl.gz filter=lfs diff=lfs merge=lfs -text
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.DS_Store
data/tydiqa-v1.0-dev.jsonl.gz

Binary file added data/Dev.jsonl.gz
Binary file not shown.
Binary file added data/full-test-long-term.jsonl.gz
Binary file not shown.
Binary file added data/full-test-parallel.jsonl.gz
Binary file not shown.
Binary file added data/static-test.jsonl.gz
Binary file not shown.
7 changes: 7 additions & 0 deletions data/test_feedback.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
feedback data/Dev-400.jsonl.gz
feedback data/full-test-long-term.jsonl.gz
feedback data/static-test.jsonl.gz
feedback data/full-test-parallel.jsonl.gz
tydi data/tydiqa-v1.0-dev.jsonl.gz
squad2 -
squad -
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
179 changes: 179 additions & 0 deletions generate_prob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import json
import csv
import string
import re
import gzip
import collections
import math
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from transformers import BertTokenizer,DebertaTokenizerFast, DebertaV2TokenizerFast, AutoTokenizer

from src.data import read_feedback_examples_and_features, get_feedback_data
from model import BertForQuestionAnswering, DebertaSQuAD2, BertForQuestionAnsweringSequence
from src_utils.merge_data import merge
from src.eval import RawResult, normalize_answer
import argparse


def get_batch_log_prob(start_probs, end_probs, start_samples, end_samples):
bs = start_samples.shape[0]
ignored_index = start_probs.size(1)
start_samples.clamp_(0, ignored_index)
end_samples.clamp_(0, ignored_index)
log_prob = start_probs[torch.arange(bs), start_samples].log() + end_probs[torch.arange(bs),
end_samples].log()
return log_prob


def load_initialization(model, ckpt_name):
ckpt = torch.load(ckpt_name)

model.load_state_dict(ckpt['model_state_dict'])
print("Loaded the model state from a saved checkpoint {}".format(ckpt_name))
return model

def main(train_batches, model, device, add_classifier):
total = 0
log_probs = []
class_log_probs = []
for step, batch in enumerate(train_batches):
batch = tuple(t.to(device) for t in batch)

input_ids, input_mask, segment_ids, start_samples, end_samples, rewards = batch
with torch.no_grad():
start_probs, end_probs, class_prob = model(batch=batch[:3], return_prob=True, classifier=add_classifier)
if args.add_classifier:
class_log_prob = class_prob.log()

class_sample = class_log_prob.argmax(dim=-1).item()

log_prob = get_batch_log_prob(
start_probs, end_probs, start_samples, end_samples)

log_probs.append(log_prob)
if args.add_classifier:
class_log_probs.append(class_log_prob[:, class_sample])
total += input_ids.size(0)

print('='*50)
print('[logging] Total: %d'%(total))
print('='*50)

if add_classifier:
return torch.cat(log_probs, dim=0), torch.cat(class_log_probs, dim=0)
else:
return torch.cat(log_probs, dim=0), None





if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model", default='deepset/deberta-v3-base-squad2', type=str)
parser.add_argument("--data_file", default=None, type=str, required=True, help='data you wish to generate prob from')
parser.add_argument("--checkpoint", default=None, type=str, required=True)
parser.add_argument("--add_classifier", action='store_true')
parser.add_argument(
"--outfile",
default=None,
type=str,
required=True,
help="The output directory where the model checkpoints and predictions will be written.")
args = parser.parse_args()


#### initialization ####
model_type = args.model
data_file = args.data_file
outfile = args.outfile
checkpoint = args.checkpoint
add_classifier = args.add_classifier

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 1


# tokenization and dataset
if model_type == 'deepset/deberta-v3-base-squad2':
tokenizer = AutoTokenizer.from_pretrained(model_type, return_offsets_mapping=True)
elif 'v3' in model_type:
tokenizer = DebertaV2TokenizerFast.from_pretrained(model_type, return_offsets_mapping=True)
else:
tokenizer = DebertaTokenizerFast.from_pretrained(model_type, return_offsets_mapping=True)

train_dataset = get_feedback_data(data_file) # original train data

# load model
if model_type == "deepset/deberta-v3-base-squad2":
model = DebertaSQuAD2(model_type=model_type)
elif model_type == 'microsoft/deberta-v3-base':
if args.add_classifier:
model = BertForQuestionAnsweringSequence(model_type=model_type)
else:
model = BertForQuestionAnswering(model_type=model_type)
if checkpoint:
model = load_initialization(model, checkpoint)
model = model.to(device)

# processing examples
train_examples, train_features = read_feedback_examples_and_features(input_data=train_dataset,
negative_reward=-0.1,
partial_reward=0.5,
reward_wrong_unans=-1,
reward_correct_span=1,
reward_correct_unans=1,
reward_class_wrong=0,
reward_class_correct_ans=1,
tokenizer=tokenizer,
max_seq_length=512,
prepend_title=True,
load_log_prob=False)

# read_feedback_examples_and_features(train_dataset, -0.1, 0.5, -1, 1, 1, tokenizer, 512, True)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)

all_start_samples = torch.tensor([f.start_sample for f in train_features], dtype=torch.long)
all_end_samples = torch.tensor([f.end_sample for f in train_features], dtype=torch.long)
all_rewards = torch.tensor([f.reward for f in train_features], dtype=torch.float)

data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_samples, all_end_samples, all_rewards)
print("***** Train *****")
print(" Num examples = %d"%len(train_features))
print(" Batch size = %d"%batch_size)

train_dataloader = DataLoader(data, batch_size=batch_size)
train_batches = [batch for batch in train_dataloader]


# main function
log_probs, class_log_probs = main(train_batches, model, device, add_classifier)
print(log_probs.size(), len(train_dataset))
assert log_probs.size(0) == len(train_dataset)
if args.add_classifier:
assert class_log_probs.size(0) == len(train_dataset)
print(class_log_probs.size())

for i, inst in enumerate(train_dataset):
inst['log_prob'] = log_probs[i].item()
# print(class_log_probs[i])
if args.add_classifier:
inst['class_log_prob'] = class_log_probs[i].item()
else:
inst['class_log_prob'] = 0

print(train_dataset[0])

# write data
fw = open(outfile, 'w')
for l in train_dataset:
fw.write(json.dumps(l) + '\n')
fw.close()




128 changes: 128 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import BertModel, AutoModel, AutoModelForQuestionAnswering
import torch


class BertForQuestionAnswering(nn.Module):
def __init__(self, model_type: str):
super(BertForQuestionAnswering, self).__init__()
if 'deberta' in model_type:
self.bert = AutoModel.from_pretrained(model_type)
elif 'bert-' in model_type:
self.bert = BertModel.from_pretrained(model_type)
else:
raise ValueError('Model type!')

self.qa_outputs = nn.Linear(self.bert.config.hidden_size, 2) # [N, L, H] => [N, L, 2]
self.classifier_coeff = 10
self.entropy_penalty = 0
# print(self.classifier_coeff)
self.softmax = nn.Softmax(dim=-1)

def forward(self, batch, classifier=False, return_prob=False, **kwargs):
'''
each batch is a list of 5 items (training) or 3 items (inference)
- input_ids: token id of the input sequence
- attention_mask: mask of the sequence (1 for present, 0 for blank)
- token_type_ids: indicator of type of sequence.
- e.g. in QA, whether it is question or document
- (training) start_positions: list of start positions of the span
- (training) end_positions: list of end positions of the span
'''

input_ids, attention_masks, token_type_ids = batch[:3]
# pooler_output, last_hidden_state
output = self.bert(input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_masks)
sequence_output = output.last_hidden_state
logits = self.qa_outputs(sequence_output) # (bs, max_input_len, 2)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) # (bs, max_input_len)
end_logits = end_logits.squeeze(-1) # (bs, max_input_len)

if len(batch) == 5:
start_positions, end_positions = batch[3:]
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)

loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2


if classifier:
answerable_mask = (start_positions != 0) | (end_positions != 0)
loss_fct = CrossEntropyLoss()

answerable_logits = self.classification(sequence_output)[:, 0]
classifier_loss = loss_fct(answerable_logits, answerable_mask.long())
total_loss += self.classifier_coeff * classifier_loss
answerable_prob = torch.softmax(answerable_logits, dim=-1)

total_loss += self.entropy_penalty * (-torch.mean(torch.sum(-answerable_prob * torch.log(answerable_prob), dim=-1)))
return total_loss, torch.softmax(self.classification(sequence_output[:, 0]), dim=-1)
return total_loss, None

elif len(batch) == 3 and not classifier:
if not return_prob:
return start_logits, end_logits, None
else:
return self.softmax(start_logits), self.softmax(end_logits), None
elif len(batch) == 3 and classifier:
if return_prob:
return self.softmax(start_logits), self.softmax(
end_logits), self.softmax(self.classification(sequence_output[:, 0]))
else:
return start_logits, end_logits, self.classification(sequence_output[:, 0])
else:
raise NotImplementedError()


class BertForQuestionAnsweringSequence(BertForQuestionAnswering):
def __init__(self, model_type: str):
super(BertForQuestionAnsweringSequence, self).__init__(model_type=model_type)
self.classification = nn.Linear(self.bert.config.hidden_size, 2) # [N, L, H] => [N, L, 2]





class DebertaSQuAD2(nn.Module):
def __init__(self, model_type: str):
super(DebertaSQuAD2, self).__init__()
if model_type == 'deepset/deberta-v3-base-squad2':
self.bert = AutoModelForQuestionAnswering.from_pretrained(model_type)
else:
raise ValueError('Model type!')

def forward(self, batch, return_prob=False, **kwargs):
'''
each batch is a list of 5 items (training) or 3 items (inference)
- input_ids: token id of the input sequence
- attention_mask: mask of the sequence (1 for present, 0 for blank)
- token_type_ids: indicator of type of sequence.
- e.g. in QA, whether it is question or document
- (training) start_positions: list of start positions of the span
- (training) end_positions: list of end positions of the span
'''

input_ids, attention_masks, token_type_ids = batch[:3]
# pooler_output, last_hidden_state
output = self.bert(input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_masks)

start_logits, end_logits = output.start_logits, output.end_logits

if len(batch) == 3:
if not return_prob:
return start_logits, end_logits
else:
return torch.softmax(start_logits, dim=-1), torch.softmax(end_logits, dim=-1)

else:
raise NotImplementedError()
Loading

0 comments on commit 06d4d36

Please sign in to comment.