diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index ba92f3d3..905967bb 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -389,8 +389,9 @@ def fit( the directory name will be `{class name}_{timestamp}`. with_timestamp: Flag to add timestamp string to the last of directory name. - logging_steps: number of steps to log metrics. - logging_strategy: what logging strategy to use. + logging_steps: Number of steps to log metrics. This will be ignored + if logging_strategy is EPOCH. + logging_strategy: Logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. save_interval: Interval to save parameters. diff --git a/d3rlpy/logging/__init__.py b/d3rlpy/logging/__init__.py index b609cd78..6d87e78e 100644 --- a/d3rlpy/logging/__init__.py +++ b/d3rlpy/logging/__init__.py @@ -3,3 +3,4 @@ from .noop_adapter import * from .tensorboard_adapter import * from .utils import * +from .wandb_adapter import * diff --git a/d3rlpy/logging/wandb_adapter.py b/d3rlpy/logging/wandb_adapter.py index 373986eb..aed2ef4e 100644 --- a/d3rlpy/logging/wandb_adapter.py +++ b/d3rlpy/logging/wandb_adapter.py @@ -1,6 +1,6 @@ from typing import Any, Dict, Optional -from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol +from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol __all__ = ["LoggerWanDBAdapter", "WanDBAdapterFactory"] @@ -14,7 +14,11 @@ class LoggerWanDBAdapter(LoggerAdapter): experiment_name (str): Name of the experiment. """ - def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None): + def __init__( + self, + project: Optional[str] = None, + experiment_name: Optional[str] = None, + ): try: import wandb except ImportError as e: @@ -29,16 +33,21 @@ def before_write_metric(self, epoch: int, step: int) -> None: """Callback executed before writing metric.""" pass - def write_metric(self, epoch: int, step: int, name: str, value: float) -> None: + def write_metric( + self, epoch: int, step: int, name: str, value: float + ) -> None: """Writes metric to WandB.""" - self.run.log({name: value, 'epoch': epoch}, step=step) + self.run.log({name: value, "epoch": epoch}, step=step) def after_write_metric(self, epoch: int, step: int) -> None: """Callback executed after writing metric.""" pass def save_model(self, epoch: int, algo: SaveProtocol) -> None: - """Saves models to Weights & Biases. Not implemented for WandB.""" + """Saves models to Weights & Biases. + + Not implemented for WandB. + """ # Implement saving model to wandb if needed pass @@ -50,8 +59,8 @@ def close(self) -> None: class WanDBAdapterFactory(LoggerAdapterFactory): r"""WandB Logger Adapter Factory class. - This class creates instances of the WandB Logger Adapter for experiment tracking. - + This class creates instances of the WandB Logger Adapter for experiment + tracking. """ _project: str @@ -61,7 +70,6 @@ def __init__(self, project: Optional[str] = None) -> None: Args: project (Optional[str], optional): The name of the WandB project. Defaults to None. - """ super().__init__() self._project = project @@ -74,6 +82,7 @@ def create(self, experiment_name: str) -> LoggerAdapter: Returns: LoggerAdapter: Instance of the WandB Logger Adapter. - """ - return LoggerWanDBAdapter(project=self._project, experiment_name=experiment_name) + return LoggerWanDBAdapter( + project=self._project, experiment_name=experiment_name + )