From 1a5b59dc3b45299b1e1b7830fc9e20c0f538d80d Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 16 Apr 2024 17:41:44 -0700 Subject: [PATCH] fix test --- src/test/distributed/checkpoint_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index ae71c2ea..5dc73453 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -449,7 +449,12 @@ def run_save_and_load_fsdp_model(dir, model_factory, model_data_factory, pre_ini # Check optimizer state. for p1, p2 in zip(fsdp_model.parameters(), fsdp_model2.parameters()): - torch.testing.assert_close(optim.state[p1], optim2.state[p2]) + if p1.numel() > 0: + torch.testing.assert_close(optim.state[p1], optim2.state[p2]) + else: + for key in ("exp_avg", "exp_avg_sq"): + assert key not in optim.state or optim.state[p1][key].numel() == 0 + assert key not in optim2.state or optim2.state[p2][key].numel() == 0 # Check unsharding model state. full_model_state = unshard_model_state(dir)