diff --git a/opacus/grad_sample/grad_sample_module.py b/opacus/grad_sample/grad_sample_module.py index b7e07491..14ca1a02 100644 --- a/opacus/grad_sample/grad_sample_module.py +++ b/opacus/grad_sample/grad_sample_module.py @@ -207,7 +207,7 @@ def add_hooks( ) self.autograd_grad_sample_hooks.append( - module.register_backward_hook( + module.register_full_backward_hook( partial( self.capture_backprops_hook, loss_reduction=loss_reduction,