Skip to content

Commit

Permalink
Merge pull request #12 from liluo2/eps_whitendata_modify
Browse files Browse the repository at this point in the history
modify the generation of eye matrix of the function whiten_data
  • Loading branch information
brandstetter-johannes authored Jan 25, 2024
2 parents 9248979 + 77c7622 commit 35891e5
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion cliffordlayers/nn/functional/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ def whiten_data(
cov = running_cov.permute(2, 0, 1)

# Upper triangle Cholesky decomposition of covariance matrix: U^T U = Cov
eye = eps * torch.eye(I, device=cov.device, dtype=cov.dtype).unsqueeze(0)
# eye = eps * torch.eye(I, device=cov.device, dtype=cov.dtype).unsqueeze(0)
# Modified the scale of eps to help prevent the occurence of negative-definite matrices
# 1e-5 may not fit the scale of matrices with large numbers
max_values = torch.amax(cov, dim=(1, 2))
A = torch.eye(cov.shape[-1], device=cov.device, dtype=cov.dtype)
eye = eps * torch.einsum('ij,k->kij', A, max_values)
U = torch.linalg.cholesky(cov + eye).mH
# Invert Cholesky decomposition, returns tensor of shape [B, C, *D, I]
x_whiten = torch.linalg.solve_triangular(
Expand Down

0 comments on commit 35891e5

Please sign in to comment.