From b50456a3493bc7cb5a93d532a916de744490ba8d Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 24 Oct 2024 22:01:29 -0700 Subject: [PATCH] sources --- classifiers/scripts/nvidia-deberta-100b.sh | 14 +- sources/stackexchange/requirements.txt | 3 + sources/stackexchange/v0.py | 178 +++++++++++++++++++++ 3 files changed, 194 insertions(+), 1 deletion(-) diff --git a/classifiers/scripts/nvidia-deberta-100b.sh b/classifiers/scripts/nvidia-deberta-100b.sh index 675037e2..5c29e669 100644 --- a/classifiers/scripts/nvidia-deberta-100b.sh +++ b/classifiers/scripts/nvidia-deberta-100b.sh @@ -2,11 +2,23 @@ DOCUMENTS='s3://ai2-llm/pretraining-data/sources/dclm/v0/documents/100*/*.jsonl.zstd' + +# DOCUMENTS='s3://ai2-llm/pretraining-data/sources/dclm/v0/documents/100b/*_dclm_shard_0000*.jsonl.zstd' +# DOCUMENTS='s3://ai2-llm/pretraining-data/sources/dclm/v0/documents/100b/*_dclm_shard_0001*.jsonl.zstd' +# DOCUMENTS='s3://ai2-llm/pretraining-data/sources/dclm/v0/documents/100b/*_dclm_shard_0002*.jsonl.zstd' + +# DOCUMENTS='s3://ai2-llm/pretraining-data/sources/dclm/v0/documents/100b-extras/*_dclm_shard_0000*.jsonl.zstd' +# DOCUMENTS='s3://ai2-llm/pretraining-data/sources/dclm/v0/documents/100b-extras/*_dclm_shard_0001*.jsonl.zstd' +# DOCUMENTS='s3://ai2-llm/pretraining-data/sources/dclm/v0/documents/100b-extras/*_dclm_shard_0002*.jsonl.zstd' + + NUM_NODES=4 +# NUM_NODES=1 MODEL_NAME="nvidia/quality-classifier-deberta" CLUSTER="ai2/jupiter*" BATCH_SIZE=512 PRIORITY="high" +# PRIORITY="urgent" # Generate a hash for the run name by combining model name and documents RUN_HASH=$(echo -n "${MODEL_NAME}${DOCUMENTS}" | md5sum | awk '{print $1}') @@ -42,4 +54,4 @@ gantry run \ --shared-memory 10GiB \ --install "pip install -e classifiers/" \ --yes \ - -- /bin/bash -c "huggingface-cli download ${MODEL_NAME} && torchrun --nnodes "${NUM_NODES}:${NUM_NODES}" --nproc-per-node 8 --rdzv_id 12347 --rdzv_backend static --rdzv_endpoint "\${BEAKER_LEADER_REPLICA_HOSTNAME}:29400" --node_rank "\${BEAKER_REPLICA_RANK}" --rdzv_conf 'read_timeout=420' -m dolma_classifiers.inference --source-prefix ${DOCUMENTS} --batch-size ${BATCH_SIZE} --use-wandb --wandb-project 'dolma-classifiers' --wandb-entity ai2-llm --model-name ${MODEL_NAME} --num-workers 4 --model-compile --max-length 1024" + -- /bin/bash -c "huggingface-cli download ${MODEL_NAME} && torchrun --nnodes "${NUM_NODES}:${NUM_NODES}" --nproc-per-node 8 --rdzv_id 12347 --rdzv_backend static --rdzv_endpoint "\${BEAKER_LEADER_REPLICA_HOSTNAME}:29400" --node_rank "\${BEAKER_REPLICA_RANK}" --rdzv_conf 'read_timeout=3600' -m dolma_classifiers.inference --source-prefix ${DOCUMENTS} --batch-size ${BATCH_SIZE} --use-wandb --wandb-project 'dolma-classifiers' --wandb-entity ai2-llm --model-name ${MODEL_NAME} --num-workers 4 --model-compile --max-length 1024" diff --git a/sources/stackexchange/requirements.txt b/sources/stackexchange/requirements.txt index 327400a2..ed1201d6 100644 --- a/sources/stackexchange/requirements.txt +++ b/sources/stackexchange/requirements.txt @@ -1,2 +1,5 @@ smart-open>=7.0.4 py7zr +lxml +pyarrow +tqdm diff --git a/sources/stackexchange/v0.py b/sources/stackexchange/v0.py index e69de29b..2e0d0214 100644 --- a/sources/stackexchange/v0.py +++ b/sources/stackexchange/v0.py @@ -0,0 +1,178 @@ +import argparse +import io +import os +import sys +from contextlib import ExitStack +from io import BytesIO +from pathlib import Path +from typing import Any, BinaryIO, Dict, Iterator, Optional + +import libarchive +import py7zr +import pyarrow as pa +import pyarrow.parquet as pq +from lxml import etree +from tqdm import tqdm + +os.environ["PYTHONBREAKPOINT"] = "ipdb.set_trace" + + +def get_7z_uncompressed_size(sz_path, entry_name): + with py7zr.SevenZipFile(sz_path, mode="r") as z: + for entry in z.list(): + if entry.filename == entry_name: + return entry.uncompressed + raise FileNotFoundError(f"File {entry_name} not found in archive {sz_path}") + + +def stream_xml_from_7z( + archive_path: str, filename: str, target_xpath: str = "//*", block_size: int = 8192 +) -> Iterator[etree._Element]: + """ + Stream XML nodes from a file within a 7z archive, parsing them lazily. + + Args: + archive_path (str): Path to the 7z archive + filename (str): Name of the XML file within the archive + target_xpath (str, optional): XPath expression to filter nodes. Defaults to "//*". + block_size (int, optional): Size of blocks to read. Defaults to 8192. + + Yields: + lxml.etree._Element: XML nodes matching the target_xpath + + Raises: + FileNotFoundError: If archive or file within archive is not found + ValueError: If file is not valid XML + """ + # Initialize the XML parser that will receive chunks of data + parser = etree.XMLPullParser(events=("end",), recover=True) + + with ExitStack() as stack: + archive = stack.enter_context(libarchive.file_reader(archive_path)) + # Find the target file in the archive + for entry in archive: + if entry.pathname != filename: + continue + + archive_name = os.path.basename(archive_path) + pbar = tqdm( + total=get_7z_uncompressed_size(archive_path, filename), + desc=f"Bytes {archive_name}::{filename}", + unit="B", + unit_scale=True, + ) + + # Create a buffer for reading blocks + buffer = BytesIO() + + # Read the file in blocks + for block in entry.get_blocks(block_size): + buffer.write(block) + pbar.update(len(block)) + + # Feed the current block to the parser + parser.feed(block) + + # Process any completed elements + for event, element in parser.read_events(): + # Only process 'end' events for complete elements + if event == "end": + # Check if the element matches our xpath + if element.xpath(target_xpath): + yield element + # Clear element to free memory + element.clear() + + # Process any remaining data + parser.feed(b"") # Signal EOF to parser + for event, element in parser.read_events(): + if event == "end" and element.xpath(target_xpath): + yield element + element.clear() + + return # Exit after processing the target file + + # If we get here, the file wasn't found + raise FileNotFoundError(f"File {filename} not found in archive {archive_path}") + + +def process_file( + archive_path: str, + output_dir: str, + entry_name: str, + batch_size: int = 100000, +): + entry_prefix, _ = os.path.basename(entry_name.lower()).split(".", 1) + output_dir = os.path.join(output_dir, entry_prefix) + archive_name = os.path.basename(archive_path) + + os.makedirs(output_dir, exist_ok=True) + data = [] + schema = None + + with ExitStack() as stack: + xml_elements = stream_xml_from_7z(archive_path, entry_name) + files_pbar = tqdm(desc=f"Files {archive_name}::{entry_name}") + elements_pbar = tqdm(xml_elements, desc=f"Rows {archive_name}::{entry_name}") + + for element in elements_pbar: + if not element.attrib: + continue + + data.append(dict(element.attrib)) + + if schema is None: + schema = pa.Table.from_pylist(data).schema + + if len(data) >= batch_size: + table = pa.Table.from_pylist(data, schema=schema) + pq.write_table( + table, + os.path.join(output_dir, f"{entry_prefix}_{files_pbar.n:06d}.parquet"), + ) + data = [] + files_pbar.update(1) + # Write any remaining data + + if data: + table = pa.Table.from_pylist(data, schema=schema) + pq.write_table( + table, + os.path.join(output_dir, f"{entry_prefix}_{files_pbar.n:06d}.parquet"), + ) + files_pbar.update(1) + + +def main(): + parser = argparse.ArgumentParser(description="Convert Stack Exchange 7z XML dumps to Parquet format") + parser.add_argument("archive_path", help="Path to the 7z archive") + parser.add_argument("output_dir", help="Directory where Parquet files will be saved") + parser.add_argument( + "--batch-size", type=int, default=100000, help="Number of rows to process at once (default: 100000)" + ) + + args = parser.parse_args() + + if os.path.isdir(args.archive_path): + archive_paths = [ + os.path.join(args.archive_path, p) for p in os.listdir(args.archive_path) if p.endswith("7z") + ] + output_paths = [os.path.join(args.output_dir, os.path.basename(p)) for p in archive_paths] + else: + archive_paths = [args.archive_path] + output_paths = [args.output_dir] + + for archive_path, output_path in tqdm( + zip(archive_paths, output_paths), desc="Archives", total=len(archive_paths) + ): + for entry_name in ["Posts.xml", "Comments.xml"]: + process_file( + archive_path=archive_path, + output_dir=output_path, + entry_name=entry_name, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main()