Skip to content

Commit

Permalink
Fix param fetching bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Nov 29, 2024
1 parent 36c8597 commit ba73d39
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
4 changes: 3 additions & 1 deletion tests/test_optimizer/test_dist_adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
base_optim.zero_grad()
dist_optim.zero_grad()

for p, tp_p in zip(base_param_group, tp_param_group):
base_params = base_model.parameters()
tp_params = tp_model.parameters()
for p, tp_p in zip(base_params, tp_params):
param_is_distributed = is_distributed_tensor(tp_p)
if param_is_distributed:
shard_spec = get_sharding_spec(tp_p)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_optimizer/test_dist_came.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
base_optim.zero_grad()
dist_optim.zero_grad()

for p, tp_p in zip(base_param_group, tp_param_group):
base_params = base_model.parameters()
tp_params = tp_model.parameters()
for p, tp_p in zip(base_params, tp_params):
param_is_distributed = is_distributed_tensor(tp_p)
if param_is_distributed:
shard_spec = get_sharding_spec(tp_p)
Expand All @@ -156,6 +158,7 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
# No TP bias
pass
correctness_verify(p.data, tp_p.data, dtype)

clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
Expand Down

0 comments on commit ba73d39

Please sign in to comment.