diff --git a/lpips/__init__.py b/lpips/__init__.py index 62b9a079..21d579d4 100755 --- a/lpips/__init__.py +++ b/lpips/__init__.py @@ -39,9 +39,9 @@ # return self.model.forward(target, pred) -def normalize_tensor(in_feat,eps=1e-10): - norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) - return in_feat/(norm_factor+eps) +def normalize_tensor(in_feat,eps=1e-8): + norm_factor = torch.sqrt(eps + torch.sum(in_feat**2,dim=1,keepdim=True)) + return in_feat/norm_factor def l2(p0, p1, range=255.): return .5*np.mean((p0 / range - p1 / range)**2)