Skip to content

Commit

Permalink
Add torchao base example
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Jan 16, 2025
1 parent b2cce71 commit e1a1304
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions benchmarks/fp8/torchao/non_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import evaluate
import torch
from functools import partial
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -104,8 +105,6 @@ def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=No
return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler




def evaluate_model(model, dataloader, metric, accelerator=None):
"Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
model.eval()
Expand All @@ -120,19 +119,31 @@ def evaluate_model(model, dataloader, metric, accelerator=None):
return metric.compute()


def module_filter_func(module, *args):
def module_filter_func(module, fqn, first_layer_name=None, last_layer_name=None):
if isinstance(module, torch.nn.Linear):
if module.in_features % 16 != 0 or module.out_features % 16 != 0:
return False

# For stability reasons, we skip the first and last linear layers
# Otherwise can lead to the model not training or converging properly
if fqn in (first_layer_name, last_layer_name):
return False
return True


def train_baseline():
set_seed(42)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
first_linear = None
last_linear = None
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if first_linear is None:
first_linear = name
last_linear = name

func = partial(module_filter_func, first_layer_name=first_linear, last_layer_name=last_linear)
model.to("cuda")
convert_to_float8_training(model, module_filter_fn=module_filter_func)
convert_to_float8_training(model, module_filter_fn=func)
base_model_results = evaluate_model(model, eval_dataloader, METRIC)
model.train()

Expand Down

0 comments on commit e1a1304

Please sign in to comment.