Skip to content

Commit

Permalink
Merge pull request #703 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v2.6.0
  • Loading branch information
Kevin Musgrave authored Jul 24, 2024
2 parents adfb78c + ef1bd06 commit 899a230
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 7 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/base_test_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
pytorch-version: 1.6
torchvision-version: 0.7
- python-version: 3.9
pytorch-version: 2.1
torchvision-version: 0.16
pytorch-version: 2.3
torchvision-version: 0.18

steps:
- uses: actions/checkout@v2
Expand All @@ -30,7 +30,7 @@ jobs:
- name: Install dependencies
run: |
pip install .[with-hooks-cpu]
pip install torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
pip install "numpy<2.0" torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
pip install --upgrade protobuf==3.20.1
pip install six
pip install packaging
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
],
python_requires=">=3.0",
install_requires=[
"numpy",
"numpy < 2.0",
"scikit-learn",
"tqdm",
"torch >= 1.6.0",
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.5.0"
__version__ = "2.6.0"
19 changes: 17 additions & 2 deletions src/pytorch_metric_learning/utils/distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import torch

from ..losses import BaseMetricLossFunction, CrossBatchMemory
Expand Down Expand Up @@ -93,15 +95,28 @@ def __init__(self, loss, efficient=False):

def forward(
self,
emb,
embeddings,
labels=None,
indices_tuple=None,
ref_emb=None,
ref_labels=None,
enqueue_mask=None,
):
if not is_distributed():
warnings.warn(
"DistributedLossWrapper is being used in a non-distributed setting. Returning the loss as is."
)
return self.loss(embeddings, labels, indices_tuple, ref_emb, ref_labels)

world_size = torch.distributed.get_world_size()
common_args = [emb, labels, indices_tuple, ref_emb, ref_labels, world_size]
common_args = [
embeddings,
labels,
indices_tuple,
ref_emb,
ref_labels,
world_size,
]
if isinstance(self.loss, CrossBatchMemory):
return self.forward_cross_batch(*common_args, enqueue_mask)
return self.forward_regular_loss(*common_args)
Expand Down

0 comments on commit 899a230

Please sign in to comment.