Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
ir2718 committed Oct 28, 2024
1 parent 31eeaaa commit 60c0583
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions tests/losses/test_tcm_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.losses import (
MultipleLosses,
ContrastiveLoss,
ThresholdConsistentMarginLoss,
)
Expand All @@ -16,12 +17,15 @@ class TestThresholdConsistentMarginLoss(unittest.TestCase):
def test_tcm_loss(self):
torch.manual_seed(3459)
for dtype in TEST_DTYPES:
loss_func = ThresholdConsistentMarginLoss(
base_loss=ContrastiveLoss(
distance=CosineSimilarity(),
pos_margin=0.9,
neg_margin=0.4,
)
loss_func = MultipleLosses(
losses=[
ContrastiveLoss(
distance=CosineSimilarity(),
pos_margin=0.9,
neg_margin=0.4,
),
ThresholdConsistentMarginLoss()
]
)
embs = torch.tensor(
[
Expand Down Expand Up @@ -49,11 +53,11 @@ def test_tcm_loss(self):
correct_loss = torch.tensor(1.0045).to(dtype)

with torch.no_grad():
res = loss_func.compute_loss(embs, labels, None, embs, labels)
res = loss_func.forward(embs, labels)
rtol = 1e-2 if dtype == torch.float16 else 1e-5
atol = 1e-4
self.assertTrue(
torch.isclose(
res["loss"]["losses"], correct_loss, rtol=rtol, atol=atol
res, correct_loss, rtol=rtol, atol=atol
)
)

0 comments on commit 60c0583

Please sign in to comment.