Skip to content

Commit

Permalink
Merge pull request #27 from pomonam/document_example
Browse files Browse the repository at this point in the history
Document example
  • Loading branch information
pomonam authored Jun 27, 2024
2 parents 180c2ec + 7278052 commit e81e653
Show file tree
Hide file tree
Showing 52 changed files with 2,579 additions and 553 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
---

> **Kronfluence** is a PyTorch package designed to compute [influence functions](https://arxiv.org/abs/1703.04730) using [Kronecker-factored Approximate Curvature (KFAC)](https://arxiv.org/abs/1503.05671) or [Eigenvalue-corrected KFAC (EKFAC)](https://arxiv.org/abs/1806.03884).
For a detailed description of the methodology, see the [**paper**](https://arxiv.org/abs/2308.03296), *Studying Large Language Model Generalization with Influence Functions*.
For detailed description of the methodology, see the [**paper**](https://arxiv.org/abs/2308.03296), *Studying Large Language Model Generalization with Influence Functions*.

---

> [!WARNING]
Expand Down
33 changes: 33 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Kronfluence: Examples

For detailed technical documentation of Kronfluence, please refer to the [Technical Documentation](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md) page.

## Getting Started

To run all examples, install the necessary packages:

```bash
pip install -r requirements.txt
```

Alternatively, navigate to each example folder and run `pip install -r requirements.txt`.


## List of Tasks

Our examples cover the following tasks:

<div align="center">

| Task | Example datasets |
|----------------------|:------------------------:|
| Regression | UCI |
| Image Classification | CIFAR-10 / ImageNet |
| Text Classification | GLUE |
| Multiple-Choice | SWAG |
| Language Modeling | WikiText-2 / OpenWebText |

</div>

These examples demonstrate various use cases of Kronfluence, including the usage of AMP (Automatic Mixed Precision) and DDP (Distributed Data Parallel).
Many examples aim to replicate the settings used in [our paper](https://arxiv.org/abs/2405.12186). If you would like to see more examples added to this repository, please leave an issue.
123 changes: 110 additions & 13 deletions examples/cifar/README.md
Original file line number Diff line number Diff line change
@@ -1,57 +1,154 @@
# CIFAR-10 & ResNet-9 Example

This directory contains scripts for training ResNet-9 on CIFAR-10. The pipeline is motivated from
[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb).
This directory contains scripts for training ResNet-9 and computing influence scores on CIFAR-10 dataset. The pipeline is motivated from
[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb). To get started, please install the necessary packages by running the following command:

```bash
pip install -r requirements.txt
```

## Training

To train ResNet-9 on CIFAR-10 dataset, run the following command:
To train ResNet-9 on the CIFAR-10 dataset, run the following command:

```bash
python train.py --dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--train_batch_size 512 \
--eval_batch_size 1024 \
--learning_rate 0.4 \
--weight_decay 0.0001 \
--weight_decay 0.001 \
--num_train_epochs 25 \
--seed 1004
```

This will train the model using the specified hyperparameters and save the trained checkpoint in the `./checkpoints` directory.

## Computing Pairwise Influence Scores

To obtain pairwise influence scores on 2000 query data points using `ekfac`, run the following command:
To compute pairwise influence scores on 2000 query data points using the `ekfac` factorization strategy, run the following command:

```bash
python analyze.py --query_batch_size 1000 \
--dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```
You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 1.5 minutes to compute the
pairwise scores (including computing EKFAC factors).

In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as the `factor_strategy`. On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the pairwise scores (including computing the EKFAC factors):

```
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
| Total | - | 11 | 112.83 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
| Compute Pairwise Score | 47.989 | 1 | 47.989 | 42.532 |
| Fit Lambda | 34.639 | 1 | 34.639 | 30.7 |
| Fit Covariance | 21.841 | 1 | 21.841 | 19.357 |
| Save Pairwise Score | 3.5998 | 1 | 3.5998 | 3.1905 |
| Perform Eigendecomposition | 2.7724 | 1 | 2.7724 | 2.4572 |
| Save Covariance | 0.85695 | 1 | 0.85695 | 0.75951 |
| Save Eigendecomposition | 0.85628 | 1 | 0.85628 | 0.75892 |
| Save Lambda | 0.12327 | 1 | 0.12327 | 0.10925 |
| Load Eigendecomposition | 0.056494 | 1 | 0.056494 | 0.05007 |
| Load All Factors | 0.048981 | 1 | 0.048981 | 0.043412 |
| Load Covariance | 0.046798 | 1 | 0.046798 | 0.041476 |
----------------------------------------------------------------------------------------------------------------------------------
```

To use AMP when computing influence scores (in addition to half precision when computing influence factors and scores), run:

```bash
python analyze.py --query_batch_size 1000 \
--dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac \
--use_half_precision
```

This reduces computation time to about 40 seconds on an A100 (80GB) GPU:

```
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
| Total | - | 11 | 42.316 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
| Compute Pairwise Score | 19.565 | 1 | 19.565 | 46.235 |
| Fit Lambda | 9.173 | 1 | 9.173 | 21.677 |
| Fit Covariance | 7.3723 | 1 | 7.3723 | 17.422 |
| Perform Eigendecomposition | 2.6613 | 1 | 2.6613 | 6.2891 |
| Save Pairwise Score | 2.0156 | 1 | 2.0156 | 4.7633 |
| Save Covariance | 0.71699 | 1 | 0.71699 | 1.6944 |
| Save Eigendecomposition | 0.52561 | 1 | 0.52561 | 1.2421 |
| Load Covariance | 0.15732 | 1 | 0.15732 | 0.37177 |
| Save Lambda | 0.063394 | 1 | 0.063394 | 0.14981 |
| Load Eigendecomposition | 0.051395 | 1 | 0.051395 | 0.12146 |
| Load All Factors | 0.014144 | 1 | 0.014144 | 0.033425 |
----------------------------------------------------------------------------------------------------------------------------------
```

You can run `half_precision_analysis.py` to verify that the scores computed with AMP have high correlations with those of the default configuration.

<p align="center">
<a href="#"><img width="380" img src="figure/half_precision.png" alt="Half Precision"/></a>
</p>

## Visualizing Influential Training Images

[This Colab notebook](https://colab.research.google.com/drive/1KIwIbeJh_om4tRwceuZ005fVKDsiXKgr?usp=sharing) provides a tutorial on visualizing the top influential training images.

## Mislabeled Data Detection

We can use self-influence scores (see Section 5.4 for the [paper](https://arxiv.org/pdf/1703.04730.pdf)) to detect mislabeled examples.
First, train the model with 10% of training examples mislabeled by running the following command:
We can use self-influence scores (see **Section 5.4** for the [paper](https://arxiv.org/pdf/1703.04730.pdf)) to detect mislabeled examples.
First, train the model with 10% of the training examples mislabeled by running:

```bash
python train.py --dataset_dir ./data \
--corrupt_percentage 0.1 \
--checkpoint_dir ./checkpoints \
--train_batch_size 512 \
--eval_batch_size 1024 \
--learning_rate 0.4 \
--weight_decay 0.0001 \
--weight_decay 0.001 \
--num_train_epochs 25 \
--seed 1004
```

Then, compute self-influence scores with the following command:
Then, compute the self-influence scores with:

```bash
python detect_mislabeled_dataset.py --dataset_dir ./data \
--corrupt_percentage 0.1 \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```

On A100 (80GB), it takes roughly 1.5 minutes to compute the self-influence scores.
We can detect around 82% of mislabeled data points by inspecting 10% of the dataset (96% by inspecting 20%).
On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the self-influence scores:

```
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
| Total | - | 11 | 122.28 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
| Compute Self-Influence Score | 61.999 | 1 | 61.999 | 50.701 |
| Fit Lambda | 34.629 | 1 | 34.629 | 28.319 |
| Fit Covariance | 21.807 | 1 | 21.807 | 17.833 |
| Perform Eigendecomposition | 1.8041 | 1 | 1.8041 | 1.4754 |
| Save Covariance | 0.86378 | 1 | 0.86378 | 0.70638 |
| Save Eigendecomposition | 0.84935 | 1 | 0.84935 | 0.69458 |
| Save Lambda | 0.18367 | 1 | 0.18367 | 0.1502 |
| Load Eigendecomposition | 0.052867 | 1 | 0.052867 | 0.043233 |
| Load Covariance | 0.051723 | 1 | 0.051723 | 0.042298 |
| Load All Factors | 0.031986 | 1 | 0.031986 | 0.026158 |
| Save Self-Influence Score | 0.010352 | 1 | 0.010352 | 0.0084653 |
----------------------------------------------------------------------------------------------------------------------------------
```

Around 80% of mislabeled data points can be detected by inspecting 10% of the dataset (97% by inspecting 20%).

<p align="center">
<a href="#"><img width="380" img src="figure/mislabel.png" alt="Mislabeled Data Detection"/></a>
</p>
51 changes: 38 additions & 13 deletions examples/cifar/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.arguments import FactorArguments
from kronfluence.arguments import FactorArguments, ScoreArguments
from kronfluence.task import Task
from kronfluence.utils.common.factor_arguments import all_low_precision_factor_arguments
from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments
from kronfluence.utils.dataset import DataLoaderKwargs

BATCH_TYPE = Tuple[torch.Tensor, torch.Tensor]
Expand Down Expand Up @@ -38,19 +40,30 @@ def parse_args():
help="A path that is storing the final checkpoint of the model.",
)

parser.add_argument(
"--factor_strategy",
type=str,
default="ekfac",
help="Strategy to compute influence factors.",
)
parser.add_argument(
"--query_batch_size",
type=int,
default=1000,
help="Batch size for computing query gradients.",
)
parser.add_argument(
"--factor_strategy",
type=str,
default="ekfac",
help="Strategy to compute influence factors.",
"--use_half_precision",
action="store_true",
default=False,
help="Whether to use half precision for computing factors and scores.",
)
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="Boolean flag to profile computations.",
)

args = parser.parse_args()

if args.checkpoint_dir is not None:
Expand All @@ -71,12 +84,12 @@ def compute_train_loss(
if not sample:
return F.cross_entropy(logits, labels, reduction="sum")
with torch.no_grad():
probs = torch.nn.functional.softmax(logits, dim=-1)
probs = torch.nn.functional.softmax(logits.detach(), dim=-1)
sampled_labels = torch.multinomial(
probs,
num_samples=1,
).flatten()
return F.cross_entropy(logits, sampled_labels.detach(), reduction="sum")
return F.cross_entropy(logits, sampled_labels, reduction="sum")

def compute_measurement(
self,
Expand Down Expand Up @@ -125,31 +138,43 @@ def main():
analysis_name="cifar10",
model=model,
task=task,
profile=args.profile,
)
# Configure parameters for DataLoader.
dataloader_kwargs = DataLoaderKwargs(num_workers=4)
analyzer.set_dataloader_kwargs(dataloader_kwargs)

# Compute influence factors.
factors_name = args.factor_strategy
factor_args = FactorArguments(strategy=args.factor_strategy)
if args.use_half_precision:
factor_args = all_low_precision_factor_arguments(strategy=args.factor_strategy, dtype=torch.bfloat16)
factors_name += "_half"
analyzer.fit_all_factors(
factors_name=args.factor_strategy,
factors_name=factors_name,
factor_args=factor_args,
dataset=train_dataset,
per_device_batch_size=None,
factor_args=factor_args,
overwrite_output_dir=False,
)

# Compute pairwise scores.
score_args = ScoreArguments()
scores_name = factor_args.strategy
if args.use_half_precision:
score_args = all_low_precision_score_arguments(dtype=torch.bfloat16)
scores_name += "_half"
analyzer.compute_pairwise_scores(
scores_name=args.factor_strategy,
factors_name=args.factor_strategy,
scores_name=scores_name,
score_args=score_args,
factors_name=factors_name,
query_dataset=eval_dataset,
query_indices=list(range(2000)),
train_dataset=train_dataset,
per_device_query_batch_size=args.query_batch_size,
overwrite_output_dir=False,
)
scores = analyzer.load_pairwise_scores(args.factor_strategy)["all_modules"]
scores = analyzer.load_pairwise_scores(scores_name)["all_modules"]
logging.info(f"Scores shape: {scores.shape}")


Expand Down
13 changes: 10 additions & 3 deletions examples/cifar/detect_mislabeled_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ def parse_args():
default="ekfac",
help="Strategy to compute influence factors.",
)

parser.add_argument(
"--profile",
action="store_true",
default=False,
help="Boolean flag to profile computations.",
)
args = parser.parse_args()

if args.checkpoint_dir is not None:
Expand Down Expand Up @@ -75,6 +80,7 @@ def main():
analysis_name="mislabeled",
model=model,
task=task,
profile=args.profile,
)
# Configure parameters for DataLoader.
dataloader_kwargs = DataLoaderKwargs(num_workers=4)
Expand All @@ -89,14 +95,15 @@ def main():
factor_args=factor_args,
overwrite_output_dir=False,
)

# Compute self-influence scores.
analyzer.compute_self_scores(
scores_name=args.factor_strategy,
factors_name=args.factor_strategy,
train_dataset=train_dataset,
overwrite_output_dir=True,
overwrite_output_dir=False,
)
scores = analyzer.load_pairwise_scores(args.factor_strategy)["all_modules"]
scores = analyzer.load_self_scores(args.factor_strategy)["all_modules"]

total_corrupt_size = int(args.corrupt_percentage * len(train_dataset))
corrupted_indices = list(range(int(args.corrupt_percentage * len(train_dataset))))
Expand Down
Binary file added examples/cifar/figure/half_precision.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/cifar/figure/mislabel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit e81e653

Please sign in to comment.