Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert Code from pytorch_transformers to huggingface transformers #7

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.ipynb_checkpoints
data/
model/
__pycache__
KR*
tempo/
Expand All @@ -8,3 +9,5 @@ tempo/
qualitative/
outputs/
*.ipynb

*.json
39 changes: 35 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,36 @@
# SOM-DST

Convert code pytorch-transformers to huggingface transformers

```
# Fixed Requirements

# pip install torch==1.7.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html
torch==1.7.1+cu110
transformers==3.0.2
wget==3.2
jsonlines
tqdm
```

```
# 동일한 Parameter setting (MW 2.1) 성능 소폭 감소 (0.5309 -> 0.5275)
------------------------------
op_code: 4, is_gt_op: False, is_gt_p_state: False, is_gt_gen: False
Epoch 0 joint accuracy : 0.5275515743756786
Epoch 0 slot turn accuracy : 0.9732401375316211
Epoch 0 slot turn F1: 0.9175307139165523
Epoch 0 op accuracy : 0.9737830256966589
Epoch 0 op F1 : {'delete': 0.018656716417910446, 'update': 0.8015826338020638, 'dontcare': 0.3235668789808917, 'carryover': 0.9862940159245958}
Epoch 0 op hit count : {'delete': 15, 'update': 7496, 'dontcare': 127, 'carryover': 207607}
Epoch 0 op all count : {'delete': 1576, 'update': 10595, 'dontcare': 581, 'carryover': 208288}
Final Joint Accuracy : 0.3713713713713714
Final slot turn F1 : 0.9101975987924662
Latency Per Prediction : 24.244383 ms
-----------------------------
```

## The original readme.md is as follows

This code is the official pytorch implementation of [Efficient Dialogue State Tracking by Selectively Overwriting Memory](https://arxiv.org/abs/1911.03906).<br>
> [Sungdong Kim](https://github.com/dsksd), [Sohee Yang](https://github.com/soheeyang), [Gyuwan Kim](mailto:[email protected]), [Sang-woo Lee](https://scholar.google.co.kr/citations?user=TMTTMuQAAAAJ)<br>
Expand Down Expand Up @@ -96,10 +127,10 @@ taxi 0.5903426791277259 0.9803219106957396
### Main results on MultiWOZ dataset (Joint Goal Accuracy)


|Model |MultiWOZ 2.0 |MultWOZ 2.1|
|-------------|------------|------------|
|SOM-DST Base | 51.72 | 53.01 |
|SOM-DST Large| 52.32 | 53.68 |
| Model | MultiWOZ 2.0 | MultWOZ 2.1 |
| ------------- | ------------ | ----------- |
| SOM-DST Base | 51.72 | 53.01 |
| SOM-DST Large | 52.32 | 53.68 |


## Citation
Expand Down
197 changes: 132 additions & 65 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@
"""

from utils.data_utils import prepare_dataset, MultiWozDataset
from utils.data_utils import make_slot_meta, domain2id, OP_SET, make_turn_label, postprocessing
from utils.data_utils import (
make_slot_meta,
domain2id,
OP_SET,
make_turn_label,
postprocessing,
)
from utils.eval_utils import compute_prf, compute_acc, per_domain_join_accuracy
from pytorch_transformers import BertTokenizer, BertConfig

# from pytorch_transformers import BertTokenizer, BertConfig
from transformers import BertTokenizer, BertConfig

from model import SomDST
import torch.nn as nn
Expand All @@ -23,51 +31,82 @@
import json
from copy import deepcopy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(args):
ontology = json.load(open(os.path.join(args.data_root, args.ontology_data)))
slot_meta, _ = make_slot_meta(ontology)
tokenizer = BertTokenizer(args.vocab_path, do_lower_case=True)
data = prepare_dataset(os.path.join(args.data_root, args.test_data),
tokenizer,
slot_meta, args.n_history, args.max_seq_length, args.op_code)
data = prepare_dataset(
os.path.join(args.data_root, args.test_data),
tokenizer,
slot_meta,
args.n_history,
args.max_seq_length,
args.op_code,
)

model_config = BertConfig.from_json_file(args.bert_config_path)
model_config.dropout = 0.1
op2id = OP_SET[args.op_code]
model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'])
ckpt = torch.load(args.model_ckpt_path, map_location='cpu')
model = SomDST(model_config, len(op2id), len(domain2id), op2id["update"])
ckpt = torch.load(args.model_ckpt_path, map_location="cpu")
model.load_state_dict(ckpt)

model.eval()
model.to(device)

if args.eval_all:
model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
False, False, False)
model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
False, False, True)
model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
False, True, False)
model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
False, True, True)
model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
True, False, False)
model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
True, True, False)
model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
True, False, True)
model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
True, True, True)
model_evaluation(
model, data, tokenizer, slot_meta, 0, args.op_code, False, False, False
)
model_evaluation(
model, data, tokenizer, slot_meta, 0, args.op_code, False, False, True
)
model_evaluation(
model, data, tokenizer, slot_meta, 0, args.op_code, False, True, False
)
model_evaluation(
model, data, tokenizer, slot_meta, 0, args.op_code, False, True, True
)
model_evaluation(
model, data, tokenizer, slot_meta, 0, args.op_code, True, False, False
)
model_evaluation(
model, data, tokenizer, slot_meta, 0, args.op_code, True, True, False
)
model_evaluation(
model, data, tokenizer, slot_meta, 0, args.op_code, True, False, True
)
model_evaluation(
model, data, tokenizer, slot_meta, 0, args.op_code, True, True, True
)
else:
model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
args.gt_op, args.gt_p_state, args.gt_gen)


def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
is_gt_op=False, is_gt_p_state=False, is_gt_gen=False):
model_evaluation(
model,
data,
tokenizer,
slot_meta,
0,
args.op_code,
args.gt_op,
args.gt_p_state,
args.gt_gen,
)


def model_evaluation(
model,
test_data,
tokenizer,
slot_meta,
epoch,
op_code="4",
is_gt_op=False,
is_gt_p_state=False,
is_gt_gen=False,
):
model.eval()
op2id = OP_SET[op_code]
id2op = {v: k for k, v in op2id.items()}
Expand All @@ -91,32 +130,35 @@ def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',

if is_gt_p_state is False:
i.last_dialog_state = deepcopy(last_dialog_state)
i.make_instance(tokenizer, word_dropout=0.)
i.make_instance(tokenizer, word_dropout=0.0)
else: # ground-truth previous dialogue state
last_dialog_state = deepcopy(i.gold_p_state)
i.last_dialog_state = deepcopy(last_dialog_state)
i.make_instance(tokenizer, word_dropout=0.)
i.make_instance(tokenizer, word_dropout=0.0)

input_ids = torch.LongTensor([i.input_id]).to(device)
input_mask = torch.FloatTensor([i.input_mask]).to(device)
segment_ids = torch.LongTensor([i.segment_id]).to(device)
state_position_ids = torch.LongTensor([i.slot_position]).to(device)

d_gold_op, _, _ = make_turn_label(slot_meta, last_dialog_state, i.gold_state,
tokenizer, op_code, dynamic=True)
d_gold_op, _, _ = make_turn_label(
slot_meta, last_dialog_state, i.gold_state, tokenizer, op_code, dynamic=True
)
gold_op_ids = torch.LongTensor([d_gold_op]).to(device)

start = time.perf_counter()
MAX_LENGTH = 9
with torch.no_grad():
# ground-truth state operation
gold_op_inputs = gold_op_ids if is_gt_op else None
d, s, g = model(input_ids=input_ids,
token_type_ids=segment_ids,
state_positions=state_position_ids,
attention_mask=input_mask,
max_value=MAX_LENGTH,
op_ids=gold_op_inputs)
d, s, g = model(
input_ids=input_ids,
token_type_ids=segment_ids,
state_positions=state_position_ids,
attention_mask=input_mask,
max_value=MAX_LENGTH,
op_ids=gold_op_inputs,
)

_, op_ids = s.view(-1, len(op2id)).max(-1)

Expand All @@ -133,20 +175,29 @@ def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',

if is_gt_gen:
# ground_truth generation
gold_gen = {'-'.join(ii.split('-')[:2]): ii.split('-')[-1] for ii in i.gold_state}
gold_gen = {
"-".join(ii.split("-")[:2]): ii.split("-")[-1] for ii in i.gold_state
}
else:
gold_gen = {}
generated, last_dialog_state = postprocessing(slot_meta, pred_ops, last_dialog_state,
generated, tokenizer, op_code, gold_gen)
generated, last_dialog_state = postprocessing(
slot_meta,
pred_ops,
last_dialog_state,
generated,
tokenizer,
op_code,
gold_gen,
)
end = time.perf_counter()
wall_times.append(end - start)
pred_state = []
for k, v in last_dialog_state.items():
pred_state.append('-'.join([k, v]))
pred_state.append("-".join([k, v]))

if set(pred_state) == set(i.gold_state):
joint_acc += 1
key = str(i.id) + '_' + str(i.turn_id)
key = str(i.id) + "_" + str(i.turn_id)
results[key] = [pred_state, i.gold_state]

# Compute prediction slot accuracy
Expand All @@ -159,7 +210,9 @@ def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
slot_F1_count += count

# Compute operation accuracy
temp_acc = sum([1 if p == g else 0 for p, g in zip(pred_ops, gold_ops)]) / len(pred_ops)
temp_acc = sum([1 if p == g else 0 for p, g in zip(pred_ops, gold_ops)]) / len(
pred_ops
)
op_acc += temp_acc

if i.is_last_turn:
Expand Down Expand Up @@ -191,14 +244,20 @@ def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
tp = tp_dic[k]
fn = fn_dic[k]
fp = fp_dic[k]
precision = tp / (tp+fp) if (tp+fp) != 0 else 0
recall = tp / (tp+fn) if (tp+fn) != 0 else 0
F1 = 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0
precision = tp / (tp + fp) if (tp + fp) != 0 else 0
recall = tp / (tp + fn) if (tp + fn) != 0 else 0
F1 = (
2 * precision * recall / float(precision + recall)
if (precision + recall) != 0
else 0
)
op_F1_score[k] = F1

print("------------------------------")
print('op_code: %s, is_gt_op: %s, is_gt_p_state: %s, is_gt_gen: %s' % \
(op_code, str(is_gt_op), str(is_gt_p_state), str(is_gt_gen)))
print(
"op_code: %s, is_gt_op: %s, is_gt_p_state: %s, is_gt_gen: %s"
% (op_code, str(is_gt_op), str(is_gt_p_state), str(is_gt_gen))
)
print("Epoch %d joint accuracy : " % epoch, joint_acc_score)
print("Epoch %d slot turn accuracy : " % epoch, turn_acc_score)
print("Epoch %d slot turn F1: " % epoch, slot_F1_score)
Expand All @@ -210,31 +269,39 @@ def model_evaluation(model, test_data, tokenizer, slot_meta, epoch, op_code='4',
print("Final slot turn F1 : ", final_slot_F1_score)
print("Latency Per Prediction : %f ms" % latency)
print("-----------------------------\n")
json.dump(results, open('preds_%d.json' % epoch, 'w'))
json.dump(results, open("preds_%d.json" % epoch, "w"))
per_domain_join_accuracy(results, slot_meta)

scores = {'epoch': epoch, 'joint_acc': joint_acc_score,
'slot_acc': turn_acc_score, 'slot_f1': slot_F1_score,
'op_acc': op_acc_score, 'op_f1': op_F1_score, 'final_slot_f1': final_slot_F1_score}
scores = {
"epoch": epoch,
"joint_acc": joint_acc_score,
"slot_acc": turn_acc_score,
"slot_f1": slot_F1_score,
"op_acc": op_acc_score,
"op_f1": op_F1_score,
"final_slot_f1": final_slot_F1_score,
}
return scores


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", default='data/mwz2.1', type=str)
parser.add_argument("--test_data", default='test_dials.json', type=str)
parser.add_argument("--ontology_data", default='ontology.json', type=str)
parser.add_argument("--vocab_path", default='assets/vocab.txt', type=str)
parser.add_argument("--bert_config_path", default='assets/bert_config_base_uncased.json', type=str)
parser.add_argument("--model_ckpt_path", default='outputs/model_best.bin', type=str)
parser.add_argument("--data_root", default="data/mwz2.1", type=str)
parser.add_argument("--test_data", default="test_dials.json", type=str)
parser.add_argument("--ontology_data", default="ontology.json", type=str)
parser.add_argument("--vocab_path", default="assets/vocab.txt", type=str)
parser.add_argument(
"--bert_config_path", default="assets/bert_config_base_uncased.json", type=str
)
parser.add_argument("--model_ckpt_path", default="outputs/model_best.bin", type=str)
parser.add_argument("--n_history", default=1, type=int)
parser.add_argument("--max_seq_length", default=256, type=int)
parser.add_argument("--op_code", default="4", type=str)

parser.add_argument("--gt_op", default=False, action='store_true')
parser.add_argument("--gt_p_state", default=False, action='store_true')
parser.add_argument("--gt_gen", default=False, action='store_true')
parser.add_argument("--eval_all", default=False, action='store_true')
parser.add_argument("--gt_op", default=False, action="store_true")
parser.add_argument("--gt_p_state", default=False, action="store_true")
parser.add_argument("--gt_gen", default=False, action="store_true")
parser.add_argument("--eval_all", default=False, action="store_true")

args = parser.parse_args()
main(args)
Loading