Skip to content

Commit

Permalink
fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu-g committed Nov 18, 2023
1 parent 79a1ae6 commit a24490c
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions src/kwja/modules/functions/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def compute_multi_label_token_mean_loss(
if input_.isnan().any().item() is True:
return torch.tensor(float("nan"), dtype=input_.dtype, device=input_.device)
else:
target = torch.where(mask, target, torch.zeros_like(target))
losses = nn.functional.binary_cross_entropy(input_, target.float(), reduction="none") # (b, seq, num_features)
# features の軸は和をとる
losses = (losses * mask).sum(dim=2) # (b, seq)
Expand Down

0 comments on commit a24490c

Please sign in to comment.