diff --git a/CHANGELOG.md b/CHANGELOG.md index ef600e43..8276edc6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `io.copy_dir()` function. - Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`. - Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint. +- Added a callback for sending Slack notifications. ### Changed diff --git a/src/olmo_core/internal/common.py b/src/olmo_core/internal/common.py index 5e0ae19c..1c2d426a 100644 --- a/src/olmo_core/internal/common.py +++ b/src/olmo_core/internal/common.py @@ -92,6 +92,7 @@ def build_launch_config( BeakerEnvSecret(name="AWS_CREDENTIALS", secret=f"{beaker_user}_AWS_CREDENTIALS"), BeakerEnvSecret(name="R2_ENDPOINT_URL", secret="R2_ENDPOINT_URL"), BeakerEnvSecret(name="WEKA_ENDPOINT_URL", secret="WEKA_ENDPOINT_URL"), + BeakerEnvSecret(name="SLACK_WEBHOOK_URL", secret="SLACK_WEBHOOK_URL"), ], setup_steps=[ # Clone repo. diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index 017b98f5..4d1e9ee7 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -36,6 +36,7 @@ LMEvaluatorCallbackConfig, ProfilerCallback, SchedulerCallback, + SlackNotifierCallback, WandBCallback, ) from olmo_core.utils import get_default_device, prepare_cli_environment, seed_all @@ -171,6 +172,7 @@ def build_common_components( ), eval_interval=1000, ), + "slack_notifier": SlackNotifierCallback(name=run_name, enabled=False), } return CommonComponents( diff --git a/src/olmo_core/train/callbacks/__init__.py b/src/olmo_core/train/callbacks/__init__.py index f37129f2..0d7883d2 100644 --- a/src/olmo_core/train/callbacks/__init__.py +++ b/src/olmo_core/train/callbacks/__init__.py @@ -21,6 +21,7 @@ from .profiler import ProfilerCallback from .scheduler import SchedulerCallback from .sequence_length_scheduler import SequenceLengthSchedulerCallback +from .slack_notifier import SlackNotificationSetting, SlackNotifierCallback from .speed_monitor import SpeedMonitorCallback from .wandb import WandBCallback @@ -43,6 +44,8 @@ "GradClipperCallback", "MatrixNormalizerCallback", "ProfilerCallback", + "SlackNotifierCallback", + "SlackNotificationSetting", "SchedulerCallback", "SequenceLengthSchedulerCallback", "SpeedMonitorCallback", diff --git a/src/olmo_core/train/callbacks/slack_notifier.py b/src/olmo_core/train/callbacks/slack_notifier.py new file mode 100644 index 00000000..57c42d56 --- /dev/null +++ b/src/olmo_core/train/callbacks/slack_notifier.py @@ -0,0 +1,117 @@ +import os +from dataclasses import dataclass +from typing import Optional + +import requests + +from olmo_core.config import StrEnum +from olmo_core.distributed.utils import get_rank +from olmo_core.exceptions import OLMoEnvironmentError + +from .callback import Callback + +SLACK_WEBHOOK_URL_ENV_VAR = "SLACK_WEBHOOK_URL" + + +class SlackNotificationSetting(StrEnum): + """ + Defines the notifications settings for the Slack notifier callback. + """ + + all = "all" + """ + Send all types notifications. + """ + + end_only = "end_only" + """ + Only send a notification when the experiment ends (successfully or with a failure). + """ + + failure_only = "failure_only" + """ + Only send a notification when the experiment fails. + """ + + none = "none" + """ + Don't send any notifcations. + """ + + +@dataclass +class SlackNotifierCallback(Callback): + name: Optional[str] = None + """ + A name to give the run. + """ + + notifications: SlackNotificationSetting = SlackNotificationSetting.end_only + """ + The notification settings. + """ + + enabled: bool = True + """ + Set to false to disable this callback. + """ + + webhook_url: Optional[str] = None + """ + The webhook URL to post. If not set, will check the environment variable ``SLACK_WEBHOOK_URL``. + """ + + def post_attach(self): + if not self.enabled or get_rank() != 0: + return + + if self.webhook_url is None and SLACK_WEBHOOK_URL_ENV_VAR not in os.environ: + raise OLMoEnvironmentError(f"missing env var '{SLACK_WEBHOOK_URL_ENV_VAR}'") + + def pre_train(self): + if not self.enabled or get_rank() != 0: + return + + if self.notifications == SlackNotificationSetting.all: + self._post_message("started") + + def post_train(self): + if not self.enabled or get_rank() != 0: + return + + if self.notifications in ( + SlackNotificationSetting.all, + SlackNotificationSetting.end_only, + ): + if self.trainer.is_canceled: + self._post_message("canceled") + else: + self._post_message("completed successfully") + + def on_error(self, exc: BaseException): + if not self.enabled or get_rank() != 0: + return + + if self.notifications in ( + SlackNotificationSetting.all, + SlackNotificationSetting.end_only, + SlackNotificationSetting.failure_only, + ): + self._post_message(f"failed with error:\n{exc}") + + def _post_message(self, msg: str): + webhook_url = self.webhook_url or os.environ.get(SLACK_WEBHOOK_URL_ENV_VAR) + if webhook_url is None: + raise OLMoEnvironmentError(f"missing env var '{SLACK_WEBHOOK_URL_ENV_VAR}'") + + progress = ( + f"*Progress:*\n" + f"- step: {self.step:,d}\n" + f"- epoch: {self.trainer.epoch}\n" + f"- tokens: {self.trainer.global_train_tokens_seen:,d}" + ) + if self.name is not None: + msg = f"Run `{self.name}` {msg}\n{progress}" + else: + msg = f"Run {msg}\n{progress}" + requests.post(webhook_url, json={"text": msg})