diff --git a/.github/workflows/base_test_workflow.yml b/.github/workflows/base_test_workflow.yml index 52fd4429..02367536 100644 --- a/.github/workflows/base_test_workflow.yml +++ b/.github/workflows/base_test_workflow.yml @@ -18,8 +18,8 @@ jobs: pytorch-version: 1.6 torchvision-version: 0.7 - python-version: 3.9 - pytorch-version: 2.1 - torchvision-version: 0.16 + pytorch-version: 2.3 + torchvision-version: 0.18 steps: - uses: actions/checkout@v2 @@ -30,7 +30,7 @@ jobs: - name: Install dependencies run: | pip install .[with-hooks-cpu] - pip install torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall + pip install "numpy<2.0" torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall pip install --upgrade protobuf==3.20.1 pip install six pip install packaging diff --git a/setup.py b/setup.py index 66326d46..2e8a8dee 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ ], python_requires=">=3.0", install_requires=[ - "numpy", + "numpy < 2.0", "scikit-learn", "tqdm", "torch >= 1.6.0", diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index 50062f87..e5e59e38 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "2.5.0" +__version__ = "2.6.0" diff --git a/src/pytorch_metric_learning/utils/distributed.py b/src/pytorch_metric_learning/utils/distributed.py index 93946eed..40dddf90 100644 --- a/src/pytorch_metric_learning/utils/distributed.py +++ b/src/pytorch_metric_learning/utils/distributed.py @@ -1,3 +1,5 @@ +import warnings + import torch from ..losses import BaseMetricLossFunction, CrossBatchMemory @@ -93,15 +95,28 @@ def __init__(self, loss, efficient=False): def forward( self, - emb, + embeddings, labels=None, indices_tuple=None, ref_emb=None, ref_labels=None, enqueue_mask=None, ): + if not is_distributed(): + warnings.warn( + "DistributedLossWrapper is being used in a non-distributed setting. Returning the loss as is." + ) + return self.loss(embeddings, labels, indices_tuple, ref_emb, ref_labels) + world_size = torch.distributed.get_world_size() - common_args = [emb, labels, indices_tuple, ref_emb, ref_labels, world_size] + common_args = [ + embeddings, + labels, + indices_tuple, + ref_emb, + ref_labels, + world_size, + ] if isinstance(self.loss, CrossBatchMemory): return self.forward_cross_batch(*common_args, enqueue_mask) return self.forward_regular_loss(*common_args)