Skip to content

Commit

Permalink
handle exception in throttled case too
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 21, 2025
1 parent 14cae66 commit 5aeca1a
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/olmo_core/distributed/checkpoint/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,13 @@ def write_items(buckets: List[List[WriteItem]]) -> List[WriteResult]:
for bucket in buckets:
file_name = gen_file_name()
path = f"{self.path}/{file_name}"
results.extend(_write_items(path, file_name, bucket, planner))
try:
results.extend(_write_items(path, file_name, bucket, planner))
except BaseException:
# NOTE: we might get an error here that can't be pickled, which causes a different failure
# later when PyTorch tries to reduce that error across ranks. So here we just make
# sure we're raising a simple error type that can be pickled.
raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
return results

results: List[WriteResult]
Expand All @@ -229,15 +235,8 @@ def write_items(buckets: List[List[WriteItem]]) -> List[WriteResult]:
futures = []
for bucket in buckets:
futures.append(executor.submit(write_items, [bucket]))

for f in as_completed(futures):
try:
results.extend(f.result())
except BaseException:
# NOTE: we might get an error here that can't be pickled, which causes a different failure
# later when PyTorch tries to reduce that error across ranks. So here we just make
# sure we're raising a simple error type that can be pickled.
raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
results.extend(f.result())

fut: Future[List[WriteResult]] = Future()
fut.set_result(results)
Expand Down

0 comments on commit 5aeca1a

Please sign in to comment.