diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index adbade983..440a03518 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -12,7 +12,8 @@ Criteo1TbDlrmSmallWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner + def key_transform(k): @@ -74,11 +75,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index f96fa672b..607a129f9 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -61,3 +61,41 @@ def out_diff(jax_workload, print(f'Max fprop difference between jax and pytorch: {max_diff}') print(f'Min fprop difference between jax and pytorch: {min_diff}') + + +class ModelDiffRunner: + def __init__(self, jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None) -> None: + """Initializes the instance based on diffing logic. + + Args: + jax_workload: Workload implementation using JAX + pytorch_workload: Workload implementation using PyTorch + jax_model_kwargs: Arguments to be used for model_fn in jax workload + pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch workload + key_transform: Transformation function for keys. + sd_transform: Transformation function for State Dictionary. + out_transform: Transformation function for the output. + """ + + self.jax_workload = jax_workload + self.pytorch_workload = pytorch_workload + self.jax_model_kwargs = jax_model_kwargs + self.pytorch_model_kwargs = pytorch_model_kwargs + self.key_transform = key_transform + self.sd_transform = sd_transform + self.out_transform = out_transform + + def run(self): + out_diff(self.jax_workload, + self.pytorch_workload, + self.jax_model_kwargs, + self.pytorch_model_kwargs, + self.key_transform, + self.sd_transform, + self.out_transform) \ No newline at end of file