Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor diffs module to use ModelDiffRunner class #820

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/modeldiffs/criteo1tb/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
Criteo1TbDlrmSmallWorkload as JaxWorkload
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \
Criteo1TbDlrmSmallWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.diff import ModelDiffRunner



def key_transform(k):
Expand Down Expand Up @@ -74,11 +75,11 @@ def sd_transform(sd):
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
ModelDiffRunner(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=key_transform,
sd_transform=sd_transform,
out_transform=None)
out_transform=None).run()
38 changes: 38 additions & 0 deletions tests/modeldiffs/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,41 @@ def out_diff(jax_workload,

print(f'Max fprop difference between jax and pytorch: {max_diff}')
print(f'Min fprop difference between jax and pytorch: {min_diff}')


class ModelDiffRunner:
def __init__(self, jax_workload,
pytorch_workload,
jax_model_kwargs,
pytorch_model_kwargs,
key_transform=None,
sd_transform=None,
out_transform=None) -> None:
"""Initializes the instance based on diffing logic.

Args:
jax_workload: Workload implementation using JAX
pytorch_workload: Workload implementation using PyTorch
jax_model_kwargs: Arguments to be used for model_fn in jax workload
pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch workload
key_transform: Transformation function for keys.
sd_transform: Transformation function for State Dictionary.
out_transform: Transformation function for the output.
"""

self.jax_workload = jax_workload
self.pytorch_workload = pytorch_workload
self.jax_model_kwargs = jax_model_kwargs
self.pytorch_model_kwargs = pytorch_model_kwargs
self.key_transform = key_transform
self.sd_transform = sd_transform
self.out_transform = out_transform

def run(self):
out_diff(self.jax_workload,
self.pytorch_workload,
self.jax_model_kwargs,
self.pytorch_model_kwargs,
self.key_transform,
self.sd_transform,
self.out_transform)
Loading