diff --git a/meta_trainer.py b/meta_trainer.py index 07e8761..dc215fc 100644 --- a/meta_trainer.py +++ b/meta_trainer.py @@ -158,7 +158,7 @@ def init_args(): # concatenate individual datasets into a single dataset combined_train_dataset = ConcatDataset(train_datasets) - combined_dev_dataset = ConcatDataset(dev_datasets) + combined_dev_dataset = ConcatDataset(train_datasets) # convert to metadataset which is suitable for sampling tasks in an episode train_dataset = l2l.data.MetaDataset(combined_train_dataset) diff --git a/simple_trainer.py b/simple_trainer.py index a4339c3..0546bd2 100644 --- a/simple_trainer.py +++ b/simple_trainer.py @@ -18,13 +18,12 @@ from typing import Dict, NamedTuple, Optional from seqeval.metrics import f1_score, precision_score, recall_score from data_utils import PosDataset, Split, get_data_config, read_examples_from_file -from simple_tagger import PosTagger, get_model_config +from simple_tagger import BERT, Classifier, get_model_config logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO ) logger = logging.getLogger(__name__) -wandb.init(project="nlp-meta-learning") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -55,12 +54,15 @@ def get_optimizers(model, num_warmup_steps, num_training_steps, lr=5e-5): return optimizer, scheduler -def _training_step(model, inputs, optimizer): +def _training_step(bert, model, inputs, optimizer): model.train() for k, v in inputs.items(): inputs[k] = v.to(DEVICE) with autocast(): - outputs = model(**inputs) + # outputs = model(**inputs) + with torch.no_grad(): + bert_output = bert(inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"]) + outputs = model(bert_output, labels=inputs["labels"], attention_mask=inputs["attention_mask"]) loss = outputs[0] grad_scaler.scale(loss).backward() return loss.item() @@ -87,7 +89,7 @@ def compute_metrics(p, label_map): } -def _prediction_loop(model, dataloader, description): +def _prediction_loop(bert, model, dataloader, label_map, description): batch_size = dataloader.batch_size logger.info("***** Running %s *****", description) logger.info(" Num examples = %d", len(dataloader.dataset)) @@ -103,7 +105,9 @@ def _prediction_loop(model, dataloader, description): inputs[k] = v.to(DEVICE) with torch.no_grad(): - outputs = model(**inputs) + # outputs = model(**inputs) + bert_output = bert(inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"]) + outputs = model(bert_output, labels=inputs["labels"], attention_mask=inputs["attention_mask"]) if has_labels: step_eval_loss, logits = outputs[:2] eval_losses += [step_eval_loss.mean().item()] @@ -141,8 +145,8 @@ def _prediction_loop(model, dataloader, description): return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) -def evaluate(loader): - output = _prediction_loop(model, loader, description="Evaluation") +def evaluate(loader, label_map, bert, model): + output = _prediction_loop(bert, model, loader, label_map, description="Evaluation") return output.metrics @@ -185,6 +189,7 @@ def init_args(): data_config = get_data_config() model_config = get_model_config() + wandb.init(project="nlp-meta-learning") wandb.config.update(model_config) wandb.config.update(vars(args)) @@ -260,7 +265,12 @@ def init_args(): ) label_map = {i: label for i, label in enumerate(labels)} - model = PosTagger(model_config["model_type"], len(labels), model_config["hidden_dropout_prob"]) + # model = PosTagger(model_config["model_type"], len(labels), model_config["hidden_dropout_prob"]) + bert = BERT(model_config["model_type"]) + bert.eval() + bert = bert.to(DEVICE) + model = Classifier(len(labels), model_config["hidden_dropout_prob"], bert.get_hidden_size()) + wandb.watch(model) model = model.to(DEVICE) @@ -275,7 +285,7 @@ def init_args(): running_loss = 0.0 epoch_iterator = tqdm(train_loader, desc="Training") for training_step, inputs in enumerate(epoch_iterator): - step_loss = _training_step(model, inputs, optimizer) + step_loss = _training_step(bert, model, inputs, optimizer) running_loss += step_loss torch.nn.utils.clip_grad_norm_(model.parameters(), model_config["max_grad_norm"]) grad_scaler.step(optimizer) @@ -285,9 +295,9 @@ def init_args(): logger.info(f"Finished epoch {epoch+1} with avg. training loss: {running_loss/len(inputs)}") wandb.log({"loss/running_loss": running_loss / len(inputs)}) - # train_metrics = evaluate(train_loader) + # train_metrics = evaluate(train_loader, label_map, bert, model) # write_logs(train_metrics, epoch, "train") - dev_metrics = evaluate(dev_loader) + dev_metrics = evaluate(dev_loader, label_map, bert, model) write_logs(dev_metrics, epoch, "validation") logger.info("Validation f1: {}".format(dev_metrics["eval_f1"])) @@ -306,8 +316,8 @@ def init_args(): model.load_state_dict(torch.load(os.path.join(wandb.run.dir, "best_model.th"))) model = model.to(DEVICE) - train_metrics = evaluate(train_loader) - dev_metrics = evaluate(dev_loader) + train_metrics = evaluate(train_loader, label_map, bert, model) + dev_metrics = evaluate(dev_loader, label_map, bert, model) test_metrics = {} for idx, test_data in enumerate(test_datasets): logging.info("Testing on {}...".format(args.datasets[idx])) @@ -317,7 +327,7 @@ def init_args(): sampler=SequentialSampler(test_data), collate_fn=DefaultDataCollator().collate_batch, ) - test_metrics[args.datasets[idx]] = evaluate(test_loader) + test_metrics[args.datasets[idx]] = evaluate(test_loader, label_map, bert, model) # dump results to file and stdout final_result = { @@ -326,7 +336,7 @@ def init_args(): "test": test_metrics, # "num_epochs": epoch, } - wandb.run.summary = final_result + wandb.run.summary["final_results"] = final_result final_result = json.dumps(final_result, indent=2) with open(os.path.join(wandb.run.dir, "result.json"), "w") as f: f.write(final_result) diff --git a/test.py b/test.py index a25a341..447c551 100644 --- a/test.py +++ b/test.py @@ -3,7 +3,8 @@ import argparse import torch import logging -import wandb + +# import wandb import statistics as stat import torch.nn as nn import numpy as np @@ -19,9 +20,12 @@ from typing import Dict, NamedTuple, Optional from seqeval.metrics import f1_score, precision_score, recall_score -from meta_data_utils import PosDataset, Split, get_data_config, read_examples_from_file +from data_utils import PosDataset as RegularPosDataset +from meta_data_utils import PosDataset as MetaPosDataset +from meta_data_utils import Split, get_data_config, read_examples_from_file from simple_tagger import BERT, Classifier, get_model_config from simple_trainer import compute_metrics, EvalPrediction +from simple_trainer import evaluate as regular_evaluate import learn2learn as l2l @@ -29,7 +33,7 @@ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO ) logger = logging.getLogger(__name__) -wandb.init(project="nlp-meta-learning") +# wandb.init(project="nlp-meta-learning") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -101,7 +105,36 @@ def compute_loss(task, bert_model, learner, batch_size): return loss / (iters + 1), metrics -def evaluate(data_gen, meta_model, bert_model, task_bs, inner_loop_steps, inner_lr, num_episodes): +def meta_evaluate(datsets, bert_model, postagger): + # concatenate individual datasets into a single dataset + combined_dataset = ConcatDataset(datasets) + + # convert to metadataset which is suitable for sampling tasks in an episode + meta_dataset = l2l.data.MetaDataset(combined_dataset) + + # shots = number of examples per task, ways = number of classes per task + shots = model_config["shots"] + ways = model_config["ways"] + + # create task generators + data_gen = l2l.data.TaskDataset( + meta_dataset, + num_tasks=model_config["num_tasks"], + task_transforms=[ + l2l.data.transforms.FusedNWaysKShots(meta_dataset, n=ways, k=shots), + l2l.data.transforms.LoadData(meta_dataset), + ], + ) + num_episodes = model_config["num_episodes"] + task_bs = model_config["task_batch_size"] + inner_loop_steps = model_config["inner_loop_steps"] + inner_lr = model_config["inner_lr"] + + if model_config["is_fomaml"]: + meta_model = l2l.algorithms.MAML(postagger, lr=inner_lr, first_order=True) + else: + meta_model = l2l.algorithms.MAML(postagger, lr=inner_lr) + task_support_error, task_query_error = 0.0, [] tqdm_bar = tqdm(range(num_episodes)) all_metrics = {} @@ -122,7 +155,7 @@ def evaluate(data_gen, meta_model, bert_model, task_bs, inner_loop_steps, inner_ query_error, metrics = compute_loss(query_task, bert_model, learner, batch_size=task_bs) task_query_error.append(query_error) tqdm_bar.set_description("Query Loss: {:.3f}".format(query_error.item())) - wandb.log({"support_loss": support_error / inner_loop_steps, "query_loss": query_error}) + # wandb.log({"support_loss": support_error / inner_loop_steps, "query_loss": query_error}) for class_index in metrics.keys(): if class_index in all_metrics: all_metrics[class_index]["p"].append(metrics[class_index]["precision"]) @@ -134,7 +167,7 @@ def evaluate(data_gen, meta_model, bert_model, task_bs, inner_loop_steps, inner_ "r": [metrics[class_index]["recall"]], "f": [metrics[class_index]["f1"]], } - wandb.log({"{}_{}".format(dataset_names[class_index], k): v for k, v in metrics[class_index].items()}) + # wandb.log({"{}_{}".format(dataset_names[class_index], k): v for k, v in metrics[class_index].items()}) summary_metrics = {} for class_index in all_metrics.keys(): @@ -144,22 +177,23 @@ def evaluate(data_gen, meta_model, bert_model, task_bs, inner_loop_steps, inner_ "r_stdev": stat.stdev(all_metrics[class_index]["r"]), "r": stat.mean(all_metrics[class_index]["r"]), "f_stdev": stat.stdev(all_metrics[class_index]["f"]), - "f": stat.mean(all_metrics[class_index]["f"]) + "f": stat.mean(all_metrics[class_index]["f"]), } summary_metrics["loss"] = torch.tensor(task_query_error).mean().item() - wandb.run.summary["summary_metrics"] = summary_metrics + # wandb.run.summary["summary_metrics"] = summary_metrics return summary_metrics def init_args(): parser = argparse.ArgumentParser(description="Test POS tagging on various UD datasets") - parser.add_argument("datasets", metavar="datasets", type=str, nargs="+", help="Datasets to meta-test on") - parser.add_argument("-s", "--split", help="Type of data to evaluate on", default="test") + parser.add_argument("datasets", metavar="datasets", type=str, nargs="+", help="Datasets to test on") + parser.add_argument("-s", "--split", help="Data split to evaluate on", default="test") + parser.add_argument( + "-e", "--eval_type", help="Type of evaluation (meta/regular)", choices=["meta", "regular", "both"], default="both" + ) # pylint: disable=unused-argument req_named_params = parser.add_argument_group("required named arguments") - req_named_params.add_argument( - "-m", "--model_path", help="WandB runpath of the model to load (//)", required=True - ) + req_named_params.add_argument("-m", "--model_path", help="path of the model to load", required=True) return parser.parse_args() @@ -189,8 +223,6 @@ def init_args(): data_split = Split.dev labels = set() - datasets = [] - # build label set for all datasets for dataset_path in dataset_paths: _, l = read_examples_from_file(dataset_path, Split.train, model_config["max_seq_length"]) @@ -198,61 +230,53 @@ def init_args(): labels = sorted(list(labels)) label_map = {i: label for i, label in enumerate(labels)} - # load individual datasets - for class_index, dataset_path in enumerate(dataset_paths): - dataset = PosDataset( - class_index, - dataset_path, - labels, - tokenizer, - model_config["model_type"], - model_config["max_seq_length"], - mode=data_split, - ) - datasets.append(dataset) - - # concatenate individual datasets into a single dataset - combined_dataset = ConcatDataset(datasets) - - # convert to metadataset which is suitable for sampling tasks in an episode - meta_dataset = l2l.data.MetaDataset(combined_dataset) - - # shots = number of examples per task, ways = number of classes per task - shots = model_config["shots"] - ways = model_config["ways"] - - # create task generators - data_gen = l2l.data.TaskDataset( - meta_dataset, - num_tasks=model_config["num_tasks"], - task_transforms=[ - l2l.data.transforms.FusedNWaysKShots(meta_dataset, n=ways, k=shots), - l2l.data.transforms.LoadData(meta_dataset), - ], - ) - - # define the bert and postagger model - bert_model = BERT(model_config["model_type"]) - bert_model.eval() - bert_model = bert_model.to(DEVICE) - - postagger = Classifier(len(labels), model_config["hidden_dropout_prob"], bert_model.get_hidden_size()) - - load_path = os.path.join(args.model_path, "best_model.th") - logging.info("Loading model from path: f{load_path}") - postagger.load_state_dict(torch.load(load_path)) - postagger.to(DEVICE) - - num_episodes = model_config["num_episodes"] - task_bs = model_config["task_batch_size"] - inner_loop_steps = model_config["inner_loop_steps"] - inner_lr = model_config["inner_lr"] - - if model_config["is_fomaml"]: - meta_model = l2l.algorithms.MAML(postagger, lr=inner_lr, first_order=True) + if args.eval_type == "both": + args.eval_type = ["regular", "meta"] else: - meta_model = l2l.algorithms.MAML(postagger, lr=inner_lr) - - summary_metrics = evaluate(data_gen, meta_model, bert_model, task_bs, inner_loop_steps, inner_lr, num_episodes) - logger.info(json.dumps(summary_metrics, indent=2)) + args.eval_type = [args.eval_type] + + for eval_type in args.eval_type: + datasets = [] + # load individual datasets + for class_index, dataset_path in enumerate(dataset_paths): + dataset_args = [ + class_index, + dataset_path, + labels, + tokenizer, + model_config["model_type"], + model_config["max_seq_length"], + data_split, + ] + if eval_type == "regular": + dataset = RegularPosDataset(*dataset_args[1:]) + else: + dataset = MetaPosDataset(*dataset_args) + datasets.append(dataset) + + # define the bert and postagger model + bert_model = BERT(model_config["model_type"]) + bert_model.eval() + bert_model = bert_model.to(DEVICE) + postagger = Classifier(len(labels), model_config["hidden_dropout_prob"], bert_model.get_hidden_size()) + + load_path = os.path.join(args.model_path, "best_model.th") + logging.info("Loading model from path: {}".format(load_path)) + postagger.load_state_dict(torch.load(load_path)) + postagger.to(DEVICE) + + logging.info("Running {} evaluation".format(eval_type)) + if eval_type == "regular": + summary_metrics = {} + for idx, data in enumerate(datasets): + loader = DataLoader( + data, + batch_size=model_config["batch_size"], + sampler=SequentialSampler(data), + collate_fn=DefaultDataCollator().collate_batch, + ) + summary_metrics[args.datasets[idx]] = regular_evaluate(loader, label_map, bert_model, postagger) + else: + summary_metrics = meta_evaluate(datasets, bert_model, postagger) + logger.info(json.dumps(summary_metrics, indent=2))