diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 4df529faa285..ad93b5310409 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -281,7 +281,9 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): base_optim.zero_grad() dist_optim.zero_grad() - for p, tp_p in zip(base_param_group, tp_param_group): + base_params = base_model.parameters() + tp_params = tp_model.parameters() + for p, tp_p in zip(base_params, tp_params): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: shard_spec = get_sharding_spec(tp_p) diff --git a/tests/test_optimizer/test_dist_came.py b/tests/test_optimizer/test_dist_came.py index b800f189418c..d662bc6748f8 100644 --- a/tests/test_optimizer/test_dist_came.py +++ b/tests/test_optimizer/test_dist_came.py @@ -138,7 +138,9 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): base_optim.zero_grad() dist_optim.zero_grad() - for p, tp_p in zip(base_param_group, tp_param_group): + base_params = base_model.parameters() + tp_params = tp_model.parameters() + for p, tp_p in zip(base_params, tp_params): param_is_distributed = is_distributed_tensor(tp_p) if param_is_distributed: shard_spec = get_sharding_spec(tp_p) @@ -156,6 +158,7 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): # No TP bias pass correctness_verify(p.data, tp_p.data, dtype) + clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache()