From 15be9f2f2de1703f84436a301e81830351238359 Mon Sep 17 00:00:00 2001 From: Shane A Date: Wed, 10 Apr 2024 12:45:36 -0700 Subject: [PATCH] Pass process group to mark_as_sharded (#7) --- src/olmo_core/distributed/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 56e8db2f..2c00b363 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -997,7 +997,7 @@ def _get_torch_fsdp_state_dict_for_checkpoint(model: nn.Module) -> Dict[str, tor unsharded_shape=tuple(og_shape), unsharded_flattened_offsets=tuple(all_offsets) ) flat_tensor = ShardedFlatTensor(param.data.detach()) - flat_tensor.mark_as_sharded(shard_spec) + flat_tensor.mark_as_sharded(shard_spec, process_group=handle.process_group) param_to_flat_tensor[param] = flat_tensor else: raise NotImplementedError(