Skip to content

Commit

Permalink
Fix formats
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Feb 18, 2024
1 parent 3d46d71 commit 97d9c97
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
5 changes: 3 additions & 2 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,9 @@ def fit(
the directory name will be `{class name}_{timestamp}`.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: number of steps to log metrics.
logging_strategy: what logging strategy to use.
logging_steps: Number of steps to log metrics. This will be ignored
if logging_strategy is EPOCH.
logging_strategy: Logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
save_interval: Interval to save parameters.
Expand Down
1 change: 1 addition & 0 deletions d3rlpy/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .noop_adapter import *
from .tensorboard_adapter import *
from .utils import *
from .wandb_adapter import *
29 changes: 19 additions & 10 deletions d3rlpy/logging/wandb_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, Optional
from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol

from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol

__all__ = ["LoggerWanDBAdapter", "WanDBAdapterFactory"]

Expand All @@ -14,7 +14,11 @@ class LoggerWanDBAdapter(LoggerAdapter):
experiment_name (str): Name of the experiment.
"""

def __init__(self, project: Optional[str] = None, experiment_name: Optional[str] = None):
def __init__(
self,
project: Optional[str] = None,
experiment_name: Optional[str] = None,
):
try:
import wandb
except ImportError as e:
Expand All @@ -29,16 +33,21 @@ def before_write_metric(self, epoch: int, step: int) -> None:
"""Callback executed before writing metric."""
pass

def write_metric(self, epoch: int, step: int, name: str, value: float) -> None:
def write_metric(
self, epoch: int, step: int, name: str, value: float
) -> None:
"""Writes metric to WandB."""
self.run.log({name: value, 'epoch': epoch}, step=step)
self.run.log({name: value, "epoch": epoch}, step=step)

def after_write_metric(self, epoch: int, step: int) -> None:
"""Callback executed after writing metric."""
pass

def save_model(self, epoch: int, algo: SaveProtocol) -> None:
"""Saves models to Weights & Biases. Not implemented for WandB."""
"""Saves models to Weights & Biases.
Not implemented for WandB.
"""
# Implement saving model to wandb if needed
pass

Expand All @@ -50,8 +59,8 @@ def close(self) -> None:
class WanDBAdapterFactory(LoggerAdapterFactory):
r"""WandB Logger Adapter Factory class.
This class creates instances of the WandB Logger Adapter for experiment tracking.
This class creates instances of the WandB Logger Adapter for experiment
tracking.
"""

_project: str
Expand All @@ -61,7 +70,6 @@ def __init__(self, project: Optional[str] = None) -> None:
Args:
project (Optional[str], optional): The name of the WandB project. Defaults to None.
"""
super().__init__()
self._project = project
Expand All @@ -74,6 +82,7 @@ def create(self, experiment_name: str) -> LoggerAdapter:
Returns:
LoggerAdapter: Instance of the WandB Logger Adapter.
"""
return LoggerWanDBAdapter(project=self._project, experiment_name=experiment_name)
return LoggerWanDBAdapter(
project=self._project, experiment_name=experiment_name
)

0 comments on commit 97d9c97

Please sign in to comment.