Skip to content

Commit

Permalink
ot bar doc
Browse files Browse the repository at this point in the history
  • Loading branch information
eloitanguy committed Jan 21, 2025
1 parent 6a3eab5 commit 3e8421e
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 23 deletions.
10 changes: 6 additions & 4 deletions examples/barycenters/plot_barycenter_generic_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
This example illustrates the computation of an Optimal Transport for a ground
cost that is not a power of a norm. We take the example of ground costs
:math:`c_k(x, y) = |P_k(x)-y|^2`, where :math:`P_k` is the (non-linear)
:math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear)
projection onto a circle k. This is an example of the fixed-point barycenter
solver introduced in [74] which generalises [20].
solver introduced in [74] which generalises [20] and [43].
The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in
\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over
Expand All @@ -22,6 +22,8 @@
Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International
Conference in Machine Learning
[43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
"""

# Author: Eloi Tanguy <[email protected]>
Expand Down Expand Up @@ -147,8 +149,8 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold):
b_list,
cost_list,
B,
max_its=fixed_point_its,
stop_threshold=stop_threshold,
numItermax=fixed_point_its,
stopThr=stop_threshold,
)

# %% Plot Barycenter (Iteration 10)
Expand Down
100 changes: 81 additions & 19 deletions ot/lp/_barycenter_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,33 +430,78 @@ class StoppingCriterionReached(Exception):

def free_support_barycenter_generic_costs(
X_init,
Y_list,
b_list,
measure_locations,
measure_weights,
cost_list,
B,
max_its=5,
stop_threshold=1e-5,
numItermax=5,
stopThr=1e-5,
log=False,
):
"""
Solves the OT barycenter problem using the fixed point algorithm, iterating
the function B on plans between the current barycentre and the measures.
r"""
Solves the OT barycenter problem for generic costs using the fixed point
algorithm, iterating the ground barycenter function B on transport plans
between the current barycentre and the measures.
The problem finds an optimal barycenter support `X` of given size (n, d)
(enforced by the initialisation), minimising a sum of pairwise transport
costs for the costs :math:`c_k`:
.. math::
\min_{X} \sum_{k=1}^K \mathcal{T}_{c_k}(X, a, Y_k, b_k),
where:
- :math:`X` (n, d) is the barycentre support,
- :math:`a` (n) is the (fixed) barycentre weights,
- :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`),
- :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`),
- :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix)
- :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycentre measure and the k-th measure with respect to the cost :math:`c_k`:
.. math::
\mathcal{T}_{c_k}(X, a, Y_k, b_k) = \min_\pi \quad \langle \pi, c_k(X, Y_k) \rangle_F
s.t. \ \pi \mathbf{1} = \mathbf{a}
\pi^T \mathbf{1} = \mathbf{b_k}
\pi \geq 0
in other words, :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is `ot.emd2(a, b_k,
c_k(X, Y_k))`.
The algorithm requires a given ground barycentre function `B` which computes
a solution of the following minimisation problem given :math:`(y_1, \cdots,
y_K) \in \mathbb{R}^{d_1}\times\cdots\times\mathbb{R}^{d_K}`:
.. math::
B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k),
where :math:`c_k(x, y_k) \in \mathbb{R}_+` is the cost between the points
:math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times
\cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to
this function, and for certain costs it can be computed explicitly of
through a numerical solver.
This function implements [74] Algorithm 2, which generalises [20] and [43]
to general costs and includes convergence guarantees, including for discrete measures.
Parameters
----------
X_init : array-like
Array of shape (n, d) representing initial barycentre points.
Y_list : list of array-like
measure_locations : list of array-like
List of K arrays of measure positions, each of shape (m_k, d_k).
b_list : list of array-like
measure_weights : list of array-like
List of K arrays of measure weights, each of shape (m_k).
cost_list : list of callable
List of K cost functions R^(n, d) x R^(m_k, d_k) -> R_+^(n, m_k).
List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}`.
B : callable
Function from R^d_1 x ... x R^d_K to R^d accepting a list of K arrays of shape (n, d_K), computing the ground barycentre.
max_its : int, optional
Function from :math:`\mathbb{R}^{d_1} \times\cdots \times \mathbb{R}^{d_K}` to :math:`\mathbb{R}^d` accepting a list of K arrays of shape (n\times d_K), computing the ground barycentre.
numItermax : int, optional
Maximum number of iterations (default is 5).
stop_threshold : float, optional
stopThr : float, optional
If the iterations move less than this, terminate (default is 1e-5).
log : bool, optional
Whether to return the log dictionary (default is False).
Expand All @@ -468,9 +513,25 @@ def free_support_barycenter_generic_costs(
log_dict : list of array-like, optional
log containing the exit status, list of iterations and list of
displacements if log is True.
.. _references-free-support-barycenter-generic-costs:
References
----------
.. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
.. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
.. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
See Also
--------
ot.lp.free_support_barycenter : Free support solver for the case where
:math:`c_k(x,y) = \|x-y\|_2^2`.
ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear.
"""
nx = get_backend(X_init, Y_list[0])
K = len(Y_list)
nx = get_backend(X_init, measure_locations[0])
K = len(measure_locations)
n = X_init.shape[0]
a = nx.ones(n) / n
X_list = [X_init] if log else [] # store the iterations
Expand All @@ -479,13 +540,14 @@ def free_support_barycenter_generic_costs(
exit_status = "Unknown"

Check warning on line 540 in ot/lp/_barycenter_solvers.py

View check run for this annotation

Codecov / codecov/patch

ot/lp/_barycenter_solvers.py#L533-L540

Added lines #L533 - L540 were not covered by tests

try:
for _ in range(max_its):
for _ in range(numItermax):
pi_list = [ # compute the pairwise transport plans

Check warning on line 544 in ot/lp/_barycenter_solvers.py

View check run for this annotation

Codecov / codecov/patch

ot/lp/_barycenter_solvers.py#L542-L544

Added lines #L542 - L544 were not covered by tests
emd(a, b_list[k], cost_list[k](X, Y_list[k])) for k in range(K)
emd(a, measure_weights[k], cost_list[k](X, measure_locations[k]))
for k in range(K)
]
Y_perm = []
for k in range(K): # compute barycentric projections
Y_perm.append(n * pi_list[k] @ Y_list[k])
Y_perm.append(n * pi_list[k] @ measure_locations[k])
X_next = B(Y_perm)

Check warning on line 551 in ot/lp/_barycenter_solvers.py

View check run for this annotation

Codecov / codecov/patch

ot/lp/_barycenter_solvers.py#L548-L551

Added lines #L548 - L551 were not covered by tests

if log:
Expand All @@ -498,7 +560,7 @@ def free_support_barycenter_generic_costs(
if log:
dX_list.append(dX)

Check warning on line 561 in ot/lp/_barycenter_solvers.py

View check run for this annotation

Codecov / codecov/patch

ot/lp/_barycenter_solvers.py#L560-L561

Added lines #L560 - L561 were not covered by tests

if dX < stop_threshold:
if dX < stopThr:
exit_status = "Stationary Point"
raise StoppingCriterionReached

Check warning on line 565 in ot/lp/_barycenter_solvers.py

View check run for this annotation

Codecov / codecov/patch

ot/lp/_barycenter_solvers.py#L563-L565

Added lines #L563 - L565 were not covered by tests

Expand Down

0 comments on commit 3e8421e

Please sign in to comment.