Skip to content

Commit

Permalink
Fix invalid check of recorded parameter orders in zero stage3. (#2550)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
inkcherry and tjruwase authored Nov 30, 2022
1 parent ffcf384 commit aeda7f9
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,16 +187,18 @@ def reset_step(self) -> None:
f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}")

if not self.is_complete_trace(): # not self.trace_complete:
# Make sure that recorded parameter and submodule orders are
# identical across ranks
# Make sure that recorded submodule orders are identical across ranks
assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order])
assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order])
assert_ints_same_as_other_ranks(
[p.step_id_last_used_at for p in self.__param_order])

if self.is_record_trace():
# Successfully recorded a trace
self.construct_parameter_trace_from_module_trace()
# Make sure that recorded parameter orders are identical across ranks
assert_ints_same_as_other_ranks(
[p.param.ds_id for p in self.__param_order])
assert_ints_same_as_other_ranks(
[p.step_id_last_used_at for p in self.__param_order])

self.__submodule_order = tuple(self.__submodule_order) # freeze
self.__param_order = tuple(self.__param_order) # freeze
self.__trace_mode = ZeRoTraceMode.COMPLETE
Expand Down

0 comments on commit aeda7f9

Please sign in to comment.