Skip to content

Commit

Permalink
simplified writer
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Oct 24, 2024
1 parent 08a44f2 commit d89e24e
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 16 deletions.
2 changes: 1 addition & 1 deletion classifiers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Run [Huggingface FineWeb classifier](https://huggingface.co/HuggingFaceFW/finewe

```bash
python -m dolma_classifiers.inference \
-s 's3://ai2-llm/pretraining-data/sources/dclm/v0/documents/40b-split/*/*zstd' \
-s 's3://ai2-llm/pretraining-data/sources/dclm/v0/documents/40b-split/20b-01/*zstd' \
-m HuggingFaceFW/fineweb-edu-classifier
```

Expand Down
99 changes: 84 additions & 15 deletions classifiers/src/dolma_classifiers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def writer_worker(
console_logger = get_logger("writer_worker")

files_writers = {}
with KeyedExitStack() as stack:
try:
encoder = msgspec.json.Encoder()
counts = defaultdict(int)
total_count = 0
Expand All @@ -156,35 +156,104 @@ def writer_worker(
if element is None:
break

if element.source not in stack:
if element.source not in files_writers:
destination_path = source_destination_mapping[element.source]
stack.push(
element.source,
smart_open.open(destination_path, "wt", encoding="utf-8")
)
files_writers[element.source] = smart_open.open(destination_path, "wt", encoding="utf-8")
console_logger.info(f"Opened {destination_path} for writing")

stack[element.source].write(
files_writers[element.source].write(
encoder.encode_lines(element.attributes).decode("utf-8")
)
progress_logger.increment(docs=len(element.attributes))
counts[element.source] += len(element.attributes)
total_count += len(element.attributes)

if total_count > log_every:
# we iterate at most once over the output queue (this avoids infinite loops if elements are popped back into the queue)
for _ in range(output_paths_queue.qsize()):
# we at most close one file per log_every documents
try:
# get the paths from the output queue (these have been fully processed)
path = output_paths_queue.get()
if path.count > counts[path.source]:
# more documents still to be written for this source; put it back
output_paths_queue.put(path)
break
path = output_paths_queue.get_nowait()
except Empty:
path = None

if path is not None and path.count == counts[path.source]:
# if path is not None:
# I've finished processing this source; close the file
stack.pop(path.source).close()
f = files_writers.pop(path.source)
f.close()
console_logger.info(f"Closed {source_destination_mapping[path.source]}")
progress_logger.increment(files=1)
elif path is not None:
console_logger.info(
f"Tried to close {source_destination_mapping[path.source]}, " +
f"but only seen {counts[path.source]}/{path.count} documents"
)
# more documents still to be written for this source; put it back
output_paths_queue.put(path)
finally:
for f in files_writers.values():
f.close()

# def writer_worker(
# scores_queue: QueueType[AttributeRow | None],
# output_paths_queue: QueueType[OutputPath],
# source_destination_mapping: dict[str, str],
# log_every: int = 10_000,
# ):

# progress_logger = ProgressLogger(log_every=log_every, wandb_logger=WandbLogger())
# console_logger = get_logger("writer_worker")

# files_writers = {}
# with KeyedExitStack() as stack:
# encoder = msgspec.json.Encoder()
# counts = defaultdict(int)
# total_count = 0

# while True:
# if scores_queue.qsize() == 0:
# time.sleep(0.1)
# continue

# element = scores_queue.get()
# if element is None:
# break

# if element.source not in stack:
# destination_path = source_destination_mapping[element.source]
# stack.push(
# element.source,
# smart_open.open(destination_path, "wt", encoding="utf-8")
# )
# console_logger.info(f"Opened {destination_path} for writing")

# stack[element.source].write(
# encoder.encode_lines(element.attributes).decode("utf-8")
# )
# progress_logger.increment(docs=len(element.attributes))
# counts[element.source] += len(element.attributes)
# total_count += len(element.attributes)

# if total_count > log_every:
# # we at most close one file per log_every documents
# try:
# # get the paths from the output queue (these have been fully processed)
# path = output_paths_queue.get_nowait()
# except Empty:
# path = None

# if path is not None and path.count == counts[path.source]:
# # I've finished processing this source; close the file
# # print(len(stack))
# f = stack.pop(path.source)
# # print(len(stack))
# # breakpoint()
# f.close()
# console_logger.info(f"Closed {source_destination_mapping[path.source]}")
# progress_logger.increment(files=1)
# elif path is not None:
# # more documents still to be written for this source; put it back
# output_paths_queue.put(path)


def process_documents(
Expand Down

0 comments on commit d89e24e

Please sign in to comment.