Skip to content

Commit

Permalink
TimmExtractor default transformations are now loaded from timm
Browse files Browse the repository at this point in the history
  • Loading branch information
a1247418 authored Jun 18, 2024
1 parent 9f31bdb commit ef21329
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions thingsvision/core/extraction/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,19 @@ def load_model_from_source(self) -> None:
f"\nCould not find {self.model_name} in timm library.\nChoose a different model.\n"
)

def get_default_transformation(
self,
mean,
std,
resize_dim: int = 256,
crop_dim: int = 224,
apply_center_crop: bool = True,
) -> Any:
data_config = timm.data.resolve_model_data_config(self.model)
transforms = timm.data.create_transform(**data_config, is_training=False)

return transforms


class KerasExtractor(TensorFlowExtractor):
def __init__(
Expand Down

0 comments on commit ef21329

Please sign in to comment.