From 734afa37325fddf35a5926abe37ae735e8b5e9a8 Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Tue, 21 Nov 2023 23:21:04 -0800 Subject: [PATCH] Add attribute correlations (#68) * Add spans * Add attributes_heatmap * Use cache * Add LineStatsCC --------- Co-authored-by: Luca Soldaini --- scripts/attributes_heatmap.py | 78 +++++++++++++++++++++++ scripts/dolma_stats.py | 115 ++++++++++++++++++++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 scripts/attributes_heatmap.py diff --git a/scripts/attributes_heatmap.py b/scripts/attributes_heatmap.py new file mode 100644 index 00000000..c97e0b8c --- /dev/null +++ b/scripts/attributes_heatmap.py @@ -0,0 +1,78 @@ +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns + +COLNAME_TO_LABEL = { + "gopher_spans": "Gopher Rules", + "decontamination_spans": "Decontamination", + "hatespeech_spans": "Hate Speech", + "pii_spans": "PII", + "dedupe_paragraphs_spans": "Deduplication", +} + + +if os.path.exists("corr.csv"): + corr = pd.read_csv("corr.csv", index_col=0) +else: + # A line is e.g. + # {"gopher_span": [], "decontamination_span": [], "hatespeech_span": [], "pii_span": [], "dedupe_paragraphs_span": [[0, 615, 1.0], [615, 1214, 1.0], [1214, 1853, 1.0], [1853, 2417, 1.0], [2417, 2849, 1.0]]} + df = pd.read_json( + #"/home/niklas/dolma/tmp.jsonl/cc_en_head-0000.json", lines=True + "cc_en_head_stats10.jsonl", lines=True + ) + ### Matching based on the entire doc ### + # Where the span is not empty turn it into True, elsewhere into False + # Compute correlations between the attributes to later turn it into a heatmap + corr = df.map(lambda x: bool(x)).corr(method='pearson') + + ### Matching based on spans ### + """ + matrix = np.zeros((len(df.columns), len(df.columns))) + columns = df.columns + for _, row in df.iterrows(): + # Iterate over the columns + for i, col1 in enumerate(columns): + for j, col2 in enumerate(columns): + # If the columns are the same, skip + if col1 == col2: continue + # Increment if the spans overlap + # e.g. [0, 615, 1.0] & [614, 1214, 1.0] -> 1 + # while [0, 615, 1.0] & [700, 1214, 1.0] -> 0 + matrix[i, j] += float( + any( + [span1[0] <= span2[0] and span1[1] >= span2[0] for span2 in row[col2]] + for span1 in row[col1] + ) + ) + + corr = matrix / len(df) + corr *= 100 + # Add the column names + corr = pd.DataFrame(corr, columns=columns, index=columns) + """ + +# Plot the heatmap +plt.figure(figsize=(36, 24)) +# define the mask to set the values in the upper triangle to True +mask = np.triu(np.ones_like(corr, dtype=bool)) +heatmap = sns.heatmap( + corr.rename(columns=COLNAME_TO_LABEL, index=COLNAME_TO_LABEL), + mask=mask, + vmin=corr.values.min(), + vmax=corr.values[~mask].max(), # Max ignoring the ones in corr + annot=True, + cmap='Blues', + linewidths=0.5, + annot_kws={"fontsize": 32}, + cbar=False, # No legend +) + +heatmap.set_xticklabels(heatmap.get_xmajorticklabels(), fontsize=32)#, fontweight="bold") +heatmap.set_yticklabels(heatmap.get_ymajorticklabels(), fontsize=32)#, fontweight="bold") + +corr.to_csv("corr.csv") +plt.savefig('attributes_heatmap_docbased_9mdocs.pdf', dpi=450, bbox_inches='tight') +plt.savefig('attributes_heatmap_docbased_9mdocs.png', dpi=450, bbox_inches='tight') diff --git a/scripts/dolma_stats.py b/scripts/dolma_stats.py index 9c55e497..d27a2dcf 100644 --- a/scripts/dolma_stats.py +++ b/scripts/dolma_stats.py @@ -639,6 +639,121 @@ class v15_cc_c4_cleaned(cc_v1_c4_cleaned): stats = "s3://ai2-llm/stats/olmo-mix/v15/cc/v1_c4_cleaned/**/*.gz" decontamination_key: str = 'perplexity_suite_v3_option2' +@Registry.add +class LineStatsCC(cc_v1_c4_cleaned): + # Selection of documents: + # import random; print([random.randint(0, 1334) for _ in range(10)]) + documents = [ + "s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0700.json.gz", + "s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0724.json.gz", + "s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0788.json.gz", + "s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-1286.json.gz", + "s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0600.json.gz", + "s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0752.json.gz", + "s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0239.json.gz", + "s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-1270.json.gz", + "s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0786.json.gz", + "s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0857.json.gz", + ] + stats = [ + "./cc_en_head-0700-stats.json.gz", + "./cc_en_head-0724-stats.json.gz", + "./cc_en_head-0788-stats.json.gz", + "./cc_en_head-1286-stats.json.gz", + "./cc_en_head-0600-stats.json.gz", + "./cc_en_head-0752-stats.json.gz", + "./cc_en_head-0239-stats.json.gz", + "./cc_en_head-1270-stats.json.gz", + "./cc_en_head-0786-stats.json.gz", + "./cc_en_head-0857-stats.json.gz", + ] + decontamination_key: str = 'decontamination' + + @classmethod + def cli(cls, num_workers: int = 1, debug: bool = False, **process_single_kwargs: Any) -> None: + cls._run_parallel_processor( + stats_root=cls.stats, + num_workers=num_workers, + debug=debug, + **process_single_kwargs, + ) + + @classmethod + def process_single( + cls, source_path: str, destination_path: str, queue: "Queue[Union[Tuple[int, ...], None]]", **kwargs: Any + ): + attributes = [ + source_path.replace("/documents/", "/attributes/gopher_rules/"), + source_path.replace("/documents/", f"/attributes/{cls.decontamination_key}/"), + source_path.replace("/documents/", "/attributes/hatespeech_nsfw_cc_v3/"), + source_path.replace("/documents/", "/attributes/pii_detection/"), + source_path.replace("/documents/", "/attributes/dedupe_paragraphs/"), + ] + + doc_decoder = msgspec.json.Decoder(InputSpec) + attr_decoder = msgspec.json.Decoder(OutputSpec) + documents = 0 + interval = 10_000 + + with ExitStack() as stack: + doc_file = stack.enter_context(smart_open.open(source_path, "rb")) + out_file = stack.enter_context(smart_open.open(destination_path, "wt")) + + try: + atts_files = [stack.enter_context(smart_open.open(path, "rb")) for path in attributes] + except Exception: + return + + for doc_line, *attr_lines in zip(doc_file, *atts_files): + doc = doc_decoder.decode(doc_line) + attrs = {} + for line in attr_lines: + attrs.update(attr_decoder.decode(line).attributes) + out_line = {} + + # Gopher stats + gopher_removal = cls.gopher_rules(attrs) + out_line["gopher_spans"] = gopher_removal + + # Deduplication stats + decontamination_removal = attrs.get("bff_duplicate_paragraph_spans_decontamination", []) + out_line["decontamination_spans"] = decontamination_removal + + # jigsaw stats + jigsaw_match: List[Tuple[int, int, float]] = [] + nsfw = attrs.get("hatespeech_nsfw_cc_v3__jigsaw_nsfw_sencence_v2____label__nsfw", []) + for span in nsfw: + if span[2] > 0.4: + bisect.insort(jigsaw_match, (span[0], span[1], 1.0)) + + toxic = attrs.get("hatespeech_nsfw_cc_v3__jigsaw_hatespeech_sentence_v2____label__toxic", []) + for span in toxic: + if span[2] > 0.4: + bisect.insort(jigsaw_match, (span[0], span[1], 1.0)) + + jigsaw_match = cls._merge_spans(jigsaw_match) + out_line["hatespeech_spans"] = jigsaw_match + + # PII stats + pii_removal = ( + attrs.get("pii_detection__pii_regex_with_counts_fast_v2__EMAIL_ADDRESS", []) + + attrs.get("pii_detection__pii_regex_with_counts_fast_v2__PHONE_NUMBER", []) + + attrs.get("pii_detection__pii_regex_with_counts_fast_v2__IP_ADDRESS", []) + ) + out_line["pii_spans"] = pii_removal + + # Duplicates stats + dups = [p for p in attrs.get("bff_duplicate_paragraph_spans", []) if p[1] - p[0] > 0] + out_line["dedupe_paragraphs_spans"] = dups + + documents += 1 + + if documents % interval == 0: + cls.increment_progressbar(queue, documents=interval) + + out_file.write(json.dumps(out_line) + "\n") + + cls.increment_progressbar(queue, files=1, documents=documents % interval) class C4InputSpec(InputSpec): metadata: Dict[str, Any] = msgspec.field(default_factory=dict)