You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In wdgrl.py - gradient_penalty(critic, h_s, h_t): The interpolates created in line 29 are of size (3 x batch_size x feature_size). Following that the gradients are also of size (3 x batch_size x feature_size). When calculating gradients.norm(2, dim=1) in line 35 dimension 1 therefore refers to the batch_size dimension. Is this correct? Intuitively I would have done that across the feature dimension.
The text was updated successfully, but these errors were encountered:
I think the code is wrong. In my opinion, the code in Line 29 should be replaced with "interpolates = torch.cat([interpolates, h_s, h_t]).requires_grad_()"
In wdgrl.py - gradient_penalty(critic, h_s, h_t): The interpolates created in line 29 are of size (3 x batch_size x feature_size). Following that the gradients are also of size (3 x batch_size x feature_size). When calculating gradients.norm(2, dim=1) in line 35 dimension 1 therefore refers to the batch_size dimension. Is this correct? Intuitively I would have done that across the feature dimension.
The text was updated successfully, but these errors were encountered: