Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom logging API for wandb alternatives (ex. ClearML) #264

Open
arcyleung opened this issue Dec 12, 2024 · 0 comments
Open

Support custom logging API for wandb alternatives (ex. ClearML) #264

arcyleung opened this issue Dec 12, 2024 · 0 comments

Comments

@arcyleung
Copy link

arcyleung commented Dec 12, 2024

Proposal: send the log_entries from /src/nanotron/trainer.py train_step_logs() to a generic reporter class, as currently it supports only wandb

Before:

if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None:
wandb.log(
{
**{log_item.tag: log_item.scalar_value for log_item in log_entries},
"iteration_step": self.iteration_step,
}
)

After:

 if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and self.reporter is not None:
    metrics = {log_item.tag: log_item.scalar_value for log_item in log_entries}
    self.reporter.log(metrics, self.iteration_step)
class Reporter(ABC):   
    @abstractmethod
    def log(self, metrics: Dict[str, Any], step: int) -> None:
        pass

class WandBReporter(Reporter)
    def __init__(self, project: str, run_name: str, config: Optional[Dict] = None):
          try:
              import wandb
          except ImportError:
              raise ImportError("Install wandb to use WandBReporter")
              
          wandb.init(
                project=self.config.general.project,
                name=f"{current_time}_{self.config.general.run}",
                config={"nanotron_config": self.config.as_dict()},
            )
        self.wandb = wandb
  
    def log(self, metrics: Dict[str, Any], step: int) -> None:
          """Log metrics to WandB."""
          self.wandb.log({
              **metrics,
              "iteration_step": step
          })

class ClearMLReporter(Reporter):
    def __init__(
        self, 
        project_name: str,
        task_name: str,
        config: Optional[Dict] = None,
        tags: Optional[List[str]] = None,
        auto_connect_frameworks: bool = True,
        auto_connect_arg_parser: bool = True,
    ):
        try:
            from clearml import Task
        except ImportError:
            raise ImportError("Please install clearml to use ClearMLReporter")

        self.task = Task.init(
            project_name=project_name,
            task_name=task_name,
            auto_connect_frameworks=auto_connect_frameworks,
            auto_connect_arg_parser=auto_connect_arg_parser
        )

        if tags:
            self.task.add_tags(tags)
            
        # Log config as hyperparameters if provided
        if config:
            self.task.connect(config)
        self.logger = self.task.get_logger()

    def log(self, metrics: Dict[str, Any], step: int) -> None:
        for name, value in metrics.items():
            self.logger.report_scalar(
                title="Training Metrics",
                series=name,
                value=value,
                iteration=step
            )

class MLFlowReporter(Reporter):
    def __init__(self, experiment_name: str):
        import mlflow
        mlflow.set_experiment(experiment_name)
        self.mlflow = mlflow

    def log(self, metrics: Dict[str, Any], step: int) -> None:
        self.mlflow.log_metrics(metrics, step=step)

class TensorBoardReporter(Reporter):
    def __init__(self, log_dir: str):
        from torch.utils.tensorboard import SummaryWriter
        self.writer = SummaryWriter(log_dir)

    def log(self, metrics: Dict[str, Any], step: int) -> None:
        for name, value in metrics.items():
            self.writer.add_scalar(name, value, step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant