Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a Deep Nearest Class Means Classifier model to Flair #3532

Merged
merged 8 commits into from
Jan 7, 2025

Conversation

sheldon-roberts
Copy link
Contributor

This PR adds a DeepNCMClassifier to flair.models
My reasons for adding this model are outlined in the issue: #3531

This model requires a TrainerPlugin because it makes the prototype updates using an after_training_batch hook. Please let me know if there is a cleaner way to handle this.

Example Script:

from flair.data import Corpus
from flair.datasets import TREC_50
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import DeepNCMClassifier
from flair.trainers import ModelTrainer
from flair.trainers.plugins import DeepNCMPlugin

# load the TREC dataset
corpus: Corpus = TREC_50()

# make a transformer document embedding
document_embeddings = TransformerDocumentEmbeddings("roberta-base", fine_tune=True)

# create the classifier
classifier = DeepNCMClassifier(
    document_embeddings,
    label_dictionary=corpus.make_label_dictionary(label_type="class"),
    label_type="class",
    use_encoder=False,
    mean_update_method="condensation",
)

# initialize the trainer
trainer = ModelTrainer(classifier, corpus)

# train the model
trainer.fine_tune(
    "resources/taggers/deepncm_trec",
    plugins=[DeepNCMPlugin()],
)

@plonerma
Copy link
Collaborator

plonerma commented Aug 19, 2024

Hello @sheldon-roberts,

Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.

I also don't see a way of how this could be implemented without a TrainerPlugin.

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).

Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

@sheldon-roberts
Copy link
Contributor Author

Hi @plonerma, Thanks for taking a look!

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).
Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

I really like both of these ideas! I will look into making these changes soon

@MattGPT-ai
Copy link
Contributor

Hello @sheldon-roberts,

Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.

I also don't see a way of how this could be implemented without a TrainerPlugin.

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).

Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

In order to avoid using a trainer plugin, could we just add a function like def after_training_epoch(): pass that gets added to the base Model class, which gets called right before or after self.dispatch("after_training_epoch", epoch=epoch) in the train_custom function?

I think this would work with this being a class, but might not work when it gets changed to a decoder.

@MattGPT-ai
Copy link
Contributor

I am currently working on converting this class to a simpler decoder. I have gotten it to work, but it requires some changes to other classes; the label tensors have to be provided to the forward passes so they can go into the decoder call. Specifically, in DefaultClassifier.forward_loss, you need to have scores = self.decoder(data_point_tensor, label_tensor). In predict, this isn't necessary because you don't need to calculate the proto updates.

Would it make sense to always pass in this in, but just have most base cases ignore the parameter? Another alternative would be to have the class set self.label_tensor before the call so it doesn't need to be an input param at all. Not sure if anyone else has a suggestion of how to design this. I will be pushing up the specific code soon, but just looking for opinions.

@MattGPT-ai MattGPT-ai force-pushed the deepncm-classifier branch 3 times, most recently from c92f501 to b19e700 Compare November 24, 2024 00:12
@MattGPT-ai
Copy link
Contributor

This has

Hello @sheldon-roberts,

Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.

I also don't see a way of how this could be implemented without a TrainerPlugin.

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).

Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

This has been updated to be a decoder. It's overall a lot less code and simpler, although it required some small changes to the DefaultClassifier class, and still requires a plugin. Am definitely open to any suggestion of how to better integrate this.

@MattGPT-ai MattGPT-ai force-pushed the deepncm-classifier branch 2 times, most recently from 088aac0 to f94a56f Compare November 24, 2024 01:59
@MattGPT-ai
Copy link
Contributor

Looks like tests are passing except for a couple of MyPy checks that aren't directly related to the changes in the PR, I think just files that this PR touches. Do you have any suggestions for fixing these typing problems?

@MattGPT-ai
Copy link
Contributor

Would it be better to move this class into flair/nn/decoder.py now that it is a decoder?

@MattGPT-ai
Copy link
Contributor

@plonerma Are you able to re-review this before the next release?

sheldon-roberts and others added 3 commits December 18, 2024 15:24
Add tests for DeepNCMClassifier

Remove old test

Add multi label support

Add type hints and doc strings
…ifferent model types. make small changes to DefaultClassifier forward_loss to pass label tensor when needed. update tests
@MattGPT-ai
Copy link
Contributor

I've moved this to decoder.py to be more consistent with other decoders

@plonerma
Copy link
Collaborator

plonerma commented Jan 3, 2025

@sheldon-roberts and @MattGPT-ai : Thanks a lot for your collaborative effort!

I made a few minor changes and merged the current master branch into the PR. Now, all checks are passed.

@MattGPT-ai
Copy link
Contributor

This looks good now, can we merge?

@plonerma plonerma merged commit 0864b22 into flairNLP:master Jan 7, 2025
1 check passed
@alanakbik
Copy link
Collaborator

Thanks a lot for adding this @MattGPT-ai and for reviewing @plonerma!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants