Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: 'SWA' object has no attribute '_optimizer_step_pre_hooks'. Did you mean: '_optimizer_step_code'? #36

Open
pbenner opened this issue May 12, 2023 · 0 comments

Comments

@pbenner
Copy link

pbenner commented May 12, 2023

The pytorch Optimizer class has changed with recent releases, which leads to the following error:

[...]
stepping every 16 training passes, cycling lr every 1 epochs
checkin at 2 epochs to match lr scheduler
Traceback (most recent call last):
  File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/eval.py", line 141, in <module>
    run_cv(X, y, f'eval-{task}-{target}.txt', n_splits)
  File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/eval.py", line 95, in run_cv
    model = train_model()
  File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/eval.py", line 62, in train_model
    model.fit(epochs=1000, losscurve=False)
  File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/model.py", line 228, in fit
    self.train()
  File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/model.py", line 140, in train
    self.optimizer.step()
  File "/home/pbenner/.local/opt/anaconda3/envs/crysfeat/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 69, in wrapper
    return wrapped(*args, **kwargs)
  File "/home/pbenner/.local/opt/anaconda3/envs/crysfeat/lib/python3.10/site-packages/torch/optim/optimizer.py", line 271, in wrapper
    for pre_hook in chain(_global_optimizer_pre_hooks.values(), self._optimizer_step_pre_hooks.values()):
AttributeError: 'SWA' object has no attribute '_optimizer_step_pre_hooks'. Did you mean: '_optimizer_step_code'?

The following patch fixed the issue:

diff --git a/utils/optim.py b/utils/optim.py
index 33008dd..18224ea 100644
--- a/utils/optim.py
+++ b/utils/optim.py
@@ -1,6 +1,7 @@
-from collections import defaultdict
+from collections import defaultdict, OrderedDict
 from itertools import chain
 from torch.optim import Optimizer
+from typing import Callable, Dict
 import torch
 import warnings
 import numpy as np
@@ -116,6 +117,8 @@ class SWA(Optimizer):
         self.optimizer = optimizer
 
         self.defaults = self.optimizer.defaults
+        self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
+        self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
         self.param_groups = self.optimizer.param_groups
         self.state = defaultdict(dict)
         self.opt_state = self.optimizer.state
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant