Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log: add histogram metrics for gradients #424

Merged
merged 9 commits into from
Oct 15, 2024

Conversation

hasan-yaman
Copy link
Contributor

Inspired by wandb.watch.
Add support for exporting histogram metrics.
Export gradients to see vanishing / gradient gradients.

Copy link
Owner

@takuseno takuseno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hasan-yaman Thank you for your another PR! This feature is really nice. Please let me share my thoughts here:

  • Having two different methods for the same purpose is a little redundant.
  • write_histogram is currently only used to monitor gradients.

Here is my proposal. We can remove write_histogram and just keep watch_model alone. watch_model is called at every update steps. For wandb, it calls self.run.watch for the first time, but does nothing after that. For FileAdapter and TensorboardAdapter, it computes gradient histograms there every gradient_loggin_steps. What do you think?

@hasan-yaman
Copy link
Contributor Author

@takuseno Thanks for the comments!
This way code looks simpler.

@@ -520,6 +525,8 @@ def fitter(
# save hyperparameters
save_config(self, logger)

logger.watch_model(0, 0, gradient_logging_steps, self)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like it but it is required for wandb watch. Without this line wand doesn't track the first epoch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be fixed via #424 (comment)

Copy link
Owner

@takuseno takuseno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! I noticed that currently, name of parameters are conflicting without their parent module names. I left some suggestions to resolve this.

) -> None:
if logging_steps is not None and step % logging_steps == 0:
for name, grad in algo.impl.modules.get_gradients():
path = os.path.join(self._logdir, f"{name}.csv")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

path = os.path.join(self._logdir, f"{name}_grad.csv")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -388,6 +391,19 @@ def reset_optimizer_states(self) -> None:
if isinstance(v, torch.optim.Optimizer):
v.state = collections.defaultdict(dict)

def get_torch_modules(self) -> List[nn.Module]:
torch_modules: List[nn.Module] = []
for v in asdict_without_copy(self).values():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's nice to return both key and value for get_gradients:

torch_modules = {}
for k, v in asdict_without_copy(self).values():
    if isinstance(v, nn.Module):
                torch_modules[k] = v
return torch_modules

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return torch_modules

def get_gradients(self) -> Iterator[Tuple[str, Float32NDArray]]:
for module in self.get_torch_modules():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concatenate module name and parameter names otherwise names conflict.

for module_name, module in self.get_torch_modules().items():
...
        yield f"{module_name}.{name}", parameter.grad.cpu().detach().numpy()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +69 to +74
self.run.watch(
tuple(algo.impl.modules.get_torch_modules().values()),
log="gradients",
log_freq=logging_steps,
)
self._is_model_watched = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of watch, we can use

self.run.log({"name": wandb.Histogram(...) })

not sure which direction I should choose

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I'm okay with the either way. The current implementation also looks good to me.

@takuseno
Copy link
Owner

@hasan-yaman Thanks for the update! The implementation looks good to me. It seems that CI complains about some typing issues. Once they're fixed, let merge your PR 😄

@hasan-yaman
Copy link
Contributor Author

@takuseno couldn't understand and fix the typing issues. am i missing something?

@takuseno
Copy link
Owner

@hasan-yaman Can you simply add type: ignore just to skip errors? I can fix them later once we merge this. In this PR, I'd like to suppress errors.

@hasan-yaman
Copy link
Contributor Author

@takuseno done!

Copy link
Owner

@takuseno takuseno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for your contribution!

@takuseno takuseno merged commit 3d51ee7 into takuseno:master Oct 15, 2024
4 checks passed
@hasan-yaman hasan-yaman deleted the wandb-log-improvements branch October 15, 2024 12:07
@takuseno
Copy link
Owner

Actually, I'm thinking that it might be good to record gradients every epoch by default. Is there any concern if we do that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants