diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 3aa529f..b0f1e5c 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -97,7 +97,7 @@ After calling `prepare_model`, you can create [DistributedDataParallel (DDP)](ht **Set up the Analyzer and Fit Factors.** Initialize the `Analyzer` and run `fit_all_factors` to compute all factors that aim to approximate the Hessian -(Gauss-Newton Hessian). The computed factors will be stored on disk. +([Gauss-Newton Hessian](https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2022/readings/L03_metrics.pdf)). The computed factors will be stored on disk. ```python from kronfluence.analyzer import Analyzer @@ -182,8 +182,8 @@ def forward(x: torch.Tensor) -> torch.Tensor: > [!WARNING] > The default arguments assume the module is used only once during the forward pass. -> IIf your model shares parameters (e.g., the module is used in multiple places during the forward pass), set -> `shared_parameters_exist=True` in both `FactorArguments` and `ScoreArguments`. +> If your model shares parameters (e.g., the module is used in multiple places during the forward pass), set +> `shared_parameters_exist=True` in `FactorArguments`. **Why are there so many arguments?** Kronfluence was originally developed to compute influence scores on large-scale models, which is why `FactorArguments` and `ScoreArguments` @@ -206,6 +206,7 @@ factor_args = FactorArguments( use_empirical_fisher=False, distributed_sync_steps=1000, amp_dtype=None, + shared_parameters_exist=False, # Settings for covariance matrix fitting. covariance_max_examples=100_000, @@ -223,7 +224,6 @@ factor_args = FactorArguments( lambda_module_partition_size=1, lambda_iterative_aggregate=False, cached_activation_cpu_offload=False, - shared_parameters_exist=False, per_sample_gradient_dtype=torch.float32, lambda_dtype=torch.float32, ) @@ -237,6 +237,7 @@ You can change: - `use_empirical_fisher`: Determines whether to use the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch) instead of the true Fisher (using sampled labels from model's predictions). It is recommended to be `False`. - `amp_dtype`: Selects the dtype for [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html). Disables AMP if set to `None`. +- `shared_parameters_exist`: Specifies whether the shared parameters exist in the forward pass. ### Fitting Covariance Matrices @@ -306,7 +307,6 @@ This corresponds to **Equation 20** in the paper. You can tune: You can set `cached_activation_cpu_offload=True` to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower. - `lambda_iterative_aggregate`: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications. This is helpful for reducing peak GPU memory, as it avoids holding multiple copies of tensors with the same shape as the per-sample-gradient. -- `shared_parameters_exist`: Specifies whether the shared parameters exist in the forward pass. - `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can also use `torch.bfloat16` or `torch.float16`. - `lambda_dtype`: `dtype` for computing Lambda matrices. You can also use `torch.bfloat16` @@ -353,7 +353,7 @@ score_args = ScoreArguments( # Configuration for query batching. query_gradient_rank=None, query_gradient_svd_dtype=torch.float32, - num_query_gradient_aggregations=1, + num_query_gradient_accumulations=1, # Configuration for dtype. score_dtype=torch.float32, @@ -370,11 +370,11 @@ score_args = ScoreArguments( - `module_partition_size`: Number of module partitions for computing influence scores. - `per_module_score`: Whether to return a per-module influence scores. Instead of summing over influences across all modules, this will keep track of intermediate module-wise scores. +- - `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores. - `query_gradient_rank`: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see **Section 3.2.2**). If `None`, no query batching will be used. - `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float64`. -- `num_query_gradient_aggregations`: Number of query gradients to aggregate over. For example, when `num_query_gradient_aggregations=2` with +- `num_query_gradient_accumulations`: Number of query gradients to accumulate over. For example, when `num_query_gradient_accumulations=2` with `query_batch_size=16`, a total of 32 query gradients will be stored in memory when computing dot products with training gradients. -- `use_measurement_for_self_influence`: Whether to use the measurement (instead of the loss) when computing self-influence scores. - `score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`. - `per_sample_gradient_dtype`: `dtype` for computing per-sample-gradient. You can use `torch.bfloat16` or `torch.float16`. - `precondition_dtype`: `dtype` for performing preconditioning. You can use `torch.bfloat16` or `torch.float16`, diff --git a/dev_requirements.txt b/dev_requirements.txt index ff0b913..7c16350 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,13 +1,13 @@ -torch -torchvision -accelerate -einops -einconv -opt_einsum -safetensors -tqdm -datasets -transformers +torch>=2.1.0 +torchvision>=0.16.0 +accelerate>=0.31.0 +einops>=0.8.0 +einconv>=0.1.0 +opt_einsum>=3.3.0 +safetensors>=0.4.2 +tqdm>=4.66.4 +datasets>=2.20.0 +transformers>=4.41.2 isort==5.13.2 pylint==3.2.3 pytest==8.2.2 diff --git a/kronfluence/arguments.py b/kronfluence/arguments.py index c557adf..db3978d 100644 --- a/kronfluence/arguments.py +++ b/kronfluence/arguments.py @@ -192,14 +192,14 @@ class ScoreArguments(Arguments): "This is useful when performing layer-wise influence analysis." }, ) + num_query_gradient_accumulations: int = field( + default=1, + metadata={"help": "Number of query batches to accumulate over before iterating over training examples."}, + ) query_gradient_rank: Optional[int] = field( default=None, metadata={"help": "Rank for the query gradient. Does not apply low-rank approximation if None."}, ) - num_query_gradient_accumulations: int = field( - default=1, - metadata={"help": "Number of query batches to accumulate over."}, - ) use_measurement_for_self_influence: bool = field( default=False, metadata={"help": "Whether to use the measurement (instead of the loss) for computing self-influence scores."}, diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index ca2a2a9..e6161b1 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -212,7 +212,7 @@ def _update_activation_covariance_matrix(self, input_activation: torch.Tensor) - The input tensor to the module, provided by the PyTorch's forward hook. """ input_activation = input_activation.to(dtype=self.factor_args.activation_covariance_dtype) - flattened_activation, count = self._get_flattened_activation(input_activation) + flattened_activation, count = self._get_flattened_activation(input_activation=input_activation) if self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME] is None: dimension = flattened_activation.size(1) @@ -262,9 +262,9 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N PyTorch's backward hook. """ output_gradient = output_gradient.to(dtype=self.factor_args.gradient_covariance_dtype) - flattened_gradient, count = self._get_flattened_gradient(output_gradient) + flattened_gradient, count = self._get_flattened_gradient(output_gradient=output_gradient) if self._gradient_scale != 1.0: - flattened_gradient.mul_(self._gradient_scale) + flattened_gradient = flattened_gradient * self._gradient_scale if self._storage[GRADIENT_COVARIANCE_MATRIX_NAME] is None: dimension = flattened_gradient.size(1) @@ -322,9 +322,10 @@ def _covariance_matrices_available(self) -> bool: @torch.no_grad() def synchronize_covariance_matrices(self) -> None: """Aggregates covariance matrices across multiple devices or nodes in a distributed setting.""" - if dist.is_initialized() and self._covariance_matrices_available(): + if dist.is_initialized() and torch.cuda.is_available() and self._covariance_matrices_available(): # Note that only the main process holds the aggregated covariance matrix. for covariance_factor_name in COVARIANCE_FACTOR_NAMES: + self._storage[covariance_factor_name] = self._storage[covariance_factor_name].cuda() dist.reduce( tensor=self._storage[covariance_factor_name], op=dist.ReduceOp.SUM, @@ -518,8 +519,9 @@ def _lambda_matrix_available(self) -> bool: @torch.no_grad() def synchronize_lambda_matrices(self) -> None: """Aggregates Lambda matrices across multiple devices or nodes in a distributed setting.""" - if dist.is_initialized() and self._lambda_matrix_available(): + if dist.is_initialized() and torch.cuda.is_available() and self._lambda_matrix_available(): for lambda_factor_name in LAMBDA_FACTOR_NAMES: + self._storage[lambda_factor_name] = self._storage[lambda_factor_name].cuda() torch.distributed.reduce( tensor=self._storage[lambda_factor_name], op=dist.ReduceOp.SUM, @@ -693,7 +695,11 @@ def truncate_preconditioned_gradient(self, keep_size: int) -> None: @torch.no_grad() def synchronize_preconditioned_gradient(self, num_processes: int) -> None: """Stacks preconditioned gradient across multiple devices or nodes in a distributed setting.""" - if dist.is_initialized() and self._storage[PRECONDITIONED_GRADIENT_NAME] is not None: + if ( + dist.is_initialized() + and torch.cuda.is_available() + and self._storage[PRECONDITIONED_GRADIENT_NAME] is not None + ): if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list): for i in range(len(self._storage[PRECONDITIONED_GRADIENT_NAME])): size = self._storage[PRECONDITIONED_GRADIENT_NAME][i].size() diff --git a/kronfluence/utils/common/__init__.py b/kronfluence/utils/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kronfluence/utils/common/factor_arguments.py b/kronfluence/utils/common/factor_arguments.py new file mode 100644 index 0000000..122baea --- /dev/null +++ b/kronfluence/utils/common/factor_arguments.py @@ -0,0 +1,71 @@ +import torch + +from kronfluence import FactorArguments + + +def default_factor_arguments(strategy: str = "ekfac") -> FactorArguments: + """Default factor arguments.""" + factor_args = FactorArguments(strategy=strategy) + return factor_args + + +def test_factor_arguments(strategy: str = "ekfac") -> FactorArguments: + """Factor arguments used for unit tests.""" + factor_args = FactorArguments(strategy=strategy) + factor_args.use_empirical_fisher = True + factor_args.activation_covariance_dtype = torch.float64 + factor_args.gradient_covariance_dtype = torch.float64 + factor_args.per_sample_gradient_dtype = torch.float64 + factor_args.lambda_dtype = torch.float32 + return factor_args + + +def smart_low_precision_factor_arguments( + strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16 +) -> FactorArguments: + """Factor arguments with low precision, except for the lambda computations.""" + factor_args = FactorArguments(strategy=strategy) + factor_args.amp_dtype = dtype + factor_args.activation_covariance_dtype = dtype + factor_args.gradient_covariance_dtype = dtype + factor_args.per_sample_gradient_dtype = dtype + factor_args.lambda_dtype = torch.float32 + return factor_args + + +def all_low_precision_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments: + """Factor arguments with low precision.""" + factor_args = FactorArguments(strategy=strategy) + factor_args.amp_dtype = dtype + factor_args.activation_covariance_dtype = dtype + factor_args.gradient_covariance_dtype = dtype + factor_args.per_sample_gradient_dtype = dtype + factor_args.lambda_dtype = dtype + return factor_args + + +def reduce_memory_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments: + """Factor arguments with low precision + iterative lambda update.""" + factor_args = all_low_precision_factor_arguments(strategy=strategy, dtype=dtype) + factor_args.lambda_iterative_aggregate = True + return factor_args + + +def extreme_reduce_memory_factor_arguments( + strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16 +) -> FactorArguments: + """Factor arguments for models that is difficult to fit in a single GPU.""" + factor_args = all_low_precision_factor_arguments(strategy=strategy, dtype=dtype) + factor_args.lambda_iterative_aggregate = True + factor_args.cached_activation_cpu_offload = True + factor_args.covariance_module_partition_size = 4 + factor_args.lambda_module_partition_size = 4 + return factor_args + + +def large_dataset_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments: + """Factor arguments for large models and datasets.""" + factor_args = smart_low_precision_factor_arguments(strategy=strategy, dtype=dtype) + factor_args.covariance_data_partition_size = 4 + factor_args.lambda_data_partition_size = 4 + return factor_args diff --git a/kronfluence/utils/common/score_arguments.py b/kronfluence/utils/common/score_arguments.py new file mode 100644 index 0000000..e674432 --- /dev/null +++ b/kronfluence/utils/common/score_arguments.py @@ -0,0 +1,96 @@ +from typing import Optional + +import torch + +from kronfluence import ScoreArguments + + +def default_score_arguments( + damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None +) -> ScoreArguments: + """Default score arguments.""" + score_args = ScoreArguments(damping=damping) + score_args.query_gradient_rank = query_gradient_rank + if score_args.query_gradient_rank is not None: + score_args.num_query_gradient_accumulations = 10 + return score_args + + +def test_score_arguments(damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None) -> ScoreArguments: + """Score arguments used for unit tests.""" + score_args = ScoreArguments(damping=damping) + score_args.query_gradient_svd_dtype = torch.float64 + score_args.score_dtype = torch.float64 + score_args.per_sample_gradient_dtype = torch.float64 + score_args.precondition_dtype = torch.float64 + score_args.query_gradient_rank = query_gradient_rank + return score_args + + +def smart_low_precision_score_arguments( + damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 +) -> ScoreArguments: + """Score arguments with low precision, except for the preconditioning computations.""" + score_args = ScoreArguments(damping=damping) + score_args.amp_dtype = dtype + score_args.query_gradient_svd_dtype = torch.float32 + score_args.score_dtype = dtype + score_args.per_sample_gradient_dtype = dtype + score_args.precondition_dtype = torch.float32 + score_args.query_gradient_rank = query_gradient_rank + if score_args.query_gradient_rank is not None: + score_args.num_query_gradient_accumulations = 10 + return score_args + + +def all_low_precision_score_arguments( + damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 +) -> ScoreArguments: + """Score arguments with low precision.""" + score_args = ScoreArguments(damping=damping) + score_args.amp_dtype = dtype + score_args.query_gradient_svd_dtype = torch.float32 + score_args.score_dtype = dtype + score_args.per_sample_gradient_dtype = dtype + score_args.precondition_dtype = dtype + score_args.query_gradient_rank = query_gradient_rank + if score_args.query_gradient_rank is not None: + score_args.num_query_gradient_accumulations = 10 + return score_args + + +def reduce_memory_score_arguments( + damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 +) -> ScoreArguments: + """Score arguments with low precision + CPU offload.""" + score_args = all_low_precision_score_arguments(damping=damping, dtype=dtype) + score_args.cached_activation_cpu_offload = True + score_args.query_gradient_rank = query_gradient_rank + if score_args.query_gradient_rank is not None: + score_args.num_query_gradient_accumulations = 10 + return score_args + + +def extreme_reduce_memory_score_arguments( + damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 +) -> ScoreArguments: + """Score arguments for models that is difficult to fit in a single GPU.""" + score_args = all_low_precision_score_arguments(damping=damping, dtype=dtype) + score_args.cached_activation_cpu_offload = True + score_args.query_gradient_rank = query_gradient_rank + score_args.module_partition_size = 4 + if score_args.query_gradient_rank is not None: + score_args.num_query_gradient_accumulations = 10 + return score_args + + +def large_dataset_score_arguments( + damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 +) -> ScoreArguments: + """Score arguments for large models and datasets.""" + score_args = smart_low_precision_score_arguments(damping=damping, dtype=dtype) + score_args.data_partition_size = 4 + score_args.query_gradient_rank = query_gradient_rank + if score_args.query_gradient_rank is not None: + score_args.num_query_gradient_accumulations = 10 + return score_args diff --git a/tests/factors/test_covariances.py b/tests/factors/test_covariances.py index f970f1b..4ab3240 100644 --- a/tests/factors/test_covariances.py +++ b/tests/factors/test_covariances.py @@ -3,7 +3,10 @@ import pytest import torch -from kronfluence.arguments import FactorArguments +from kronfluence.utils.common.factor_arguments import ( + default_factor_arguments, + test_factor_arguments, +) from kronfluence.utils.constants import ( ACTIVATION_COVARIANCE_MATRIX_NAME, COVARIANCE_FACTOR_NAMES, @@ -44,7 +47,7 @@ def test_fit_covariance_matrices( train_size: int, seed: int, ) -> None: - # Make sure that the covariance computations are working properly. + # Makes sure that the covariance computations are working properly. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -56,10 +59,9 @@ def test_fit_covariance_matrices( task=task, ) - factor_args = FactorArguments( - activation_covariance_dtype=activation_covariance_dtype, - gradient_covariance_dtype=gradient_covariance_dtype, - ) + factor_args = default_factor_arguments() + factor_args.activation_covariance_dtype = activation_covariance_dtype + factor_args.gradient_covariance_dtype = gradient_covariance_dtype factors_name = f"pytest_{test_name}_{test_fit_covariance_matrices.__name__}" analyzer.fit_covariance_matrices( factors_name=factors_name, @@ -106,11 +108,7 @@ def test_covariance_matrices_batch_size_equivalence( task=task, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - ) + factor_args = test_factor_arguments() analyzer.fit_covariance_matrices( factors_name=f"pytest_{test_name}_{test_covariance_matrices_batch_size_equivalence.__name__}_bs1", dataset=train_dataset, @@ -149,12 +147,10 @@ def test_covariance_matrices_batch_size_equivalence( [ "mlp", "conv", - "conv_bn", - "gpt", ], ) -@pytest.mark.parametrize("data_partition_size", [1, 4]) -@pytest.mark.parametrize("module_partition_size", [1, 3]) +@pytest.mark.parametrize("data_partition_size", [2, 4]) +@pytest.mark.parametrize("module_partition_size", [2, 3]) @pytest.mark.parametrize("train_size", [62]) @pytest.mark.parametrize("seed", [2]) def test_covariance_matrices_partition_equivalence( @@ -176,11 +172,7 @@ def test_covariance_matrices_partition_equivalence( task=task, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - ) + factor_args = test_factor_arguments() factors_name = f"pytest_{test_name}_{test_covariance_matrices_partition_equivalence.__name__}" analyzer.fit_covariance_matrices( factors_name=factors_name, @@ -192,13 +184,8 @@ def test_covariance_matrices_partition_equivalence( ) covariance_factors = analyzer.load_covariance_matrices(factors_name=factors_name) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - covariance_data_partition_size=data_partition_size, - covariance_module_partition_size=module_partition_size, - ) + factor_args.covariance_data_partition_size = data_partition_size + factor_args.covariance_module_partition_size = module_partition_size analyzer.fit_covariance_matrices( factors_name=f"pytest_{test_name}_partitioned_{data_partition_size}_{module_partition_size}", dataset=train_dataset, @@ -228,7 +215,7 @@ def test_covariance_matrices_attention_mask( train_size: int, seed: int, ) -> None: - # Make sure the attention mask is correctly implemented by comparing with the results + # Makes sure the attention mask is correctly implemented by comparing with the results # without any padding applied (and batch size of 1). model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, @@ -255,11 +242,7 @@ def test_covariance_matrices_attention_mask( task=task, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - ) + factor_args = test_factor_arguments() factors_name = f"pytest_{test_name}_{test_covariance_matrices_attention_mask.__name__}" analyzer.fit_covariance_matrices( factors_name=factors_name, @@ -314,7 +297,7 @@ def test_covariance_matrices_automatic_batch_size( train_size: int, seed: int, ) -> None: - # Make sure the automatic batch size search feature is working properly. + # Makes sure the automatic batch size search feature is working properly. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -326,11 +309,7 @@ def test_covariance_matrices_automatic_batch_size( task=task, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - ) + factor_args = test_factor_arguments() factors_name = f"pytest_{test_name}_{test_covariance_matrices_automatic_batch_size.__name__}" analyzer.fit_covariance_matrices( factors_name=factors_name, @@ -373,7 +352,7 @@ def test_covariance_matrices_max_examples( train_size: int, seed: int, ) -> None: - # Make sure the max covariance data selection is working properly. + # Makes sure the max covariance data selection is working properly. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -386,11 +365,10 @@ def test_covariance_matrices_max_examples( ) MAX_EXAMPLES = 26 - factor_args = FactorArguments( - use_empirical_fisher=True, - covariance_max_examples=MAX_EXAMPLES, - covariance_data_partition_size=data_partition_size, - ) + factor_args = test_factor_arguments() + factor_args.covariance_max_examples = MAX_EXAMPLES + factor_args.covariance_data_partition_size = data_partition_size + factors_name = f"pytest_{test_name}_{test_covariance_matrices_max_examples.__name__}" analyzer.fit_covariance_matrices( factors_name=factors_name, @@ -435,11 +413,7 @@ def test_covariance_matrices_amp( task=task, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - ) + factor_args = test_factor_arguments() analyzer.fit_covariance_matrices( factors_name=f"pytest_{test_name}_{test_covariance_matrices_amp.__name__}", dataset=train_dataset, @@ -452,12 +426,7 @@ def test_covariance_matrices_amp( factors_name=f"pytest_{test_name}_{test_covariance_matrices_amp.__name__}" ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - amp_dtype=torch.float16, - ) + factor_args.amp_dtype = torch.float16 analyzer.fit_covariance_matrices( factors_name=f"pytest_{test_name}_{test_covariance_matrices_amp.__name__}_amp", dataset=train_dataset, @@ -496,11 +465,7 @@ def test_covariance_matrices_gradient_checkpoint( task=task, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - ) + factor_args = test_factor_arguments() analyzer.fit_covariance_matrices( factors_name=f"pytest_{test_covariance_matrices_gradient_checkpoint.__name__}", dataset=train_dataset, diff --git a/tests/factors/test_eigens.py b/tests/factors/test_eigens.py index e2cae9b..cb8facb 100644 --- a/tests/factors/test_eigens.py +++ b/tests/factors/test_eigens.py @@ -4,6 +4,7 @@ import torch from kronfluence.arguments import FactorArguments +from kronfluence.utils.common.factor_arguments import test_factor_arguments from kronfluence.utils.constants import ( ACTIVATION_EIGENVECTORS_NAME, EIGENDECOMPOSITION_FACTOR_NAMES, @@ -26,6 +27,7 @@ "test_name", [ "mlp", + "repeated_mlp", "mlp_checkpoint", "conv", "conv_bn", @@ -42,7 +44,7 @@ def test_perform_eigendecomposition( train_size: int, seed: int, ) -> None: - # Make sure that the Eigendecomposition computations are working properly. + # Makes sure that the Eigendecomposition computations are working properly. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -53,17 +55,18 @@ def test_perform_eigendecomposition( model=model, task=task, ) + factor_args = FactorArguments( + eigendecomposition_dtype=eigendecomposition_dtype, + ) factors_name = f"pytest_{test_name}_{test_perform_eigendecomposition.__name__}" analyzer.fit_covariance_matrices( factors_name=factors_name, + factor_args=factor_args, dataset=train_dataset, per_device_batch_size=4, overwrite_output_dir=True, dataloader_kwargs=kwargs, ) - factor_args = FactorArguments( - eigendecomposition_dtype=eigendecomposition_dtype, - ) analyzer.perform_eigendecomposition( factors_name=factors_name, factor_args=factor_args, @@ -81,6 +84,8 @@ def test_perform_eigendecomposition( "test_name", [ "mlp", + "repeated_mlp", + "mlp_checkpoint", "conv", "conv_bn", "bert", @@ -98,7 +103,7 @@ def test_fit_lambda_matrices( train_size: int, seed: int, ) -> None: - # Make sure that the Lambda computations are working properly. + # Makes sure that the Lambda computations are working properly. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -114,6 +119,9 @@ def test_fit_lambda_matrices( lambda_dtype=lambda_dtype, per_sample_gradient_dtype=per_sample_gradient_dtype, ) + if test_name == "repeated_mlp": + factor_args.shared_parameters_exist = True + factors_name = f"pytest_{test_name}_{test_fit_lambda_matrices.__name__}" analyzer.fit_all_factors( factors_name=factors_name, @@ -160,13 +168,7 @@ def test_lambda_matrices_batch_size_equivalence( task=task, ) - factor_args = FactorArguments( - strategy=strategy, - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = test_factor_arguments(strategy=strategy) analyzer.fit_all_factors( factors_name=f"pytest_{test_name}_{test_lambda_matrices_batch_size_equivalence.__name__}_{strategy}_bs1", dataset=train_dataset, @@ -200,13 +202,11 @@ def test_lambda_matrices_batch_size_equivalence( [ "mlp", "conv", - "conv_bn", - "gpt", ], ) @pytest.mark.parametrize("strategy", ["diagonal", "ekfac"]) -@pytest.mark.parametrize("data_partition_size", [1, 4]) -@pytest.mark.parametrize("module_partition_size", [1, 3]) +@pytest.mark.parametrize("data_partition_size", [4]) +@pytest.mark.parametrize("module_partition_size", [3]) @pytest.mark.parametrize("train_size", [81]) @pytest.mark.parametrize("seed", [2]) def test_lambda_matrices_partition_equivalence( @@ -229,13 +229,7 @@ def test_lambda_matrices_partition_equivalence( task=task, ) - factor_args = FactorArguments( - strategy=strategy, - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = test_factor_arguments(strategy=strategy) factors_name = f"pytest_{test_name}_{strategy}_{test_lambda_matrices_partition_equivalence.__name__}" analyzer.fit_all_factors( factors_name=factors_name, @@ -249,15 +243,8 @@ def test_lambda_matrices_partition_equivalence( factors_name=factors_name, ) - factor_args = FactorArguments( - strategy=strategy, - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - lambda_data_partition_size=data_partition_size, - lambda_module_partition_size=module_partition_size, - ) + factor_args.lambda_data_partition_size = data_partition_size + factor_args.lambda_module_partition_size = module_partition_size analyzer.fit_all_factors( factors_name=f"pytest_{test_name}_{strategy}_{data_partition_size}_{module_partition_size}", dataset=train_dataset, @@ -291,26 +278,22 @@ def test_lambda_matrices_iterative_aggregate( train_size: int, seed: int, ) -> None: - # Make sure aggregated lambda computation is working properly. + # Makes sure iterative lambda computation is working properly. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, seed=seed, ) kwargs = DataLoaderKwargs(collate_fn=data_collator) + model = model.to(dtype=torch.float64) model, analyzer = prepare_model_and_analyzer( model=model, task=task, ) factors_name = f"pytest_{test_name}_{test_lambda_matrices_iterative_aggregate.__name__}" - factor_args = FactorArguments( - use_empirical_fisher=True, - lambda_iterative_aggregate=False, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = test_factor_arguments() + factor_args.lambda_iterative_aggregate = False analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, @@ -323,13 +306,7 @@ def test_lambda_matrices_iterative_aggregate( factors_name=factors_name, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - lambda_iterative_aggregate=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args.lambda_iterative_aggregate = True analyzer.fit_all_factors( factors_name=factors_name + "_iterative", dataset=train_dataset, @@ -359,7 +336,7 @@ def test_lambda_matrices_max_examples( train_size: int, seed: int, ) -> None: - # Make sure the max Lambda data selection is working properly. + # Makes sure the max Lambda data selection is working properly. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -371,7 +348,7 @@ def test_lambda_matrices_max_examples( task=task, ) - MAX_EXAMPLES = 28 + MAX_EXAMPLES = 33 factor_args = FactorArguments( use_empirical_fisher=True, lambda_max_examples=MAX_EXAMPLES, lambda_data_partition_size=data_partition_size ) @@ -417,12 +394,7 @@ def test_lambda_matrices_amp( task=task, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - lambda_dtype=torch.float64, - ) + factor_args = test_factor_arguments() analyzer.fit_all_factors( factors_name=f"pytest_{test_name}_{test_lambda_matrices_amp.__name__}", dataset=train_dataset, @@ -435,13 +407,7 @@ def test_lambda_matrices_amp( factors_name=f"pytest_{test_name}_{test_lambda_matrices_amp.__name__}" ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - amp_dtype=torch.float16, - lambda_dtype=torch.float64, - ) + factor_args.amp_dtype = torch.float16 analyzer.fit_all_factors( factors_name=f"pytest_{test_name}_{test_lambda_matrices_amp.__name__}_amp", dataset=train_dataset, @@ -475,11 +441,7 @@ def test_lambda_matrices_gradient_checkpoint( task=task, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - ) + factor_args = test_factor_arguments() analyzer.fit_all_factors( factors_name=f"pytest_{test_lambda_matrices_gradient_checkpoint.__name__}", dataset=train_dataset, @@ -523,7 +485,8 @@ def test_lambda_matrices_shared_parameters( train_size: int, seed: int, ) -> None: - # When there are no shared parameters, they should have identical results. + # When there are no shared parameters, results with and without `shared_parameters_exist` should + # produce the same results. model, train_dataset, _, data_collator, task = prepare_test( test_name="mlp", train_size=train_size, @@ -534,12 +497,7 @@ def test_lambda_matrices_shared_parameters( task=task, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - shared_parameters_exist=False, - ) + factor_args = test_factor_arguments() analyzer.fit_all_factors( factors_name=f"pytest_{test_lambda_matrices_shared_parameters.__name__}", dataset=train_dataset, @@ -561,12 +519,7 @@ def test_lambda_matrices_shared_parameters( task=task, ) - factor_args = FactorArguments( - use_empirical_fisher=True, - activation_covariance_dtype=torch.float64, - gradient_covariance_dtype=torch.float64, - shared_parameters_exist=True, - ) + factor_args.shared_parameters_exist = True analyzer.fit_all_factors( factors_name=f"pytest_{test_lambda_matrices_shared_parameters.__name__}_shared", dataset=train_dataset, diff --git a/tests/gpu_tests/README.md b/tests/gpu_tests/README.md index 4d1dd31..dee6108 100644 --- a/tests/gpu_tests/README.md +++ b/tests/gpu_tests/README.md @@ -2,42 +2,55 @@ This folder contains various GPU tests for Kronfluence. Before running the tests, you need to prepare the baseline results by training an MNIST model and saving the results obtained with a single GPU: + ```bash python prepare_tests.py ``` ### CPU Tests + To test if running on CPU yields the same result as the GPU, run: + ```bash python cpu_test.py ``` ### DDP Tests + To test if running with Distributed Data Parallel (DDP) with 3 GPUs obtains the same result, run: + ```bash torchrun --nnodes=1 --nproc_per_node=3 ddp_test.py ``` ### FSDP Tests + To test if running with Fully Sharded Data Parallel (FSDP) with 3 GPUs obtains the same result, run: + ```bash torchrun --nnodes=1 --nproc_per_node=3 fsdp_test.py ``` ### torch.compile Tests + To test if running with `torch.compile` obtains the same result, run: + ```bash python compile_test.py ``` ### AMP Tests + To test if running with automatic mixed precision (AMP) obtains the similar result, run: + ```bash python amp_test.py ``` ### CPU Offload Test + To test if `cached_activation_cpu_offload` option is properly implemented, run: + ```bash pytest test_offload_cpu.py ``` \ No newline at end of file diff --git a/tests/gpu_tests/amp_test.py b/tests/gpu_tests/amp_test.py index 7c27df8..8f2ad17 100644 --- a/tests/gpu_tests/amp_test.py +++ b/tests/gpu_tests/amp_test.py @@ -69,8 +69,8 @@ def test_covariance_matrices(self) -> None: assert check_tensor_dict_equivalence( covariance_factors[name], new_covariance_factors[name], - atol=1e-5, - rtol=1e-3, + atol=1e-3, + rtol=1e-1, ) def test_lambda_matrices(self): diff --git a/tests/gpu_tests/compile_test.py b/tests/gpu_tests/compile_test.py index 61e9db5..a997ddf 100644 --- a/tests/gpu_tests/compile_test.py +++ b/tests/gpu_tests/compile_test.py @@ -122,7 +122,6 @@ def test_pairwise_scores(self) -> None: ) new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) - torch.set_printoptions(threshold=30_000) print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][10]}") print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][10]}") @@ -152,7 +151,6 @@ def test_self_scores(self) -> None: new_self_scores = self.analyzer.load_self_scores(scores_name=NEW_SCORE_NAME) self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME) - torch.set_printoptions(threshold=30_000) print(f"Previous score: {self_scores[ALL_MODULE_NAME]}") print(f"Previous shape: {self_scores[ALL_MODULE_NAME].shape}") print(f"New score: {new_self_scores[ALL_MODULE_NAME]}") diff --git a/tests/gpu_tests/ddp_test.py b/tests/gpu_tests/ddp_test.py index d310f53..a05d569 100644 --- a/tests/gpu_tests/ddp_test.py +++ b/tests/gpu_tests/ddp_test.py @@ -230,12 +230,52 @@ def test_lr_pairwise_scores(self) -> None: rtol=1e-1, ) - def test_lr_aggregate_pairwise_scores(self) -> None: + def test_per_module_pairwise_scores(self) -> None: + pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb") + + score_args = ScoreArguments( + per_module_score=True, + score_dtype=torch.float64, + per_sample_gradient_dtype=torch.float64, + precondition_dtype=torch.float64, + query_gradient_svd_dtype=torch.float64, + ) + self.analyzer.compute_pairwise_scores( + scores_name=NEW_SCORE_NAME + "_per_module", + factors_name=OLD_FACTOR_NAME, + query_dataset=self.eval_dataset, + train_dataset=self.train_dataset, + train_indices=list(range(TRAIN_INDICES)), + query_indices=list(range(QUERY_INDICES)), + per_device_query_batch_size=12, + per_device_train_batch_size=512, + score_args=score_args, + overwrite_output_dir=True, + ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="ddp_qb") + + if LOCAL_RANK == 0: + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}") + print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}") + print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}") + print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}") + assert check_tensor_dict_equivalence( + pairwise_scores, + new_pairwise_scores, + atol=1e-3, + rtol=1e-1, + ) + + def test_lr_accumulate_pairwise_scores(self) -> None: pairwise_scores = self.analyzer.load_pairwise_scores(scores_name="single_gpu_qb") score_args = ScoreArguments( query_gradient_rank=32, - num_query_gradient_aggregations=3, + num_query_gradient_accumulations=3, score_dtype=torch.float64, per_sample_gradient_dtype=torch.float64, precondition_dtype=torch.float64, diff --git a/tests/gpu_tests/test_offload_cpu.py b/tests/gpu_tests/test_offload_cpu.py index 655ef68..0066270 100644 --- a/tests/gpu_tests/test_offload_cpu.py +++ b/tests/gpu_tests/test_offload_cpu.py @@ -1,27 +1,14 @@ # pylint: skip-file -from typing import Tuple import pytest -from torch import nn +import torch from kronfluence.analyzer import Analyzer, prepare_model from kronfluence.arguments import FactorArguments, ScoreArguments -from kronfluence.task import Task from kronfluence.utils.constants import ALL_MODULE_NAME from kronfluence.utils.dataset import DataLoaderKwargs -from tests.utils import prepare_test - - -def prepare_model_and_analyzer(model: nn.Module, task: Task) -> Tuple[nn.Module, Analyzer]: - model = prepare_model(model=model, task=task) - analyzer = Analyzer( - analysis_name=f"pytest_{__name__}", - model=model, - task=task, - disable_model_save=True, - ) - return model, analyzer +from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence, prepare_test @pytest.mark.parametrize( @@ -29,6 +16,7 @@ def prepare_model_and_analyzer(model: nn.Module, task: Task) -> Tuple[nn.Module, [ "mlp", "repeated_mlp", + "mlp_checkpoint", "conv", "conv_bn", "bert", @@ -38,7 +26,7 @@ def prepare_model_and_analyzer(model: nn.Module, task: Task) -> Tuple[nn.Module, @pytest.mark.parametrize("cached_activation_cpu_offload", [True, False]) @pytest.mark.parametrize("query_size", [16]) @pytest.mark.parametrize("train_size", [32]) -@pytest.mark.parametrize("seed", [0]) +@pytest.mark.parametrize("seed", [1]) def test_cpu_offloads( test_name: str, cached_activation_cpu_offload: bool, @@ -53,13 +41,19 @@ def test_cpu_offloads( seed=seed, ) kwargs = DataLoaderKwargs(collate_fn=data_collator) - model, analyzer = prepare_model_and_analyzer( + model = prepare_model(model=model, task=task) + analyzer = Analyzer( + analysis_name=f"pytest_{__name__}", model=model, task=task, + disable_model_save=True, + disable_tqdm=True, ) factor_args = FactorArguments( cached_activation_cpu_offload=cached_activation_cpu_offload, ) + if test_name == "repeated_mlp": + factor_args.shared_parameters_exist = True factors_name = f"pytest_{test_name}_{test_cpu_offloads.__name__}" analyzer.fit_all_factors( factors_name=factors_name, @@ -89,3 +83,107 @@ def test_cpu_offloads( pairwise_scores = analyzer.load_pairwise_scores(scores_name=scores_name) assert pairwise_scores[ALL_MODULE_NAME].size(0) == query_size assert pairwise_scores[ALL_MODULE_NAME].size(1) == train_size + + +@pytest.mark.parametrize( + "test_name", + [ + "mlp", + "repeated_mlp", + "mlp_checkpoint", + "conv", + ], +) +@pytest.mark.parametrize("per_module_score", [False, True]) +@pytest.mark.parametrize("query_size", [50]) +@pytest.mark.parametrize("train_size", [102]) +@pytest.mark.parametrize("seed", [1]) +def test_cpu_offloads_identical( + test_name: str, + per_module_score: bool, + query_size: int, + train_size: int, + seed: int, +) -> None: + model, train_dataset, test_dataset, data_collator, task = prepare_test( + test_name=test_name, + query_size=query_size, + train_size=train_size, + seed=seed, + ) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model = model.to(dtype=torch.float64) + model = prepare_model(model=model, task=task) + + analyzer = Analyzer( + analysis_name=f"pytest_{test_cpu_offloads_identical}_{__name__}", + model=model, + task=task, + disable_model_save=True, + disable_tqdm=True, + ) + factor_args = FactorArguments( + use_empirical_fisher=True, + cached_activation_cpu_offload=False, + activation_covariance_dtype=torch.float64, + gradient_covariance_dtype=torch.float64, + lambda_dtype=torch.float64, + ) + if test_name == "repeated_mlp": + factor_args.shared_parameters_exist = True + factors_name = f"pytest_{test_name}_{test_cpu_offloads_identical.__name__}" + analyzer.fit_all_factors( + factors_name=factors_name, + dataset=train_dataset, + dataloader_kwargs=kwargs, + per_device_batch_size=32, + factor_args=factor_args, + overwrite_output_dir=True, + ) + score_args = ScoreArguments( + cached_activation_cpu_offload=False, + per_sample_gradient_dtype=torch.float64, + score_dtype=torch.float64, + precondition_dtype=torch.float64, + per_module_score=per_module_score, + ) + scores_name = f"pytest_{test_name}_{test_cpu_offloads_identical.__name__}_scores" + analyzer.compute_pairwise_scores( + scores_name=scores_name, + factors_name=factors_name, + query_dataset=test_dataset, + per_device_query_batch_size=4, + train_dataset=train_dataset, + per_device_train_batch_size=8, + dataloader_kwargs=kwargs, + score_args=score_args, + overwrite_output_dir=True, + ) + pairwise_scores = analyzer.load_pairwise_scores(scores_name=scores_name) + + factors_name = f"pytest_{test_name}_{test_cpu_offloads_identical.__name__}_cached" + factor_args.cached_activation_cpu_offload = True + analyzer.fit_all_factors( + factors_name=factors_name, + dataset=train_dataset, + dataloader_kwargs=kwargs, + per_device_batch_size=16, + factor_args=factor_args, + overwrite_output_dir=True, + ) + score_args.cached_activation_cpu_offload = True + scores_name = f"pytest_{test_name}_{test_cpu_offloads_identical.__name__}_cached_scores" + analyzer.compute_pairwise_scores( + scores_name=scores_name, + factors_name=factors_name, + query_dataset=test_dataset, + per_device_query_batch_size=6, + train_dataset=train_dataset, + per_device_train_batch_size=8, + dataloader_kwargs=kwargs, + score_args=score_args, + overwrite_output_dir=True, + ) + cached_pairwise_scores = analyzer.load_pairwise_scores(scores_name=scores_name) + + assert check_tensor_dict_equivalence(pairwise_scores, cached_pairwise_scores, atol=ATOL, rtol=RTOL) diff --git a/tests/modules/test_modules.py b/tests/modules/test_modules.py index 25dde9a..3dc2f3e 100644 --- a/tests/modules/test_modules.py +++ b/tests/modules/test_modules.py @@ -85,7 +85,7 @@ def test_tracked_modules_forward_equivalence( ], ) @pytest.mark.parametrize("train_size", [32]) -@pytest.mark.parametrize("seed", [0]) +@pytest.mark.parametrize("seed", [1]) def test_tracked_modules_backward_equivalence( test_name: str, mode: ModuleMode, diff --git a/tests/modules/test_per_sample_gradients.py b/tests/modules/test_per_sample_gradients.py index 1625b4a..b18e6d9 100644 --- a/tests/modules/test_per_sample_gradients.py +++ b/tests/modules/test_per_sample_gradients.py @@ -14,7 +14,11 @@ from kronfluence.analyzer import Analyzer, prepare_model from kronfluence.arguments import FactorArguments from kronfluence.module.tracked_module import ModuleMode, TrackedModule -from kronfluence.module.utils import set_mode, update_factor_args +from kronfluence.module.utils import ( + finalize_preconditioned_gradient, + set_mode, + update_factor_args, +) from kronfluence.task import Task from kronfluence.utils.constants import LAMBDA_MATRIX_NAME, PRECONDITIONED_GRADIENT_NAME from kronfluence.utils.dataset import DataLoaderKwargs @@ -91,6 +95,7 @@ def for_loop_per_sample_gradient( "test_name", [ "mlp", + "repeated_mlp", "conv", "conv_bn", "bert", @@ -132,6 +137,8 @@ def test_for_loop_per_sample_gradient_equivalence( factor_args = FactorArguments( strategy="identity", ) + if test_name == "repeated_mlp": + factor_args.shared_parameters_exist = True update_factor_args(model=model, factor_args=factor_args) per_sample_gradients = [] @@ -147,6 +154,9 @@ def test_for_loop_per_sample_gradient_equivalence( loss = task.compute_train_loss(batch=batch_lst[i], model=model, sample=False) loss.backward() + if test_name == "repeated_mlp": + finalize_preconditioned_gradient(model=model) + module_gradients = {} for module in model.modules(): if isinstance(module, TrackedModule): @@ -177,6 +187,7 @@ def test_for_loop_per_sample_gradient_equivalence( "test_name", [ "mlp", + "repeated_mlp", "conv", "conv_bn", "bert", @@ -218,6 +229,8 @@ def test_mean_gradient_equivalence( factor_args = FactorArguments( strategy="identity", ) + if test_name == "repeated_mlp": + factor_args.shared_parameters_exist = True update_factor_args(model=model, factor_args=factor_args) per_sample_gradients = [] @@ -233,6 +246,9 @@ def test_mean_gradient_equivalence( loss = task.compute_train_loss(batch=batch_lst[i], model=model, sample=False) loss.backward() + if test_name == "repeated_mlp": + finalize_preconditioned_gradient(model=model) + module_gradients = {} for module in model.modules(): if isinstance(module, TrackedModule): @@ -328,6 +344,7 @@ def test_lambda_equivalence( task=task, disable_model_save=True, cpu=True, + disable_tqdm=True, ) kwargs = DataLoaderKwargs(collate_fn=data_collator) factor_args = FactorArguments( diff --git a/tests/scores/test_pairwise_scores.py b/tests/scores/test_pairwise_scores.py index ebdbca8..048309d 100644 --- a/tests/scores/test_pairwise_scores.py +++ b/tests/scores/test_pairwise_scores.py @@ -6,6 +6,8 @@ import torch from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.utils.common.factor_arguments import test_factor_arguments +from kronfluence.utils.common.score_arguments import test_score_arguments from kronfluence.utils.constants import ALL_MODULE_NAME from kronfluence.utils.dataset import DataLoaderKwargs from tests.utils import ( @@ -21,6 +23,8 @@ "test_name", [ "mlp", + "repeated_mlp", + "mlp_checkpoint", "conv", "conv_bn", "bert", @@ -40,7 +44,7 @@ def test_compute_pairwise_scores( train_size: int, seed: int, ) -> None: - # Make sure that the pairwise influence computations are working properly. + # Makes sure that the pairwise influence computations are working properly. model, train_dataset, test_dataset, data_collator, task = prepare_test( test_name=test_name, query_size=query_size, @@ -52,9 +56,14 @@ def test_compute_pairwise_scores( model=model, task=task, ) + factor_args = test_factor_arguments() + if test_name == "repeated_mlp": + factor_args.shared_parameters_exist = True + factors_name = f"pytest_{test_name}_{test_compute_pairwise_scores.__name__}" analyzer.fit_all_factors( factors_name=factors_name, + factor_args=factor_args, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=32, @@ -94,6 +103,7 @@ def test_compute_pairwise_scores( @pytest.mark.parametrize("precondition_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("score_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("query_gradient_rank", [None, 16]) +@pytest.mark.parametrize("damping", [None, 1e-08]) @pytest.mark.parametrize("query_size", [16]) @pytest.mark.parametrize("train_size", [32]) @pytest.mark.parametrize("seed", [6]) @@ -103,11 +113,12 @@ def test_compute_pairwise_scores_dtype( precondition_dtype: torch.dtype, score_dtype: torch.dtype, query_gradient_rank: Optional[int], + damping: Optional[float], query_size: int, train_size: int, seed: int, ) -> None: - # Make sure that the pairwise influence computations are working properly with different dtypes. + # Makes sure that the pairwise influence computations are working properly with different dtypes. model, train_dataset, test_dataset, data_collator, task = prepare_test( test_name=test_name, query_size=query_size, @@ -129,6 +140,7 @@ def test_compute_pairwise_scores_dtype( ) score_args = ScoreArguments( + damping=damping, score_dtype=score_dtype, query_gradient_rank=query_gradient_rank, per_sample_gradient_dtype=per_sample_gradient_dtype, @@ -199,12 +211,7 @@ def test_pairwise_scores_batch_size_equivalence( overwrite_output_dir=True, ) - score_args = ScoreArguments( - per_module_score=False, - score_dtype=torch.float64, - precondition_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - ) + score_args = test_score_arguments() analyzer.compute_pairwise_scores( scores_name=f"pytest_{test_name}_{test_pairwise_scores_batch_size_equivalence.__name__}_{strategy}_score_bs1", factors_name=factors_name, @@ -311,12 +318,8 @@ def test_pairwise_scores_partition_equivalence( ) scores_name = f"pytest_{test_name}_{test_pairwise_scores_partition_equivalence.__name__}_scores" - score_args = ScoreArguments( - per_module_score=per_module_score, - score_dtype=torch.float64, - precondition_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - ) + score_args = test_score_arguments() + score_args.per_module_score = per_module_score analyzer.compute_pairwise_scores( scores_name=scores_name, factors_name=factors_name, @@ -330,14 +333,8 @@ def test_pairwise_scores_partition_equivalence( ) scores = analyzer.load_pairwise_scores(scores_name=scores_name) - score_args = ScoreArguments( - data_partition_size=data_partition_size, - module_partition_size=module_partition_size, - per_module_score=per_module_score, - score_dtype=torch.float64, - precondition_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - ) + score_args.data_partition_size = data_partition_size + score_args.module_partition_size = module_partition_size analyzer.compute_pairwise_scores( scores_name=f"pytest_{test_name}_partition_{data_partition_size}_{module_partition_size}", factors_name=factors_name, @@ -402,12 +399,7 @@ def test_per_module_scores_equivalence( ) scores_name = f"pytest_{test_name}_{test_per_module_scores_equivalence.__name__}_scores" - score_args = ScoreArguments( - per_module_score=False, - score_dtype=torch.float64, - precondition_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - ) + score_args = test_score_arguments() analyzer.compute_pairwise_scores( scores_name=scores_name, factors_name=factors_name, @@ -421,12 +413,7 @@ def test_per_module_scores_equivalence( ) scores = analyzer.load_pairwise_scores(scores_name=scores_name) - score_args = ScoreArguments( - per_module_score=True, - score_dtype=torch.float64, - precondition_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - ) + score_args.per_module_score = True analyzer.compute_pairwise_scores( scores_name=scores_name + "_per_module", factors_name=factors_name, @@ -455,19 +442,18 @@ def test_per_module_scores_equivalence( [ "mlp", "conv_bn", - "gpt", ], ) @pytest.mark.parametrize("query_size", [60]) @pytest.mark.parametrize("train_size", [60]) -@pytest.mark.parametrize("seed", [0]) +@pytest.mark.parametrize("seed", [2]) def test_compute_pairwise_scores_with_indices( test_name: str, query_size: int, train_size: int, seed: int, ) -> None: - # Make sure the indices selection is correctly implemented. + # Makes sure the indices selection is correctly implemented. model, train_dataset, test_dataset, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -488,7 +474,8 @@ def test_compute_pairwise_scores_with_indices( overwrite_output_dir=True, ) - score_args = ScoreArguments(data_partition_size=2) + score_args = test_score_arguments() + score_args.data_partition_size = 2 scores_name = f"pytest_{test_name}_{test_compute_pairwise_scores_with_indices.__name__}_scores" analyzer.compute_pairwise_scores( scores_name=scores_name, @@ -527,7 +514,7 @@ def test_query_accumulation( num_query_gradient_accumulations: int, seed: int, ) -> None: - # Make sure the query accumulation is correctly implemented. + # Makes sure the query accumulation is correctly implemented. model, train_dataset, test_dataset, data_collator, task = prepare_test( test_name=test_name, query_size=query_size, @@ -551,13 +538,7 @@ def test_query_accumulation( ) scores_name = f"pytest_{test_name}_{test_query_accumulation.__name__}_scores" - score_args = ScoreArguments( - query_gradient_rank=8, - num_query_gradient_accumulations=1, - score_dtype=torch.float64, - precondition_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - ) + score_args = test_score_arguments(query_gradient_rank=8) analyzer.compute_pairwise_scores( scores_name=scores_name, factors_name=factors_name, @@ -571,13 +552,7 @@ def test_query_accumulation( ) scores = analyzer.load_pairwise_scores(scores_name=scores_name) - score_args = ScoreArguments( - query_gradient_rank=8, - num_query_gradient_accumulations=num_query_gradient_accumulations, - score_dtype=torch.float64, - precondition_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - ) + score_args.num_query_gradient_accumulations = num_query_gradient_accumulations analyzer.compute_pairwise_scores( scores_name=f"pytest_{test_name}_{test_query_accumulation.__name__}_accumulated_scores", factors_name=factors_name, @@ -599,3 +574,87 @@ def test_query_accumulation( atol=ATOL, rtol=RTOL, ) + + +@pytest.mark.parametrize( + "test_name", + [ + "mlp", + "conv", + ], +) +@pytest.mark.parametrize("query_size", [50]) +@pytest.mark.parametrize("train_size", [32]) +@pytest.mark.parametrize("seed", [8]) +def test_pairwise_shared_parameters( + test_name: str, + query_size: int, + train_size: int, + seed: int, +) -> None: + # Makes sure the scores are identical with and without `shared_parameters_exist` flag. + model, train_dataset, test_dataset, data_collator, task = prepare_test( + test_name=test_name, + query_size=query_size, + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + factor_args = test_factor_arguments() + factor_args.shared_parameters_exist = False + factors_name = f"pytest_{test_name}_{test_pairwise_shared_parameters.__name__}" + analyzer.fit_all_factors( + factors_name=factors_name, + factor_args=factor_args, + dataset=train_dataset, + dataloader_kwargs=kwargs, + per_device_batch_size=8, + overwrite_output_dir=True, + ) + scores_name = f"pytest_{test_name}_{test_pairwise_shared_parameters.__name__}_scores" + analyzer.compute_pairwise_scores( + scores_name=scores_name, + factors_name=factors_name, + query_dataset=test_dataset, + per_device_query_batch_size=4, + train_dataset=train_dataset, + per_device_train_batch_size=8, + dataloader_kwargs=kwargs, + overwrite_output_dir=True, + ) + scores = analyzer.load_pairwise_scores(scores_name=scores_name) + + factor_args.shared_parameters_exist = True + factors_name = f"pytest_{test_name}_{test_pairwise_shared_parameters.__name__}_shared" + analyzer.fit_all_factors( + factors_name=factors_name, + factor_args=factor_args, + dataset=train_dataset, + dataloader_kwargs=kwargs, + per_device_batch_size=8, + overwrite_output_dir=True, + ) + scores_name = f"pytest_{test_name}_{test_pairwise_shared_parameters.__name__}_shared_scores" + analyzer.compute_pairwise_scores( + scores_name=scores_name, + factors_name=factors_name, + query_dataset=test_dataset, + per_device_query_batch_size=4, + train_dataset=train_dataset, + per_device_train_batch_size=8, + dataloader_kwargs=kwargs, + overwrite_output_dir=True, + ) + shared_scores = analyzer.load_pairwise_scores(scores_name=scores_name) + + assert check_tensor_dict_equivalence( + scores, + shared_scores, + atol=ATOL, + rtol=RTOL, + ) diff --git a/tests/scores/test_self_scores.py b/tests/scores/test_self_scores.py index 40736e1..e94ebda 100644 --- a/tests/scores/test_self_scores.py +++ b/tests/scores/test_self_scores.py @@ -3,7 +3,12 @@ import pytest import torch -from kronfluence.arguments import FactorArguments, ScoreArguments +from kronfluence.arguments import ScoreArguments +from kronfluence.utils.common.factor_arguments import ( + default_factor_arguments, + test_factor_arguments, +) +from kronfluence.utils.common.score_arguments import test_score_arguments from kronfluence.utils.constants import ALL_MODULE_NAME from kronfluence.utils.dataset import DataLoaderKwargs from tests.utils import ( @@ -19,22 +24,26 @@ "test_name", [ "mlp", + "repeated_mlp", + "mlp_checkpoint", "conv", "conv_bn", "bert", "gpt", ], ) -@pytest.mark.parametrize("score_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("use_measurement_for_self_influence", [False, True]) +@pytest.mark.parametrize("score_dtype", [torch.float32]) @pytest.mark.parametrize("train_size", [22]) @pytest.mark.parametrize("seed", [0]) def test_compute_self_scores( test_name: str, + use_measurement_for_self_influence: bool, score_dtype: torch.dtype, train_size: int, seed: int, ) -> None: - # Make sure that the self-influence computations are working properly. + # Makes sure that the self-influence computations are working properly. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -45,9 +54,14 @@ def test_compute_self_scores( model=model, task=task, ) + factor_args = default_factor_arguments() + if test_name == "repeated_mlp": + factor_args.shared_parameters_exist = True + factors_name = f"pytest_{test_name}_{test_compute_self_scores.__name__}" analyzer.fit_all_factors( factors_name=factors_name, + factor_args=factor_args, dataset=train_dataset, dataloader_kwargs=kwargs, per_device_batch_size=32, @@ -55,6 +69,7 @@ def test_compute_self_scores( ) score_args = ScoreArguments( + use_measurement_for_self_influence=use_measurement_for_self_influence, score_dtype=score_dtype, ) scores_name = f"pytest_{test_name}_{test_compute_self_scores.__name__}_scores" @@ -76,9 +91,7 @@ def test_compute_self_scores( @pytest.mark.parametrize( "test_name", - [ - "mlp", - ], + ["mlp"], ) @pytest.mark.parametrize("per_sample_gradient_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("precondition_dtype", [torch.float32, torch.bfloat16]) @@ -165,9 +178,7 @@ def test_self_scores_batch_size_equivalence( task=task, ) - factor_args = FactorArguments( - strategy=strategy, - ) + factor_args = test_factor_arguments(strategy=strategy) factors_name = f"pytest_{test_name}_{test_self_scores_batch_size_equivalence.__name__}" analyzer.fit_all_factors( factors_name=factors_name, @@ -178,12 +189,7 @@ def test_self_scores_batch_size_equivalence( overwrite_output_dir=True, ) - score_args = ScoreArguments( - per_module_score=False, - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args = test_score_arguments() analyzer.compute_self_scores( scores_name=f"pytest_{test_name}_{test_self_scores_batch_size_equivalence.__name__}_{strategy}_score_bs1", factors_name=factors_name, @@ -280,11 +286,7 @@ def test_self_scores_partition_equivalence( ) scores_name = f"pytest_{test_name}_{test_self_scores_partition_equivalence.__name__}_scores" - score_args = ScoreArguments( - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args = test_score_arguments() analyzer.compute_self_scores( scores_name=scores_name, factors_name=factors_name, @@ -296,13 +298,8 @@ def test_self_scores_partition_equivalence( ) scores = analyzer.load_self_scores(scores_name=scores_name) - score_args = ScoreArguments( - data_partition_size=data_partition_size, - module_partition_size=module_partition_size, - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args.data_partition_size = data_partition_size + score_args.module_partition_size = module_partition_size analyzer.compute_self_scores( scores_name=f"pytest_{test_name}_partition_{data_partition_size}_{module_partition_size}", factors_name=factors_name, @@ -362,8 +359,10 @@ def test_per_module_scores_equivalence( ) scores_name = f"pytest_{test_name}_{test_per_module_scores_equivalence.__name__}_scores" + score_args = test_score_arguments() analyzer.compute_self_scores( scores_name=scores_name, + score_args=score_args, factors_name=factors_name, train_dataset=train_dataset, per_device_train_batch_size=8, @@ -372,7 +371,7 @@ def test_per_module_scores_equivalence( ) scores = analyzer.load_self_scores(scores_name=scores_name) - score_args = ScoreArguments(per_module_score=True) + score_args.per_module_score = True analyzer.compute_self_scores( scores_name=scores_name + "_per_module", factors_name=factors_name, @@ -409,7 +408,7 @@ def test_compute_self_scores_with_indices( train_size: int, seed: int, ) -> None: - # Make sure the indices selection is correctly implemented. + # Makes sure the indices selection is correctly implemented. model, train_dataset, _, data_collator, task = prepare_test( test_name=test_name, train_size=train_size, @@ -429,7 +428,8 @@ def test_compute_self_scores_with_indices( overwrite_output_dir=True, ) - score_args = ScoreArguments(data_partition_size=2) + score_args = test_score_arguments() + score_args.data_partition_size = 2 scores_name = f"pytest_{test_name}_{test_compute_self_scores_with_indices.__name__}_scores" analyzer.compute_self_scores( scores_name=scores_name, @@ -481,12 +481,7 @@ def test_compute_self_scores_with_diagonal_pairwise_equivalence( ) scores_name = f"pytest_{test_name}_{test_compute_self_scores_with_diagonal_pairwise_equivalence.__name__}_scores" - score_args = ScoreArguments( - use_measurement_for_self_influence=False, - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args = test_score_arguments() analyzer.compute_self_scores( scores_name=scores_name, factors_name=factors_name, @@ -556,12 +551,8 @@ def test_compute_self_measurement_scores_with_diagonal_pairwise_equivalence( scores_name = ( f"pytest_{test_name}_{test_compute_self_measurement_scores_with_diagonal_pairwise_equivalence.__name__}_scores" ) - score_args = ScoreArguments( - use_measurement_for_self_influence=True, - score_dtype=torch.float64, - per_sample_gradient_dtype=torch.float64, - precondition_dtype=torch.float64, - ) + score_args = test_score_arguments() + score_args.use_measurement_for_self_influence = True analyzer.compute_self_scores( scores_name=scores_name, factors_name=factors_name, @@ -592,3 +583,89 @@ def test_compute_self_measurement_scores_with_diagonal_pairwise_equivalence( atol=ATOL, rtol=RTOL, ) + + +@pytest.mark.parametrize( + "test_name", + [ + "mlp", + "conv", + ], +) +@pytest.mark.parametrize("use_measurement_for_self_influence", [False, True]) +@pytest.mark.parametrize("query_size", [50]) +@pytest.mark.parametrize("train_size", [32]) +@pytest.mark.parametrize("seed", [9]) +def test_self_shared_parameters( + test_name: str, + use_measurement_for_self_influence: bool, + query_size: int, + train_size: int, + seed: int, +) -> None: + # Makes sure the scores are identical with and without `shared_parameters_exist` flag. + model, train_dataset, _, data_collator, task = prepare_test( + test_name=test_name, + query_size=query_size, + train_size=train_size, + seed=seed, + ) + model = model.to(dtype=torch.float64) + kwargs = DataLoaderKwargs(collate_fn=data_collator) + model, analyzer = prepare_model_and_analyzer( + model=model, + task=task, + ) + factor_args = test_factor_arguments() + factor_args.shared_parameters_exist = False + score_args = test_score_arguments() + score_args.use_measurement_for_self_influence = use_measurement_for_self_influence + factors_name = f"pytest_{test_name}_{test_self_shared_parameters.__name__}" + analyzer.fit_all_factors( + factors_name=factors_name, + factor_args=factor_args, + dataset=train_dataset, + dataloader_kwargs=kwargs, + per_device_batch_size=8, + overwrite_output_dir=True, + ) + scores_name = f"pytest_{test_name}_{test_self_shared_parameters.__name__}_scores" + analyzer.compute_self_scores( + scores_name=scores_name, + score_args=score_args, + factors_name=factors_name, + train_dataset=train_dataset, + per_device_train_batch_size=8, + dataloader_kwargs=kwargs, + overwrite_output_dir=True, + ) + scores = analyzer.load_self_scores(scores_name=scores_name) + + factor_args.shared_parameters_exist = True + factors_name = f"pytest_{test_name}_{test_self_shared_parameters.__name__}_shared" + analyzer.fit_all_factors( + factors_name=factors_name, + factor_args=factor_args, + dataset=train_dataset, + dataloader_kwargs=kwargs, + per_device_batch_size=8, + overwrite_output_dir=True, + ) + scores_name = f"pytest_{test_name}_{test_self_shared_parameters.__name__}_shared_scores" + analyzer.compute_self_scores( + scores_name=scores_name, + score_args=score_args, + factors_name=factors_name, + train_dataset=train_dataset, + per_device_train_batch_size=8, + dataloader_kwargs=kwargs, + overwrite_output_dir=True, + ) + shared_scores = analyzer.load_self_scores(scores_name=scores_name) + + assert check_tensor_dict_equivalence( + scores, + shared_scores, + atol=ATOL, + rtol=RTOL, + ) diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index 79e2b33..a66879b 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -13,6 +13,7 @@ "test_name", [ "mlp", + "repeated_mlp", "conv", "conv_bn", "bert", @@ -48,10 +49,14 @@ def test_analyzer( model=model, task=task, disable_model_save=True, + disable_tqdm=True, cpu=True, ) kwargs = DataLoaderKwargs(collate_fn=data_collator) + factor_args = FactorArguments(strategy=strategy) + if test_name == "repeated_mlp": + factor_args.shared_parameters_exist = True analyzer.fit_all_factors( factors_name=f"pytest_{test_analyzer.__name__}_{test_name}", dataset=train_dataset, @@ -70,12 +75,23 @@ def test_analyzer( dataloader_kwargs=kwargs, overwrite_output_dir=True, ) + score_args = ScoreArguments() analyzer.compute_self_scores( scores_name="self", factors_name=f"pytest_{test_analyzer.__name__}_{test_name}", train_dataset=train_dataset, per_device_train_batch_size=8, dataloader_kwargs=kwargs, + score_args=score_args, + overwrite_output_dir=True, + ) + analyzer.compute_self_scores( + scores_name="self", + factors_name=f"pytest_{test_analyzer.__name__}_{test_name}", + train_dataset=train_dataset, + per_device_train_batch_size=8, + dataloader_kwargs=kwargs, + score_args=score_args, overwrite_output_dir=True, ) @@ -94,6 +110,7 @@ def test_default_factor_arguments() -> None: assert factor_args.covariance_module_partition_size == 1 assert factor_args.activation_covariance_dtype == torch.float32 assert factor_args.gradient_covariance_dtype == torch.float32 + assert factor_args.eigendecomposition_dtype == torch.float64 assert factor_args.lambda_max_examples == 100_000 @@ -116,11 +133,11 @@ def test_default_score_arguments() -> None: assert score_args.data_partition_size == 1 assert score_args.module_partition_size == 1 assert score_args.per_module_score is False + assert score_args.use_measurement_for_self_influence is False assert score_args.query_gradient_rank is None assert score_args.num_query_gradient_accumulations == 1 assert score_args.query_gradient_svd_dtype == torch.float32 - assert score_args.use_measurement_for_self_influence is False assert score_args.score_dtype == torch.float32 assert score_args.per_sample_gradient_dtype == torch.float32