Skip to content

Commit

Permalink
fix: smooth bce loss generics
Browse files Browse the repository at this point in the history
  • Loading branch information
sokovninn committed Jan 29, 2025
1 parent 44f40ba commit 32f4bfc
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .bce_with_logits import BCEWithLogitsLoss


class SmoothBCEWithLogitsLoss(BaseLoss[list[Tensor], Tensor]):
class SmoothBCEWithLogitsLoss(BaseLoss[Tensor, Tensor]):
supported_tasks: list[TaskType] = [
TaskType.SEGMENTATION,
TaskType.CLASSIFICATION,
Expand Down Expand Up @@ -65,6 +65,7 @@ def forward(self, predictions: Tensor, target: Tensor) -> Tensor:
@rtype: Tensor
@return: A scalar tensor.
"""
print(predictions[0].shape)
if predictions.shape != target.shape:
raise RuntimeError(
f"Target tensor dimension ({target.shape}) and predictions tensor "
Expand Down

0 comments on commit 32f4bfc

Please sign in to comment.