Skip to content

Commit

Permalink
Merge pull request #3289 from flairNLP/3284-bug-support-transformers-…
Browse files Browse the repository at this point in the history
…4310

3284 bug support transformers 4310
  • Loading branch information
alanakbik authored Aug 7, 2023
2 parents 419f13a + 3695fe7 commit f54e6d6
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 3,386 deletions.
2 changes: 1 addition & 1 deletion flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def save(self, savefile):
def __setstate__(self, d):
self.__dict__ = d
# set 'add_unk' if the dictionary was created with a version of Flair older than 0.9
if "add_unk" not in self.__dict__.keys():
if "add_unk" not in self.__dict__:
self.__dict__["add_unk"] = b"<unk>" in self.__dict__["idx2item"]

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion flair/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def embedding_type(self) -> str:
def _everything_embedded(self, data_points: Sequence[Sentence]) -> bool:
for sentence in data_points:
for token in sentence.tokens:
if self.name not in token._embeddings.keys():
if self.name not in token._embeddings:
return False
return True

Expand Down
2 changes: 1 addition & 1 deletion flair/embeddings/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class PrecomputedImageEmbeddings(ImageEmbeddings):
def __init__(self, url2tensor_dict, name) -> None:
self.url2tensor_dict = url2tensor_dict
self.name = name
self.__embedding_length = len(list(self.url2tensor_dict.values())[0])
self.__embedding_length = len(next(iter(self.url2tensor_dict.values())))
self.static_embeddings = True
super().__init__()

Expand Down
2 changes: 1 addition & 1 deletion flair/embeddings/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,7 +1330,7 @@ def embed(self, sentences: Union[List[Sentence], Sentence]):
sentences = [sentences]

for sentence in sentences:
if self.name not in sentence._embeddings.keys():
if self.name not in sentence._embeddings:
everything_embedded = False

if not everything_embedded:
Expand Down
14 changes: 14 additions & 0 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast

import torch
import transformers
from semver import Version
from torch.jit import ScriptModule
from transformers import (
CONFIG_MAPPING,
Expand Down Expand Up @@ -1105,6 +1107,16 @@ def embedding_length(self) -> int:

return self.embedding_length_internal

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if transformers.__version__ >= Version(4, 31, 0):
assert isinstance(state_dict, dict)
state_dict.pop(f"{prefix}model.embeddings.position_ids", None)
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def _has_initial_cls_token(self) -> bool:
# most models have CLS token as last token (GPT-1, GPT-2, TransfoXL, XLNet, XLM), but BERT is initial
if self.tokenizer_needs_ocr_boxes:
Expand Down Expand Up @@ -1191,6 +1203,8 @@ def __setstate__(self, state):
self.__dict__[key] = embedding.__dict__[key]

if model_state_dict:
if transformers.__version__ >= Version(4, 31, 0):
model_state_dict.pop("embeddings.position_ids", None)
self.model.load_state_dict(model_state_dict)

@classmethod
Expand Down
7 changes: 1 addition & 6 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,7 @@ def micro_avg_f_score(self):
return self.mean_squared_error()

def to_tsv(self):
return "{}\t{}\t{}\t{}".format(
self.mean_squared_error(),
self.mean_absolute_error(),
self.pearsonr(),
self.spearmanr(),
)
return f"{self.mean_squared_error()}\t{self.mean_absolute_error()}\t{self.pearsonr()}\t{self.spearmanr()}"

@staticmethod
def tsv_header(prefix=None):
Expand Down
3,374 changes: 0 additions & 3,374 deletions poetry.lock

This file was deleted.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ ignore = [
"D105",
"D107",
"E501", # Ignore line too long
"RUF012",
]

unfixable = [
Expand All @@ -96,6 +97,8 @@ unfixable = [

[tool.ruff.per-file-ignores]
"flair/embeddings/legacy.py" = ["D205"]
"scripts/*" = ["INP001"] # no need for __ini__ for scripts
"flair/datasets/*" = ["D417"] # need to fix datasets in a unified way later.

[tool.ruff.pydocstyle]
convention = "google"
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ tqdm>=4.63.0
transformer-smaller-training-vocab>=0.2.3
transformers[sentencepiece]>=4.18.0,<5.0.0
urllib3<2.0.0,>=1.0.0 # pin below 2 to make dependency resolution faster.
wikipedia-api>=0.5.7
wikipedia-api>=0.5.7
semver<4.0.0,>=3.0.0
2 changes: 1 addition & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def test_masakhane_corpus(tasks_base_path):
"bam": {"train": 4462, "dev": 638, "test": 1274},
"bbj": {"train": 3384, "dev": 483, "test": 966},
"ewe": {"train": 3505, "dev": 501, "test": 1001},
"fon": {"train": 4343, "dev": 621, "test": 1240},
"fon": {"train": 4343, "dev": 623, "test": 1228},
"hau": {"train": 5716, "dev": 816, "test": 1633},
"ibo": {"train": 7634, "dev": 1090, "test": 2181},
"kin": {"train": 7825, "dev": 1118, "test": 2235},
Expand Down

0 comments on commit f54e6d6

Please sign in to comment.