Skip to content

Commit

Permalink
Lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 27, 2024
1 parent bec08c9 commit 7278052
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 12 deletions.
4 changes: 2 additions & 2 deletions examples/glue/evaluate_lds.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


def evaluate_correlations(scores: torch.Tensor) -> float:
margins = torch.from_numpy(torch.load(open(f"files/margins.pt", "rb")))
masks = torch.from_numpy(torch.load(open(f"files/masks.pt", "rb"))).float()
margins = torch.from_numpy(torch.load(open("files/margins.pt", "rb")))
masks = torch.from_numpy(torch.load(open("files/masks.pt", "rb"))).float()

val_indices = np.arange(277)
preds = masks @ scores.T
Expand Down
6 changes: 6 additions & 0 deletions examples/openwebtext/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
```bash
python analyze.py --query_batch_size 32 \
--train_batch_size 64 \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```
18 changes: 13 additions & 5 deletions examples/openwebtext/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@
import torch
import torch.nn.functional as F
from torch import nn
from transformers import default_data_collator, AutoTokenizer
from transformers import default_data_collator

from examples.openwebtext.pipeline import get_openwebtext_dataset, get_custom_dataset, construct_llama3
from examples.openwebtext.pipeline import (
construct_llama3,
get_custom_dataset,
get_openwebtext_dataset,
)
from examples.wikitext.pipeline import construct_gpt2, get_wikitext_dataset
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.arguments import FactorArguments, ScoreArguments
from kronfluence.task import Task
from kronfluence.utils.common.factor_arguments import all_low_precision_factor_arguments, \
extreme_reduce_memory_factor_arguments
from kronfluence.utils.common.factor_arguments import (
all_low_precision_factor_arguments,
extreme_reduce_memory_factor_arguments,
)
from kronfluence.utils.common.score_arguments import all_low_precision_score_arguments
from kronfluence.utils.dataset import DataLoaderKwargs

Expand Down Expand Up @@ -102,7 +108,9 @@ def compute_train_loss(
labels = batch["labels"]
shift_labels = labels[..., 1:].contiguous()
reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
summed_loss = F.cross_entropy(reshaped_shift_logits, shift_labels.view(-1), reduction="sum", ignore_index=-100)
summed_loss = F.cross_entropy(
reshaped_shift_logits, shift_labels.view(-1), reduction="sum", ignore_index=-100
)
else:
reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
with torch.no_grad():
Expand Down
7 changes: 4 additions & 3 deletions examples/openwebtext/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_custom_dataset(
) -> data.Dataset:
data_kwargs = {
"path": "json",
"data_files": f"./data/data.json",
"data_files": "./data/data.json",
"num_proc": 4,
}
raw_datasets = load_dataset(**data_kwargs)["train"]
Expand All @@ -75,7 +75,9 @@ def tokenize_function(examples):
attention_mask = prompt_results["attention_mask"] + completion_results["attention_mask"][1:]
data_dict["input_ids"] = input_ids
data_dict["labels"] = copy.deepcopy(input_ids)
data_dict["labels"][:len(prompt_results["input_ids"])] = [-100 for _ in range(len(prompt_results["input_ids"]))]
data_dict["labels"][: len(prompt_results["input_ids"])] = [
-100 for _ in range(len(prompt_results["input_ids"]))
]
data_dict["attention_mask"] = attention_mask
return data_dict

Expand All @@ -98,4 +100,3 @@ def tokenize_function(examples):

model = construct_llama3()
print(Analyzer.get_module_summary(model))

4 changes: 2 additions & 2 deletions examples/swag/evaluate_lds.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


def evaluate_correlations(scores: torch.Tensor) -> float:
margins = torch.from_numpy(torch.load(open(f"files/margins.pt", "rb")))
masks = torch.from_numpy(torch.load(open(f"files/masks.pt", "rb"))).float()
margins = torch.from_numpy(torch.load(open("files/margins.pt", "rb")))
masks = torch.from_numpy(torch.load(open("files/masks.pt", "rb"))).float()

val_indices = np.arange(277)
preds = masks @ scores.T
Expand Down

0 comments on commit 7278052

Please sign in to comment.