Skip to content

Commit

Permalink
integrated cls token into extraction pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasMut committed Aug 4, 2023
1 parent 3079f3c commit b626d98
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 33 deletions.
6 changes: 6 additions & 0 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@
"pretrained": True,
"source": "ssl",
},
"dino-vit-base-p8": {
"modules": ["norm"],
"pretrained": True,
"source": "ssl",
"extract_cls_token": True,
},
# Harmonization models
"Harmonization": {
"modules": ["visual"],
Expand Down
32 changes: 21 additions & 11 deletions thingsvision/core/extraction/extractors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Any, Dict
from dataclasses import field
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import tensorflow as tf
Expand All @@ -13,9 +14,10 @@
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url

from thingsvision.utils.checkpointing import get_torch_home

from .tensorflow import TensorFlowExtractor
from .torch import PyTorchExtractor
from thingsvision.utils.checkpointing import get_torch_home

# neccessary to prevent gpu memory conflicts between torch and tf
gpus = tf.config.list_physical_devices("GPU")
Expand All @@ -41,8 +43,10 @@ def __init__(
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict = None,
preprocess: Any = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
preprocess: Optional[Callable] = None,
) -> None:
model_parameters = (
model_parameters if model_parameters else {"weights": "DEFAULT"},
Expand Down Expand Up @@ -111,8 +115,10 @@ def __init__(
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict = None,
preprocess: Any = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
preprocess: Optional[Callable] = None,
) -> None:
super().__init__(
model_name=model_name,
Expand Down Expand Up @@ -140,8 +146,10 @@ def __init__(
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict = None,
preprocess: Any = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
preprocess: Optional[Callable] = None,
) -> None:
model_parameters = (
model_parameters if model_parameters else {"weights": "imagenet"}
Expand Down Expand Up @@ -287,8 +295,10 @@ def __init__(
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict = None,
preprocess: Any = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
preprocess: Optional[Callable] = None,
) -> None:
super().__init__(
model_name=model_name,
Expand Down Expand Up @@ -319,7 +329,7 @@ def _download_and_save_model(self, model_url: str, output_model_filepath: str):
return converted_model

def _replace_module_prefix(
self, state_dict: Dict[str, Any], prefix: str, replace_with: str = ""
self, state_dict: Dict[str, Any], prefix: str, replace_with: str = ""
):
"""
Remove prefixes in a state_dict needed when loading models that are not VISSL
Expand Down
6 changes: 5 additions & 1 deletion thingsvision/core/extraction/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,18 @@ def create_custom_extractor(
import thingsvision.custom_models.harmonization as harmonization

custom_model = getattr(harmonization, model_name)
elif model_name == "DreamSim":
import thingsvision.custom_models.dreamsim as dreamsim

custom_model = getattr(dreamsim, model_name)
else:
import thingsvision.custom_models as custom_models

if hasattr(custom_models, model_name):
custom_model = getattr(custom_models, model_name)
else:
raise ValueError(
f"\nCould not find {model_name} among available custom models.\nChoose a different model.\n"
f"\nCould not find {model_name} among available custom models.\nChoose a different model that is available.\n"
)
model_parameters = model_parameters if model_parameters else {}
custom_model = custom_model(device, model_parameters)
Expand Down
8 changes: 5 additions & 3 deletions thingsvision/core/extraction/tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import field
from typing import Any, List
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np

Expand All @@ -21,9 +21,11 @@ def __init__(
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Any = field(default_factory=lambda: {}),
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
model: Any = None,
preprocess: Any = None,
preprocess: Optional[Callable] = None,
) -> None:
super().__init__(device, preprocess)
self.model_name = model_name
Expand Down
23 changes: 11 additions & 12 deletions thingsvision/core/extraction/torch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import field
from typing import Any, Callable, Iterator, List, Optional, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Union

import numpy as np
import torch
Expand All @@ -18,9 +18,11 @@ def __init__(
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Any = field(default_factory=lambda: {}),
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
model: Any = None,
preprocess: Any = None,
preprocess: Optional[Callable] = None,
) -> None:
super().__init__(device, preprocess)
self.model_name = model_name
Expand All @@ -31,10 +33,9 @@ def __init__(
self.activations = {}
self.hook_handle = None

if 'extract_cls_token' in self.model_parameters:
self.extract_cls_token = self.model_parameters['extract_cls_token']
else:
self.extract_cls_token = False
if isinstance(self.model_parameters, dict):
if "extract_cls_token" in self.model_parameters:
self.extract_cls_token = self.model_parameters["extract_cls_token"]

if not self.model:
self.load_model()
Expand Down Expand Up @@ -73,11 +74,6 @@ def hook(model, input, output) -> None:
act = output[0]
else:
act = output

if self.extract_cls_token:
# we extract only the representations of the first token (cls token)
act = act[:, 0]

try:
self.activations[name] = act.clone().detach()
except AttributeError:
Expand Down Expand Up @@ -111,6 +107,9 @@ def _extract_batch(
batch = batch.to(self.device)
_ = self.forward(batch)
act = self.activations[module_name]
if hasattr(self, "extract_cls_token"):
# we are only interested in the representations of the first token, i.e., [cls] token
act = act[:, 0, :]
if flatten_acts:
if self.model_name.lower().startswith("clip"):
act = self.flatten_acts(act, batch, module_name)
Expand Down
1 change: 0 additions & 1 deletion thingsvision/custom_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .alexnet_ecoset import Alexnet_ecoset
from .alexnet_salobjsub import AlexNet_SalObjSub
from .dreamsim import DreamSim
from .inception_ecoset import Inception_ecoset
from .official_clip import clip
from .openclip import OpenCLIP
Expand Down
1 change: 1 addition & 0 deletions thingsvision/custom_models/dreamsim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dreamsim import DreamSim
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ def __init__(self, model_type, device) -> None:

self.model_type = model_type
self.device = device
model_dir = os.path.join(get_torch_home(), 'dreamsim')
model_dir = os.path.join(get_torch_home(), "dreamsim")
self.model, _ = dreamsim(
pretrained=True, dreamsim_type=model_type, normalize_embeds=False,
device=device, cache_dir=model_dir
pretrained=True,
dreamsim_type=model_type,
normalize_embeds=False,
device=device,
cache_dir=model_dir,
)

def forward(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -59,4 +62,3 @@ def preprocess(pil_img) -> Callable:

def create_model(self) -> Tuple[nn.Module, Callable]:
return DreamSimModel(self.variant, self.device), self.preprocess

2 changes: 1 addition & 1 deletion thingsvision/custom_models/vgg16bn_ecoset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, device, parameters) -> None:

def create_model(self) -> Any:
model = models.vgg16_bn(weights=None, num_classes=565)
path_to_weights = "https://osf.io/z5uf3/download"
path_to_weights = "https://osf.io/fe7s5/download"
state_dict = torch.hub.load_state_dict_from_url(
path_to_weights, map_location=self.device, file_name="VGG16bn_ecoset"
)
Expand Down

0 comments on commit b626d98

Please sign in to comment.