Skip to content

Commit

Permalink
file push 2/3
Browse files Browse the repository at this point in the history
  • Loading branch information
Confusezius committed Dec 10, 2024
1 parent 26e4fe2 commit 8f419c8
Show file tree
Hide file tree
Showing 49 changed files with 8,910 additions and 0 deletions.
Empty file added .gitattributes
Empty file.
540 changes: 540 additions & 0 deletions backbones/__init__.py

Large diffs are not rendered by default.

112 changes: 112 additions & 0 deletions backbones/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import backbone.foundation_models
import backbone.wrapper
import timm
import torch

import utils.conf


def resnet18(
class_names: list[str] = None,
num_classes: int = None,
other_hooks: list = [],
pretrained: bool = False,
replace_output: bool = True,
) -> torch.nn.Module:
model_backbone = backbone.wrapper.ExtractionWrapper(
num_classes=num_classes,
backbone=timm.create_model("resnet18", pretrained=pretrained),
output_hook="fc",
feature_hook="global_pool",
other_hooks=other_hooks,
replace_output=replace_output,
).to(utils.conf.get_device())
model_backbone.pretrained = pretrained
model_backbone.pretraining = "IN1k" if pretrained else "None"
return model_backbone


def resnet50(
class_names: list[str] = None,
num_classes: int = None,
other_hooks: list = [],
pretrained: bool = False,
replace_output: bool = True,
) -> torch.nn.Module:
model_backbone = backbone.wrapper.ExtractionWrapper(
num_classes=num_classes,
backbone=timm.create_model("resnet50", pretrained=pretrained),
output_hook="fc",
feature_hook="global_pool",
other_hooks=other_hooks,
replace_output=replace_output,
).to(utils.conf.get_device())
model_backbone.pretrained = pretrained
model_backbone.pretraining = "IN1k" if pretrained else "None"
return model_backbone


def efficientnet_b2(
class_names: list[str] = None,
num_classes: int = None,
other_hooks: list = [],
pretrained: bool = False,
replace_output: bool = True,
) -> torch.nn.Module:
model_backbone = backbone.wrapper.ExtractionWrapper(
num_classes=num_classes,
backbone=timm.create_model("efficientnet_b2", pretrained=pretrained),
output_hook="classifier",
feature_hook="global_pool",
other_hooks=other_hooks,
replace_output=replace_output,
).to(utils.conf.get_device())
model_backbone.pretrained = pretrained
model_backbone.pretraining = "IN1k" if pretrained else "None"
return model_backbone


def clip_vit_b32(
class_names: list[str] = None,
num_classes: int = None,
other_hooks: list = [],
pretrained: bool = False,
replace_output: bool = False,
) -> torch.nn.Module:
device = utils.conf.get_device()
model_backbone = backbone.wrapper.ExtractionWrapper(
num_classes=num_classes,
backbone=backbone.foundation_models.CLIP(
class_names, backbone="ViT-B/32", random_init=pretrained, device=device
),
output_hook="text_features",
feature_hook="backbone.visual",
other_hooks=other_hooks,
replace_output=replace_output,
).to(device)
model_backbone.pretrained = pretrained
model_backbone.pretraining = "CLIP" if pretrained else "None"
return model_backbone


def clip_resnet50(
class_names: list[str] = None,
num_classes: int = None,
other_hooks: list = [],
pretrained: bool = False,
replace_output: bool = False,
) -> torch.nn.Module:
device = utils.conf.get_device()
model_backbone = backbone.wrapper.ExtractionWrapper(
num_classes=num_classes,
backbone=backbone.foundation_models.CLIP(
class_names, backbone="RN50", random_init=pretrained, device=device
),
output_hook="text_features",
feature_hook="backbone.visual",
other_hooks=other_hooks,
replace_output=replace_output,
).to(device)
model_backbone.pretrained = pretrained
model_backbone.pretraining = "CLIP" if pretrained else "None"
return model_backbone
Empty file added backbones/utils/__init__.py
Empty file.
17 changes: 17 additions & 0 deletions backbones/utils/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
def convert_model_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
if p.grad:
p.grad.data = p.grad.data.float()


def refine_classname(class_names):
for i, class_name in enumerate(class_names):
new_name = class_name.lower().replace("_", " ").replace("-", " ")
new_name = (
"an " + new_name
if new_name[0] in ["a", "e", "i", "o", "u"]
else "a " + new_name
)
class_names[i] = new_name
return class_names
53 changes: 53 additions & 0 deletions backbones/utils/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter


class AlphaModule(nn.Module):
def __init__(self, shape):
super(AlphaModule, self).__init__()
if not isinstance(shape, tuple):
shape = (shape,)
self.alpha = Parameter(
torch.rand(tuple([1] + list(shape))) * 0.1, requires_grad=True
)

def forward(self, x):
return x * self.alpha

def parameters(self, recurse: bool = True):
yield self.alpha


class ListModule(nn.Module):
def __init__(self, *args):
super(ListModule, self).__init__()
self.idx = 0
for module in args:
self.add_module(str(self.idx), module)
self.idx += 1

def append(self, module):
self.add_module(str(self.idx), module)
self.idx += 1

def __getitem__(self, idx):
if idx < 0:
idx += self.idx
if idx >= len(self._modules):
raise IndexError("index {} is out of range".format(idx))
it = iter(self._modules.values())
for i in range(idx):
next(it)
return next(it)

def __iter__(self):
return iter(self._modules.values())

def __len__(self):
return len(self._modules)
Loading

0 comments on commit 8f419c8

Please sign in to comment.