When using EmbeddingBasedQueryStrategy with some transformers, model has an unsupported input token_type_ids
when creating embeddings.
#54
Labels
bug
Something isn't working
Bug description
Requires
query_strategy
to be a subclass ofEmbeddingBasedQueryStrategy
, such asEmbeddingKMeans
;Requires
transformer_model
to be a model that does not expecttoken_type_ids
in its forward function, such asdistilbert-base-uncased
Steps to reproduce
When performing active learning, the model has an unsupported input
token_type_ids
when creating embeddings.Expected behavior
The keys of model input are adjusted according to the specific models.
Cause:
In file
small_text/integrations/transformers/classifiers/classification.py
, function_create_embeddings
:the following code:
need to be changed to
removing the
token_type_ids
field if the seed model does not expecttoken_type_ids
in its forward function.Environment:
Python version: 3.11.7
small-text version: 1.3.3
small-text integrations (e.g., transformers): transformers 4.36.2
PyTorch version: 2.1.2
PyTorch-cuda: 11.8
The text was updated successfully, but these errors were encountered: