diff --git a/requirements.txt b/requirements.txt index 50cb09d..f9bfb53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ h5py keras-cv-attention-models>=1.3.5 matplotlib numba -numpy +numpy<=1.26.4 open_clip_torch==2.24.* pandas regex @@ -20,4 +20,4 @@ tqdm dreamsim==0.1.3 git+https://github.com/openai/CLIP.git git+https://github.com/serre-lab/Harmonization.git -transformers==4.40.1 \ No newline at end of file +transformers==4.40.1 diff --git a/thingsvision/_version.py b/thingsvision/_version.py index 492f7d9..763cefd 100644 --- a/thingsvision/_version.py +++ b/thingsvision/_version.py @@ -1 +1 @@ -__version__ = "2.6.7" +__version__ = "2.6.8" diff --git a/thingsvision/core/extraction/extractors.py b/thingsvision/core/extraction/extractors.py index 5646125..72d78ee 100644 --- a/thingsvision/core/extraction/extractors.py +++ b/thingsvision/core/extraction/extractors.py @@ -1,13 +1,13 @@ import os +import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -import timm -import torchvision - import tensorflow as tf import tensorflow.keras.applications as tensorflow_models +import timm import torch +import torchvision try: from torch.hub import load_state_dict_from_url @@ -104,6 +104,11 @@ def get_default_transformation( apply_center_crop: bool = True, ) -> Any: if self.weights: + warnings.warn( + message="\nInput arguments are ignored because transforms are automatically inferred from model weights.\n", + category=UserWarning, + stacklevel=2, + ) transforms = self.weights.transforms() else: transforms = super().get_default_transformation( @@ -141,6 +146,24 @@ def load_model_from_source(self) -> None: f"\nCould not find {self.model_name} in timm library.\nChoose a different model.\n" ) + def get_default_transformation( + self, + mean, + std, + resize_dim: int = 256, + crop_dim: int = 224, + apply_center_crop: bool = True, + ) -> Any: + warnings.warn( + message="\nInput arguments are ignored because automatically infers transforms from model config.\n", + category=UserWarning, + stacklevel=2, + ) + data_config = timm.data.resolve_model_data_config(self.model) + transforms = timm.data.create_transform(**data_config, is_training=False) + + return transforms + class KerasExtractor(TensorFlowExtractor): def __init__(