Skip to content

Release 0.10

Compare
Choose a tag to compare
@alanakbik alanakbik released this 18 Nov 08:33
· 2232 commits to master since this release
45700cb

This release adds several new features such as in-built "model cards" for all Flair models, the first pre-trained models for Relation Extraction, better support for fine-tuning and a refactoring of the model training methods for more flexibility. It also fixes a number of critical bugs that were introduced by the refactorings in Flair 0.9.

Model Trainer Enhancements

Breaking change: We changed the ModelTrainer such that you now no longer pass the optimizer during initialization. Rather, it is now passed as a parameter of the train or fine_tune method.

Old syntax:

# 1. initialize trainer with AdamW optimizer
trainer = ModelTrainer(classifier, corpus, optimizer=torch.optim.AdamW)

# 2. run training with small learning rate and mini-batch size
trainer.train('resources/taggers/question-classification-with-transformer',
              learning_rate=5.0e-5,
              mini_batch_size=4,
             )

New syntax (optimizer is parameter of train method):

# 1. initialize trainer 
trainer = ModelTrainer(classifier, corpus)

# 2. run training with AdamW, small learning rate and mini-batch size
trainer.train('resources/taggers/question-classification-with-transformer',
              learning_rate=5.0e-5,
              mini_batch_size=4,
              optimizer=torch.optim.AdamW,
             )

Convenience function for fine-tuning (#2439)

Adds a fine_tune routine that sets default parameters used for fine-tuning (AdamW optimizer, small learning rate, few epochs, cyclic learning rate scheduling, etc.). Uses the new linear scheduler with warmup (#2415).

New syntax with fine_tune method:

from flair.data import Corpus
from flair.datasets import TREC_6
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

# 1. get the corpus
corpus: Corpus = TREC_6()

# 2. what label do we want to predict?
label_type = 'question_class'

# 3. create the label dictionary
label_dict = corpus.make_label_dictionary(label_type=label_type)

# 4. initialize transformer document embeddings (many models are available)
document_embeddings = TransformerDocumentEmbeddings('distilbert-base-uncased', fine_tune=True)

# 5. create the text classifier
classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, label_type=label_type)

# 6. initialize trainer
trainer = ModelTrainer(classifier, corpus)

# 7. run training with fine-tuning
trainer.fine_tune('resources/taggers/question-classification-with-transformer',
                  learning_rate=5.0e-5,
                  mini_batch_size=4,
                  )

Model Cards (#2457)

When you train any Flair model, a "model card" will now automatically be saved that stores all training parameters and versions used to train this model. Later when you load a Flair model, you can print the model card and understand how the model was trained.

The following example trains a small POS-tagger and prints the model card in the end:

# initialize corpus and make label dictionary for POS tags
corpus = UD_ENGLISH().downsample(0.01)
tag_type = "pos"
tag_dictionary = corpus.make_label_dictionary(tag_type)

# simple sequence tagger
tagger = SequenceTagger(hidden_size=256,
                        embeddings=WordEmbeddings("glove"),
                        tag_dictionary=tag_dictionary,
                        tag_type=tag_type)

# initialize model trainer and experiment path
trainer = ModelTrainer(tagger, corpus)
path = f'resources/taggers/model-card'

# train for a few epochs
trainer.train(path,
              max_epochs=20,
              )

# load best model and print "model card"
trained_model = SequenceTagger.load(path + '/best-model.pt')
trained_model.print_model_card()

This should print a model card like:

------------------------------------
--------- Flair Model Card ---------
------------------------------------
- this Flair model was trained with:
-- Flair version 0.9
-- PyTorch version 1.7.1
-- Transformers version 4.8.1
------------------------------------
------- Training Parameters: -------
------------------------------------
-- base_path = resources/taggers/model-card
-- learning_rate = 0.1
-- mini_batch_size = 32
-- mini_batch_chunk_size = None
-- max_epochs = 20
-- train_with_dev = False
-- train_with_test = False
[... shortened ...]
------------------------------------

Resume training any model (#2457)

Previously, we distinguished between checkpoints and model files. Now all models can function as checkpoints, meaning you can load them and continue training them. Say you want to load the model above (trained to epoch 20) and continue training it to epoch 25. Do it like this:

# resume training best model, but this time until epoch 25
trainer.resume(trained_model,
               base_path=path + '-resume',
               max_epochs=25,
               )

Pass optimizer and scheduler instance

You can also now pass an initialized optimizer and scheduler to the train and fine_tune methods.

Multi-Label Predictions and Confidence Threshold in TARS models (#2430)

Adding the possibility to set confidence thresholds on multi-label prediction in TARS, and setting whether a problem is single-label or multi-label:

from flair.models import TARSClassifier
from flair.data import Sentence

# 1. Load our pre-trained TARS model for English
tars: TARSClassifier = TARSClassifier.load('tars-base')

# switch to a multi-label task (emotion detection)
tars.switch_to_task('GO_EMOTIONS')

# sentence with two emotions
sentence = Sentence("I am happy and sad")

# predict normally
tars.predict(sentence)
print(sentence)

# predict with lower label threshold (you can set this to 0. to get all labels)
tars.predict(sentence, label_threshold=0.01)
print(sentence)

# predict and enforce a single-label prediction
tars.predict(sentence, label_threshold=0.01, multi_label=False)
print(sentence)

Relation Extraction ( #2471 #2492)

We refactored the RelationExtractor for more options, hopefully better code clarity and small speed improvements.

We also added two few relation extraction models, trained over a modified version of TACRED: relations and relations-fast. To use these models, you also need an entity tagger. The tagger identifies entities, then the relation extractor possible entities.

For instance use this code:

from flair.data import Sentence
from flair.models import RelationExtractor, SequenceTagger

# 1. make example sentence
sentence = Sentence("George was born in Washington")

# 2. load entity tagger and predict entities
tagger = SequenceTagger.load('ner-fast')
tagger.predict(sentence)

# check which entities have been found in the sentence
entities = sentence.get_labels('ner')
for entity in entities:
    print(entity)

# 3. load relation extractor
extractor: RelationExtractor = RelationExtractor.load('relations-fast')

# predict relations
extractor.predict(sentence)

# check which relations have been found
relations = sentence.get_labels('relation')
for relation in relations:
    print(relation)

Embeddings

  • Refactoring of WordEmbeddings to avoid gensim version issues and enable further fine-tuning of pre-trained embeddings (#2491)
  • Refactoring of OneHotEmbeddings to fix errors caused by some corpora and enable "stable embeddings" (#2490 )

Other Enhancements and Bug Fixes

  • Compatibility with gensim 4 and Python 3.9 (#2496)
  • Fix TransformerWordEmbeddings if model_max_length not set in Tokenizer (#2502)
  • Fix TransformerWordEmbeddings handling of lang ids (#2417)
  • Fix attention mask for special Transformer architectures (#2485)
  • Fix regression model (#2424)
  • Fix problems caused by refactoring of Dictionary (#2429 #2435 #2453)
  • Fix infinite loop in Span::to_original_text (#2462)
  • Fix result object in ModelTrainer (#2519)
  • Fix bug in wsd_ufsac corpus (#2521)
  • Fix bugs in TARS and simple sequence tagger (#2468)
  • Add Amharic FLAIR EMBEDDING model (#2494)
  • Add MultiCoNer Dataset (#2507)
  • Add Korean Flair Tutorials (#2516 #2517)
  • Remove hyperparameter features (#2518)
  • Make it optional to create logfiles and loss files (#2421)
  • Small simplification of TransformerWordEmbeddings (#2425)