Skip to content

Commit

Permalink
Merge pull request #3066 from flairNLP/smaller-training-vocab
Browse files Browse the repository at this point in the history
integrate transformer-smaller-training-vocab
  • Loading branch information
alanakbik authored Mar 3, 2023
2 parents bef9641 + fb55e61 commit cb6b0e5
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 101 deletions.
20 changes: 11 additions & 9 deletions flair/models/pairwise_classification_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import typing
from typing import List

import torch

import flair.embeddings
import flair.nn
from flair.data import Sentence, TextPair
from flair.data import Corpus, Sentence, TextPair, _iter_dataset


class TextPairClassifier(flair.nn.DefaultClassifier[TextPair, TextPair]):
Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(
**classifierargs,
embeddings=embeddings,
final_embedding_size=2 * embeddings.embedding_length if embed_separately else embeddings.embedding_length,
should_embed_sentence=False,
)

self._label_type = label_type
Expand All @@ -47,11 +49,11 @@ def __init__(
# set separator to concatenate two sentences
self.sep = " "
if isinstance(
self.document_embeddings,
self.embeddings,
flair.embeddings.document.TransformerDocumentEmbeddings,
):
if self.document_embeddings.tokenizer.sep_token:
self.sep = " " + str(self.document_embeddings.tokenizer.sep_token) + " "
if self.embeddings.tokenizer.sep_token:
self.sep = " " + str(self.embeddings.tokenizer.sep_token) + " "
else:
self.sep = " [SEP] "

Expand Down Expand Up @@ -92,9 +94,6 @@ def _get_state_dict(self):
"document_embeddings": self.embeddings.save_embeddings(use_state_dict=False),
"label_dictionary": self.label_dictionary,
"label_type": self.label_type,
"multi_label": self.multi_label,
"multi_label_threshold": self.multi_label_threshold,
"weight_dict": self.weight_dict,
"embed_separately": self.embed_separately,
}
return model_state
Expand All @@ -106,8 +105,11 @@ def _init_model_with_state_dict(cls, state, **kwargs):
embeddings=state.get("document_embeddings"),
label_dictionary=state.get("label_dictionary"),
label_type=state.get("label_type"),
multi_label=state.get("multi_label_threshold", 0.5),
loss_weights=state.get("weight_dict"),
embed_separately=state.get("embed_separately"),
**kwargs,
)

def get_used_tokens(self, corpus: Corpus) -> typing.Iterable[List[str]]:
for sentence_pair in _iter_dataset(corpus.get_all_sentences()):
yield [t.text for t in sentence_pair.first]
yield [t.text for t in sentence_pair.second]
19 changes: 18 additions & 1 deletion flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import logging
import typing
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Expand All @@ -20,7 +21,16 @@
from torch.utils.data.dataset import Dataset

import flair
from flair.data import Corpus, Dictionary, Label, Relation, Sentence, Span, Token
from flair.data import (
Corpus,
Dictionary,
Label,
Relation,
Sentence,
Span,
Token,
_iter_dataset,
)
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.embeddings import DocumentEmbeddings, TransformerDocumentEmbeddings
from flair.tokenization import SpaceTokenizer
Expand Down Expand Up @@ -707,6 +717,13 @@ def zero_tag_value(self) -> str:
def allow_unk_tag(self) -> bool:
return self._allow_unk_tag

def get_used_tokens(self, corpus: Corpus) -> typing.Iterable[List[str]]:
yield from super().get_used_tokens(corpus)
for sentence in _iter_dataset(corpus.get_all_sentences()):
for span in sentence.get_spans(self.label_type):
yield self.encoding_strategy.encode_head(span, span.get_label(self.label_type)).split(" ")
yield self.encoding_strategy.encode_tail(span, span.get_label(self.label_type)).split(" ")

@classmethod
def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "RelationClassifier":
from typing import cast
Expand Down
10 changes: 9 additions & 1 deletion flair/models/tars_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import typing
from abc import ABC
from collections import OrderedDict
from pathlib import Path
Expand All @@ -11,7 +12,7 @@
from tqdm import tqdm

import flair
from flair.data import Dictionary, Sentence, Span
from flair.data import Corpus, Dictionary, Sentence, Span
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.embeddings import (
TokenEmbeddings,
Expand All @@ -32,6 +33,7 @@ def __init__(self):
self._task_specific_attributes = {}
self.label_nearest_map = None
self.tars_model: flair.nn.Classifier[Sentence]
self.separator: str

super(FewshotClassifier, self).__init__()

Expand Down Expand Up @@ -308,6 +310,12 @@ def predict_zero_shot(

return

def get_used_tokens(self, corpus: Corpus) -> typing.Iterable[List[str]]:
yield from super().get_used_tokens(corpus)
for label in self.get_current_label_dictionary().idx2item:
yield [label.decode("utf-8")]
yield [self.separator]

@classmethod
def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "FewshotClassifier":
from typing import cast
Expand Down
11 changes: 8 additions & 3 deletions flair/models/text_regression_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import typing
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -9,22 +10,22 @@

import flair
import flair.embeddings
from flair.data import Dictionary, Sentence
from flair.data import Corpus, Dictionary, Sentence, _iter_dataset
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.embeddings.base import load_embeddings
from flair.nn.model import ReduceTransformerVocabMixin
from flair.training_utils import MetricRegression, Result, store_embeddings

log = logging.getLogger("flair")


class TextRegressor(flair.nn.Model[Sentence]):
class TextRegressor(flair.nn.Model[Sentence], ReduceTransformerVocabMixin):
def __init__(
self,
document_embeddings: flair.embeddings.DocumentEmbeddings,
label_name: str = "label",
):
super().__init__()
log.info("Using REGRESSION - experimental")

self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings
self.label_name = label_name
Expand Down Expand Up @@ -234,3 +235,7 @@ def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "TextRegressor":
from typing import cast

return cast("TextRegressor", super().load(model_path=model_path))

def get_used_tokens(self, corpus: Corpus) -> typing.Iterable[List[str]]:
for sentence in _iter_dataset(corpus.get_all_sentences()):
yield [t.text for t in sentence]
21 changes: 17 additions & 4 deletions flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tqdm import tqdm

import flair
from flair.data import DT, DT2, Dictionary, Sentence
from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.embeddings import Embeddings
from flair.embeddings.base import load_embeddings
Expand Down Expand Up @@ -234,7 +234,13 @@ def print_model_card(self):
)


class Classifier(Model[DT], typing.Generic[DT], ABC):
class ReduceTransformerVocabMixin(ABC):
@abstractmethod
def get_used_tokens(self, corpus: Corpus) -> typing.Iterable[List[str]]:
pass


class Classifier(Model[DT], typing.Generic[DT], ReduceTransformerVocabMixin, ABC):
"""Abstract base class for all Flair models that do classification,
both single- and multi-label. It inherits from flair.nn.Model and adds an
unified evaluate() function so that all classification models use the same
Expand Down Expand Up @@ -535,13 +541,17 @@ def _print_predictions(self, batch, gold_label_type):
correct_string = " -> MISMATCH!\n" if g != p else ""
# print info
eval_line = (
f"{datapoint.to_original_text()}\n"
f"{datapoint.text}\n"
f" - Gold: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels(gold_label_type))}\n"
f" - Pred: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels('predicted'))}\n{correct_string}\n"
)
lines.append(eval_line)
return lines

def get_used_tokens(self, corpus: Corpus) -> typing.Iterable[List[str]]:
for sentence in _iter_dataset(corpus.get_all_sentences()):
yield [t.text for t in sentence]

@classmethod
def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Classifier":
from typing import cast
Expand Down Expand Up @@ -572,6 +582,7 @@ def __init__(
decoder: Optional[torch.nn.Module] = None,
inverse_model: bool = False,
train_on_gold_pairs_only: bool = False,
should_embed_sentence: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -600,6 +611,7 @@ def __init__(
self.dropout: torch.nn.Dropout = torch.nn.Dropout(dropout)
self.locked_dropout = flair.nn.LockedDropout(locked_dropout)
self.word_dropout = flair.nn.WordDropout(word_dropout)
self.should_embed_sentence = should_embed_sentence

# loss weights and loss function
self.weight_dict = loss_weights
Expand Down Expand Up @@ -693,7 +705,8 @@ def _prepare_label_tensor(self, prediction_data_points: List[DT2]) -> torch.Tens

def _encode_data_points(self, sentences: List[DT], data_points: List[DT2]):
# embed sentences
self.embeddings.embed(sentences)
if self.should_embed_sentence:
self.embeddings.embed(sentences)

# get a tensor of data points
data_point_tensor = torch.stack([self._get_embedding_for_data_point(data_point) for data_point in data_points])
Expand Down
Loading

0 comments on commit cb6b0e5

Please sign in to comment.