Releases: KevinMusgrave/pytorch-metric-learning
v0.9.94
Various bug fixes and improvements
- A list or dictionary of miners can be passed into MultipleLosses. #212
- Fixed bug where MultipleLosses failed in list mode. #213
- Fixed bug where IntraPairVarianceLoss and MarginLoss were overriding
sub_loss_names
instead of_sub_loss_names
. This likely caused embedding regularizers to have no effect for these two losses. #215 - ModuleWithRecordsAndReducer now creates copies of the input reducer when necessary. #216
- Moved
cos.clone()
insidetorch.no_grad()
in RegularFaceRegularizer. Should be more efficient? #219 - In utils.inference, moved faiss import inside of FaissIndexer since that is the only class that requires it. #222
- Added a
copy_weights
init argument to LogitGetter, to make copying optional #223
v0.9.93
Small update
- Optimized
get_random_triplet_indices
, so if you were using DistanceWeightedMiner, or if you ever set thetriplets_per_anchor
argument to something other than"all"
anywhere in your code, it should run a lot faster now. Thanks @AlexSchuy
v0.9.92
New Features
DistributedLossWrapper and DistributedMinerWrapper
Added DistributedLossWrapper and DistributedMinerWrapper. Wrap a loss or miner with these when using PyTorch's DistributedDataParallel (i.e. multiprocessing). Most of the code is by @JohnGiorgi (https://github.com/JohnGiorgi/DeCLUTR).
from pytorch_metric_learning import losses, miners
from pytorch_metric_learning.utils import distributed as pml_dist
loss_func = pml_dist.DistributedLossWrapper(loss = losses.ContrastiveLoss())
miner = pml_dist.DistributedMinerWrapper(miner = miners.MultiSimilarityMiner())
For a working example, see the "Multiprocessing with DistributedDataParallel" notebook.
Added enqueue_idx
to CrossBatchMemory
Now you can make CrossBatchMemory work with MoCo. This adds a great deal of flexibility to the MoCo framework, because you can use any tuple loss and tuple miner in CrossBatchMemory.
Previously this wasn't possible because all embeddings passed into CrossBatchMemory would go into the memory queue. In contrast, MoCo only queues the momentum encoder's embeddings.
The new enqueue_idx
argument lets you do this, by specifying which embeddings should be added to memory. Here's a modified snippet from the MoCo on CIFAR10 notebook:
from pytorch_metric_learning.losses import CrossBatchMemory, NTXentLoss
loss_fn = CrossBatchMemory(loss = NTXentLoss(), embedding_size = 64, memory_size = 16384)
### snippet from the training loop ###
for images, _ in train_loader:
...
previous_max_label = torch.max(loss_fn.label_memory)
num_pos_pairs = encQ_out.size(0)
labels = torch.arange(0, num_pos_pairs)
labels = torch.cat((labels , labels)).to(device)
### add an offset so that the labels do not overlap with any labels in the memory queue ###
labels += previous_max_label + 1
### we want to enqueue the output of encK, which is the 2nd half of the batch ###
enqueue_idx = torch.arange(num_pos_pairs, num_pos_pairs*2)
all_enc = torch.cat([encQ_out, encK_out], dim=0)
### now only encK_out will be added to the memory queue ###
loss = loss_fn(all_enc, labels, enqueue_idx = enqueue_idx)
...
Check out the MoCo on CIFAR10 notebook to see the entire script.
TuplesToWeightsSampler
This is a simple offline miner. It does the following:
- Take a random subset of your dataset, if you provide
subset_size
- Use a specified miner to mine tuples from the subset dataset.
- Compute weights based on how often an element appears in the mined tuples.
- Randomly sample, using the weights as probabilities.
from pytorch_metric_learning.samplers import TuplesToWeightsSampler
from pytorch_metric_learning.miners import MultiSimilarityMiner
miner = MultiSimilarityMiner(epsilon=-0.2)
sampler = TuplesToWeightsSampler(model, miner, dataset, subset_size = 5000)
# then pass the sampler into your Dataloader
LogitGetter
Added utils.inference.LogitGetter to make it easier to compute logits of classifier loss functions.
from pytorch_metric_learning.losses import ArcFaceLoss
from pytorch_metric_learning.utils.inference import LogitGetter
loss_fn = ArcFaceLoss(num_classes = 100, embedding_size = 512)
LG = LogitGetter(loss_fn)
logits = LG(embeddings)
Other
-
Added optional
batch_size
argument to MPerClassSampler. If you pass in this argument, then each batch is guaranteed to havem
samples per class. Otherwise, most batches will havem
samples per class, but it's not guaranteed for every batch. Note there restrictions on the values ofm
andbatch_size
. For example,batch_size
must be a multiple ofm
. For all the restrictions, see the documentation. -
Added
trainable_attributes
to BaseTrainer and to standardize theset_to_train
andset_to_eval
functions. -
Added
save_models
init argument to HookContainer. If set toFalse
then models will not be saved. -
Added
losses_sizes
as a stat for BaseReducer -
Added a type check and conversion in
common_functions.labels_to_indices
to go from torch tensor to numpy
v0.9.91
Bug Fixes and Improvements
- Fixed CircleLoss bug, by improving the
logsumexp
keep_mask
implementation. See #173 - Fixed convert_to_weights bug, which caused a runtime error when an empty
indices_tuple
was passed in. See #174 - ProxyAnchorLoss now adds miner weights to the exponents which are fed to
logsumexp
. This is equivalent to scaling each loss component bye^(miner_weight)
. The previous behavior was to scale each loss component by justminer_weight
.
Other updates
- Added an example notebook which shows how to use a customized loss + miner in a simple training loop.
- A new arxiv paper and an improved Readme give a better high level explanation of the library.
- Added better explanations for the testers and the default accuracy metrics
v0.9.90
********** Summary **********
The main update is the new distances module, which adds an extra level of modularity to loss functions. It is a pretty big design change, which is why so many arguments have become obsolete. See the documentation for a description of the new module.
Other updates include support for half-precision, new regularizers and mixins, improved documentation, and default values for most initialization parameters.
********** Breaking Changes **********
Dependencies
This library now requires PyTorch >= 1.6.0. Previously there was no explicit version requirement.
Losses and Miners
All loss functions
normalize_embeddings
has been removed
- If you never used this argument, nothing needs to be done.
normalize_embeddings = True
: just remove the argument.normalize_embeddings = False
: remove the argument and instead pass it into adistance
object. For example:
from pytorch_metric_learning.distances import LpDistance
loss_func = TripletMarginLoss(distance=LpDistance(normalize_embeddings=False))
ContrastiveLoss, GenericPairLoss, BatchHardMiner, HDCMiner, PairMarginMiner
use_similarity
has been removed
- If you never used this argument, nothing needs to be done.
use_similarity = True
: remove the argument and:
### if you had set normalize_embeddings = False ###
from pytorch_metric_learning.distances import DotProductSimilarity
loss_func = ContrastiveLoss(distance=DotProductSimilarity(normalize_embeddings=False))
#### otherwise ###
from pytorch_metric_learning.distances import CosineSimilarity
loss_func = ContrastiveLoss(distance=CosineSimilarity())
squared_distances
has been removed
- If you never used this argument, nothing needs to be done.
squared_distances = True
: remove the argument and instead passpower=2
into adistance
object. For example:
from pytorch_metric_learning.distances import LpDistance
loss_func = ContrastiveLoss(distance=LpDistance(power=2))
squared_distances = False
: just remove the argument.
ContrastiveLoss, TripletMarginLoss
power
has been removed
- If you never used this argument, nothing needs to be done.
power = 1
: just remove the argumentpower = X, where X != 1
: remove the argument and instead pass it into adistance
object. For example:
from pytorch_metric_learning.distances import LpDistance
loss_func = TripletMarginLoss(distance=LpDistance(power=2))
TripletMarginLoss
distance_norm
has been removed
- If you never used this argument, nothing needs to be done.
distance_norm = 2
: just remove the argumentdistance_norm = X, where X != 2
: remove the argument and instead pass it asp
into adistance
object. For example:
from pytorch_metric_learning.distances import LpDistance
loss_func = TripletMarginLoss(distance=LpDistance(p=1))
NPairsLoss
l2_reg_weight
has been removed
- If you never used this argument, nothing needs to be done.
l2_reg_weight = 0
: just remove the argumentl2_reg_weight = X, where X > 0
: remove the argument and instead pass in anLpRegularizer
and weight:
from pytorch_metric_learning.regularizers import LpRegularizer
loss_func = NPairsLoss(embedding_regularizer=LpRegularizer(), embedding_reg_weight=0.123)
SignalToNoiseRatioContrastiveLoss
regularizer_weight
has been removed
- If you never used this argument, nothing needs to be done.
regularizer_weight = 0
: just remove the argumentregularizer_weight = X, where X > 0
: remove the argument and instead pass in aZeroMeanRegularizer
and weight:
from pytorch_metric_learning.regularizers import LpRegularizer
loss_func = SignalToNoiseRatioContrastiveLoss(embedding_regularizer=ZeroMeanRegularizer(), embedding_reg_weight=0.123)
SoftTripleLoss
reg_weight
has been removed
- If you never used this argument, do the following to obtain the same default behavior:
from pytorch_metric_learning.regularizers import SparseCentersRegularizer
weight_regularizer = SparseCentersRegularizer(num_classes, centers_per_class)
SoftTripleLoss(..., weight_regularizer=weight_regularizer, weight_reg_weight=0.2)
reg_weight = X
: remove the argument, and use theSparseCenterRegularizer
as shown above.
WeightRegularizerMixin and all classification loss functions
- If you never specified
regularizer
orreg_weight
, nothing needs to be done. regularizer = X
: replace withweight_regularizer = X
reg_weight = X
: replace withweight_reg_weight = X
Classification losses
- For all losses and miners, default values have been set for as many arguments as possible. This has caused a change in ordering in positional arguments for several of the classification losses. The typical form is now:
loss_func = SomeClassificatinLoss(num_classes, embedding_loss, <keyword arguments>)
See the documentation for specifics
Reducers
ThresholdReducer
threshold
has been replaced by low
and high
- Replace
threshold = X
withlow = X
Regularizers
All regularizers
normalize_weights
has been removed
- If you never used this argument, nothing needs to be done.
normalize_weights = True
: just remove the argument.normalize_weights = False
: remove the argument and instead passnormalize_embeddings = False
into adistance
object. For example:
from pytorch_metric_learning.distances import DotProductSimilarity
loss_func = RegularFaceRegularizer(distance=DotProductSimilarity(normalize_embeddings=False))
Inference
MatchFinder
mode
has been removed
- Replace
mode="sim"
with eitherdistance=CosineSimilarity()
ordistance=DotProductSimilarity()
- Replace
mode="dist"
withdistance=LpDistance()
- Replace
mode="squared_dist"
withdistance=LpDistance(power=2)
********** New Features **********
Distances
Distances bring an additional level of modularity to building loss functions. Here's an example of how they work.
Consider the TripletMarginLoss in its default form:
from pytorch_metric_learning.losses import TripletMarginLoss
loss_func = TripletMarginLoss(margin=0.2)
This loss function attempts to minimize [dap - dan + margin]+.
In other words, it tries to make the anchor-positive distances (dap) smaller than the anchor-negative distances (dan).
Typically, dap and dan represent Euclidean or L2 distances. But what if we want to use a squared L2 distance, or an unnormalized L1 distance, or completely different distance measure like signal-to-noise ratio? With the distances module, you can try out these ideas easily:
### TripletMarginLoss with squared L2 distance ###
from pytorch_metric_learning.distances import LpDistance
loss_func = TripletMarginLoss(margin=0.2, distance=LpDistance(power=2))
### TripletMarginLoss with unnormalized L1 distance ###
loss_func = TripletMarginLoss(margin=0.2, distance=LpDistance(normalize_embeddings=False, p=1))
### TripletMarginLoss with signal-to-noise ratio###
from pytorch_metric_learning.distances import SNRDistance
loss_func = TripletMarginLoss(margin=0.2, distance=SNRDistance())
You can also use similarity measures rather than distances, and the loss function will make the necessary adjustments:
### TripletMarginLoss with cosine similarity##
from pytorch_metric_learning.distances import CosineSimilarity
loss_func = TripletMarginLoss(margin=0.2, distance=CosineSimilarity())
With a similarity measure, the TripletMarginLoss internally swaps the anchor-positive and anchor-negative terms: [san - sap + margin]+. In other words, it will try to make the anchor-negative similarities smaller than the anchor-positive similarities.
All losses, miners, and regularizers accept a distance
argument. So you can try out the MultiSimilarityMiner
using SNRDistance
, or the NTXentLoss
using LpDistance(p=1)
and so on. Note that some losses/miners/regularizers have restrictions on the type of distances they can accept. For example, some classification losses only allow CosineSimilarity
or DotProductSimilarity
as their distance measure between embeddings and weights. To view restrictions for specific loss functions, see the documentation
There are four distances implemented (LpDistance, SNRDistance, CosineSimilarity, DotProductSimilarity), but of course you can extend the BaseDistance class and write a custom distance measure if you want. See the documentation for more.
EmbeddingRegularizerMixin
All loss functions now extend EmbeddingRegularizerMixin
, which means you can optionally pass in (to any loss function) an embedding regularizer and its weight. The embedding regularizer will compute some loss based on the embeddings alone, ignoring labels and tuples. For example:
from pytorch_metric_learning.regularizers import LpRegularizer
loss_func = MultiSimilarityLoss(embedding_regularizer=LpRegularizer(), embedding_reg_weight=0.123)
WeightRegularizerMixin is now a subclass of WeightMixin
As in previous versions, classification losses extend WeightRegularizerMixin
, which which means you can optionally pass i...
v0.9.89
CrossBatchMemory
- Fixed bug where CrossBatchMemory would use self-comparisons as positive pairs. This was uniquely a CrossBatchMemory problem because of the nature of adding each current batch to the queue.
- Fixed bug where DistanceWeightedMiner would not work with CrossBatchMemory due to missing
ref_label
- Changed 3rd keyword argument of forward() from
input_indices_tuple
toindices_tuple
to be consistent with all other losses.
AccuracyCalculator
- Fixed bug in AccuracyCalculator where it would return NaN if the reference set contained none of query set labels. Now it will log a warning and return 0.
BaseTester
- Fixed bug where "compared_to_training_set" mode of BaseTester fails due to list(None) bug.
InferenceModel
- New
get_nearest_neighbors
function will return nearest neighbors of a query. By @btseytlin
Loss and miner utils
- Switched to
fill_diagonal_
in theget_all_pairs_indices
andget_all_triplets_indices
code, instead of creatingtorch.eye
.
v0.9.88
v0.9.87
v0.9.87 comes with some major changes that may cause your existing code to break.
BREAKING CHANGES
Losses
- The
avg_non_zero_only
init argument has been removed fromContrastiveLoss
,TripletMarginLoss
, andSignalToNoiseRatioContrastiveLoss
. Here's how to translate from old to new code:avg_non_zero_only=True
: Just remove this input parameter. Nothing else needs to be done as this is the default behavior.avg_non_zero_only=False
: Remove this input parameter and replace it withreducer=reducers.MeanReducer()
. You'll need to add this to your imports:from pytorch_metric_learning import reducers
learnable_param_names
andnum_class_per_param
has been removed fromBaseMetricLossFunction
due to lack of use.- MarginLoss is the only built-in loss function that is affected by this. Here's how to translate from old to new code:
learnable_param_names=["beta"]
: Remove this input parameter and instead pass inlearn_beta=True
.num_class_per_param=N
: Remove this input parameter and instead pass innum_classes=N
.
- MarginLoss is the only built-in loss function that is affected by this. Here's how to translate from old to new code:
AccuracyCalculator
- The
average_per_class
init argument is nowavg_of_avgs
. The new name better reflects the functionality. - The old way to import was:
from pytorch_metric_learning.utils import AccuracyCalculator
. This will no longer work. The new way is:from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
. The reason for this change is to avoid an unnecessary import of the Faiss library, especially when this library is used in other packages.
New feature: Reducers
Reducers specify how to go from many loss values to a single loss value. For example, the ContrastiveLoss computes a loss for every positive and negative pair in a batch. A reducer will take all these per-pair losses, and reduce them to a single value. Here's where reducers fit in this library's flow of filters and computations:
Your Data --> Sampler --> Miner --> Loss --> Reducer --> Final loss value
Reducers are passed into loss functions like this:
from pytorch_metric_learning import losses, reducers
reducer = reducers.SomeReducer()
loss_func = losses.SomeLoss(reducer=reducer)
loss = loss_func(embeddings, labels) # in your training for-loop
Internally, the loss function creates a dictionary that contains the losses and other information. The reducer takes this dictionary, performs the reduction, and returns a single value on which .backward()
can be called. Most reducers are written such that they can be passed into any loss function.
See the documentation for details.
Other updates
Utils
Inference
InferenceModel
has been added to the library. It is a model wrapper that makes it convenient to find matching pairs within a batch, or from a set of pairs. Take a look at this notebook to see example usage.
AccuracyCalculator
- The
k
value for k-nearest neighbors can optionally be specified as an init argument. - k-nn based metrics now receive knn distances in their kwargs. See #118 by @marijnl
Other stuff
Unit tests were added for almost all losses, miners, regularizers, and reducers.
Bug fixes
Trainers
Loss and miner utils
- Fixed bug where
convert_to_triplets
could encounter a RuntimeError. See #95
v0.9.86
Losses + miners
- Added assertions to make sure the number of input embeddings is equal to the number of input labels.
- MarginLoss
- Fixed bug where loss explodes if self.nu > 0 and number of active pairs is 0. See #98 (comment)
Trainers
- Added
freeze_these
to the init arguments of BaseTrainer. This optional argument takes a list or tuple of strings as input. The strings must correspond to the names of models or loss functions, and these models/losses will have their parameters frozen during training. Their corresponding optimizers will also not be stepped. - Fixed indices shifting bug in the TwoStreamMetricLoss trainer. By @marijnl
Testers
- BaseTester
- Pass in epoch to
visualizer_hook
- Added
eval
option toget_all_embeddings
. By default it is True, and will set the input trunk and embedder to eval() mode.
- Pass in epoch to
Utils
- HookContainer
- Allow training to resume from best model, rather than just the latest model.
- The best models are now saved as
<model_name>_best<epoch>.pth
rather than<model_name>_best.pth
. To easily get the new suffix for loading the best model you can do:
from pytorch_metric_learning.utils import common_functions as c_f
_, best_model_suffix = c_f.latest_version(your_model_folder, best=True)
best_trunk = "trunk_{}.pth".format(best_model_suffix)
best_embedder = "embedder_{}.pth".format(best_model_suffix)
v0.9.85
Trainers
- Added TwoStreamMetricLoss. By @marijnl.
- All
BaseTrainer
child classes now accept*args
and pass it toBaseTrainer
, so that you can use positional arguments when you init those child classes, rather than just keyword arguments. - Fixed a key verification bug in CascadedEmbeddings that made it impossible to pass in an optimizer for the metric loss.
Testers
- Added GlobalTwoStreamEmbeddingSpaceTester. By @marijnl
- BaseTester
- The input visualizer should now implement the
fit_transform
method, rather thanfit
andtransform
separately. - Fixed various bugs related to
label_hierarchy_level
- The input visualizer should now implement the
- WithSameParentLabelTester
- Fixed bugs that were causing this tester to encounter a runtime error.
Utils
- HookContainer
- Added methods for retrieving loss and accuracy history.
- Fixed bug where the value for
best_epoch
could beNone
.
- AccuracyCalculator
- Got rid of bug that returned NaN when dealing with classes containing only one sample.
- Added
average_per_class
option, which computes the average accuracy per class, and then returns the average of those averages. This can be useful when evaluating datasets with unbalanced classes.
Other stuff
- Added the
with-hooks
andwith-hooks-cpu
pip install options. The following will install record-keeper, faiss-gpu, and tensorboard, in addition to pytorch-metric-learning
pip install pytorch-metric-learning[with-hooks]
If you don't have a GPU you can do:
pip install pytorch-metric-learning[with-hooks-cpu]
- Added more tests for AccuracyCalculator