Skip to content

Commit

Permalink
Add status.warn so warnings are displayed on Colab
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701288451
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 29, 2024
1 parent 7bfb131 commit a410acc
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
4 changes: 2 additions & 2 deletions kauldron/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

import dataclasses
from typing import Literal, Optional
import warnings

import flax.struct
import jax.numpy as jnp
from kauldron import kontext
from kauldron.metrics import base
from kauldron.metrics import base_state
from kauldron.typing import Bool, Float, typechecked # pylint: disable=g-multiple-import,g-importing-member
from kauldron.utils.status_utils import status # pylint: disable=g-importing-member


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
Expand Down Expand Up @@ -97,7 +97,7 @@ def merge(self, other: base_state.AverageState) -> base_state.AverageState:
assert isinstance(other, Norm.State)

if self.parent.axis is None and self.parent.aggregation_type is None:
warnings.warn(
status.warn(
"When setting axis=None in kd.metrics.Norm and running a TreeReduce"
" over it, Norm will average the norms of individual leaves, rather"
" than computing the norm as if everything was concatenated. Please"
Expand Down
7 changes: 2 additions & 5 deletions kauldron/summaries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@

import abc
from typing import Any
import warnings

from etils import epy
from kauldron import kontext
from kauldron.typing import Float, UInt8, typechecked # pylint: disable=g-multiple-import,g-importing-member
from kauldron.utils.status_utils import status # pylint: disable=g-importing-member

Images = Float["*b h w c"] | UInt8["*b h w c"]

Expand All @@ -44,9 +43,7 @@ def __init_subclass__(cls, **kwargs):
"Migrate to the new kd.metrics.Metric based summaries. "
"See kd.summaries.images.ShowImages for an example."
)
warnings.warn(msg, DeprecationWarning, stacklevel=2)
if epy.is_notebook():
print(f"WARNING: {msg}")
status.warn(msg, DeprecationWarning, stacklevel=2)
super().__init_subclass__(**kwargs)

@abc.abstractmethod
Expand Down
24 changes: 19 additions & 5 deletions kauldron/utils/status_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"""

import functools
import warnings

from absl import logging
from etils import epy
Expand Down Expand Up @@ -84,11 +85,11 @@ def xid(self) -> int:

def log_status(self, msg: str) -> None:
"""Log a message (from lead host), will be displayed on the XM UI."""
self.log(msg, _stacklevel_increment=1)
self.log(msg, stacklevel=2)
if self.on_xmanager and self.is_lead_host:
self.wu.set_notes(msg)

def log(self, msg: str, *, _stacklevel_increment: int = 0) -> None:
def log(self, msg: str, *, stacklevel: int = 1) -> None:
"""Print a message.
* On Colab: Use `print`
Expand All @@ -98,16 +99,29 @@ def log(self, msg: str, *, _stacklevel_increment: int = 0) -> None:
Args:
msg: the message to print.
_stacklevel_increment: If wrapping this function, indicate the number of
frame to skip so logging display the correct caller site.
stacklevel: If wrapping this function, indicate the number of frame to
skip so logging display the correct caller site.
"""
if (
not epy.is_notebook()
):
logging.info(msg, stacklevel=2 + _stacklevel_increment)
logging.info(msg, stacklevel=1 + stacklevel)
return
else:
print(msg, flush=True) # Colab or local

def warn(
self,
msg: str,
category: type[Warning] | None = None,
*,
stacklevel: int = 1,
) -> None:
"""Print a warning."""
warnings.warn(msg, category, stacklevel=1 + stacklevel)
if epy.is_notebook():
category = category or Warning
print(f"{category.__name__}: {msg}")


status = _Status()

0 comments on commit a410acc

Please sign in to comment.