Skip to content

Commit

Permalink
Initial commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
fkdosilovic committed Oct 19, 2024
1 parent 2993108 commit caea8bb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
27 changes: 11 additions & 16 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
from abc import abstractmethod
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast
from typing import Any, cast, Dict, List, Literal, Optional, Tuple, Type, Union

import torch
import transformers
from packaging.version import Version
from torch.jit import ScriptModule
from transformers import (
CONFIG_MAPPING,
AutoConfig,
AutoFeatureExtractor,
AutoModel,
AutoTokenizer,
CONFIG_MAPPING,
FeatureExtractionMixin,
LayoutLMTokenizer,
LayoutLMTokenizerFast,
Expand All @@ -32,13 +32,8 @@
from transformers.utils import PaddingStrategy

import flair
from flair.data import Sentence, Token, log
from flair.embeddings.base import (
DocumentEmbeddings,
Embeddings,
TokenEmbeddings,
register_embeddings,
)
from flair.data import log, Sentence, Token
from flair.embeddings.base import DocumentEmbeddings, Embeddings, register_embeddings, TokenEmbeddings

SENTENCE_BOUNDARY_TAG: str = "[FLERT]"

Expand Down Expand Up @@ -198,24 +193,28 @@ def fill_mean_token_embeddings(


@torch.jit.script_if_tracing
def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor):
def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
result = torch.zeros(
sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype
)

for i in torch.arange(sentence_hidden_states.shape[0]):
result[i] = sentence_hidden_states[i, : sentence_lengths[i]].mean(dim=0)

return result


@torch.jit.script_if_tracing
def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor):
def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
result = torch.zeros(
sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype
)

for i in torch.arange(sentence_hidden_states.shape[0]):
result[i], _ = sentence_hidden_states[i, : sentence_lengths[i]].max(dim=0)

return result


def _legacy_reconstruct_word_ids(
embedding: "TransformerBaseEmbeddings", flair_tokens: List[List[str]]
Expand Down Expand Up @@ -1127,11 +1126,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool:
if peft_config is not None:
# add adapters for finetuning
try:
from peft import (
TaskType,
get_peft_model,
prepare_model_for_kbit_training,
)
from peft import get_peft_model, prepare_model_for_kbit_training, TaskType
except ImportError:
log.error("You cannot use the PEFT finetuning without peft being installed")
raise
Expand Down
19 changes: 19 additions & 0 deletions tests/embeddings/test_transformer_document_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pytest

from flair.data import Dictionary
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.nn import Classifier

from tests.embedding_test_utils import BaseEmbeddingsTest


Expand Down Expand Up @@ -37,3 +40,19 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path):
# check that context_length and use_context_separator is the same for both
assert model.embeddings.context_length == loaded_single_task.embeddings.context_length
assert model.embeddings.use_context_separator == loaded_single_task.embeddings.use_context_separator


@pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"])
def test_cls_pooling(cls_pooling):
from flair.data import Sentence
from flair.embeddings import TransformerDocumentEmbeddings

embeddings = TransformerDocumentEmbeddings(
model="xlm-roberta-base",
layers="-1",
cls_pooling=cls_pooling,
allow_long_sentences=True,
)
sentence = Sentence("Today is a good day.")
embeddings.embed(sentence)
assert sentence.embedding is not None

0 comments on commit caea8bb

Please sign in to comment.