diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 56e8db2f..2c00b363 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -997,7 +997,7 @@ def _get_torch_fsdp_state_dict_for_checkpoint(model: nn.Module) -> Dict[str, tor unsharded_shape=tuple(og_shape), unsharded_flattened_offsets=tuple(all_offsets) ) flat_tensor = ShardedFlatTensor(param.data.detach()) - flat_tensor.mark_as_sharded(shard_spec) + flat_tensor.mark_as_sharded(shard_spec, process_group=handle.process_group) param_to_flat_tensor[param] = flat_tensor else: raise NotImplementedError(