Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul committed Oct 2, 2020
1 parent 934a6ce commit 19e7149
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 88 deletions.
2 changes: 1 addition & 1 deletion meta_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def init_args():

# concatenate individual datasets into a single dataset
combined_train_dataset = ConcatDataset(train_datasets)
combined_dev_dataset = ConcatDataset(dev_datasets)
combined_dev_dataset = ConcatDataset(train_datasets)

# convert to metadataset which is suitable for sampling tasks in an episode
train_dataset = l2l.data.MetaDataset(combined_train_dataset)
Expand Down
42 changes: 26 additions & 16 deletions simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
from typing import Dict, NamedTuple, Optional
from seqeval.metrics import f1_score, precision_score, recall_score
from data_utils import PosDataset, Split, get_data_config, read_examples_from_file
from simple_tagger import PosTagger, get_model_config
from simple_tagger import BERT, Classifier, get_model_config

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
logger = logging.getLogger(__name__)
wandb.init(project="nlp-meta-learning")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Expand Down Expand Up @@ -55,12 +54,15 @@ def get_optimizers(model, num_warmup_steps, num_training_steps, lr=5e-5):
return optimizer, scheduler


def _training_step(model, inputs, optimizer):
def _training_step(bert, model, inputs, optimizer):
model.train()
for k, v in inputs.items():
inputs[k] = v.to(DEVICE)
with autocast():
outputs = model(**inputs)
# outputs = model(**inputs)
with torch.no_grad():
bert_output = bert(inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"])
outputs = model(bert_output, labels=inputs["labels"], attention_mask=inputs["attention_mask"])
loss = outputs[0]
grad_scaler.scale(loss).backward()
return loss.item()
Expand All @@ -87,7 +89,7 @@ def compute_metrics(p, label_map):
}


def _prediction_loop(model, dataloader, description):
def _prediction_loop(bert, model, dataloader, label_map, description):
batch_size = dataloader.batch_size
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", len(dataloader.dataset))
Expand All @@ -103,7 +105,9 @@ def _prediction_loop(model, dataloader, description):
inputs[k] = v.to(DEVICE)

with torch.no_grad():
outputs = model(**inputs)
# outputs = model(**inputs)
bert_output = bert(inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"])
outputs = model(bert_output, labels=inputs["labels"], attention_mask=inputs["attention_mask"])
if has_labels:
step_eval_loss, logits = outputs[:2]
eval_losses += [step_eval_loss.mean().item()]
Expand Down Expand Up @@ -141,8 +145,8 @@ def _prediction_loop(model, dataloader, description):
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)


def evaluate(loader):
output = _prediction_loop(model, loader, description="Evaluation")
def evaluate(loader, label_map, bert, model):
output = _prediction_loop(bert, model, loader, label_map, description="Evaluation")
return output.metrics


Expand Down Expand Up @@ -185,6 +189,7 @@ def init_args():
data_config = get_data_config()
model_config = get_model_config()

wandb.init(project="nlp-meta-learning")
wandb.config.update(model_config)
wandb.config.update(vars(args))

Expand Down Expand Up @@ -260,7 +265,12 @@ def init_args():
)

label_map = {i: label for i, label in enumerate(labels)}
model = PosTagger(model_config["model_type"], len(labels), model_config["hidden_dropout_prob"])
# model = PosTagger(model_config["model_type"], len(labels), model_config["hidden_dropout_prob"])
bert = BERT(model_config["model_type"])
bert.eval()
bert = bert.to(DEVICE)
model = Classifier(len(labels), model_config["hidden_dropout_prob"], bert.get_hidden_size())

wandb.watch(model)
model = model.to(DEVICE)

Expand All @@ -275,7 +285,7 @@ def init_args():
running_loss = 0.0
epoch_iterator = tqdm(train_loader, desc="Training")
for training_step, inputs in enumerate(epoch_iterator):
step_loss = _training_step(model, inputs, optimizer)
step_loss = _training_step(bert, model, inputs, optimizer)
running_loss += step_loss
torch.nn.utils.clip_grad_norm_(model.parameters(), model_config["max_grad_norm"])
grad_scaler.step(optimizer)
Expand All @@ -285,9 +295,9 @@ def init_args():
logger.info(f"Finished epoch {epoch+1} with avg. training loss: {running_loss/len(inputs)}")
wandb.log({"loss/running_loss": running_loss / len(inputs)})

# train_metrics = evaluate(train_loader)
# train_metrics = evaluate(train_loader, label_map, bert, model)
# write_logs(train_metrics, epoch, "train")
dev_metrics = evaluate(dev_loader)
dev_metrics = evaluate(dev_loader, label_map, bert, model)
write_logs(dev_metrics, epoch, "validation")

logger.info("Validation f1: {}".format(dev_metrics["eval_f1"]))
Expand All @@ -306,8 +316,8 @@ def init_args():

model.load_state_dict(torch.load(os.path.join(wandb.run.dir, "best_model.th")))
model = model.to(DEVICE)
train_metrics = evaluate(train_loader)
dev_metrics = evaluate(dev_loader)
train_metrics = evaluate(train_loader, label_map, bert, model)
dev_metrics = evaluate(dev_loader, label_map, bert, model)
test_metrics = {}
for idx, test_data in enumerate(test_datasets):
logging.info("Testing on {}...".format(args.datasets[idx]))
Expand All @@ -317,7 +327,7 @@ def init_args():
sampler=SequentialSampler(test_data),
collate_fn=DefaultDataCollator().collate_batch,
)
test_metrics[args.datasets[idx]] = evaluate(test_loader)
test_metrics[args.datasets[idx]] = evaluate(test_loader, label_map, bert, model)

# dump results to file and stdout
final_result = {
Expand All @@ -326,7 +336,7 @@ def init_args():
"test": test_metrics,
# "num_epochs": epoch,
}
wandb.run.summary = final_result
wandb.run.summary["final_results"] = final_result
final_result = json.dumps(final_result, indent=2)
with open(os.path.join(wandb.run.dir, "result.json"), "w") as f:
f.write(final_result)
Expand Down
166 changes: 95 additions & 71 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import argparse
import torch
import logging
import wandb

# import wandb
import statistics as stat
import torch.nn as nn
import numpy as np
Expand All @@ -19,17 +20,20 @@

from typing import Dict, NamedTuple, Optional
from seqeval.metrics import f1_score, precision_score, recall_score
from meta_data_utils import PosDataset, Split, get_data_config, read_examples_from_file
from data_utils import PosDataset as RegularPosDataset
from meta_data_utils import PosDataset as MetaPosDataset
from meta_data_utils import Split, get_data_config, read_examples_from_file
from simple_tagger import BERT, Classifier, get_model_config
from simple_trainer import compute_metrics, EvalPrediction
from simple_trainer import evaluate as regular_evaluate

import learn2learn as l2l

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
logger = logging.getLogger(__name__)
wandb.init(project="nlp-meta-learning")
# wandb.init(project="nlp-meta-learning")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Expand Down Expand Up @@ -101,7 +105,36 @@ def compute_loss(task, bert_model, learner, batch_size):
return loss / (iters + 1), metrics


def evaluate(data_gen, meta_model, bert_model, task_bs, inner_loop_steps, inner_lr, num_episodes):
def meta_evaluate(datsets, bert_model, postagger):
# concatenate individual datasets into a single dataset
combined_dataset = ConcatDataset(datasets)

# convert to metadataset which is suitable for sampling tasks in an episode
meta_dataset = l2l.data.MetaDataset(combined_dataset)

# shots = number of examples per task, ways = number of classes per task
shots = model_config["shots"]
ways = model_config["ways"]

# create task generators
data_gen = l2l.data.TaskDataset(
meta_dataset,
num_tasks=model_config["num_tasks"],
task_transforms=[
l2l.data.transforms.FusedNWaysKShots(meta_dataset, n=ways, k=shots),
l2l.data.transforms.LoadData(meta_dataset),
],
)
num_episodes = model_config["num_episodes"]
task_bs = model_config["task_batch_size"]
inner_loop_steps = model_config["inner_loop_steps"]
inner_lr = model_config["inner_lr"]

if model_config["is_fomaml"]:
meta_model = l2l.algorithms.MAML(postagger, lr=inner_lr, first_order=True)
else:
meta_model = l2l.algorithms.MAML(postagger, lr=inner_lr)

task_support_error, task_query_error = 0.0, []
tqdm_bar = tqdm(range(num_episodes))
all_metrics = {}
Expand All @@ -122,7 +155,7 @@ def evaluate(data_gen, meta_model, bert_model, task_bs, inner_loop_steps, inner_
query_error, metrics = compute_loss(query_task, bert_model, learner, batch_size=task_bs)
task_query_error.append(query_error)
tqdm_bar.set_description("Query Loss: {:.3f}".format(query_error.item()))
wandb.log({"support_loss": support_error / inner_loop_steps, "query_loss": query_error})
# wandb.log({"support_loss": support_error / inner_loop_steps, "query_loss": query_error})
for class_index in metrics.keys():
if class_index in all_metrics:
all_metrics[class_index]["p"].append(metrics[class_index]["precision"])
Expand All @@ -134,7 +167,7 @@ def evaluate(data_gen, meta_model, bert_model, task_bs, inner_loop_steps, inner_
"r": [metrics[class_index]["recall"]],
"f": [metrics[class_index]["f1"]],
}
wandb.log({"{}_{}".format(dataset_names[class_index], k): v for k, v in metrics[class_index].items()})
# wandb.log({"{}_{}".format(dataset_names[class_index], k): v for k, v in metrics[class_index].items()})

summary_metrics = {}
for class_index in all_metrics.keys():
Expand All @@ -144,22 +177,23 @@ def evaluate(data_gen, meta_model, bert_model, task_bs, inner_loop_steps, inner_
"r_stdev": stat.stdev(all_metrics[class_index]["r"]),
"r": stat.mean(all_metrics[class_index]["r"]),
"f_stdev": stat.stdev(all_metrics[class_index]["f"]),
"f": stat.mean(all_metrics[class_index]["f"])
"f": stat.mean(all_metrics[class_index]["f"]),
}
summary_metrics["loss"] = torch.tensor(task_query_error).mean().item()
wandb.run.summary["summary_metrics"] = summary_metrics
# wandb.run.summary["summary_metrics"] = summary_metrics
return summary_metrics


def init_args():
parser = argparse.ArgumentParser(description="Test POS tagging on various UD datasets")
parser.add_argument("datasets", metavar="datasets", type=str, nargs="+", help="Datasets to meta-test on")
parser.add_argument("-s", "--split", help="Type of data to evaluate on", default="test")
parser.add_argument("datasets", metavar="datasets", type=str, nargs="+", help="Datasets to test on")
parser.add_argument("-s", "--split", help="Data split to evaluate on", default="test")
parser.add_argument(
"-e", "--eval_type", help="Type of evaluation (meta/regular)", choices=["meta", "regular", "both"], default="both"
)
# pylint: disable=unused-argument
req_named_params = parser.add_argument_group("required named arguments")
req_named_params.add_argument(
"-m", "--model_path", help="WandB runpath of the model to load (<user>/<project>/<run>)", required=True
)
req_named_params.add_argument("-m", "--model_path", help="path of the model to load", required=True)
return parser.parse_args()


Expand Down Expand Up @@ -189,70 +223,60 @@ def init_args():
data_split = Split.dev

labels = set()
datasets = []

# build label set for all datasets
for dataset_path in dataset_paths:
_, l = read_examples_from_file(dataset_path, Split.train, model_config["max_seq_length"])
labels.update(l)
labels = sorted(list(labels))
label_map = {i: label for i, label in enumerate(labels)}

# load individual datasets
for class_index, dataset_path in enumerate(dataset_paths):
dataset = PosDataset(
class_index,
dataset_path,
labels,
tokenizer,
model_config["model_type"],
model_config["max_seq_length"],
mode=data_split,
)
datasets.append(dataset)

# concatenate individual datasets into a single dataset
combined_dataset = ConcatDataset(datasets)

# convert to metadataset which is suitable for sampling tasks in an episode
meta_dataset = l2l.data.MetaDataset(combined_dataset)

# shots = number of examples per task, ways = number of classes per task
shots = model_config["shots"]
ways = model_config["ways"]

# create task generators
data_gen = l2l.data.TaskDataset(
meta_dataset,
num_tasks=model_config["num_tasks"],
task_transforms=[
l2l.data.transforms.FusedNWaysKShots(meta_dataset, n=ways, k=shots),
l2l.data.transforms.LoadData(meta_dataset),
],
)

# define the bert and postagger model
bert_model = BERT(model_config["model_type"])
bert_model.eval()
bert_model = bert_model.to(DEVICE)

postagger = Classifier(len(labels), model_config["hidden_dropout_prob"], bert_model.get_hidden_size())

load_path = os.path.join(args.model_path, "best_model.th")
logging.info("Loading model from path: f{load_path}")
postagger.load_state_dict(torch.load(load_path))
postagger.to(DEVICE)

num_episodes = model_config["num_episodes"]
task_bs = model_config["task_batch_size"]
inner_loop_steps = model_config["inner_loop_steps"]
inner_lr = model_config["inner_lr"]

if model_config["is_fomaml"]:
meta_model = l2l.algorithms.MAML(postagger, lr=inner_lr, first_order=True)
if args.eval_type == "both":
args.eval_type = ["regular", "meta"]
else:
meta_model = l2l.algorithms.MAML(postagger, lr=inner_lr)

summary_metrics = evaluate(data_gen, meta_model, bert_model, task_bs, inner_loop_steps, inner_lr, num_episodes)
logger.info(json.dumps(summary_metrics, indent=2))
args.eval_type = [args.eval_type]

for eval_type in args.eval_type:
datasets = []
# load individual datasets
for class_index, dataset_path in enumerate(dataset_paths):
dataset_args = [
class_index,
dataset_path,
labels,
tokenizer,
model_config["model_type"],
model_config["max_seq_length"],
data_split,
]
if eval_type == "regular":
dataset = RegularPosDataset(*dataset_args[1:])
else:
dataset = MetaPosDataset(*dataset_args)
datasets.append(dataset)

# define the bert and postagger model
bert_model = BERT(model_config["model_type"])
bert_model.eval()
bert_model = bert_model.to(DEVICE)
postagger = Classifier(len(labels), model_config["hidden_dropout_prob"], bert_model.get_hidden_size())

load_path = os.path.join(args.model_path, "best_model.th")
logging.info("Loading model from path: {}".format(load_path))
postagger.load_state_dict(torch.load(load_path))
postagger.to(DEVICE)

logging.info("Running {} evaluation".format(eval_type))
if eval_type == "regular":
summary_metrics = {}
for idx, data in enumerate(datasets):
loader = DataLoader(
data,
batch_size=model_config["batch_size"],
sampler=SequentialSampler(data),
collate_fn=DefaultDataCollator().collate_batch,
)
summary_metrics[args.datasets[idx]] = regular_evaluate(loader, label_map, bert_model, postagger)
else:
summary_metrics = meta_evaluate(datasets, bert_model, postagger)
logger.info(json.dumps(summary_metrics, indent=2))

0 comments on commit 19e7149

Please sign in to comment.