You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In my opinion, the BERT pretrain batch loss in the function _get_batch_loss_bert is not correct. The following is the detail:
The CrossEntropyLoss is initialized with default reduction 'mean', loss = nn.CrossEntropyLoss()
In the function _get_batch_loss_bert, mlm_loss and nsp_loss used the same input instance loss for computation. mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *mlm_weights_X.reshape(-1, 1)
Since the reduction='mean', the resultant tensor of 'loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) ' is a scalar tensor, it leads a problem for mlm loss computation by positionwise product with the input tensor mlm_weights_X.
The text was updated successfully, but these errors were encountered:
lyconghk
changed the title
The mlm loss computation in the function _get_batch_loss_bert is wrong in d2l pytorch code
The mlm loss computation in the function _get_batch_loss_bert seems wrong in d2l pytorch code
Jan 6, 2024
Agree with you @lyconghk . Have you come up with any better solution to apply mlm_weights_X in mlm_l calculation?
The weight parameter of PyTorch CrossEntropyLoss does not seem to support mlm_weights_X in the way that the MXNet does. I guess that is why the PyTorch version of _get_batch_loss_bert calculate mlm_l in this way. It tries to reduce the impact of padded tokens to mlm_l, but it does not use mlm_weights_X in an correct way.
How about just use the package torch.nn import functional to calculate the two cross entropy loss of mlm and nsp?
And remove the input parameter loss in the function _get_batch_loss_ber.
In my opinion, the BERT pretrain batch loss in the function _get_batch_loss_bert is not correct. The following is the detail:
The CrossEntropyLoss is initialized with default reduction 'mean',
loss = nn.CrossEntropyLoss()
In the function _get_batch_loss_bert, mlm_loss and nsp_loss used the same input instance loss for computation.
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *mlm_weights_X.reshape(-1, 1)
Since the reduction='mean', the resultant tensor of 'loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) ' is a scalar tensor, it leads a problem for mlm loss computation by positionwise product with the input tensor mlm_weights_X.
The text was updated successfully, but these errors were encountered: