You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Flair has a decoder (PrototypicalDecoder) inspired by the paper Prototypical Networks for Few-shot Learning, but there is a notable difference in how the prototypes are being calculated.
The original paper states that we should "take a class’s prototype to be the mean of its support set in the embedding space". The PrototypicalDecoder, however, treats class prototypes as learnable model parameters that simply get updated during back propagation. This approach has some drawbacks:
It compromises important theoretical properties, such as the equivalence to mixture density estimation on the support set.
Randomly initializing prototypes fails to leverage the knowledge captured in pre-trained embeddings, potentially slowing down convergence and reducing performance.
Poor performance on few-shot classification and incremental learning tasks (compared to the class means approach)
TL;DR
It would be nice to have a model in Flair that uses prototypes such that each class is represented by the mean of its examples.
Solution
I want to add a model to Flair that uses a class-mean update rule like a Prototypical Network, but forgoes the episodic training in order to remain compatible with the existing model trainers. There happens to be a paper called DEEP NEAREST CLASS MEAN CLASSIFIERS which defines update rules that do exactly that. I have experimented with this approach, and produced models that significantly outperform the PrototypicalDecoder at certain few-shot training tasks.
This model has been very useful for me, and I'm sure others would benefit too if it was supported by Flair.
Additional Context
No response
The text was updated successfully, but these errors were encountered:
Problem statement
Flair has a decoder (
PrototypicalDecoder
) inspired by the paper Prototypical Networks for Few-shot Learning, but there is a notable difference in how the prototypes are being calculated.The original paper states that we should "take a class’s prototype to be the mean of its support set in the embedding space". The
PrototypicalDecoder
, however, treats class prototypes as learnable model parameters that simply get updated during back propagation. This approach has some drawbacks:TL;DR
It would be nice to have a model in Flair that uses prototypes such that each class is represented by the mean of its examples.
Solution
I want to add a model to Flair that uses a class-mean update rule like a Prototypical Network, but forgoes the episodic training in order to remain compatible with the existing model trainers. There happens to be a paper called DEEP NEAREST CLASS MEAN CLASSIFIERS which defines update rules that do exactly that. I have experimented with this approach, and produced models that significantly outperform the
PrototypicalDecoder
at certain few-shot training tasks.This model has been very useful for me, and I'm sure others would benefit too if it was supported by Flair.
Additional Context
No response
The text was updated successfully, but these errors were encountered: