Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable option to train on huggingface datasets. #17

Merged
merged 8 commits into from
Oct 17, 2023
2 changes: 1 addition & 1 deletion elpis/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from elpis.datasets.dataset import CleaningOptions, Dataset, ProcessingBatch
from elpis.datasets.preprocessing import process_batch
from elpis.datasets.processing import create_dataset, prepare_dataset
from elpis.datasets.processing import prepare_dataset, create_dataset

__all__ = [
"CleaningOptions",
Expand Down
24 changes: 15 additions & 9 deletions elpis/datasets/clean_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,37 @@
def clean_text(
text: str,
words_to_remove: Optional[List[str]] = None,
punctuation_to_explode: str = "",
punctuation_to_remove: str = "",
characters_to_explode: str = "",
characters_to_remove: str = "",
to_lower=True,
) -> str:
"""Cleans the text based on the supplied options.

Parameters:
text: The text to clean.
options: The cleaning options.
words_to_remove: Words to remove from the text.
characters_to_remove: A string of chars to remove from the text.
characters_to_explode: A string of chars to replace with spaces in the text.
to_lower: True iff the resulting text should be converted to lower case.
Converts to uppercase if False.

Returns:
The cleaned text
"""
words = text.upper().split()
words = text.split()

if words_to_remove is not None:
words = filter(lambda word: word not in words_to_remove, words)

if punctuation_to_explode != "":
words = map(lambda word: explode(word, punctuation_to_explode), words)
if characters_to_explode != "":
words = map(lambda word: explode(word, characters_to_explode), words)

if punctuation_to_remove != "":
words = map(lambda word: collapse(word, punctuation_to_remove), words)
if characters_to_remove != "":
words = map(lambda word: collapse(word, characters_to_remove), words)

result = " ".join(words).strip()
return remove_consecutive_spaces(result)
result = remove_consecutive_spaces(result)
return result.lower() if to_lower else result.upper()


def explode(text: str, pattern: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions elpis/datasets/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def clean_annotation(
transcript = clean_text(
text=annotation.transcript,
words_to_remove=cleaning_options.words_to_remove,
punctuation_to_explode=cleaning_options.punctuation_to_explode,
punctuation_to_remove=cleaning_options.punctuation_to_remove,
characters_to_explode=cleaning_options.punctuation_to_explode,
characters_to_remove=cleaning_options.punctuation_to_remove,
)
result = copy(annotation)
result.transcript = transcript
Expand Down
135 changes: 126 additions & 9 deletions elpis/datasets/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,31 @@
from pathlib import Path
from typing import Any, Dict, List

from datasets import Audio, DatasetDict, load_dataset
from datasets import (
Audio,
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from loguru import logger
from transformers import AutoFeatureExtractor, AutoTokenizer

from elpis.datasets.clean_text import clean_text
from elpis.models.job import Job

LOGGING_TRANSCRIPT_SAMPLE = 2


def create_dataset(
def create_dataset(job: Job) -> DatasetDict | IterableDatasetDict:
if Path(job.data_args.dataset_name_or_path).is_dir():
return create_local_dataset(job)

return create_hf_dataset(job)


def create_local_dataset(
job: Job,
test_size: float = 0.2,
) -> DatasetDict:
Expand Down Expand Up @@ -65,18 +80,64 @@ def resolve_audio_path(row: Dict[str, Any]) -> Dict[str, Any]:
return dataset


def create_hf_dataset(job: Job) -> DatasetDict | IterableDatasetDict:
data_args = job.data_args

dataset = DatasetDict()
if data_args.stream_dataset:
dataset = IterableDatasetDict()

if job.training_args.do_train:
dataset["train"] = load_dataset(
data_args.dataset_name_or_path,
data_args.dataset_config_name,
streaming=data_args.stream_dataset,
split=data_args.train_split_name,
token=data_args.token,
)

if data_args.audio_column_name not in dataset["train"].column_names:
raise ValueError(
f"audio_column_name '{data_args.audio_column_name}' not found"
f" in dataset '{data_args.dataset_name_or_path}'."
" Make sure to set `audio_column_name` to the correct audio column - one of"
f" {', '.join(dataset['train'].column_names)}."
)

if data_args.text_column_name not in dataset["train"].column_names:
raise ValueError(
f"text_column_name {data_args.text_column_name} not found"
f" in dataset '{data_args.dataset_name_or_path}'. "
"Make sure to set `text_column_name` to the correct text column - one of "
f"{', '.join(dataset['train'].column_names)}."
)

if job.training_args.do_eval:
dataset["eval"] = load_dataset(
data_args.dataset_name_or_path,
data_args.dataset_config_name,
split=data_args.eval_split_name,
token=data_args.token,
streaming=data_args.stream_dataset,
)

return dataset


def prepare_dataset(
job: Job,
tokenizer: AutoTokenizer,
feature_extractor: AutoFeatureExtractor,
dataset: DatasetDict,
) -> DatasetDict:
dataset: DatasetDict | IterableDatasetDict,
) -> DatasetDict | IterableDatasetDict:
"""Runs some preprocessing over the given dataset.

Parameters:
dataset: The dataset on which to apply the preprocessing
processor: The processor to apply over the dataset
"""
dataset = clean_dataset(job, dataset)
dataset = constrain_to_max_samples(job, dataset)

# Load the audio data and resample if necessary.
dataset = dataset.cast_column(
Expand Down Expand Up @@ -114,20 +175,76 @@ def is_audio_in_length_range(length: int):

with job.training_args.main_process_first(desc="dataset map preprocessing"):
worker_count = job.data_args.preprocessing_num_workers

kwargs = {}
if not job.data_args.stream_dataset:
kwargs = {
"num_proc": worker_count,
"desc": "Dataset Preprocessing",
}

dataset = dataset.map(
_prepare_dataset,
remove_columns=next(iter(dataset.values())).column_names,
num_proc=worker_count,
desc="preprocess datasets",
**kwargs,
)

# filter data that is shorter than min_input_length
dataset = dataset.filter(
is_audio_in_length_range,
num_proc=worker_count,
input_columns=["input_length"],
is_audio_in_length_range, input_columns=["input_length"], **kwargs
)

logger.info(f"Test encoding labels: {dataset['train'][0]['labels']}")

return dataset


def constrain_to_max_samples(
job: Job, dataset: DatasetDict | IterableDatasetDict
) -> DatasetDict | IterableDatasetDict:
max_train_samples = job.data_args.max_train_samples
max_eval_samples = job.data_args.max_eval_samples

def take(n: int, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
if job.data_args.stream_dataset:
return dataset.take(n) # type: ignore
return dataset.select(range(n)) # type: ignore

if job.training_args.do_train and max_train_samples is not None:
dataset["train"] = take(max_train_samples, dataset["train"]) # type: ignore

if job.training_args.do_eval and max_eval_samples is not None:
dataset["eval"] = take(max_eval_samples, dataset["eval"]) # type: ignore

return dataset


def clean_dataset(
job: Job, dataset: DatasetDict | IterableDatasetDict
) -> DatasetDict | IterableDatasetDict:
if not job.data_args.do_clean:
return dataset

text_column = job.data_args.text_column_name

def clean(batch: Dict[str, Any]):
characters_to_remove = "".join(job.data_args.chars_to_remove or [])
characters_to_explode = "".join(job.data_args.chars_to_explode or [])

batch[text_column] = (
clean_text(
batch[text_column],
words_to_remove=job.data_args.words_to_remove,
characters_to_remove=characters_to_remove,
characters_to_explode=characters_to_explode,
to_lower=job.data_args.do_lower_case or True,
)
+ " " # Note: not sure why this is necessary, but saw in hf docs.
)

return batch

with job.training_args.main_process_first(desc="Dataset cleaning."):
dataset = dataset.map(clean)

return dataset
12 changes: 11 additions & 1 deletion elpis/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from elpis.models.annotation import Annotation
from elpis.models.elan_options import ElanOptions, ElanTierSelector
from elpis.models.job import DataArguments, Job, ModelArguments
from elpis.models.vocab import VOCAB_FILE, Vocab

__all__ = ["Annotation", "ElanOptions", "ElanTierSelector", "Vocab", "VOCAB_FILE"]
__all__ = [
"Annotation",
"ElanOptions",
"ElanTierSelector",
"Job",
"Vocab",
"VOCAB_FILE",
"DataArguments",
"ModelArguments",
]
55 changes: 45 additions & 10 deletions elpis/models/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
from typing import Any, Dict, List, Optional

from transformers import HfArgumentParser, TrainingArguments

Expand Down Expand Up @@ -133,6 +133,12 @@ class DataArguments:
"help": "The configuration name of the dataset to use (via the datasets library)."
},
)
stream_dataset: bool = field(
default=False,
metadata={
"help": "Whether to stream the dataset as opposed to downloading it all at once."
},
)
train_split_name: str = field(
default="train+validation",
metadata={
Expand Down Expand Up @@ -186,11 +192,33 @@ class DataArguments:
)
},
)
chars_to_ignore: Optional[List[str]] = list_field(
do_clean: bool = field(
default=True,
metadata={"help": "True if the dataset should be cleaned before use."},
)
words_to_remove: Optional[List[str]] = list_field(
default=[],
metadata={
"help": "A list of words to remove from the transcripts during dataset cleaning."
},
)
chars_to_remove: Optional[List[str]] = list_field(
default=[],
metadata={
"help": "A list of characters to remove from the transcripts during dataset cleaning."
},
)
chars_to_explode: Optional[List[str]] = list_field(
default=[],
metadata={
"help": "A list of characters to replace with spaces in the transcripts during dataset cleaning."
},
)
do_lower_case: Optional[bool] = field(
default=None,
metadata={"help": "A list of characters to remove from the transcripts."},
metadata={"help": "Whether the target text should be lower cased."},
)
eval_metrics: List[str] = list_field(
eval_metrics: List[str] = list_field( # type: ignore
default=DEFAULT_METRICS,
metadata={
"help": "A list of metrics the model should be evaluated on. E.g. `('wer', 'cer')`"
Expand Down Expand Up @@ -270,10 +298,6 @@ class DataArguments:
)
},
)
do_lower_case: Optional[bool] = field(
default=None,
metadata={"help": "Whether the target text should be lower cased."},
)


@dataclass
Expand All @@ -289,12 +313,12 @@ def parser():
return HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) # type: ignore

@classmethod
def from_args(cls) -> Job:
def from_args(cls, args=None) -> Job:
(
model_args,
data_args,
training_args,
) = Job.parser().parse_args_into_dataclasses()
) = Job.parser().parse_args_into_dataclasses(args)
return cls(
model_args=model_args, data_args=data_args, training_args=training_args
)
Expand All @@ -309,3 +333,14 @@ def from_json(cls, file: Path) -> Job:
return cls(
model_args=model_args, data_args=data_args, training_args=training_args
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Job:
(
model_args,
data_args,
training_args,
) = Job.parser().parse_dict(data)
return cls(
model_args=model_args, data_args=data_args, training_args=training_args
)
2 changes: 2 additions & 0 deletions elpis/models/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Set

from datasets import DatasetDict

VOCAB_FILE = "vocab.json"


Expand Down
Loading