Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
TakeshiMusgrave committed Aug 8, 2020
2 parents 67fdd14 + db597ed commit 20d65ef
Show file tree
Hide file tree
Showing 11 changed files with 578 additions and 155 deletions.
14 changes: 14 additions & 0 deletions docs/common_functions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Common Functions

## TorchInitWrapper
A simpler wrapper to convert the torch weight initialization functions into class form, which can then be applied within loss functions.

Example usage:
```python
from pytorch_metric_learning.utils import common_functions as c_f
import torch

# use kaiming_uniform, with a=1 and mode='fan_out'
weight_init_func = c_f.TorchInitWrapper(torch.nn.kaiming_uniform_, a=1, mode='fan_out')
loss_func = SomeClassificationLoss(..., weight_init_func=weight_init_func)
```
101 changes: 101 additions & 0 deletions docs/distances.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Distances

Distances are classes that when given a batch of embeddings, will return a matrix representing the pairwise distances/similarities between all the embeddings.

Consider the TripletMarginLoss in its default form:
```python
from pytorch_metric_learning.losses import TripletMarginLoss
loss_func = TripletMarginLoss(margin=0.2)
```
This loss function attempts to minimize [d<sub>ap</sub> - d<sub>an</sub> + margin]<sub>+</sub>.

In other words, it tries to make the anchor-positive distances (d<sub>ap</sub>) smaller than the anchor-negative distances (d<sub>an</sub>).

Typically, d<sub>ap</sub> and d<sub>an</sub> represent Euclidean or L2 distances. But what if we want to use a squared L2 distance, or an unnormalized L1 distance, or completely different distance measure like signal-to-noise ratio? With the distances module, you can try out these ideas easily:
```python
### TripletMarginLoss with squared L2 distance ###
from pytorch_metric_learning.distances import LpDistance
loss_func = TripletMarginLoss(margin=0.2, distance=LpDistance(power=2))

### TripletMarginLoss with unnormalized L1 distance ###
loss_func = TripletMarginLoss(margin=0.2, distance=LpDistance(normalize_embeddings=False, p=1))

### TripletMarginLoss with signal-to-noise ratio###
from pytorch_metric_learning.distances import SNRDistance
loss_func = TripletMarginLoss(margin=0.2, distance=SNRDistance())
```

You can also use similarity measures rather than distances, and the loss function will make the necessary adjustments:
```python
### TripletMarginLoss with cosine similarity##
from pytorch_metric_learning.distances import CosineSimilarity
loss_func = TripletMarginLoss(margin=0.2, distance=CosineSimilarity())
```
With a similarity measure, the TripletMarginLoss internally swaps the anchor-positive and anchor-negative terms: [s<sub>an</sub> - s<sub>ap</sub> + margin]<sub>+</sub>. In other words, it will try to make the anchor-negative similarities smaller than the anchor-positive similarities.

All **losses, miners, and regularizers** accept a ```distance``` argument. So you can try out the ```MultiSimilarityMiner``` using ```SNRDistance```, or the ```NTXentLoss``` using ```LpDistance(p=1)``` and so on. Note that some losses/miners/regularizers have restrictions on the type of distances they can accept. For example, some classification losses only allow ```CosineSimilarity``` or ```DotProductSimilarity``` as their distance measure between embeddings and weights. To view restrictions for specific loss functions, see the [losses page](losses.md)

## BaseDistance

All distances extend this class and therefore inherit its ```__init__``` parameters.

```python
distances.BaseDistance(collect_stats = True,
normalize_embeddings=True,
p=2,
power=1,
is_inverted=False)
```

**Parameters**:

* **collect_stats**: If True, will collect various statistics that may be useful to analyze during experiments. If False, these computations will be skipped.
* **normalize_embeddings**: If True, embeddings will be normalized to have an Lp norm of 1, before the distance/similarity matrix is computed.
* **p**: The distance norm.
* **power**: If not 1, each element of the distance/similarity matrix will be raised to this power.
* **is_inverted**: Should be set by child classes. If True, then small values represent embeddings that are close together. If False, then large values represent embeddings that are similar to each other.

**Required Implementations**:
```python
# Must return a matrix where mat[j,k] represents
# the distance/similarity between query_emb[j] and ref_emb[k]
def compute_mat(self, query_emb, ref_emb):
raise NotImplementedError

# Must return a tensor where output[j] represents
# the distance/similarity between query_emb[j] and ref_emb[j]
def pairwise_distance(self, query_emb, ref_emb):
raise NotImplementedError
```


## CosineSimilarity
```python
distances.CosineSimilarity(**kwargs)
```

The returned ```mat[i,j]``` is the cosine similarity between ```query_emb[i]``` and ```ref_emb[j]```. This class is equivalent to [```DotProductSimilarity(normalize_embeddings=True)```](distances.md#dotproductsimilarity).

## DotProductSimilarity
```python
distances.DotProductSimilarity(**kwargs)
```
The returned ```mat[i,j]``` is equal to ```torch.sum(query_emb[i] * ref_emb[j])```


## LpDistance
```python
distances.LpDistance(**kwargs)
```
The returned ```mat[i,j]``` is the Lp distance between ```query_emb[i]``` and ```ref_emb[j]```. With default parameters, this is the Euclidean distance.

## SNRDistance
[Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Yuan_Signal-To-Noise_Ratio_A_Robust_Distance_Metric_for_Deep_Metric_Learning_CVPR_2019_paper.pdf){target=_blank}
```python
distances.SNRDistance(**kwargs)
```
The returned ```mat[i,j]``` is equal to:

```python
torch.var(query_emb[i] - ref_emb[j]) / torch.var(query_emb[i])
```
Binary file added docs/imgs/zero_mean_regularizer_equation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 4 additions & 7 deletions docs/inference_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ InferenceModel(trunk,

* **trunk**: Your trained model for computing embeddings.
* **embedder**: Optional. This is if your model is split into two components (trunk and embedder). If None, then the embedder will simply return the trunk's output.
* **match_finder**: A [MatchFinder](inference_models.md#matchfinder) object. If ```None```, it will be set to ```MatchFinder(mode="sim", threshold=0.9)```.
* **match_finder**: A [MatchFinder](inference_models.md#matchfinder) object. If ```None```, it will be set to ```MatchFinder(distance=CosineSimilarity(), threshold=0.9)```.
* **indexer**: The object used for computing k-nearest-neighbors. If ```None```, it will be set to ```FaissIndexer()```.
* **normalize_embeddings**: If True, embeddings will be normalized to have Euclidean norm of 1.
* **batch_size**: The batch size used to compute embeddings, when training the indexer for k-nearest-neighbor retrieval.
Expand All @@ -26,16 +26,13 @@ InferenceModel(trunk,
## MatchFinder
```python
from pytorch_metric_learning.utils.inference import MatchFinder
MatchFinder(mode="dist", threshold=None)
MatchFinder(distance=None, threshold=None)
```

**Parameters**:

* **mode**: One of:
* ```dist```: Use the Euclidean distance between vectors
* ```squared_dist```: Use the squared Euclidean distance between vectors
* ```sim```: Use the dot product of vectors
* **threshold**: Optional. Pairs will be a match if they fall under this threshold for distance modes, or over this value for the similarity mode. If not provided, then a threshold must be provided during function calls.
* **distance**: A [distance](distances.md) object.
* **threshold**: Optional. Pairs will be a match if they fall under this threshold for non-inverted distances, or over this value for inverted distances. If not provided, then a threshold must be provided during function calls.


## FaissIndexer
Expand Down
Loading

0 comments on commit 20d65ef

Please sign in to comment.