Skip to content

Commit

Permalink
change imports
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinMusgrave committed Dec 11, 2024
1 parent c46dd33 commit 99a97f7
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 9 deletions.
8 changes: 4 additions & 4 deletions docs/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ datasets.base_dataset.BaseDataset(
## CUB-200-2011

```python
datasets.cub.CUB(*args, **kwargs)
datasets.CUB(*args, **kwargs)
```

**Defined splits**:
Expand Down Expand Up @@ -75,7 +75,7 @@ train_and_test_dataset = CUB(root="data",
## Cars196

```python
datasets.cars196.Cars196(*args, **kwargs)
datasets.Cars196(*args, **kwargs)
```

**Defined splits**:
Expand Down Expand Up @@ -110,7 +110,7 @@ train_and_test_dataset = Cars196(root="data",
## INaturalist2018

```python
datasets.inaturalist2018.INaturalist2018(*args, **kwargs)
datasets.INaturalist2018(*args, **kwargs)
```

**Defined splits**:
Expand Down Expand Up @@ -146,7 +146,7 @@ train_and_test_dataset = INaturalist2018(root="data",
## StanfordOnlineProducts

```python
datasets.sop.StanfordOnlineProducts(*args, **kwargs)
datasets.StanfordOnlineProducts(*args, **kwargs)
```

**Defined splits**:
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.8.0"
__version__ = "2.8.1"
5 changes: 5 additions & 0 deletions src/pytorch_metric_learning/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .base_dataset import BaseDataset
from .cars196 import Cars196
from .cub import CUB
from .inaturalist2018 import INaturalist2018
from .sop import StanfordOnlineProducts
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
device_from_environ = os.environ.get("TEST_DEVICE", "cuda")
with_collect_stats = os.environ.get("WITH_COLLECT_STATS", "false")
TEST_DATASETS = os.environ.get("TEST_DATASETS", "false")

TEST_DTYPES = [getattr(torch, x) for x in dtypes_from_environ]
TEST_DEVICE = torch.device(device_from_environ)
Expand Down
5 changes: 4 additions & 1 deletion tests/datasets/test_cars196.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from pytorch_metric_learning.datasets.cars196 import Cars196
from pytorch_metric_learning.datasets import Cars196
from .. import TEST_DATASETS


class TestCars196(unittest.TestCase):
Expand All @@ -19,6 +20,7 @@ def setUpClass(cls):
if os.path.exists(cls.CARS_196_ROOT):
cls.ALREADY_EXISTS = True

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_Cars196(self):
train_test_data = Cars196(
root=TestCars196.CARS_196_ROOT, split="train+test", download=True
Expand All @@ -34,6 +36,7 @@ def test_Cars196(self):
self.assertTrue(len(train_data) == 8054)
self.assertTrue(len(test_data) == 8131)

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_CARS_196_dataloader(self):
test_data = Cars196(
root=TestCars196.CARS_196_ROOT,
Expand Down
5 changes: 4 additions & 1 deletion tests/datasets/test_cub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from pytorch_metric_learning.datasets.cub import CUB
from pytorch_metric_learning.datasets import CUB
from .. import TEST_DATASETS


class TestCUB(unittest.TestCase):
Expand All @@ -19,6 +20,7 @@ def setUpClass(cls):
if os.path.exists(cls.CUB_ROOT):
cls.ALREADY_EXISTS = True

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_CUB(self):
train_test_data = CUB(root=TestCUB.CUB_ROOT, split="train+test", download=True)
train_data = CUB(root=TestCUB.CUB_ROOT, split="train", download=True)
Expand All @@ -28,6 +30,7 @@ def test_CUB(self):
self.assertTrue(len(train_data) == 5864)
self.assertTrue(len(test_data) == 5924)

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_CUB_dataloader(self):
test_data = CUB(
root=TestCUB.CUB_ROOT,
Expand Down
5 changes: 4 additions & 1 deletion tests/datasets/test_inaturalist2018.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from pytorch_metric_learning.datasets.inaturalist2018 import INaturalist2018
from pytorch_metric_learning.datasets import INaturalist2018
from .. import TEST_DATASETS


class TestINaturalist2018(unittest.TestCase):
Expand All @@ -19,6 +20,7 @@ def setUpClass(cls):
if os.path.exists(cls.INATURALIST2018_ROOT):
cls.ALREADY_EXISTS = True

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_INaturalist2018(self):
train_test_data = INaturalist2018(
root=TestINaturalist2018.INATURALIST2018_ROOT,
Expand All @@ -36,6 +38,7 @@ def test_INaturalist2018(self):
self.assertTrue(len(train_data) == 325846)
self.assertTrue(len(test_data) == 136093)

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_INaturalist2018_dataloader(self):
test_data = INaturalist2018(
root=TestINaturalist2018.INATURALIST2018_ROOT,
Expand Down
5 changes: 4 additions & 1 deletion tests/datasets/test_sop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from pytorch_metric_learning.datasets.sop import StanfordOnlineProducts
from pytorch_metric_learning.datasets import StanfordOnlineProducts
from .. import TEST_DATASETS


class TestStanfordOnlineProducts(unittest.TestCase):
Expand All @@ -19,6 +20,7 @@ def setUpClass(cls):
if os.path.exists(cls.SOP_ROOT):
cls.ALREADY_EXISTS = True

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_SOP(self):
train_test_data = StanfordOnlineProducts(
root=TestStanfordOnlineProducts.SOP_ROOT, split="train+test", download=True
Expand All @@ -34,6 +36,7 @@ def test_SOP(self):
self.assertTrue(len(train_data) == 59551)
self.assertTrue(len(test_data) == 60502)

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_SOP_dataloader(self):
test_data = StanfordOnlineProducts(
root=TestStanfordOnlineProducts.SOP_ROOT,
Expand Down

0 comments on commit 99a97f7

Please sign in to comment.