Skip to content

Commit

Permalink
Add a callback for sending Slack notifications (#125)
Browse files Browse the repository at this point in the history
Requires the env var `SLACK_WEBHOOK_URL`.
  • Loading branch information
epwalsh authored Dec 20, 2024
1 parent 6d60464 commit 9e0992b
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `io.copy_dir()` function.
- Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`.
- Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint.
- Added a callback for sending Slack notifications.

### Changed

Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/internal/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def build_launch_config(
BeakerEnvSecret(name="AWS_CREDENTIALS", secret=f"{beaker_user}_AWS_CREDENTIALS"),
BeakerEnvSecret(name="R2_ENDPOINT_URL", secret="R2_ENDPOINT_URL"),
BeakerEnvSecret(name="WEKA_ENDPOINT_URL", secret="WEKA_ENDPOINT_URL"),
BeakerEnvSecret(name="SLACK_WEBHOOK_URL", secret="SLACK_WEBHOOK_URL"),
],
setup_steps=[
# Clone repo.
Expand Down
2 changes: 2 additions & 0 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
LMEvaluatorCallbackConfig,
ProfilerCallback,
SchedulerCallback,
SlackNotifierCallback,
WandBCallback,
)
from olmo_core.utils import get_default_device, prepare_cli_environment, seed_all
Expand Down Expand Up @@ -171,6 +172,7 @@ def build_common_components(
),
eval_interval=1000,
),
"slack_notifier": SlackNotifierCallback(name=run_name, enabled=False),
}

return CommonComponents(
Expand Down
3 changes: 3 additions & 0 deletions src/olmo_core/train/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .profiler import ProfilerCallback
from .scheduler import SchedulerCallback
from .sequence_length_scheduler import SequenceLengthSchedulerCallback
from .slack_notifier import SlackNotificationSetting, SlackNotifierCallback
from .speed_monitor import SpeedMonitorCallback
from .wandb import WandBCallback

Expand All @@ -43,6 +44,8 @@
"GradClipperCallback",
"MatrixNormalizerCallback",
"ProfilerCallback",
"SlackNotifierCallback",
"SlackNotificationSetting",
"SchedulerCallback",
"SequenceLengthSchedulerCallback",
"SpeedMonitorCallback",
Expand Down
117 changes: 117 additions & 0 deletions src/olmo_core/train/callbacks/slack_notifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
from dataclasses import dataclass
from typing import Optional

import requests

from olmo_core.config import StrEnum
from olmo_core.distributed.utils import get_rank
from olmo_core.exceptions import OLMoEnvironmentError

from .callback import Callback

SLACK_WEBHOOK_URL_ENV_VAR = "SLACK_WEBHOOK_URL"


class SlackNotificationSetting(StrEnum):
"""
Defines the notifications settings for the Slack notifier callback.
"""

all = "all"
"""
Send all types notifications.
"""

end_only = "end_only"
"""
Only send a notification when the experiment ends (successfully or with a failure).
"""

failure_only = "failure_only"
"""
Only send a notification when the experiment fails.
"""

none = "none"
"""
Don't send any notifcations.
"""


@dataclass
class SlackNotifierCallback(Callback):
name: Optional[str] = None
"""
A name to give the run.
"""

notifications: SlackNotificationSetting = SlackNotificationSetting.end_only
"""
The notification settings.
"""

enabled: bool = True
"""
Set to false to disable this callback.
"""

webhook_url: Optional[str] = None
"""
The webhook URL to post. If not set, will check the environment variable ``SLACK_WEBHOOK_URL``.
"""

def post_attach(self):
if not self.enabled or get_rank() != 0:
return

if self.webhook_url is None and SLACK_WEBHOOK_URL_ENV_VAR not in os.environ:
raise OLMoEnvironmentError(f"missing env var '{SLACK_WEBHOOK_URL_ENV_VAR}'")

def pre_train(self):
if not self.enabled or get_rank() != 0:
return

if self.notifications == SlackNotificationSetting.all:
self._post_message("started")

def post_train(self):
if not self.enabled or get_rank() != 0:
return

if self.notifications in (
SlackNotificationSetting.all,
SlackNotificationSetting.end_only,
):
if self.trainer.is_canceled:
self._post_message("canceled")
else:
self._post_message("completed successfully")

def on_error(self, exc: BaseException):
if not self.enabled or get_rank() != 0:
return

if self.notifications in (
SlackNotificationSetting.all,
SlackNotificationSetting.end_only,
SlackNotificationSetting.failure_only,
):
self._post_message(f"failed with error:\n{exc}")

def _post_message(self, msg: str):
webhook_url = self.webhook_url or os.environ.get(SLACK_WEBHOOK_URL_ENV_VAR)
if webhook_url is None:
raise OLMoEnvironmentError(f"missing env var '{SLACK_WEBHOOK_URL_ENV_VAR}'")

progress = (
f"*Progress:*\n"
f"- step: {self.step:,d}\n"
f"- epoch: {self.trainer.epoch}\n"
f"- tokens: {self.trainer.global_train_tokens_seen:,d}"
)
if self.name is not None:
msg = f"Run `{self.name}` {msg}\n{progress}"
else:
msg = f"Run {msg}\n{progress}"
requests.post(webhook_url, json={"text": msg})

0 comments on commit 9e0992b

Please sign in to comment.