Skip to content

Commit

Permalink
[MRG] Add a geomloss wrapper for sinkhorn solver (#571)
Browse files Browse the repository at this point in the history
* add first geomlowss wrapper

* pep8

* working geomlos wrapper

* pep8

* small edit

* test for geomloss wrapper

* test for geomloss wrapper

* ad geomloss to tests

* pep8 test

* add option in solve_sample

* limyt to rceent python for geomloss

* add keops as depedency
  • Loading branch information
rflamary authored Nov 21, 2023
1 parent cffb6cf commit 299f560
Show file tree
Hide file tree
Showing 9 changed files with 339 additions and 8 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,6 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil

[59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017.

[60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast and scalable optimal transport for brain tractograms](https://arxiv.org/pdf/2107.02010.pdf). In Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part III 22 (pp. 636-644). Springer International Publishing.

[61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. (2021). [Kernel operations on the gpu, with autodiff, without memory overflows](https://www.jmlr.org/papers/volume22/20-275/20-275.pdf). The Journal of Machine Learning Research, 22(1), 3457-3462.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
+ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559)
+ Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551)
+ New API function `ot.solve_sample` for solving OT problems from empirical samples (PR #563)
+ Wrapper for `geomloss`` solver on empirical samples (PR #571)
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)

Expand Down
6 changes: 4 additions & 2 deletions ot/bregman/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,17 @@

from ._dictionary import (unmix)

from ._geomloss import (empirical_sinkhorn2_geomloss, geomloss)


__all__ = ['geometricBar', 'geometricMean', 'projR', 'projC',
'sinkhorn', 'sinkhorn2', 'sinkhorn_knopp', 'sinkhorn_log',
'greenkhorn', 'sinkhorn_stabilized', 'sinkhorn_epsilon_scaling',
'barycenter', 'barycenter_sinkhorn', 'free_support_sinkhorn_barycenter',
'barycenter_stabilized', 'barycenter_debiased', 'jcpot_barycenter',
'convolutional_barycenter2d', 'convolutional_barycenter2d_debiased',
'empirical_sinkhorn', 'empirical_sinkhorn2',
'empirical_sinkhorn_divergence',
'empirical_sinkhorn', 'empirical_sinkhorn2', 'empirical_sinkhorn2_geomloss'
'empirical_sinkhorn_divergence', 'geomloss',
'screenkhorn',
'unmix'
]
216 changes: 216 additions & 0 deletions ot/bregman/_geomloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
"""
Wrapper functions for geomloss
"""

# Author: Remi Flamary <[email protected]>
#
# License: MIT License

import numpy as np
try:
import geomloss
from geomloss import SamplesLoss
import torch
from torch.autograd import grad
from ..utils import get_backend, LazyTensor, dist
except ImportError:
geomloss = False


def get_sinkhorn_geomloss_lazytensor(X_a, X_b, f, g, a, b, metric='sqeuclidean', blur=0.1, nx=None):
""" Get a LazyTensor of sinkhorn solution T = exp((f+g^T-C)/reg)*(ab^T)
Parameters
----------
X_a : array-like, shape (n_samples_a, dim)
samples in the source domain
X_torch: array-like, shape (n_samples_b, dim)
samples in the target domain
f : array-like, shape (n_samples_a,)
First dual potentials (log space)
g : array-like, shape (n_samples_b,)
Second dual potentials (log space)
metric : str, default='sqeuclidean'
Metric used for the cost matrix computation
blur : float, default=1e-1
blur term (blur=sqrt(reg)) >0
nx : Backend(), default=None
Numerical backend used
Returns
-------
T : LazyTensor
Lowrank tensor T = exp((f+g^T-C)/reg)*(ab^T)
"""

if nx is None:
nx = get_backend(X_a, X_b, f, g)

shape = (X_a.shape[0], X_b.shape[0])

def func(i, j, X_a, X_b, f, g, a, b, metric, blur):
if metric == 'sqeuclidean':
C = dist(X_a[i], X_b[j], metric=metric) / 2
else:
C = dist(X_a[i], X_b[j], metric=metric)
return nx.exp((f[i, None] + g[None, j] - C) / (blur**2)) * (a[i, None] * b[None, j])

T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, blur=blur)

return T


def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', scaling=0.95,
verbose=False, debias=False, log=False, backend='auto'):
r""" Solve the entropic regularization optimal transport problem with geomloss
The function solves the following optimization problem:
.. math::
\gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
s.t. \gamma 1 = a
\gamma^T 1= b
\gamma\geq 0
where :
- :math:`C` is the cost matrix such that :math:`C_{i,j}=d(x_i^s,x_j^t)` and
:math:`d` is a metric.
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j}\gamma_{i,j}\log(\gamma_{i,j})-\gamma_{i,j}+1`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in and computed in log space for
better stability and epsilon-scaling. The solution is computed ina lzy way
using the Geomloss [60] and the KeOps library [61].
Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
a : array-like, shape (n_samples_a,), default=None
samples weights in the source domain
b : array-like, shape (n_samples_b,), default=None
samples weights in the target domain
metric : str, default='sqeuclidean'
Metric used for the cost matrix computation Only acepted values are
'sqeuclidean' and 'euclidean'.
scaling : float, default=0.95
Scaling parameter used for epsilon scaling. Value close to one promote
precision while value close to zero promote speed.
verbose : bool, default=False
Print information
debias : bool, default=False
Use the debiased version of Sinkhorn algorithm [12]_.
log : bool, default=False
Return log dictionary containing all computed objects
backend : str, default='auto'
Numerical backend for geomloss. Only 'auto' and 'tensorized' 'online'
and 'multiscale' are accepted values.
Returns
-------
value : float
OT value
log : dict
Log dictionary return only if log==True in parameters
References
----------
.. [60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast
and scalable optimal transport for brain tractograms. In Medical Image
Computing and Computer Assisted Intervention–MICCAI 2019: 22nd
International Conference, Shenzhen, China, October 13–17, 2019,
Proceedings, Part III 22 (pp. 636-644). Springer International
Publishing.
.. [61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G.
(2021). Kernel operations on the gpu, with autodiff, without memory
overflows. The Journal of Machine Learning Research, 22(1), 3457-3462.
"""

if geomloss:

nx = get_backend(X_s, X_t, a, b)

if nx.__name__ not in ['torch', 'numpy']:
raise ValueError('geomloss only support torch or numpy backend')

if a is None:
a = nx.ones(X_s.shape[0], type_as=X_s) / X_s.shape[0]
if b is None:
b = nx.ones(X_t.shape[0], type_as=X_t) / X_t.shape[0]

if nx.__name__ == 'numpy':
X_s_torch = torch.tensor(X_s)
X_t_torch = torch.tensor(X_t)

a_torch = torch.tensor(a)
b_torch = torch.tensor(b)

else:
X_s_torch = X_s
X_t_torch = X_t

a_torch = a
b_torch = b

# after that we are all in torch

# set blur value and p
if metric == 'sqeuclidean':
p = 2
blur = np.sqrt(reg / 2) # because geomloss divides cost by two
elif metric == 'euclidean':
p = 1
blur = np.sqrt(reg)
else:
raise ValueError('geomloss only supports sqeuclidean and euclidean metrics')

# force gradients for computing dual
a_torch.requires_grad = True
b_torch.requires_grad = True

loss = SamplesLoss(loss='sinkhorn', p=p, blur=blur, backend=backend, debias=debias, scaling=scaling, verbose=verbose)

# compute value
value = loss(a_torch, X_s_torch, b_torch, X_t_torch) # linear + entropic/KL reg?

# get dual potentials
f, g = grad(value, [a_torch, b_torch])

if metric == 'sqeuclidean':
value *= 2 # because geomloss divides cost by two

if nx.__name__ == 'numpy':
f = f.cpu().detach().numpy()
g = g.cpu().detach().numpy()
value = value.cpu().detach().numpy()

if log:
log = {}
log['f'] = f
log['g'] = g
log['value'] = value

log['lazy_plan'] = get_sinkhorn_geomloss_lazytensor(X_s, X_t, f, g, a, b, metric=metric, blur=blur, nx=nx)

return value, log

else:
return value

else:
raise ImportError('geomloss not installed')
35 changes: 32 additions & 3 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .lp import emd2, wasserstein_1d
from .backend import get_backend
from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced
from .bregman import sinkhorn_log, empirical_sinkhorn2
from .bregman import sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss
from .partial import partial_wasserstein_lagrange
from .smooth import smooth_ot_dual
from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2,
Expand All @@ -23,6 +23,8 @@
from .gaussian import empirical_bures_wasserstein_distance
from .factored import factored_optimal_transport

lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale']


def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None,
Expand Down Expand Up @@ -865,7 +867,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None,

def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL",
unbalanced=None,
unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100,
unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95,
potentials_init=None, X_init=None, tol=None, verbose=False):
r"""Solve the discrete optimal transport problem using the samples in the source and target domains.
Expand Down Expand Up @@ -922,6 +924,10 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
Maximum number of iteration, by default None (default values in each solvers)
plan_init : array_like, shape (dim_a, dim_b), optional
Initialization of the OT plan for iterative methods, by default None
rank : int, optional
Rank of the OT matrix for lazy solers (method='factored'), by default 100
scaling : float, optional
Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
Initialization of the OT dual potentials for iterative methods, by default None
tol : _type_, optional
Expand All @@ -939,6 +945,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
- res.potentials : OT dual potentials
- res.value : Optimal value of the optimization problem
- res.value_linear : Linear OT loss with the optimal OT plan
- res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method)
See :any:`OTResult` for more information.
Expand Down Expand Up @@ -1148,7 +1155,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
"""

if method is not None and method.lower() in ['1d', 'gaussian', 'lowrank', 'factored']:
if method is not None and method.lower() in lst_method_lazy:
lazy0 = lazy
lazy = True

Expand Down Expand Up @@ -1221,6 +1228,28 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
if not lazy0: # store plan if not lazy
plan = lazy_plan[:]

elif method.startswith('geomloss'): # Geomloss solver for entropi OT

split_method = method.split('_')
if len(split_method) == 2:
backend = split_method[1]
else:
if lazy0 is None:
backend = 'auto'
elif lazy0:
backend = 'online'
else:
backend = 'tensorized'

value, log = empirical_sinkhorn2_geomloss(X_a, X_b, reg=reg, a=a, b=b, metric=metric, log=True, verbose=verbose, scaling=scaling, backend=backend)

lazy_plan = log['lazy_plan']
if not lazy0: # store plan if not lazy
plan = lazy_plan[:]

# return scaled potentials (to be consistent with other solvers)
potentials = (log['f'] / (lazy_plan.blur**2), log['g'] / (lazy_plan.blur**2))

elif reg is None or reg == 0: # exact OT

if unbalanced is None: # balanced EMD solver not available for lazy
Expand Down
2 changes: 1 addition & 1 deletion ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100):
"""

if nx is None:
nx = get_backend(a[0])
nx = get_backend(a[0:1])

if axis is None:
res = 0.0
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ jaxlib
tensorflow
pytest
torch_geometric
cvxpy
cvxpy
geomloss
pykeops
Loading

0 comments on commit 299f560

Please sign in to comment.