diff --git a/geomloss/utils.py b/geomloss/utils.py index 242bbe9..f7e6cdb 100644 --- a/geomloss/utils.py +++ b/geomloss/utils.py @@ -9,6 +9,10 @@ def scal(α, f, batch=False): + # f can become inf which would produce NaNs later on. Here we basically + # enforce 0.0 * inf = 0.0. + f = torch.where(α == 0.0, f.new_full((), 0.0), f) + if batch: B = α.shape[0] return (α.view(B, -1) * f.view(B, -1)).sum(1)