Skip to content

Commit

Permalink
fixed bug in dino models
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasd4 committed Aug 7, 2023
1 parent 04f232b commit 0f33ce9
Show file tree
Hide file tree
Showing 4 changed files with 1,199 additions and 48 deletions.
116 changes: 68 additions & 48 deletions thingsvision/core/extraction/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .tensorflow import TensorFlowExtractor
from .torch import PyTorchExtractor
from thingsvision.utils.models.dino import vit_base, vit_small

# neccessary to prevent gpu memory conflicts between torch and tf
gpus = tf.config.list_physical_devices("GPU")
Expand All @@ -38,15 +39,15 @@

class TorchvisionExtractor(PyTorchExtractor):
def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
preprocess: Optional[Callable] = None,
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
preprocess: Optional[Callable] = None,
) -> None:
model_parameters = (
model_parameters if model_parameters else {"weights": "DEFAULT"},
Expand Down Expand Up @@ -91,12 +92,12 @@ def load_model_from_source(self) -> None:
)

def get_default_transformation(
self,
mean,
std,
resize_dim: int = 256,
crop_dim: int = 224,
apply_center_crop: bool = True,
self,
mean,
std,
resize_dim: int = 256,
crop_dim: int = 224,
apply_center_crop: bool = True,
) -> Any:
if self.weights:
transforms = self.weights.transforms()
Expand All @@ -110,15 +111,15 @@ def get_default_transformation(

class TimmExtractor(PyTorchExtractor):
def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
preprocess: Optional[Callable] = None,
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
preprocess: Optional[Callable] = None,
) -> None:
super().__init__(
model_name=model_name,
Expand All @@ -141,15 +142,15 @@ def load_model_from_source(self) -> None:

class KerasExtractor(TensorFlowExtractor):
def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
preprocess: Optional[Callable] = None,
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
preprocess: Optional[Callable] = None,
) -> None:
model_parameters = (
model_parameters if model_parameters else {"weights": "imagenet"}
Expand Down Expand Up @@ -226,21 +227,25 @@ class SSLExtractor(PyTorchExtractor):
"repository": "facebookresearch/dino:main",
"arch": "dino_vits16",
"type": "hub",
"checkpoint_url": "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
},
"dino-vit-small-p8": {
"repository": "facebookresearch/dino:main",
"arch": "dino_vits8",
"type": "hub",
"checkpoint_url": "https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
},
"dino-vit-base-p16": {
"repository": "facebookresearch/dino:main",
"arch": "dino_vitb16",
"type": "hub",
"checkpoint_url": "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
},
"dino-vit-base-p8": {
"repository": "facebookresearch/dino:main",
"arch": "dino_vitb8",
"type": "hub",
"checkpoint_url": "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
},
"dino-xcit-small-12-p16": {
"repository": "facebookresearch/dino:main",
Expand Down Expand Up @@ -290,15 +295,15 @@ class SSLExtractor(PyTorchExtractor):
}

def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
preprocess: Optional[Callable] = None,
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = field(
default_factory=lambda: {}
),
preprocess: Optional[Callable] = None,
) -> None:
super().__init__(
model_name=model_name,
Expand Down Expand Up @@ -329,7 +334,7 @@ def _download_and_save_model(self, model_url: str, output_model_filepath: str):
return converted_model

def _replace_module_prefix(
self, state_dict: Dict[str, Any], prefix: str, replace_with: str = ""
self, state_dict: Dict[str, Any], prefix: str, replace_with: str = ""
):
"""
Remove prefixes in a state_dict needed when loading models that are not VISSL
Expand Down Expand Up @@ -369,11 +374,26 @@ def load_model_from_source(self) -> None:
self.model.fc = torch.nn.Identity()
self.model.load_state_dict(model_state_dict, strict=True)
elif model_config["type"] == "hub":
self.model = torch.hub.load(
model_config["repository"], model_config["arch"]
)
if model_config["arch"] == "resnet50":
self.model.fc = torch.nn.Identity()
if self.model_name.startswith("dino-vit"):
if self.model_name == "dino-vit-small-p8":
model = vit_small(patch_size=8)
elif self.model_name == "dino-vit-small-p16":
model = vit_small(patch_size=16)
elif self.model_name == "dino-vit-base-p8":
model = vit_base(patch_size=8)
elif self.model_name == "dino-vit-base-p16":
model = vit_base(patch_size=16)
else:
raise ValueError()
state_dict = torch.hub.load_state_dict_from_url(model_config["checkpoint_url"])
model.load_state_dict(state_dict, strict=True)
self.model = model
else:
self.model = torch.hub.load(
model_config["repository"], model_config["arch"]
)
if model_config["arch"] == "resnet50":
self.model.fc = torch.nn.Identity()
else:
raise ValueError(f"\nUnknown model type.\n")
else:
Expand Down
1 change: 1 addition & 0 deletions thingsvision/utils/models/dino/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .vision_transformer import vit_base, vit_small, vit_tiny
Loading

0 comments on commit 0f33ce9

Please sign in to comment.