Skip to content

Commit

Permalink
Feature/improvement (#100)
Browse files Browse the repository at this point in the history
* support configuring filter_deduplicate

* support normal evaluation

* support configuring more parameters

* dont compute metrics

* predict loss only in evaluation

* hard code to fix evaluation error

* remove label_names

* remove_unused_columns=None

* custom prediction step

* bugfix

* pop labels

* print eval loss

* fix

* fix

* no_grad

* set do_eval

* bugfix

* set evaluation_strategy no when valid_ds is empty
  • Loading branch information
SeanLee97 authored Sep 30, 2024
1 parent 9f5780b commit 71c925c
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 9 deletions.
30 changes: 21 additions & 9 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ def compute_mlm_loss(self, logits, mask_target_labels):
ignore_index=self.pad_token_id,
)

def compute_loss(self, model, inputs, return_outputs=False):
def compute_loss(self, model, inputs, return_outputs: bool = False):
""" Compute loss for AnglE.
:param model: Huggingface model.
Expand Down Expand Up @@ -859,6 +859,11 @@ def compute_loss(self, model, inputs, return_outputs=False):

return (loss, outputs) if return_outputs else loss

@torch.no_grad()
def prediction_step(self, model, inputs, *args, **kwargs):
eval_loss = self.compute_loss(model, inputs, return_outputs=False)
return eval_loss, None, None


class AngleESETrainer(AngleTrainer):
"""
Expand Down Expand Up @@ -1412,13 +1417,15 @@ def detect_dataset_format(self, ds: Dataset):
def fit(self,
train_ds: Dataset,
valid_ds: Optional[Dataset] = None,
valid_ds_for_callback: Optional[Dataset] = None,
batch_size: int = 32,
output_dir: Optional[str] = None,
epochs: int = 1,
learning_rate: float = 1e-5,
warmup_steps: int = 1000,
logging_steps: int = 10,
eval_steps: Optional[int] = None,
eval_steps: int = 1000,
evaluation_strategy: str = 'steps',
save_steps: int = 100,
save_strategy: str = 'steps',
save_total_limit: int = 10,
Expand All @@ -1439,13 +1446,17 @@ def fit(self,
:param train_ds: Dataset. tokenized train dataset. Required.
:param valid_ds: Optional[Dataset]. tokenized valid dataset. Default None.
:param valid_ds_for_callback: Optional[Dataset]. tokenized valid dataset for callback use.
The dataset format should be `DatasetFormats.A`. The spearmans' correlation will be computed
after each epoch training and the best model will be saved. Default None.
:param batch_size: int. Default 32.
:param output_dir: Optional[str]. save dir. Default None.
:param epochs: int. Default 1.
:param learning_rate: float. Default 1e-5.
:param warmup_steps: int. Default 1000.
:param logging_steps: int. Default 10.
:param eval_steps: Optional[int]. Default None.
:param eval_steps: int. Default 1000.
:param evaluation_strategy: str. Default 'steps'.
:param save_steps: int. Default 100.
:param save_strategy: str. Default steps.
:param save_total_limit: int. Default 10.
Expand Down Expand Up @@ -1491,16 +1502,16 @@ def fit(self,
trainer_kwargs = {}

callbacks = None
if valid_ds is not None:
if valid_ds_for_callback is not None:
# check format
for obj in valid_ds:
for obj in valid_ds_for_callback:
if obj['extra']['dataset_format'] != DatasetFormats.A:
raise ValueError('Currently only support evaluation for DatasetFormats.A.')
break
best_ckpt_dir = None
if output_dir is not None:
best_ckpt_dir = os.path.join(output_dir, 'best-checkpoint')
evaluate_callback = EvaluateCallback(self, valid_ds,
evaluate_callback = EvaluateCallback(self, valid_ds_for_callback,
partial(self.evaluate, batch_size=batch_size),
save_dir=best_ckpt_dir,
push_to_hub=push_to_hub,
Expand All @@ -1519,7 +1530,7 @@ def fit(self,
model=self.backbone,
dataset_format=self.detect_dataset_format(train_ds),
train_dataset=train_ds,
eval_dataset=None,
eval_dataset=valid_ds,
loss_kwargs=loss_kwargs,
tokenizer=self.tokenizer,
args=TrainingArguments(
Expand All @@ -1530,14 +1541,15 @@ def fit(self,
learning_rate=learning_rate,
fp16=fp16,
logging_steps=logging_steps,
save_steps=save_steps,
save_strategy=save_strategy,
evaluation_strategy=evaluation_strategy if valid_ds is not None else 'no',
eval_steps=eval_steps,
save_steps=save_steps,
output_dir=output_dir,
save_total_limit=save_total_limit,
load_best_model_at_end=False,
ddp_find_unused_parameters=False if self.gpu_count > 1 else None,
label_names=AnglE.special_columns,
remove_unused_columns=False,
**argument_kwargs,
),
callbacks=callbacks,
Expand Down
38 changes: 38 additions & 0 deletions angle_emb/angle_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,20 @@
help='Specify huggingface datasets subset name for valid set, default None')
parser.add_argument('--valid_split_name', type=str, default='train',
help='Specify huggingface datasets split name for valid set, default `train`')
parser.add_argument('--valid_name_or_path_for_callback', type=str, default=None,
help='Specify huggingface datasets name or local file path for callback valid set. '
'The dataset format should be `DatasetFormats.A`. Default None.')
parser.add_argument('--valid_subset_name_for_callback', type=str, default=None,
help='Specify huggingface datasets subset name for valid set for callback use, default None')
parser.add_argument('--valid_split_name_for_callback', type=str, default='train',
help='Specify huggingface datasets split name for valid set for callback use, default `train`')
parser.add_argument('--prompt_template', type=str, default=None,
help='Specify prompt_template like "xxx: {text}", default None.'
'This prompt will be applied for all text columns.'
'If you want to specify different prompts for different text columns,'
'please handle it in the preprocessing step.')
parser.add_argument('--filter_duplicate', type=int, default=1, choices=[0, 1],
help='Specify filter_duplicate, choices [0, 1], defaut 1')
parser.add_argument('--save_dir', type=str, default=None,
help='Specify save dir, default None')
parser.add_argument('--seed', type=int, default=-1,
Expand Down Expand Up @@ -84,6 +93,11 @@
parser.add_argument('--max_steps', type=int, default=-1,
help='Specify max steps, default -1 (Automatically calculated from epochs)')
parser.add_argument('--save_steps', type=int, default=100, help='Specify save_steps, default 1000')
parser.add_argument('--save_strategy', type=str, default='steps', choices=['steps', 'epoch'],
help='Specify save_strategy, default steps')
parser.add_argument('--eval_steps', type=int, default=1000, help='Specify eval_steps, default 1000')
parser.add_argument('--evaluation_strategy', type=str, default='steps', choices=['steps', 'epoch'],
help='Specify evaluation_strategy, default steps')
parser.add_argument('--batch_size', type=int, default=32, help='Specify batch size, default 32')
parser.add_argument('--maxlen', type=int, default=512, help='Specify max length, default 512')
parser.add_argument('--streaming', action='store_true', default=False,
Expand Down Expand Up @@ -227,6 +241,25 @@ def main():
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
num_proc=args.workers)

valid_ds_for_callback = None
if valid_ds_for_callback is None and args.valid_name_or_path_for_callback is not None:
logger.info('Validation for callback detected, processing validation...')
if os.path.exists(args.valid_name_or_path_for_callback):
valid_ds_for_callback = load_dataset(
'json', data_files=[args.valid_name_or_path_for_callback], num_proc=args.workers)
else:
if args.valid_subset_name_for_callback is not None:
valid_ds_for_callback = load_dataset(
args.valid_name_or_path_for_callback,
args.valid_subset_name_for_callback,
num_proc=args.workers)
else:
valid_ds_for_callback = load_dataset(
args.valid_name_or_path_for_callback, num_proc=args.workers)
valid_ds_for_callback = valid_ds_for_callback[args.valid_split_name_for_callback or 'train'].map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
num_proc=args.workers)

argument_kwargs = {}
if args.push_to_hub:
assert args.hub_model_id is not None, 'Please specify hub_mode_id via --hub_model_id xxx'
Expand Down Expand Up @@ -254,11 +287,15 @@ def main():
model.fit(
train_ds=train_ds,
valid_ds=valid_ds,
valid_ds_for_callback=valid_ds_for_callback,
output_dir=args.save_dir,
batch_size=args.batch_size,
epochs=args.epochs,
learning_rate=args.learning_rate,
save_steps=args.save_steps,
save_strategy=args.save_strategy,
eval_steps=args.eval_steps,
evaluation_strategy=args.evaluation_strategy,
warmup_steps=args.warmup_steps,
logging_steps=args.logging_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
Expand All @@ -271,6 +308,7 @@ def main():
'angle_tau': args.angle_tau,
},
fp16=args.fp16,
filter_duplicate=args.filter_duplicate,
argument_kwargs=argument_kwargs,
apply_ese=args.apply_ese,
trainer_kwargs=trainer_kwargs,
Expand Down

0 comments on commit 71c925c

Please sign in to comment.