Skip to content

Commit

Permalink
Fix bug in loss computation
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 13, 2024
1 parent 141109f commit 5e76f33
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions examples/openwebtext/data/data.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"completion": " hydrogen and oxygen atoms."
},
{
"prompt": "Water is composed of",
"completion": " hydrogen and oxygen atoms."
"prompt": "물을 이루는 원소는",
"completion": " 산소와 탄소이다."
}
]
3 changes: 1 addition & 2 deletions examples/openwebtext/fit_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def parse_args():
parser.add_argument(
"--factors_name",
type=str,
required=True,
help="Name of the factor.",
)
parser.add_argument(
Expand Down Expand Up @@ -83,8 +84,6 @@ def main():
)
factor_args.covariance_module_partitions = 2
factor_args.lambda_module_partitions = 4

# For better numerical precision.
factor_args.covariance_data_partitions = 4
factor_args.lambda_data_partitions = 4
analyzer.fit_all_factors(
Expand Down
1 change: 1 addition & 0 deletions examples/openwebtext/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
# MODEL_NAME = "EleutherAI/pythia-70m"
MAX_LENGTH = 512


Expand Down
14 changes: 7 additions & 7 deletions examples/openwebtext/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ def compute_train_loss(
logits = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
).logits
).logits.float()
logits = logits[..., :-1, :].contiguous()
logits = logits.view(-1, logits.size(-1))

labels = batch["labels"][..., 1:].contiguous()
if not sample:
labels = batch["labels"]
shift_labels = labels[..., 1:].contiguous()
summed_loss = F.cross_entropy(logits, shift_labels.view(-1), reduction="sum", ignore_index=-100)
summed_loss = F.cross_entropy(logits, labels.view(-1), reduction="sum", ignore_index=-100)
else:
with torch.no_grad():
probs = torch.nn.functional.softmax(logits.detach(), dim=-1)
sampled_labels = torch.multinomial(
probs,
num_samples=1,
).flatten()
masks = labels.view(-1) == -100
sampled_labels[masks] = -100
summed_loss = F.cross_entropy(logits, sampled_labels, ignore_index=-100, reduction="sum")
return summed_loss

Expand All @@ -45,15 +45,15 @@ def compute_measurement(
logits = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
).logits
).logits.float()
shift_labels = batch["labels"][..., 1:].contiguous().view(-1)
logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1))
return F.cross_entropy(logits, shift_labels, ignore_index=-100, reduction="sum")

def get_influence_tracked_modules(self) -> List[str]:
total_modules = []

# You can uncomment the following lines if you would like to compute influence also on attention layers.
# You can uncomment the following lines if you would like to compute influence on attention layers.
# for i in range(32):
# total_modules.append(f"model.layers.{i}.self_attn.q_proj")
# total_modules.append(f"model.layers.{i}.self_attn.k_proj")
Expand Down

0 comments on commit 5e76f33

Please sign in to comment.