-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
26e4fe2
commit 8f419c8
Showing
49 changed files
with
8,910 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import backbone.foundation_models | ||
import backbone.wrapper | ||
import timm | ||
import torch | ||
|
||
import utils.conf | ||
|
||
|
||
def resnet18( | ||
class_names: list[str] = None, | ||
num_classes: int = None, | ||
other_hooks: list = [], | ||
pretrained: bool = False, | ||
replace_output: bool = True, | ||
) -> torch.nn.Module: | ||
model_backbone = backbone.wrapper.ExtractionWrapper( | ||
num_classes=num_classes, | ||
backbone=timm.create_model("resnet18", pretrained=pretrained), | ||
output_hook="fc", | ||
feature_hook="global_pool", | ||
other_hooks=other_hooks, | ||
replace_output=replace_output, | ||
).to(utils.conf.get_device()) | ||
model_backbone.pretrained = pretrained | ||
model_backbone.pretraining = "IN1k" if pretrained else "None" | ||
return model_backbone | ||
|
||
|
||
def resnet50( | ||
class_names: list[str] = None, | ||
num_classes: int = None, | ||
other_hooks: list = [], | ||
pretrained: bool = False, | ||
replace_output: bool = True, | ||
) -> torch.nn.Module: | ||
model_backbone = backbone.wrapper.ExtractionWrapper( | ||
num_classes=num_classes, | ||
backbone=timm.create_model("resnet50", pretrained=pretrained), | ||
output_hook="fc", | ||
feature_hook="global_pool", | ||
other_hooks=other_hooks, | ||
replace_output=replace_output, | ||
).to(utils.conf.get_device()) | ||
model_backbone.pretrained = pretrained | ||
model_backbone.pretraining = "IN1k" if pretrained else "None" | ||
return model_backbone | ||
|
||
|
||
def efficientnet_b2( | ||
class_names: list[str] = None, | ||
num_classes: int = None, | ||
other_hooks: list = [], | ||
pretrained: bool = False, | ||
replace_output: bool = True, | ||
) -> torch.nn.Module: | ||
model_backbone = backbone.wrapper.ExtractionWrapper( | ||
num_classes=num_classes, | ||
backbone=timm.create_model("efficientnet_b2", pretrained=pretrained), | ||
output_hook="classifier", | ||
feature_hook="global_pool", | ||
other_hooks=other_hooks, | ||
replace_output=replace_output, | ||
).to(utils.conf.get_device()) | ||
model_backbone.pretrained = pretrained | ||
model_backbone.pretraining = "IN1k" if pretrained else "None" | ||
return model_backbone | ||
|
||
|
||
def clip_vit_b32( | ||
class_names: list[str] = None, | ||
num_classes: int = None, | ||
other_hooks: list = [], | ||
pretrained: bool = False, | ||
replace_output: bool = False, | ||
) -> torch.nn.Module: | ||
device = utils.conf.get_device() | ||
model_backbone = backbone.wrapper.ExtractionWrapper( | ||
num_classes=num_classes, | ||
backbone=backbone.foundation_models.CLIP( | ||
class_names, backbone="ViT-B/32", random_init=pretrained, device=device | ||
), | ||
output_hook="text_features", | ||
feature_hook="backbone.visual", | ||
other_hooks=other_hooks, | ||
replace_output=replace_output, | ||
).to(device) | ||
model_backbone.pretrained = pretrained | ||
model_backbone.pretraining = "CLIP" if pretrained else "None" | ||
return model_backbone | ||
|
||
|
||
def clip_resnet50( | ||
class_names: list[str] = None, | ||
num_classes: int = None, | ||
other_hooks: list = [], | ||
pretrained: bool = False, | ||
replace_output: bool = False, | ||
) -> torch.nn.Module: | ||
device = utils.conf.get_device() | ||
model_backbone = backbone.wrapper.ExtractionWrapper( | ||
num_classes=num_classes, | ||
backbone=backbone.foundation_models.CLIP( | ||
class_names, backbone="RN50", random_init=pretrained, device=device | ||
), | ||
output_hook="text_features", | ||
feature_hook="backbone.visual", | ||
other_hooks=other_hooks, | ||
replace_output=replace_output, | ||
).to(device) | ||
model_backbone.pretrained = pretrained | ||
model_backbone.pretraining = "CLIP" if pretrained else "None" | ||
return model_backbone |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
def convert_model_to_fp32(model): | ||
for p in model.parameters(): | ||
p.data = p.data.float() | ||
if p.grad: | ||
p.grad.data = p.grad.data.float() | ||
|
||
|
||
def refine_classname(class_names): | ||
for i, class_name in enumerate(class_names): | ||
new_name = class_name.lower().replace("_", " ").replace("-", " ") | ||
new_name = ( | ||
"an " + new_name | ||
if new_name[0] in ["a", "e", "i", "o", "u"] | ||
else "a " + new_name | ||
) | ||
class_names[i] = new_name | ||
return class_names |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. | ||
# All rights reserved. | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.nn.parameter import Parameter | ||
|
||
|
||
class AlphaModule(nn.Module): | ||
def __init__(self, shape): | ||
super(AlphaModule, self).__init__() | ||
if not isinstance(shape, tuple): | ||
shape = (shape,) | ||
self.alpha = Parameter( | ||
torch.rand(tuple([1] + list(shape))) * 0.1, requires_grad=True | ||
) | ||
|
||
def forward(self, x): | ||
return x * self.alpha | ||
|
||
def parameters(self, recurse: bool = True): | ||
yield self.alpha | ||
|
||
|
||
class ListModule(nn.Module): | ||
def __init__(self, *args): | ||
super(ListModule, self).__init__() | ||
self.idx = 0 | ||
for module in args: | ||
self.add_module(str(self.idx), module) | ||
self.idx += 1 | ||
|
||
def append(self, module): | ||
self.add_module(str(self.idx), module) | ||
self.idx += 1 | ||
|
||
def __getitem__(self, idx): | ||
if idx < 0: | ||
idx += self.idx | ||
if idx >= len(self._modules): | ||
raise IndexError("index {} is out of range".format(idx)) | ||
it = iter(self._modules.values()) | ||
for i in range(idx): | ||
next(it) | ||
return next(it) | ||
|
||
def __iter__(self): | ||
return iter(self._modules.values()) | ||
|
||
def __len__(self): | ||
return len(self._modules) |
Oops, something went wrong.