Skip to content

Commit

Permalink
Add attribute correlations (#68)
Browse files Browse the repository at this point in the history
* Add spans

* Add attributes_heatmap

* Use cache

* Add LineStatsCC

---------

Co-authored-by: Luca Soldaini <[email protected]>
  • Loading branch information
Muennighoff and soldni authored Nov 22, 2023
1 parent 00504ff commit 734afa3
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 0 deletions.
78 changes: 78 additions & 0 deletions scripts/attributes_heatmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

COLNAME_TO_LABEL = {
"gopher_spans": "Gopher Rules",
"decontamination_spans": "Decontamination",
"hatespeech_spans": "Hate Speech",
"pii_spans": "PII",
"dedupe_paragraphs_spans": "Deduplication",
}


if os.path.exists("corr.csv"):
corr = pd.read_csv("corr.csv", index_col=0)
else:
# A line is e.g.
# {"gopher_span": [], "decontamination_span": [], "hatespeech_span": [], "pii_span": [], "dedupe_paragraphs_span": [[0, 615, 1.0], [615, 1214, 1.0], [1214, 1853, 1.0], [1853, 2417, 1.0], [2417, 2849, 1.0]]}
df = pd.read_json(
#"/home/niklas/dolma/tmp.jsonl/cc_en_head-0000.json", lines=True
"cc_en_head_stats10.jsonl", lines=True
)
### Matching based on the entire doc ###
# Where the span is not empty turn it into True, elsewhere into False
# Compute correlations between the attributes to later turn it into a heatmap
corr = df.map(lambda x: bool(x)).corr(method='pearson')

### Matching based on spans ###
"""
matrix = np.zeros((len(df.columns), len(df.columns)))
columns = df.columns
for _, row in df.iterrows():
# Iterate over the columns
for i, col1 in enumerate(columns):
for j, col2 in enumerate(columns):
# If the columns are the same, skip
if col1 == col2: continue
# Increment if the spans overlap
# e.g. [0, 615, 1.0] & [614, 1214, 1.0] -> 1
# while [0, 615, 1.0] & [700, 1214, 1.0] -> 0
matrix[i, j] += float(
any(
[span1[0] <= span2[0] and span1[1] >= span2[0] for span2 in row[col2]]
for span1 in row[col1]
)
)
corr = matrix / len(df)
corr *= 100
# Add the column names
corr = pd.DataFrame(corr, columns=columns, index=columns)
"""

# Plot the heatmap
plt.figure(figsize=(36, 24))
# define the mask to set the values in the upper triangle to True
mask = np.triu(np.ones_like(corr, dtype=bool))
heatmap = sns.heatmap(
corr.rename(columns=COLNAME_TO_LABEL, index=COLNAME_TO_LABEL),
mask=mask,
vmin=corr.values.min(),
vmax=corr.values[~mask].max(), # Max ignoring the ones in corr
annot=True,
cmap='Blues',
linewidths=0.5,
annot_kws={"fontsize": 32},
cbar=False, # No legend
)

heatmap.set_xticklabels(heatmap.get_xmajorticklabels(), fontsize=32)#, fontweight="bold")
heatmap.set_yticklabels(heatmap.get_ymajorticklabels(), fontsize=32)#, fontweight="bold")

corr.to_csv("corr.csv")
plt.savefig('attributes_heatmap_docbased_9mdocs.pdf', dpi=450, bbox_inches='tight')
plt.savefig('attributes_heatmap_docbased_9mdocs.png', dpi=450, bbox_inches='tight')
115 changes: 115 additions & 0 deletions scripts/dolma_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,121 @@ class v15_cc_c4_cleaned(cc_v1_c4_cleaned):
stats = "s3://ai2-llm/stats/olmo-mix/v15/cc/v1_c4_cleaned/**/*.gz"
decontamination_key: str = 'perplexity_suite_v3_option2'

@Registry.add
class LineStatsCC(cc_v1_c4_cleaned):
# Selection of documents:
# import random; print([random.randint(0, 1334) for _ in range(10)])
documents = [
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0700.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0724.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0788.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-1286.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0600.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0752.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0239.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-1270.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0786.json.gz",
"s3://ai2-llm/pretraining-data/sources/common-crawl/v1-c4-cleaned/documents/cc_en_head/cc_en_head-0857.json.gz",
]
stats = [
"./cc_en_head-0700-stats.json.gz",
"./cc_en_head-0724-stats.json.gz",
"./cc_en_head-0788-stats.json.gz",
"./cc_en_head-1286-stats.json.gz",
"./cc_en_head-0600-stats.json.gz",
"./cc_en_head-0752-stats.json.gz",
"./cc_en_head-0239-stats.json.gz",
"./cc_en_head-1270-stats.json.gz",
"./cc_en_head-0786-stats.json.gz",
"./cc_en_head-0857-stats.json.gz",
]
decontamination_key: str = 'decontamination'

@classmethod
def cli(cls, num_workers: int = 1, debug: bool = False, **process_single_kwargs: Any) -> None:
cls._run_parallel_processor(
stats_root=cls.stats,
num_workers=num_workers,
debug=debug,
**process_single_kwargs,
)

@classmethod
def process_single(
cls, source_path: str, destination_path: str, queue: "Queue[Union[Tuple[int, ...], None]]", **kwargs: Any
):
attributes = [
source_path.replace("/documents/", "/attributes/gopher_rules/"),
source_path.replace("/documents/", f"/attributes/{cls.decontamination_key}/"),
source_path.replace("/documents/", "/attributes/hatespeech_nsfw_cc_v3/"),
source_path.replace("/documents/", "/attributes/pii_detection/"),
source_path.replace("/documents/", "/attributes/dedupe_paragraphs/"),
]

doc_decoder = msgspec.json.Decoder(InputSpec)
attr_decoder = msgspec.json.Decoder(OutputSpec)
documents = 0
interval = 10_000

with ExitStack() as stack:
doc_file = stack.enter_context(smart_open.open(source_path, "rb"))
out_file = stack.enter_context(smart_open.open(destination_path, "wt"))

try:
atts_files = [stack.enter_context(smart_open.open(path, "rb")) for path in attributes]
except Exception:
return

for doc_line, *attr_lines in zip(doc_file, *atts_files):
doc = doc_decoder.decode(doc_line)
attrs = {}
for line in attr_lines:
attrs.update(attr_decoder.decode(line).attributes)
out_line = {}

# Gopher stats
gopher_removal = cls.gopher_rules(attrs)
out_line["gopher_spans"] = gopher_removal

# Deduplication stats
decontamination_removal = attrs.get("bff_duplicate_paragraph_spans_decontamination", [])
out_line["decontamination_spans"] = decontamination_removal

# jigsaw stats
jigsaw_match: List[Tuple[int, int, float]] = []
nsfw = attrs.get("hatespeech_nsfw_cc_v3__jigsaw_nsfw_sencence_v2____label__nsfw", [])
for span in nsfw:
if span[2] > 0.4:
bisect.insort(jigsaw_match, (span[0], span[1], 1.0))

toxic = attrs.get("hatespeech_nsfw_cc_v3__jigsaw_hatespeech_sentence_v2____label__toxic", [])
for span in toxic:
if span[2] > 0.4:
bisect.insort(jigsaw_match, (span[0], span[1], 1.0))

jigsaw_match = cls._merge_spans(jigsaw_match)
out_line["hatespeech_spans"] = jigsaw_match

# PII stats
pii_removal = (
attrs.get("pii_detection__pii_regex_with_counts_fast_v2__EMAIL_ADDRESS", [])
+ attrs.get("pii_detection__pii_regex_with_counts_fast_v2__PHONE_NUMBER", [])
+ attrs.get("pii_detection__pii_regex_with_counts_fast_v2__IP_ADDRESS", [])
)
out_line["pii_spans"] = pii_removal

# Duplicates stats
dups = [p for p in attrs.get("bff_duplicate_paragraph_spans", []) if p[1] - p[0] > 0]
out_line["dedupe_paragraphs_spans"] = dups

documents += 1

if documents % interval == 0:
cls.increment_progressbar(queue, documents=interval)

out_file.write(json.dumps(out_line) + "\n")

cls.increment_progressbar(queue, files=1, documents=documents % interval)

class C4InputSpec(InputSpec):
metadata: Dict[str, Any] = msgspec.field(default_factory=dict)
Expand Down

0 comments on commit 734afa3

Please sign in to comment.