-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathflair-ner-trainer-ft.py
116 lines (94 loc) · 3.28 KB
/
flair-ner-trainer-ft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import click
import json
import sys
import flair
import torch
from typing import List
from flair.datasets import ColumnCorpus
from flair.embeddings import (
TokenEmbeddings,
StackedEmbeddings,
TransformerWordEmbeddings
)
from flair import set_seed
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
def run_experiment(seed, batch_size, epoch, learning_rate, json_config):
# Config values
# Replace it with more Pythonic solutions later!
hf_model = json_config["hf_model"]
context_size = json_config["context_size"]
layers = json_config["layers"] if "layers" in json_config else "-1"
use_crf = json_config["use_crf"] if "use_crf" in json_config else False
task_name = json_config["task_name"]
# Dataset-related
data_folder = json_config["data_folder"]
train_file = json_config["train_file"]
dev_file = json_config["dev_file"]
test_file = json_config["test_file"]
# Set seed for reproducibility
set_seed(seed)
if context_size == 0:
context_size = False
print("FLERT Context:", context_size)
print("Layers:", layers)
print("Use CRF:", use_crf)
# Configuration
column_format = {0: "text", 1: "ner"}
# Corpus
corpus = ColumnCorpus(data_folder=data_folder,
column_format=column_format,
train_file=train_file,
dev_file=dev_file,
test_file=test_file,
tag_to_bioes="ner",
)
# Corpus configuration
tag_type = "ner"
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
print(tag_dictionary.idx2item)
# Embeddings
embeddings = TransformerWordEmbeddings(
model=hf_model,
layers=layers,
subtoken_pooling="first",
fine_tune=True,
use_context=context_size,
)
tagger: SequenceTagger = SequenceTagger(
hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type=tag_type,
use_crf=use_crf,
use_rnn=False,
reproject_embeddings=False,
)
# Trainer
trainer: ModelTrainer = ModelTrainer(tagger, corpus)
trainer.fine_tune(
f"histo-flert-fine-tuning-{task_name}-{hf_model}-bs{batch_size}-ws{context_size}-e{epoch}-lr{learning_rate}-layers{layers}-crf{use_crf}-{seed}",
learning_rate=learning_rate,
mini_batch_size=batch_size,
max_epochs=epoch,
shuffle=True,
embeddings_storage_mode='none',
weight_decay=0.,
use_final_model_for_eval=False,
)
if __name__ == "__main__":
# Read JSON configuration
filename = sys.argv[1]
with open(filename, "rt") as f_p:
json_config = json.load(f_p)
seeds = json_config["seeds"]
batch_sizes = json_config["batch_sizes"]
epochs = json_config["epochs"]
learning_rates = json_config["learning_rates"]
cuda = json_config["cuda"]
flair.device = f'cuda:{cuda}'
for seed in seeds:
for batch_size in batch_sizes:
for epoch in epochs:
for learning_rate in learning_rates:
run_experiment(seed, batch_size, epoch, learning_rate, json_config) # pylint: disable=no-value-for-parameter