Skip to content

Commit

Permalink
Merge pull request #3385 from flairNLP/agnews_dataset
Browse files Browse the repository at this point in the history
Add AGNews corpus
  • Loading branch information
alanakbik authored Dec 18, 2023
2 parents c84776c + f180771 commit ddf3bb3
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
2 changes: 2 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@

# Expose all document classification datasets
from .document_classification import (
AGNEWS,
AMAZON_REVIEWS,
COMMUNICATIVE_FUNCTIONS,
GERMEVAL_2018_OFFENSIVE_LANGUAGE,
Expand Down Expand Up @@ -314,6 +315,7 @@
"SentenceDataset",
"MongoDataset",
"StringDataset",
"AGNEWS",
"ANAT_EM",
"AZDZ",
"BC2GM",
Expand Down
68 changes: 68 additions & 0 deletions flair/datasets/document_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,74 @@ def __init__(
super().__init__(data_folder, tokenizer=tokenizer, memory_mode=memory_mode, **corpusargs)


class AGNEWS(ClassificationCorpus):
"""The AG's News Topic Classification Corpus, classifying news into 4 coarse-grained topics.
Labels: World, Sports, Business, Sci/Tech.
"""

def __init__(
self,
base_path: Optional[Union[str, Path]] = None,
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode="partial",
**corpusargs,
):
"""Instantiates AGNews Classification Corpus with 4 classes.
:param base_path: Provide this only if you store the AGNEWS corpus in a specific folder, otherwise use default.
:param tokenizer: Custom tokenizer to use (default is SpaceTokenizer)
:param memory_mode: Set to 'partial' by default. Can also be 'full' or 'none'.
:param corpusargs: Other args for ClassificationCorpus.
"""
base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)

dataset_name = self.__class__.__name__.lower()

data_folder = base_path / dataset_name

# download data from same source as in huggingface's implementations
agnews_path = "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/"

original_filenames = ["train.csv", "test.csv", "classes.txt"]
new_filenames = ["train.txt", "test.txt"]

for original_filename in original_filenames:
cached_path(f"{agnews_path}{original_filename}", Path("datasets") / dataset_name / "original")

data_file = data_folder / new_filenames[0]
label_dict = []
label_path = original_filenames[-1]

# read label order
with open(data_folder / "original" / label_path) as f:
for line in f:
line = line.rstrip()
label_dict.append(line)

original_filenames = original_filenames[:-1]
if not data_file.is_file():
for original_filename, new_filename in zip(original_filenames, new_filenames):
with open(data_folder / "original" / original_filename, encoding="utf-8") as open_fp, open(
data_folder / new_filename, "w", encoding="utf-8"
) as write_fp:
csv_reader = csv.reader(
open_fp, quotechar='"', delimiter=",", quoting=csv.QUOTE_ALL, skipinitialspace=True
)
for id_, row in enumerate(csv_reader):
label, title, description = row
# Original labels are [1, 2, 3, 4] -> ['World', 'Sports', 'Business', 'Sci/Tech']
# Re-map to [0, 1, 2, 3].
text = " ".join((title, description))

new_label = "__label__"
new_label += label_dict[int(label) - 1]

write_fp.write(f"{new_label} {text}\n")

super().__init__(data_folder, label_type="topic", tokenizer=tokenizer, memory_mode=memory_mode, **corpusargs)


class STACKOVERFLOW(ClassificationCorpus):
"""Stackoverflow corpus classifying questions into one of 20 labels.
Expand Down

0 comments on commit ddf3bb3

Please sign in to comment.