Skip to content

Commit

Permalink
pipeline fix
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Oct 24, 2024
1 parent e22cea5 commit 10ae4f1
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 22 deletions.
3 changes: 1 addition & 2 deletions classifiers/scripts/fineweb_40b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

DOCUMENTS='s3://ai2-llm/pretraining-data/sources/dclm/v0/documents/40b-split/*/*zstd'
NUM_NODES=1
MODEL_NAME="HuggingFaceFW/fineweb-edu-classifier"
CLUSTER="ai2/jupiter*"
BATCH_SIZE=1024
CLUSTER="ai2/neptune*"
PRIORITY="high"


Expand Down
7 changes: 0 additions & 7 deletions classifiers/scripts/fineweb_dclm07.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@ CLUSTER="ai2/jupiter*"
BATCH_SIZE=1024
PRIORITY="urgent"

# Test Values
# DOCUMENTS='s3://ai2-llm/pretraining-data/sources/dclm/v0/documents/40b-split/20b-01/*zstd'
# NUM_NODES=1
# BATCH_SIZE=1024
# CLUSTER="ai2/neptune*"
# PRIORITY="high"

# Generate a hash for the run name by combining model name and documents
RUN_HASH=$(echo -n "${MODEL_NAME}${DOCUMENTS}" | md5sum | awk '{print $1}')
RUN_NAME="fineweb_classifier_${RUN_HASH:0:8}"
Expand Down
24 changes: 12 additions & 12 deletions classifiers/src/dolma_classifiers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class AttributeRow(NamedTuple):


def writer_worker(
error_event: Event,
scores_queue: QueueType[AttributeRow | None],
output_paths_queue: QueueType[OutputPath],
source_destination_mapping: dict[str, str],
Expand Down Expand Up @@ -158,7 +159,6 @@ def writer_worker(
files_writers[source] = smart_open.open(destination_path, "wt", encoding="utf-8")
console_logger.info(f"Opened {destination_path} for writing")


for source, attributes in group_by_source.items():
files_writers[source].write(
encoder.encode_lines(attributes).decode("utf-8")
Expand All @@ -181,9 +181,9 @@ def writer_worker(
f.close()
console_logger.info(f"Closed {source_destination_mapping[path.source]}")
progress_logger.increment(files=1)
elif path is not None and path.count > counts[path.source]:
elif path is not None and counts[path.source] > path.count:
raise RuntimeError(
f"More documents ({path.count}) than expected ({counts[path.source]}) " +
f"More documents ({counts[path.source]}) than expected ({path.count}) " +
f"for source {path.source}. This should not happen!"
)
elif path is not None:
Expand All @@ -194,6 +194,9 @@ def writer_worker(
# more documents still to be written for this source; put it back
output_paths_queue.put(path)
total_count = 0
except Exception as e:
console_logger.error(f"Writer process encountered an error: {e}")
error_event.set()
finally:
for f in files_writers.values():
f.close()
Expand Down Expand Up @@ -247,7 +250,7 @@ def process_documents(

writer_process_error = Event()
writer_process = Process(
target=writer_worker_wrapper,
target=writer_worker,
kwargs=dict(
scores_queue=scores_queue,
output_paths_queue=output_paths_queue,
Expand Down Expand Up @@ -278,7 +281,12 @@ def process_documents(
collate_fn=partial(collate_batch, pad_token_id=getattr(classifier.tokenizer, "pad_token_id", 0)),
)

counts = defaultdict(int)

for batch in data_loader:
for s in batch.sources:
counts[s] += 1

if writer_process_error.is_set():
raise RuntimeError("Writer process encountered an error")

Expand All @@ -299,14 +307,6 @@ def process_documents(

cleanup()

def writer_worker_wrapper(error_event: Event, **kwargs):
try:
writer_worker(**kwargs)
except Exception as e:
console_logger = get_logger("writer_worker_wrapper")
console_logger.error(f"Writer process encountered an error: {e}")
error_event.set()


def longest_common_sequence(paths: list[str]) -> str:
# Split each string by "/"
Expand Down
3 changes: 2 additions & 1 deletion classifiers/src/dolma_classifiers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def setup() -> tuple[int, int]:


def cleanup():
dist.destroy_process_group()
if dist.is_initialized():
dist.destroy_process_group()


def sanitize_model_name(model_name: str, suffix_data: Any = None) -> str:
Expand Down

0 comments on commit 10ae4f1

Please sign in to comment.