Skip to content

Commit

Permalink
Merge pull request #176 from ViCCo-Group/timm_extractor_transforms_fix
Browse files Browse the repository at this point in the history
Fixed TimmExtractor default transforms
  • Loading branch information
LukasMut authored Jun 18, 2024
2 parents 9f31bdb + 34e04c2 commit ad903f3
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 6 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ h5py
keras-cv-attention-models>=1.3.5
matplotlib
numba
numpy
numpy<=1.26.4
open_clip_torch==2.24.*
pandas
regex
Expand All @@ -20,4 +20,4 @@ tqdm
dreamsim==0.1.3
git+https://github.com/openai/CLIP.git
git+https://github.com/serre-lab/Harmonization.git
transformers==4.40.1
transformers==4.40.1
2 changes: 1 addition & 1 deletion thingsvision/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.6.7"
__version__ = "2.6.8"
29 changes: 26 additions & 3 deletions thingsvision/core/extraction/extractors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import timm
import torchvision

import tensorflow as tf
import tensorflow.keras.applications as tensorflow_models
import timm
import torch
import torchvision

try:
from torch.hub import load_state_dict_from_url
Expand Down Expand Up @@ -104,6 +104,11 @@ def get_default_transformation(
apply_center_crop: bool = True,
) -> Any:
if self.weights:
warnings.warn(
message="\nInput arguments are ignored because transforms are automatically inferred from model weights.\n",
category=UserWarning,
stacklevel=2,
)
transforms = self.weights.transforms()
else:
transforms = super().get_default_transformation(
Expand Down Expand Up @@ -141,6 +146,24 @@ def load_model_from_source(self) -> None:
f"\nCould not find {self.model_name} in timm library.\nChoose a different model.\n"
)

def get_default_transformation(
self,
mean,
std,
resize_dim: int = 256,
crop_dim: int = 224,
apply_center_crop: bool = True,
) -> Any:
warnings.warn(
message="\nInput arguments are ignored because <timm> automatically infers transforms from model config.\n",
category=UserWarning,
stacklevel=2,
)
data_config = timm.data.resolve_model_data_config(self.model)
transforms = timm.data.create_transform(**data_config, is_training=False)

return transforms


class KerasExtractor(TensorFlowExtractor):
def __init__(
Expand Down

0 comments on commit ad903f3

Please sign in to comment.