-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #26 from pomonam/finalize_documents
Finalize documents
- Loading branch information
Showing
19 changed files
with
700 additions
and
290 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import torch | ||
|
||
from kronfluence import FactorArguments | ||
|
||
|
||
def default_factor_arguments(strategy: str = "ekfac") -> FactorArguments: | ||
"""Default factor arguments.""" | ||
factor_args = FactorArguments(strategy=strategy) | ||
return factor_args | ||
|
||
|
||
def test_factor_arguments(strategy: str = "ekfac") -> FactorArguments: | ||
"""Factor arguments used for unit tests.""" | ||
factor_args = FactorArguments(strategy=strategy) | ||
factor_args.use_empirical_fisher = True | ||
factor_args.activation_covariance_dtype = torch.float64 | ||
factor_args.gradient_covariance_dtype = torch.float64 | ||
factor_args.per_sample_gradient_dtype = torch.float64 | ||
factor_args.lambda_dtype = torch.float32 | ||
return factor_args | ||
|
||
|
||
def smart_low_precision_factor_arguments( | ||
strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16 | ||
) -> FactorArguments: | ||
"""Factor arguments with low precision, except for the lambda computations.""" | ||
factor_args = FactorArguments(strategy=strategy) | ||
factor_args.amp_dtype = dtype | ||
factor_args.activation_covariance_dtype = dtype | ||
factor_args.gradient_covariance_dtype = dtype | ||
factor_args.per_sample_gradient_dtype = dtype | ||
factor_args.lambda_dtype = torch.float32 | ||
return factor_args | ||
|
||
|
||
def all_low_precision_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments: | ||
"""Factor arguments with low precision.""" | ||
factor_args = FactorArguments(strategy=strategy) | ||
factor_args.amp_dtype = dtype | ||
factor_args.activation_covariance_dtype = dtype | ||
factor_args.gradient_covariance_dtype = dtype | ||
factor_args.per_sample_gradient_dtype = dtype | ||
factor_args.lambda_dtype = dtype | ||
return factor_args | ||
|
||
|
||
def reduce_memory_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments: | ||
"""Factor arguments with low precision + iterative lambda update.""" | ||
factor_args = all_low_precision_factor_arguments(strategy=strategy, dtype=dtype) | ||
factor_args.lambda_iterative_aggregate = True | ||
return factor_args | ||
|
||
|
||
def extreme_reduce_memory_factor_arguments( | ||
strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16 | ||
) -> FactorArguments: | ||
"""Factor arguments for models that is difficult to fit in a single GPU.""" | ||
factor_args = all_low_precision_factor_arguments(strategy=strategy, dtype=dtype) | ||
factor_args.lambda_iterative_aggregate = True | ||
factor_args.cached_activation_cpu_offload = True | ||
factor_args.covariance_module_partition_size = 4 | ||
factor_args.lambda_module_partition_size = 4 | ||
return factor_args | ||
|
||
|
||
def large_dataset_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments: | ||
"""Factor arguments for large models and datasets.""" | ||
factor_args = smart_low_precision_factor_arguments(strategy=strategy, dtype=dtype) | ||
factor_args.covariance_data_partition_size = 4 | ||
factor_args.lambda_data_partition_size = 4 | ||
return factor_args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from kronfluence import ScoreArguments | ||
|
||
|
||
def default_score_arguments( | ||
damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None | ||
) -> ScoreArguments: | ||
"""Default score arguments.""" | ||
score_args = ScoreArguments(damping=damping) | ||
score_args.query_gradient_rank = query_gradient_rank | ||
if score_args.query_gradient_rank is not None: | ||
score_args.num_query_gradient_accumulations = 10 | ||
return score_args | ||
|
||
|
||
def test_score_arguments(damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None) -> ScoreArguments: | ||
"""Score arguments used for unit tests.""" | ||
score_args = ScoreArguments(damping=damping) | ||
score_args.query_gradient_svd_dtype = torch.float64 | ||
score_args.score_dtype = torch.float64 | ||
score_args.per_sample_gradient_dtype = torch.float64 | ||
score_args.precondition_dtype = torch.float64 | ||
score_args.query_gradient_rank = query_gradient_rank | ||
return score_args | ||
|
||
|
||
def smart_low_precision_score_arguments( | ||
damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 | ||
) -> ScoreArguments: | ||
"""Score arguments with low precision, except for the preconditioning computations.""" | ||
score_args = ScoreArguments(damping=damping) | ||
score_args.amp_dtype = dtype | ||
score_args.query_gradient_svd_dtype = torch.float32 | ||
score_args.score_dtype = dtype | ||
score_args.per_sample_gradient_dtype = dtype | ||
score_args.precondition_dtype = torch.float32 | ||
score_args.query_gradient_rank = query_gradient_rank | ||
if score_args.query_gradient_rank is not None: | ||
score_args.num_query_gradient_accumulations = 10 | ||
return score_args | ||
|
||
|
||
def all_low_precision_score_arguments( | ||
damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 | ||
) -> ScoreArguments: | ||
"""Score arguments with low precision.""" | ||
score_args = ScoreArguments(damping=damping) | ||
score_args.amp_dtype = dtype | ||
score_args.query_gradient_svd_dtype = torch.float32 | ||
score_args.score_dtype = dtype | ||
score_args.per_sample_gradient_dtype = dtype | ||
score_args.precondition_dtype = dtype | ||
score_args.query_gradient_rank = query_gradient_rank | ||
if score_args.query_gradient_rank is not None: | ||
score_args.num_query_gradient_accumulations = 10 | ||
return score_args | ||
|
||
|
||
def reduce_memory_score_arguments( | ||
damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 | ||
) -> ScoreArguments: | ||
"""Score arguments with low precision + CPU offload.""" | ||
score_args = all_low_precision_score_arguments(damping=damping, dtype=dtype) | ||
score_args.cached_activation_cpu_offload = True | ||
score_args.query_gradient_rank = query_gradient_rank | ||
if score_args.query_gradient_rank is not None: | ||
score_args.num_query_gradient_accumulations = 10 | ||
return score_args | ||
|
||
|
||
def extreme_reduce_memory_score_arguments( | ||
damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 | ||
) -> ScoreArguments: | ||
"""Score arguments for models that is difficult to fit in a single GPU.""" | ||
score_args = all_low_precision_score_arguments(damping=damping, dtype=dtype) | ||
score_args.cached_activation_cpu_offload = True | ||
score_args.query_gradient_rank = query_gradient_rank | ||
score_args.module_partition_size = 4 | ||
if score_args.query_gradient_rank is not None: | ||
score_args.num_query_gradient_accumulations = 10 | ||
return score_args | ||
|
||
|
||
def large_dataset_score_arguments( | ||
damping: Optional[float] = 1e-08, query_gradient_rank: Optional[int] = None, dtype: torch.dtype = torch.bfloat16 | ||
) -> ScoreArguments: | ||
"""Score arguments for large models and datasets.""" | ||
score_args = smart_low_precision_score_arguments(damping=damping, dtype=dtype) | ||
score_args.data_partition_size = 4 | ||
score_args.query_gradient_rank = query_gradient_rank | ||
if score_args.query_gradient_rank is not None: | ||
score_args.num_query_gradient_accumulations = 10 | ||
return score_args |
Oops, something went wrong.