Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 10, 2024
1 parent 5392a9d commit a7242cb
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 24 deletions.
4 changes: 2 additions & 2 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def forward(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tens
The output of the forward pass.
"""
outputs = self.original_module(inputs, *args, **kwargs)
# if outputs.requires_grad and self.gradient_scale == 1.0:
# return outputs
if outputs.requires_grad:
return outputs
return outputs + self._constant

def prepare_storage(self, device: torch.device) -> None:
Expand Down
3 changes: 0 additions & 3 deletions kronfluence/module/tracker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,12 @@ def _preprocess_gradient(self, output_gradient: torch.Tensor, target_dtype: torc
The preprocessed gradient.
"""
original_dtype = output_gradient.dtype
print(original_dtype)
output_gradient = output_gradient.to(dtype=target_dtype)
if self.module.gradient_scale != 1.0:
if original_dtype != target_dtype:
output_gradient.mul_(self.module.gradient_scale)
else:
output_gradient = output_gradient * self.module.gradient_scale
print("After")
print(output_gradient.dtype)
return output_gradient

def register_hooks(self) -> None:
Expand Down
21 changes: 11 additions & 10 deletions kronfluence/module/tracker/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def _update_gradient_covariance_matrix(
)
self._gradient_covariance_initialized = True
self.module.storage[NUM_GRADIENT_COVARIANCE_PROCESSED].add_(count)
self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(output_gradient.t(), output_gradient)
alpha = 1
if self.module.gradient_scale != 1.0:
alpha = self.module.gradient_scale ** 2.
self.module.storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(output_gradient.t(), output_gradient, alpha=alpha)

def register_hooks(self) -> None:
"""Sets up hooks to compute activation and gradient covariance matrices."""
Expand All @@ -112,9 +115,7 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.
def backward_hook(output_gradient: torch.Tensor) -> None:
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient.detach(), target_dtype=self.module.factor_args.gradient_covariance_dtype
)
output_gradient = output_gradient.detach().to(self.module.factor_args.gradient_covariance_dtype)
# Computes and updates pseudo-gradient covariance during backward pass.
output_gradient, count = self.module.get_flattened_gradient(output_gradient=output_gradient)
self._update_gradient_covariance_matrix(output_gradient=output_gradient, count=count)
Expand Down Expand Up @@ -259,13 +260,13 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._raise_cache_not_found_exception()
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient=output_gradient.detach(), target_dtype=self.module.factor_args.per_sample_gradient_dtype
)
output_gradient = output_gradient.detach().to(self.module.factor_args.per_sample_gradient_dtype)
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=self.cached_activations.to(device=output_gradient.device),
output_gradient=output_gradient,
).to(dtype=self.module.factor_args.lambda_dtype)
if self.module.gradient_scale != 1.0:
per_sample_gradient.mul_(self.module.gradient_scale)
self.clear_all_cache()
del output_gradient
# Computes and updates lambda matrix during backward pass.
Expand All @@ -275,14 +276,14 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
def shared_backward_hook(output_gradient: torch.Tensor) -> None:
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient=output_gradient.detach(), target_dtype=self.module.factor_args.per_sample_gradient_dtype
)
output_gradient = output_gradient.detach().to(self.module.factor_args.per_sample_gradient_dtype)
cached_activation = self.cached_activations.pop()
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=cached_activation.to(device=output_gradient.device),
output_gradient=output_gradient,
)
if self.module.gradient_scale != 1.0:
per_sample_gradient.mul_(self.module.gradient_scale)
if self.cached_per_sample_gradient is None:
self.cached_per_sample_gradient = torch.zeros_like(per_sample_gradient, requires_grad=False)
# Aggregates per-sample gradients during backward pass.
Expand Down
9 changes: 6 additions & 3 deletions kronfluence/module/tracker/pairwise_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def _compute_pairwise_score_with_gradient(self, per_sample_gradient: torch.Tenso
per_sample_gradient,
)

if self.module.gradient_scale != 1.0:
scores.mul_(self.module.gradient_scale)

if self.module.storage[PAIRWISE_SCORE_MATRIX_NAME] is not None:
self.module.storage[PAIRWISE_SCORE_MATRIX_NAME].add_(scores)
else:
Expand Down Expand Up @@ -76,9 +79,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._raise_cache_not_found_exception()
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient.detach(), target_dtype=self.module.score_args.score_dtype
)
output_gradient = output_gradient.detach().to(self.module.factor_args.score_dtype)
if isinstance(self.cached_activations, list):
cached_activation = self.cached_activations.pop()
else:
Expand All @@ -91,6 +92,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
output_gradient=output_gradient,
)
del cached_activation, output_gradient
if self.module.gradient_scale != 1.0:
self.module.storage[PAIRWISE_SCORE_MATRIX_NAME].mul_(self.module.gradient_scale)
self.clear_all_cache()
else:
per_sample_gradient = self.module.compute_per_sample_gradient(
Expand Down
12 changes: 6 additions & 6 deletions kronfluence/module/tracker/precondition.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
self._raise_cache_not_found_exception()
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient=output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype
)
output_gradient = output_gradient.detach().to(self.module.score_args.per_sample_gradient_dtype)
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=self.cached_activations.to(device=output_gradient.device),
output_gradient=output_gradient,
Expand All @@ -120,15 +118,15 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
storage=self.module.storage,
)
del per_sample_gradient
if self.module.gradient_scale != 1.0:
preconditioned_gradient.mul_(self.module.gradient_scale)
self._process_preconditioned_gradient(preconditioned_gradient=preconditioned_gradient)

@torch.no_grad()
def shared_backward_hook(output_gradient: torch.Tensor) -> None:
handle = self.cached_hooks.pop()
handle.remove()
output_gradient = self._preprocess_gradient(
output_gradient=output_gradient.detach(), target_dtype=self.module.score_args.per_sample_gradient_dtype
)
output_gradient = output_gradient.detach().to(self.module.score_args.per_sample_gradient_dtype)
cached_activation = self.cached_activations.pop()
per_sample_gradient = self.module.compute_per_sample_gradient(
input_activation=cached_activation.to(device=output_gradient.device),
Expand All @@ -153,6 +151,8 @@ def finalize_iteration(self) -> None:
storage=self.module.storage,
)
self.cached_per_sample_gradient = None
if self.module.gradient_scale != 1.0:
preconditioned_gradient.mul_(self.module.gradient_scale)
self._process_preconditioned_gradient(preconditioned_gradient=preconditioned_gradient)
self.clear_all_cache()

Expand Down

0 comments on commit a7242cb

Please sign in to comment.