Skip to content

Commit

Permalink
Merge pull request #26 from pomonam/finalize_documents
Browse files Browse the repository at this point in the history
Finalize documents
  • Loading branch information
pomonam authored Jun 24, 2024
2 parents 2b7dbd3 + acd6529 commit 180c2ec
Show file tree
Hide file tree
Showing 19 changed files with 700 additions and 290 deletions.
16 changes: 8 additions & 8 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ After calling `prepare_model`, you can create [DistributedDataParallel (DDP)](ht

**Set up the Analyzer and Fit Factors.**
Initialize the `Analyzer` and run `fit_all_factors` to compute all factors that aim to approximate the Hessian
(Gauss-Newton Hessian). The computed factors will be stored on disk.
([Gauss-Newton Hessian](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2022/readings/L03_metrics.pdf)). The computed factors will be stored on disk.

```python
from kronfluence.analyzer import Analyzer
Expand Down Expand Up @@ -182,8 +182,8 @@ def forward(x: torch.Tensor) -> torch.Tensor:

> [!WARNING]
> The default arguments assume the module is used only once during the forward pass.
> IIf your model shares parameters (e.g., the module is used in multiple places during the forward pass), set
> `shared_parameters_exist=True` in both `FactorArguments` and `ScoreArguments`.
> If your model shares parameters (e.g., the module is used in multiple places during the forward pass), set
> `shared_parameters_exist=True` in `FactorArguments`.
**Why are there so many arguments?**
Kronfluence was originally developed to compute influence scores on large-scale models, which is why `FactorArguments` and `ScoreArguments`
Expand All @@ -206,6 +206,7 @@ factor_args = FactorArguments(
use_empirical_fisher=False,
distributed_sync_steps=1000,
amp_dtype=None,
shared_parameters_exist=False,

# Settings for covariance matrix fitting.
covariance_max_examples=100_000,
Expand All @@ -223,7 +224,6 @@ factor_args = FactorArguments(
lambda_module_partition_size=1,
lambda_iterative_aggregate=False,
cached_activation_cpu_offload=False,
shared_parameters_exist=False,
per_sample_gradient_dtype=torch.float32,
lambda_dtype=torch.float32,
)
Expand All @@ -237,6 +237,7 @@ You can change:
- `use_empirical_fisher`: Determines whether to use the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch)
instead of the true Fisher (using sampled labels from model's predictions). It is recommended to be `False`.
- `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`.
- `shared_parameters_exist`: Specifies whether the shared parameters exist in the forward pass.

### Fitting Covariance Matrices

Expand Down Expand Up @@ -306,7 +307,6 @@ This corresponds to **Equation 20** in the paper. You can tune:
You can set `cached_activation_cpu_offload=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower.
- `lambda_iterative_aggregate`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications.
This is helpful for reducing peak GPU memory, as it avoids holding multiple copies of tensors with the same shape as the per-sample-gradient.
- `shared_parameters_exist`: Specifies whether the shared parameters exist in the forward pass.
- `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can also use `torch.bfloat16`
or `torch.float16`.
- `lambda_dtype`: `dtype` for computing Lambda matrices. You can also use `torch.bfloat16`
Expand Down Expand Up @@ -353,7 +353,7 @@ score_args = ScoreArguments(
# Configuration for query batching.
query_gradient_rank=None,
query_gradient_svd_dtype=torch.float32,
num_query_gradient_aggregations=1,
num_query_gradient_accumulations=1,

# Configuration for dtype.
score_dtype=torch.float32,
Expand All @@ -370,11 +370,11 @@ score_args = ScoreArguments(
- `module_partition_size`: Number of module partitions for computing influence scores.
- `per_module_score`: Whether to return a per-module influence scores. Instead of summing over influences across
all modules, this will keep track of intermediate module-wise scores.
- - `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.
- `query_gradient_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used.
- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float64`.
- `num_query_gradient_aggregations`: Number of query gradients to aggregate over. For example, when `num_query_gradient_aggregations=2` with
- `num_query_gradient_accumulations`: Number of query gradients to accumulate over. For example, when `num_query_gradient_accumulations=2` with
`query_batch_size=16`, a total of 32 query gradients will be stored in memory when computing dot products with training gradients.
- `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores.
- `score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`.
- `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can use `torch.bfloat16` or `torch.float16`.
- `precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`,
Expand Down
20 changes: 10 additions & 10 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
torch
torchvision
accelerate
einops
einconv
opt_einsum
safetensors
tqdm
datasets
transformers
torch>=2.1.0
torchvision>=0.16.0
accelerate>=0.31.0
einops>=0.8.0
einconv>=0.1.0
opt_einsum>=3.3.0
safetensors>=0.4.2
tqdm>=4.66.4
datasets>=2.20.0
transformers>=4.41.2
isort==5.13.2
pylint==3.2.3
pytest==8.2.2
Expand Down
8 changes: 4 additions & 4 deletions kronfluence/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,14 @@ class ScoreArguments(Arguments):
"This is useful when performing layer-wise influence analysis."
},
)
num_query_gradient_accumulations: int = field(
default=1,
metadata={"help": "Number of query batches to accumulate over before iterating over training examples."},
)
query_gradient_rank: Optional[int] = field(
default=None,
metadata={"help": "Rank for the query gradient. Does not apply low-rank approximation if None."},
)
num_query_gradient_accumulations: int = field(
default=1,
metadata={"help": "Number of query batches to accumulate over."},
)
use_measurement_for_self_influence: bool = field(
default=False,
metadata={"help": "Whether to use the measurement (instead of the loss) for computing self-influence scores."},
Expand Down
18 changes: 12 additions & 6 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _update_activation_covariance_matrix(self, input_activation: torch.Tensor) -
The input tensor to the module, provided by the PyTorch's forward hook.
"""
input_activation = input_activation.to(dtype=self.factor_args.activation_covariance_dtype)
flattened_activation, count = self._get_flattened_activation(input_activation)
flattened_activation, count = self._get_flattened_activation(input_activation=input_activation)

if self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME] is None:
dimension = flattened_activation.size(1)
Expand Down Expand Up @@ -262,9 +262,9 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N
PyTorch's backward hook.
"""
output_gradient = output_gradient.to(dtype=self.factor_args.gradient_covariance_dtype)
flattened_gradient, count = self._get_flattened_gradient(output_gradient)
flattened_gradient, count = self._get_flattened_gradient(output_gradient=output_gradient)
if self._gradient_scale != 1.0:
flattened_gradient.mul_(self._gradient_scale)
flattened_gradient = flattened_gradient * self._gradient_scale

if self._storage[GRADIENT_COVARIANCE_MATRIX_NAME] is None:
dimension = flattened_gradient.size(1)
Expand Down Expand Up @@ -322,9 +322,10 @@ def _covariance_matrices_available(self) -> bool:
@torch.no_grad()
def synchronize_covariance_matrices(self) -> None:
"""Aggregates covariance matrices across multiple devices or nodes in a distributed setting."""
if dist.is_initialized() and self._covariance_matrices_available():
if dist.is_initialized() and torch.cuda.is_available() and self._covariance_matrices_available():
# Note that only the main process holds the aggregated covariance matrix.
for covariance_factor_name in COVARIANCE_FACTOR_NAMES:
self._storage[covariance_factor_name] = self._storage[covariance_factor_name].cuda()
dist.reduce(
tensor=self._storage[covariance_factor_name],
op=dist.ReduceOp.SUM,
Expand Down Expand Up @@ -518,8 +519,9 @@ def _lambda_matrix_available(self) -> bool:
@torch.no_grad()
def synchronize_lambda_matrices(self) -> None:
"""Aggregates Lambda matrices across multiple devices or nodes in a distributed setting."""
if dist.is_initialized() and self._lambda_matrix_available():
if dist.is_initialized() and torch.cuda.is_available() and self._lambda_matrix_available():
for lambda_factor_name in LAMBDA_FACTOR_NAMES:
self._storage[lambda_factor_name] = self._storage[lambda_factor_name].cuda()
torch.distributed.reduce(
tensor=self._storage[lambda_factor_name],
op=dist.ReduceOp.SUM,
Expand Down Expand Up @@ -693,7 +695,11 @@ def truncate_preconditioned_gradient(self, keep_size: int) -> None:
@torch.no_grad()
def synchronize_preconditioned_gradient(self, num_processes: int) -> None:
"""Stacks preconditioned gradient across multiple devices or nodes in a distributed setting."""
if dist.is_initialized() and self._storage[PRECONDITIONED_GRADIENT_NAME] is not None:
if (
dist.is_initialized()
and torch.cuda.is_available()
and self._storage[PRECONDITIONED_GRADIENT_NAME] is not None
):
if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list):
for i in range(len(self._storage[PRECONDITIONED_GRADIENT_NAME])):
size = self._storage[PRECONDITIONED_GRADIENT_NAME][i].size()
Expand Down
Empty file.
71 changes: 71 additions & 0 deletions kronfluence/utils/common/factor_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch

from kronfluence import FactorArguments


def default_factor_arguments(strategy: str = "ekfac") -> FactorArguments:
"""Default factor arguments."""
factor_args = FactorArguments(strategy=strategy)
return factor_args


def test_factor_arguments(strategy: str = "ekfac") -> FactorArguments:
"""Factor arguments used for unit tests."""
factor_args = FactorArguments(strategy=strategy)
factor_args.use_empirical_fisher = True
factor_args.activation_covariance_dtype = torch.float64
factor_args.gradient_covariance_dtype = torch.float64
factor_args.per_sample_gradient_dtype = torch.float64
factor_args.lambda_dtype = torch.float32
return factor_args


def smart_low_precision_factor_arguments(
strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16
) -> FactorArguments:
"""Factor arguments with low precision, except for the lambda computations."""
factor_args = FactorArguments(strategy=strategy)
factor_args.amp_dtype = dtype
factor_args.activation_covariance_dtype = dtype
factor_args.gradient_covariance_dtype = dtype
factor_args.per_sample_gradient_dtype = dtype
factor_args.lambda_dtype = torch.float32
return factor_args


def all_low_precision_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments:
"""Factor arguments with low precision."""
factor_args = FactorArguments(strategy=strategy)
factor_args.amp_dtype = dtype
factor_args.activation_covariance_dtype = dtype
factor_args.gradient_covariance_dtype = dtype
factor_args.per_sample_gradient_dtype = dtype
factor_args.lambda_dtype = dtype
return factor_args


def reduce_memory_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments:
"""Factor arguments with low precision + iterative lambda update."""
factor_args = all_low_precision_factor_arguments(strategy=strategy, dtype=dtype)
factor_args.lambda_iterative_aggregate = True
return factor_args


def extreme_reduce_memory_factor_arguments(
strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16
) -> FactorArguments:
"""Factor arguments for models that is difficult to fit in a single GPU."""
factor_args = all_low_precision_factor_arguments(strategy=strategy, dtype=dtype)
factor_args.lambda_iterative_aggregate = True
factor_args.cached_activation_cpu_offload = True
factor_args.covariance_module_partition_size = 4
factor_args.lambda_module_partition_size = 4
return factor_args


def large_dataset_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments:
"""Factor arguments for large models and datasets."""
factor_args = smart_low_precision_factor_arguments(strategy=strategy, dtype=dtype)
factor_args.covariance_data_partition_size = 4
factor_args.lambda_data_partition_size = 4
return factor_args
96 changes: 96 additions & 0 deletions kronfluence/utils/common/score_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Optional

import torch

from kronfluence import ScoreArguments


def default_score_arguments(
damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None
) -> ScoreArguments:
"""Default score arguments."""
score_args = ScoreArguments(damping=damping)
score_args.query_gradient_rank = query_gradient_rank
if score_args.query_gradient_rank is not None:
score_args.num_query_gradient_accumulations = 10
return score_args


def test_score_arguments(damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None) -> ScoreArguments:
"""Score arguments used for unit tests."""
score_args = ScoreArguments(damping=damping)
score_args.query_gradient_svd_dtype = torch.float64
score_args.score_dtype = torch.float64
score_args.per_sample_gradient_dtype = torch.float64
score_args.precondition_dtype = torch.float64
score_args.query_gradient_rank = query_gradient_rank
return score_args


def smart_low_precision_score_arguments(
damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16
) -> ScoreArguments:
"""Score arguments with low precision, except for the preconditioning computations."""
score_args = ScoreArguments(damping=damping)
score_args.amp_dtype = dtype
score_args.query_gradient_svd_dtype = torch.float32
score_args.score_dtype = dtype
score_args.per_sample_gradient_dtype = dtype
score_args.precondition_dtype = torch.float32
score_args.query_gradient_rank = query_gradient_rank
if score_args.query_gradient_rank is not None:
score_args.num_query_gradient_accumulations = 10
return score_args


def all_low_precision_score_arguments(
damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16
) -> ScoreArguments:
"""Score arguments with low precision."""
score_args = ScoreArguments(damping=damping)
score_args.amp_dtype = dtype
score_args.query_gradient_svd_dtype = torch.float32
score_args.score_dtype = dtype
score_args.per_sample_gradient_dtype = dtype
score_args.precondition_dtype = dtype
score_args.query_gradient_rank = query_gradient_rank
if score_args.query_gradient_rank is not None:
score_args.num_query_gradient_accumulations = 10
return score_args


def reduce_memory_score_arguments(
damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16
) -> ScoreArguments:
"""Score arguments with low precision + CPU offload."""
score_args = all_low_precision_score_arguments(damping=damping, dtype=dtype)
score_args.cached_activation_cpu_offload = True
score_args.query_gradient_rank = query_gradient_rank
if score_args.query_gradient_rank is not None:
score_args.num_query_gradient_accumulations = 10
return score_args


def extreme_reduce_memory_score_arguments(
damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16
) -> ScoreArguments:
"""Score arguments for models that is difficult to fit in a single GPU."""
score_args = all_low_precision_score_arguments(damping=damping, dtype=dtype)
score_args.cached_activation_cpu_offload = True
score_args.query_gradient_rank = query_gradient_rank
score_args.module_partition_size = 4
if score_args.query_gradient_rank is not None:
score_args.num_query_gradient_accumulations = 10
return score_args


def large_dataset_score_arguments(
damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16
) -> ScoreArguments:
"""Score arguments for large models and datasets."""
score_args = smart_low_precision_score_arguments(damping=damping, dtype=dtype)
score_args.data_partition_size = 4
score_args.query_gradient_rank = query_gradient_rank
if score_args.query_gradient_rank is not None:
score_args.num_query_gradient_accumulations = 10
return score_args
Loading

0 comments on commit 180c2ec

Please sign in to comment.